Skip to content

Commit 599791d

Browse files
committed
fixed missing strategy params
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parent e040f90 commit 599791d

File tree

1 file changed

+4
-0
lines changed
  • tensorrt_llm/_torch/auto_deploy/transform/library

1 file changed

+4
-0
lines changed

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def _process_ssm_sharding(
434434
rank=rank,
435435
world_size=world_size,
436436
dist_op="all_reduce",
437+
allreduce_strategy=sharding_config.allreduce_strategy,
437438
)
438439
)
439440
return 1
@@ -605,6 +606,7 @@ def detect_sharding_from_factory_config(
605606
world_size=world_size,
606607
dist_op="all_reduce",
607608
min_local_shape=min_local_shape,
609+
allreduce_strategy=sharding_config.allreduce_strategy,
608610
)
609611
)
610612
num_row_col_shards += 1
@@ -649,6 +651,7 @@ def detect_sharding_from_factory_config(
649651
world_size=world_size,
650652
dist_op="all_reduce",
651653
min_local_shape=min_local_shape,
654+
allreduce_strategy=sharding_config.allreduce_strategy,
652655
)
653656
)
654657
num_row_col_shards += 1
@@ -963,6 +966,7 @@ def detect_column_row_shard(
963966
world_size=world_size,
964967
dist_op="all_reduce",
965968
min_local_shape=min_local_shape,
969+
allreduce_strategy=sharding_config.allreduce_strategy,
966970
)
967971
)
968972

0 commit comments

Comments
 (0)