Skip to content

Commit a5a6c33

Browse files
committed
moved to use global config variable and transform sets it
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parent 94aa93d commit a5a6c33

File tree

4 files changed

+69
-17
lines changed

4 files changed

+69
-17
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ transforms:
8080
sharding_source: ['heuristic']
8181
support_partial_config: true
8282
sharding_dims: ['tp', 'ep', 'bmm']
83+
allreduce_strategy: 'AUTO'
8384
requires_shape_prop: true
8485
# TODO: (hg) need to ensure run_shape_prop after sharding.
8586
sharding_transform_executor:

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,36 @@
1313
# warmup causes hangs due to workspace allocation with CPU synchronization
1414
_allreduce_cache = {}
1515

16-
# Global allreduce strategy configuration
17-
# Can be set via set_allreduce_strategy() to override the default AUTO strategy
16+
# Global AllReduce Strategy Configuration
17+
# =========================================
18+
# This global variable controls which allreduce implementation is used across
19+
# all distributed operations in AutoDeploy. It's set once at initialization
20+
# time via set_allreduce_strategy() and remains constant during execution.
1821
_global_allreduce_strategy = AllReduceStrategy.AUTO
1922

2023
def set_allreduce_strategy(strategy: AllReduceStrategy):
21-
"""Set the global allreduce strategy for distributed operations.
24+
"""Set the global allreduce strategy for all distributed operations.
2225
23-
Args:
24-
strategy: AllReduceStrategy enum value (AUTO, NCCL, ONESHOT, TWOSHOT, etc.)
26+
This should be called once during initialization, before any distributed
27+
operations are executed. All subsequent allreduce calls will use this strategy.
28+
29+
Note:
30+
This clears the allreduce cache to ensure new operations use the updated strategy.
31+
Call this before any model compilation or CUDA graph capture.
2532
"""
2633
global _global_allreduce_strategy
2734
_global_allreduce_strategy = strategy
2835
# Clear cache when strategy changes to force recreation with new strategy
2936
_allreduce_cache.clear()
3037

38+
def get_allreduce_strategy() -> AllReduceStrategy:
39+
"""Get the current global allreduce strategy.
40+
41+
Returns:
42+
The currently configured AllReduceStrategy enum value.
43+
"""
44+
return _global_allreduce_strategy
45+
3146
def trtllm_allgather(tensor, dim, sizes=None):
3247
rank, world_size = get_rank_world_size()
3348
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
@@ -77,6 +92,9 @@ def fused_allreduce_residual_rmsnorm_fake(
7792
def set_allreduce_strategy(strategy):
7893
raise ImportError("TRT-LLM is not available.")
7994

95+
def get_allreduce_strategy():
96+
raise ImportError("TRT-LLM is not available.")
97+
8098
def trtllm_allgather(tensor, dim, sizes=None):
8199
raise ImportError("TRT-LLM is not available.")
82100

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,17 +325,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
325325
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
326326
dist.initialize_or_skip(rank, world_size, port)
327327

328-
# Configure allreduce strategy if specified
329-
if hasattr(ad_config, "allreduce_strategy") and ad_config.allreduce_strategy != "AUTO":
330-
from tensorrt_llm.functional import AllReduceStrategy
331-
332-
from ..distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy
333-
334-
if TRTLLM_OP_AVAILABLE:
335-
strategy = getattr(AllReduceStrategy, ad_config.allreduce_strategy)
336-
set_allreduce_strategy(strategy)
337-
ad_logger.info(f"Using allreduce strategy: {ad_config.allreduce_strategy}")
338-
339328
# some config
340329
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
341330

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from typing import DefaultDict, Dict, List, Set, Tuple, Type
2323

2424
import torch
25-
from pydantic import Field
25+
from pydantic import Field, field_validator
2626
from torch.fx import GraphModule, Node
2727

28+
from .....functional import AllReduceStrategy
2829
from ...models.factory import ModelFactory, ShardingConfigSource
2930
from ...shim.interface import CachedSequenceInterface
3031
from ...utils.logger import ad_logger
@@ -149,6 +150,32 @@ class ShardingTransformConfig(TransformConfig):
149150
sharding_dims: List[ShardingDim] = Field(
150151
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
151152
)
153+
allreduce_strategy: AllReduceStrategy = Field(
154+
default=AllReduceStrategy.AUTO,
155+
description="AllReduce strategy for distributed operations. Options: AUTO (automatic selection), "
156+
"NCCL (NCCL-based), ONESHOT (single-phase fusion kernel), TWOSHOT (two-phase fusion kernel), "
157+
"MIN_LATENCY (minimum latency heuristic), LOWPRECISION (low precision allreduce), "
158+
"UB (unified buffer), MNNVL (multi-node NVLINK), NCCL_SYMMETRIC (NCCL symmetric). "
159+
"This is set as a global variable during transform application.",
160+
)
161+
162+
@field_validator("allreduce_strategy", mode="before")
163+
@classmethod
164+
def _validate_allreduce_strategy(cls, v):
165+
"""Convert string names like 'AUTO' or 'ONESHOT' to AllReduceStrategy enum."""
166+
if isinstance(v, AllReduceStrategy):
167+
return v
168+
if isinstance(v, str):
169+
try:
170+
return AllReduceStrategy[v]
171+
except KeyError:
172+
raise ValueError(
173+
f"Invalid allreduce strategy: {v}. "
174+
f"Valid options: {', '.join(s.name for s in AllReduceStrategy)}"
175+
)
176+
if isinstance(v, int):
177+
return AllReduceStrategy(v)
178+
return v
152179

153180

154181
@TransformRegistry.register("detect_sharding")
@@ -186,6 +213,23 @@ def _apply(
186213
local_rank, world_size = shared_config.local_rank, shared_config.world_size
187214
# world_size = 2
188215

216+
# Configure global allreduce strategy from transform config
217+
# This is set once during sharding transform and used by all distributed operations
218+
if hasattr(self.config, "allreduce_strategy"):
219+
try:
220+
from ...distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy
221+
222+
if TRTLLM_OP_AVAILABLE:
223+
# config.allreduce_strategy is already an AllReduceStrategy enum
224+
set_allreduce_strategy(self.config.allreduce_strategy)
225+
if self.config.allreduce_strategy != AllReduceStrategy.AUTO:
226+
ad_logger.info(
227+
f"Global allreduce strategy configured from transform: "
228+
f"{self.config.allreduce_strategy.name}"
229+
)
230+
except (ImportError, AttributeError) as e:
231+
ad_logger.warning(f"Failed to set allreduce strategy: {e}")
232+
189233
if world_size < 2:
190234
ad_logger.info("Skipping sharding for single device")
191235
return gm, TransformInfo(

0 commit comments

Comments
 (0)