File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed
tensorrt_llm/_torch/auto_deploy/transform/library Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -434,6 +434,7 @@ def _process_ssm_sharding(
434434 rank = rank ,
435435 world_size = world_size ,
436436 dist_op = "all_reduce" ,
437+ allreduce_strategy = sharding_config .allreduce_strategy ,
437438 )
438439 )
439440 return 1
@@ -605,6 +606,7 @@ def detect_sharding_from_factory_config(
605606 world_size = world_size ,
606607 dist_op = "all_reduce" ,
607608 min_local_shape = min_local_shape ,
609+ allreduce_strategy = sharding_config .allreduce_strategy ,
608610 )
609611 )
610612 num_row_col_shards += 1
@@ -649,6 +651,7 @@ def detect_sharding_from_factory_config(
649651 world_size = world_size ,
650652 dist_op = "all_reduce" ,
651653 min_local_shape = min_local_shape ,
654+ allreduce_strategy = sharding_config .allreduce_strategy ,
652655 )
653656 )
654657 num_row_col_shards += 1
@@ -963,6 +966,7 @@ def detect_column_row_shard(
963966 world_size = world_size ,
964967 dist_op = "all_reduce" ,
965968 min_local_shape = min_local_shape ,
969+ allreduce_strategy = sharding_config .allreduce_strategy ,
966970 )
967971 )
968972
You can’t perform that action at this time.
0 commit comments