|
22 | 22 | from typing import DefaultDict, Dict, List, Set, Tuple, Type |
23 | 23 |
|
24 | 24 | import torch |
25 | | -from pydantic import Field |
| 25 | +from pydantic import Field, field_validator |
26 | 26 | from torch.fx import GraphModule, Node |
27 | 27 |
|
| 28 | +from .....functional import AllReduceStrategy |
28 | 29 | from ...models.factory import ModelFactory, ShardingConfigSource |
29 | 30 | from ...shim.interface import CachedSequenceInterface |
30 | 31 | from ...utils.logger import ad_logger |
@@ -149,6 +150,32 @@ class ShardingTransformConfig(TransformConfig): |
149 | 150 | sharding_dims: List[ShardingDim] = Field( |
150 | 151 | default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM] |
151 | 152 | ) |
| 153 | + allreduce_strategy: AllReduceStrategy = Field( |
| 154 | + default=AllReduceStrategy.AUTO, |
| 155 | + description="AllReduce strategy for distributed operations. Options: AUTO (automatic selection), " |
| 156 | + "NCCL (NCCL-based), ONESHOT (single-phase fusion kernel), TWOSHOT (two-phase fusion kernel), " |
| 157 | + "MIN_LATENCY (minimum latency heuristic), LOWPRECISION (low precision allreduce), " |
| 158 | + "UB (unified buffer), MNNVL (multi-node NVLINK), NCCL_SYMMETRIC (NCCL symmetric). " |
| 159 | + "This is set as a global variable during transform application.", |
| 160 | + ) |
| 161 | + |
| 162 | + @field_validator("allreduce_strategy", mode="before") |
| 163 | + @classmethod |
| 164 | + def _validate_allreduce_strategy(cls, v): |
| 165 | + """Convert string names like 'AUTO' or 'ONESHOT' to AllReduceStrategy enum.""" |
| 166 | + if isinstance(v, AllReduceStrategy): |
| 167 | + return v |
| 168 | + if isinstance(v, str): |
| 169 | + try: |
| 170 | + return AllReduceStrategy[v] |
| 171 | + except KeyError: |
| 172 | + raise ValueError( |
| 173 | + f"Invalid allreduce strategy: {v}. " |
| 174 | + f"Valid options: {', '.join(s.name for s in AllReduceStrategy)}" |
| 175 | + ) |
| 176 | + if isinstance(v, int): |
| 177 | + return AllReduceStrategy(v) |
| 178 | + return v |
152 | 179 |
|
153 | 180 |
|
154 | 181 | @TransformRegistry.register("detect_sharding") |
@@ -186,6 +213,23 @@ def _apply( |
186 | 213 | local_rank, world_size = shared_config.local_rank, shared_config.world_size |
187 | 214 | # world_size = 2 |
188 | 215 |
|
| 216 | + # Configure global allreduce strategy from transform config |
| 217 | + # This is set once during sharding transform and used by all distributed operations |
| 218 | + if hasattr(self.config, "allreduce_strategy"): |
| 219 | + try: |
| 220 | + from ...distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy |
| 221 | + |
| 222 | + if TRTLLM_OP_AVAILABLE: |
| 223 | + # config.allreduce_strategy is already an AllReduceStrategy enum |
| 224 | + set_allreduce_strategy(self.config.allreduce_strategy) |
| 225 | + if self.config.allreduce_strategy != AllReduceStrategy.AUTO: |
| 226 | + ad_logger.info( |
| 227 | + f"Global allreduce strategy configured from transform: " |
| 228 | + f"{self.config.allreduce_strategy.name}" |
| 229 | + ) |
| 230 | + except (ImportError, AttributeError) as e: |
| 231 | + ad_logger.warning(f"Failed to set allreduce strategy: {e}") |
| 232 | + |
189 | 233 | if world_size < 2: |
190 | 234 | ad_logger.info("Skipping sharding for single device") |
191 | 235 | return gm, TransformInfo( |
|
0 commit comments