|
48 | 48 | ] |
49 | 49 |
|
50 | 50 |
|
| 51 | +def _get_ghost_mode_optimizer(clipping: str, distributed: bool): |
| 52 | + """Get optimizer class for ghost grad_sample_mode.""" |
| 53 | + if clipping != "flat": |
| 54 | + raise ValueError( |
| 55 | + f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: ghost" |
| 56 | + ) |
| 57 | + if distributed: |
| 58 | + return DistributedDPOptimizerFastGradientClipping |
| 59 | + return DPOptimizerFastGradientClipping |
| 60 | + |
| 61 | + |
| 62 | +def _get_ghost_fsdp_optimizer(clipping: str, distributed: bool): |
| 63 | + """Get optimizer class for ghost_fsdp grad_sample_mode.""" |
| 64 | + if clipping != "flat" or not distributed: |
| 65 | + raise ValueError( |
| 66 | + f"Unsupported combination of parameters. Clipping: {clipping}, " |
| 67 | + f"distributed: {distributed}, and grad_sample_mode: ghost_fsdp" |
| 68 | + ) |
| 69 | + return FSDPOptimizerFastGradientClipping |
| 70 | + |
| 71 | + |
| 72 | +def _get_per_layer_distributed_optimizer(grad_sample_mode: str): |
| 73 | + """Get optimizer class for per_layer distributed case.""" |
| 74 | + if grad_sample_mode not in ("hooks", "ew"): |
| 75 | + raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}") |
| 76 | + return SimpleDistributedPerLayerOptimizer |
| 77 | + |
| 78 | + |
51 | 79 | def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None): |
| 80 | + # Handle special grad_sample_mode cases first |
52 | 81 | if grad_sample_mode == "ghost": |
53 | | - if clipping == "flat" and distributed is False: |
54 | | - return DPOptimizerFastGradientClipping |
55 | | - elif clipping == "flat" and distributed is True: |
56 | | - return DistributedDPOptimizerFastGradientClipping |
57 | | - else: |
58 | | - raise ValueError( |
59 | | - f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}" |
60 | | - ) |
61 | | - elif grad_sample_mode == "ghost_fsdp": |
62 | | - if clipping == "flat" and distributed is True: |
63 | | - return FSDPOptimizerFastGradientClipping |
64 | | - else: |
65 | | - raise ValueError( |
66 | | - f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}" |
67 | | - ) |
68 | | - elif clipping == "flat" and distributed is False: |
69 | | - return DPOptimizer |
70 | | - elif clipping == "flat" and distributed is True: |
71 | | - return DistributedDPOptimizer |
72 | | - elif clipping == "per_layer" and distributed is False: |
73 | | - return DPPerLayerOptimizer |
74 | | - elif clipping == "per_layer" and distributed is True: |
75 | | - if grad_sample_mode == "hooks" or grad_sample_mode == "ew": |
76 | | - return SimpleDistributedPerLayerOptimizer |
77 | | - else: |
78 | | - raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}") |
79 | | - elif clipping == "automatic" and distributed is False: |
80 | | - return DPAutomaticClippingOptimizer |
81 | | - elif clipping == "automatic" and distributed is True: |
82 | | - return DistributedDPAutomaticClippingOptimizer |
83 | | - elif clipping == "automatic_per_layer" and distributed is False: |
84 | | - return DPPerLayerAutomaticClippingOptimizer |
85 | | - elif clipping == "automatic_per_layer" and distributed is True: |
86 | | - return DistributedDPPerLayerAutomaticClippingOptimizer |
87 | | - elif clipping == "adaptive" and distributed is False: |
88 | | - return AdaClipDPOptimizer |
| 82 | + return _get_ghost_mode_optimizer(clipping, distributed) |
| 83 | + if grad_sample_mode == "ghost_fsdp": |
| 84 | + return _get_ghost_fsdp_optimizer(clipping, distributed) |
| 85 | + |
| 86 | + # Handle per_layer distributed case with grad_sample_mode check |
| 87 | + if clipping == "per_layer" and distributed: |
| 88 | + return _get_per_layer_distributed_optimizer(grad_sample_mode) |
| 89 | + |
| 90 | + # Standard lookup for common cases |
| 91 | + optimizer_map = { |
| 92 | + ("flat", False): DPOptimizer, |
| 93 | + ("flat", True): DistributedDPOptimizer, |
| 94 | + ("per_layer", False): DPPerLayerOptimizer, |
| 95 | + ("automatic", False): DPAutomaticClippingOptimizer, |
| 96 | + ("automatic", True): DistributedDPAutomaticClippingOptimizer, |
| 97 | + ("automatic_per_layer", False): DPPerLayerAutomaticClippingOptimizer, |
| 98 | + ("automatic_per_layer", True): DistributedDPPerLayerAutomaticClippingOptimizer, |
| 99 | + ("adaptive", False): AdaClipDPOptimizer, |
| 100 | + } |
| 101 | + |
| 102 | + key = (clipping, distributed) |
| 103 | + if key in optimizer_map: |
| 104 | + return optimizer_map[key] |
| 105 | + |
89 | 106 | raise ValueError( |
90 | 107 | f"Unexpected optimizer parameters. Clipping: {clipping}, distributed: {distributed}" |
91 | 108 | ) |
0 commit comments