Skip to content

Commit 3124116

Browse files
committed
made strategy mandatory, fixed missing param
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parent 599791d commit 3124116

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def all_reduce(t: torch.Tensor, strategy: str = "AUTO") -> torch.Tensor:
3737
efficient all_reduce ops one should write/replace it with a fused op.
3838
"""
3939
if trtllm_dist.is_trtllm_op_available():
40+
# Debug logging to see what strategy is actually passed
41+
if not hasattr(all_reduce, "_logged_strategies"):
42+
all_reduce._logged_strategies = set()
43+
if strategy not in all_reduce._logged_strategies:
44+
from tensorrt_llm.logger import logger
45+
46+
logger.info(f"[DEBUG] torch_dist_all_reduce called with strategy='{strategy}'")
47+
all_reduce._logged_strategies.add(strategy)
4048
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM, strategy=strategy)
4149
t_res = t.clone()
4250
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)

tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818

1919
def _allreduce_residual_rmsnorm_pattern(
20-
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
20+
x: torch.Tensor,
21+
residual: torch.Tensor,
22+
weight: torch.Tensor,
23+
eps: float = 0.1253,
24+
strategy: str = "AUTO",
2125
):
2226
"""
2327
Reference PyTorch composition of:
@@ -28,7 +32,7 @@ def _allreduce_residual_rmsnorm_pattern(
2832
"""
2933

3034
input_dtype = x.dtype
31-
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
35+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, strategy)
3236
add = residual + hidden_states
3337

3438
hidden_states = add.to(torch.float32)
@@ -41,7 +45,11 @@ def _allreduce_residual_rmsnorm_pattern(
4145

4246

4347
def _allreduce_residual_rmsnorm_pattern2(
44-
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
48+
x: torch.Tensor,
49+
residual: torch.Tensor,
50+
weight: torch.Tensor,
51+
eps: float = 0.1253,
52+
strategy: str = "AUTO",
4553
):
4654
"""
4755
Reference PyTorch composition of:
@@ -52,7 +60,7 @@ def _allreduce_residual_rmsnorm_pattern2(
5260
"""
5361

5462
input_dtype = x.dtype
55-
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
63+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, strategy)
5664
add = hidden_states + residual
5765

5866
hidden_states = add.to(torch.float32)
@@ -65,9 +73,13 @@ def _allreduce_residual_rmsnorm_pattern2(
6573

6674

6775
def _allreduce_residual_rmsnorm_repl(
68-
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
76+
x: torch.Tensor,
77+
residual: torch.Tensor,
78+
weight: torch.Tensor,
79+
eps: float,
80+
strategy: str = "AUTO",
6981
):
70-
return torch.ops.dist.fused_allreduce_residual_rmsnorm(x, residual, weight, eps)
82+
return torch.ops.dist.fused_allreduce_residual_rmsnorm(x, residual, weight, eps, strategy)
7183

7284

7385
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
@@ -90,6 +102,7 @@ def _apply(
90102
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual
91103
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
92104
0.1253, # eps
105+
"AUTO", # strategy
93106
]
94107

95108
register_ad_pattern(
@@ -98,15 +111,15 @@ def _apply(
98111
patterns=patterns,
99112
dummy_args=dummy_args,
100113
op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)},
101-
scalar_workaround={"eps": 0.1253},
114+
scalar_workaround={"eps": 0.1253, "strategy": "AUTO"},
102115
)
103116
register_ad_pattern(
104117
search_fn=_allreduce_residual_rmsnorm_pattern2,
105118
replace_fn=_allreduce_residual_rmsnorm_repl,
106119
patterns=patterns,
107120
dummy_args=dummy_args,
108121
op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)},
109-
scalar_workaround={"eps": 0.1253},
122+
scalar_workaround={"eps": 0.1253, "strategy": "AUTO"},
110123
)
111124

112125
num_matches = patterns.apply(gm.graph)

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _process_simple_shard(
133133
world_size=world_size,
134134
dist_op="all_gather",
135135
min_local_shape=1,
136+
allreduce_strategy=sharding_config.allreduce_strategy,
136137
)
137138
)
138139
)
@@ -360,6 +361,7 @@ def _process_ssm_sharding(
360361
dist_op=None,
361362
min_local_shape=min_local_shape,
362363
fused_weight_dims=fused_weight_dims["in_proj"],
364+
allreduce_strategy=sharding_config.allreduce_strategy,
363365
)
364366
)
365367

@@ -398,6 +400,7 @@ def _process_ssm_sharding(
398400
dist_op=None,
399401
min_local_shape=min_local_shape,
400402
fused_weight_dims=fused_dims,
403+
allreduce_strategy=sharding_config.allreduce_strategy,
401404
)
402405
)
403406

@@ -461,6 +464,7 @@ def _process_column_sharding(
461464
world_size=world_size,
462465
dist_op=None, # for column sharding, no dist op is performed
463466
min_local_shape=min_local_shape,
467+
allreduce_strategy=sharding_config.allreduce_strategy,
464468
)
465469
)
466470

@@ -594,6 +598,7 @@ def detect_sharding_from_factory_config(
594598
world_size=world_size,
595599
dist_op=None,
596600
min_local_shape=min_local_shape,
601+
allreduce_strategy=sharding_config.allreduce_strategy,
597602
)
598603
)
599604
num_row_col_shards += 1
@@ -620,6 +625,7 @@ def detect_sharding_from_factory_config(
620625
dist_op=None,
621626
min_local_shape=min_local_shape,
622627
layer_type=LayerType.MAMBA,
628+
allreduce_strategy=sharding_config.allreduce_strategy,
623629
)
624630
)
625631
num_row_col_shards += 1
@@ -640,6 +646,7 @@ def detect_sharding_from_factory_config(
640646
world_size=world_size,
641647
dist_op=None,
642648
min_local_shape=min_local_shape,
649+
allreduce_strategy=sharding_config.allreduce_strategy,
643650
)
644651
)
645652
elif col_row_action == "rowwise":
@@ -671,6 +678,7 @@ def detect_sharding_from_factory_config(
671678
world_size=world_size,
672679
dist_op="all_gather",
673680
min_local_shape=1,
681+
allreduce_strategy=sharding_config.allreduce_strategy,
674682
)
675683
)
676684
num_simple_shards += 1
@@ -686,6 +694,7 @@ def detect_sharding_from_factory_config(
686694
world_size=world_size,
687695
dist_op="all_gather",
688696
min_local_shape=1,
697+
allreduce_strategy=sharding_config.allreduce_strategy,
689698
)
690699
)
691700
# after successful match, break the loop
@@ -1085,6 +1094,7 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo
10851094
node,
10861095
rank=rank,
10871096
world_size=world_size,
1097+
allreduce_strategy=sharding_config.allreduce_strategy,
10881098
)
10891099
)
10901100
num_moe_patterns += 1

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ class ShardingTransformInfo(BaseModel, ABC):
566566
target_node: str
567567
rank: int
568568
world_size: int
569-
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
569+
allreduce_strategy: AllReduceStrategy # REQUIRED: must be explicitly passed
570570

571571
@field_validator("allreduce_strategy", mode="before")
572572
@classmethod
@@ -696,6 +696,8 @@ class ParameterUpdateInfo(ShardingTransformInfo):
696696
rank: int
697697
world_size: int
698698
args: tuple
699+
# ParameterUpdateInfo doesn't insert distributed ops, so strategy doesn't matter
700+
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO
699701

700702
def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
701703
"""Validate the transformation configuration."""
@@ -984,8 +986,8 @@ def _insert_sharded_moe(
984986
node: Node,
985987
rank: int,
986988
world_size: int,
989+
allreduce_strategy: AllReduceStrategy, # REQUIRED: must be explicitly passed
987990
scale_names: Sequence[str] = (),
988-
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
989991
):
990992
"""Update the torch_moe node with sharded weight lists,
991993
sharded `selected_experts` and `final_scales(router_logics)`.
@@ -1091,7 +1093,7 @@ def _insert_sharded_mxfp4_mlp_ep(
10911093
node: Node,
10921094
rank: int,
10931095
world_size: int,
1094-
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
1096+
allreduce_strategy: AllReduceStrategy, # REQUIRED: must be explicitly passed
10951097
):
10961098
"""
10971099
Transform a call to auto_deploy::triton_mxfp4_moe into:
@@ -1165,7 +1167,7 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
11651167

11661168
def apply(self, gm: GraphModule, node: Node) -> None:
11671169
"""Apply EP sharding transformation to the graph module."""
1168-
_insert_sharded_moe(gm, node, self.rank, self.world_size, [], self.allreduce_strategy)
1170+
_insert_sharded_moe(gm, node, self.rank, self.world_size, self.allreduce_strategy, [])
11691171

11701172

11711173
class MXFP4EPShardingInfo(EPShardingInfo):
@@ -1196,7 +1198,7 @@ def scale_names(self) -> List[str]:
11961198

11971199
def apply(self, gm: GraphModule, node: Node) -> None:
11981200
_insert_sharded_moe(
1199-
gm, node, self.rank, self.world_size, self.scale_names(), self.allreduce_strategy
1201+
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.scale_names()
12001202
)
12011203

12021204

@@ -1214,7 +1216,7 @@ def scale_names(self) -> List[str]:
12141216

12151217
def apply(self, gm: GraphModule, node: Node) -> None:
12161218
_insert_sharded_moe(
1217-
gm, node, self.rank, self.world_size, self.scale_names(), self.allreduce_strategy
1219+
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.scale_names()
12181220
)
12191221

12201222

0 commit comments

Comments
 (0)