Skip to content

Commit a42721f

Browse files
committed
Add ghost mode and per-layer distributed optimizer functions and made get_optimizer_class less complex (flake8 error fix)
1 parent abbcc42 commit a42721f

File tree

1 file changed

+53
-36
lines changed

1 file changed

+53
-36
lines changed

opacus/optimizers/__init__.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,44 +48,61 @@
4848
]
4949

5050

51+
def _get_ghost_mode_optimizer(clipping: str, distributed: bool):
52+
"""Get optimizer class for ghost grad_sample_mode."""
53+
if clipping != "flat":
54+
raise ValueError(
55+
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: ghost"
56+
)
57+
if distributed:
58+
return DistributedDPOptimizerFastGradientClipping
59+
return DPOptimizerFastGradientClipping
60+
61+
62+
def _get_ghost_fsdp_optimizer(clipping: str, distributed: bool):
63+
"""Get optimizer class for ghost_fsdp grad_sample_mode."""
64+
if clipping != "flat" or not distributed:
65+
raise ValueError(
66+
f"Unsupported combination of parameters. Clipping: {clipping}, "
67+
f"distributed: {distributed}, and grad_sample_mode: ghost_fsdp"
68+
)
69+
return FSDPOptimizerFastGradientClipping
70+
71+
72+
def _get_per_layer_distributed_optimizer(grad_sample_mode: str):
73+
"""Get optimizer class for per_layer distributed case."""
74+
if grad_sample_mode not in ("hooks", "ew"):
75+
raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}")
76+
return SimpleDistributedPerLayerOptimizer
77+
78+
5179
def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None):
80+
# Handle special grad_sample_mode cases first
5281
if grad_sample_mode == "ghost":
53-
if clipping == "flat" and distributed is False:
54-
return DPOptimizerFastGradientClipping
55-
elif clipping == "flat" and distributed is True:
56-
return DistributedDPOptimizerFastGradientClipping
57-
else:
58-
raise ValueError(
59-
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
60-
)
61-
elif grad_sample_mode == "ghost_fsdp":
62-
if clipping == "flat" and distributed is True:
63-
return FSDPOptimizerFastGradientClipping
64-
else:
65-
raise ValueError(
66-
f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}"
67-
)
68-
elif clipping == "flat" and distributed is False:
69-
return DPOptimizer
70-
elif clipping == "flat" and distributed is True:
71-
return DistributedDPOptimizer
72-
elif clipping == "per_layer" and distributed is False:
73-
return DPPerLayerOptimizer
74-
elif clipping == "per_layer" and distributed is True:
75-
if grad_sample_mode == "hooks" or grad_sample_mode == "ew":
76-
return SimpleDistributedPerLayerOptimizer
77-
else:
78-
raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}")
79-
elif clipping == "automatic" and distributed is False:
80-
return DPAutomaticClippingOptimizer
81-
elif clipping == "automatic" and distributed is True:
82-
return DistributedDPAutomaticClippingOptimizer
83-
elif clipping == "automatic_per_layer" and distributed is False:
84-
return DPPerLayerAutomaticClippingOptimizer
85-
elif clipping == "automatic_per_layer" and distributed is True:
86-
return DistributedDPPerLayerAutomaticClippingOptimizer
87-
elif clipping == "adaptive" and distributed is False:
88-
return AdaClipDPOptimizer
82+
return _get_ghost_mode_optimizer(clipping, distributed)
83+
if grad_sample_mode == "ghost_fsdp":
84+
return _get_ghost_fsdp_optimizer(clipping, distributed)
85+
86+
# Handle per_layer distributed case with grad_sample_mode check
87+
if clipping == "per_layer" and distributed:
88+
return _get_per_layer_distributed_optimizer(grad_sample_mode)
89+
90+
# Standard lookup for common cases
91+
optimizer_map = {
92+
("flat", False): DPOptimizer,
93+
("flat", True): DistributedDPOptimizer,
94+
("per_layer", False): DPPerLayerOptimizer,
95+
("automatic", False): DPAutomaticClippingOptimizer,
96+
("automatic", True): DistributedDPAutomaticClippingOptimizer,
97+
("automatic_per_layer", False): DPPerLayerAutomaticClippingOptimizer,
98+
("automatic_per_layer", True): DistributedDPPerLayerAutomaticClippingOptimizer,
99+
("adaptive", False): AdaClipDPOptimizer,
100+
}
101+
102+
key = (clipping, distributed)
103+
if key in optimizer_map:
104+
return optimizer_map[key]
105+
89106
raise ValueError(
90107
f"Unexpected optimizer parameters. Clipping: {clipping}, distributed: {distributed}"
91108
)

0 commit comments

Comments
 (0)