Skip to content

Commit 3640691

Browse files
committed
strategy configurable , added test
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parent 278d5fa commit 3640691

File tree

4 files changed

+234
-4
lines changed

4 files changed

+234
-4
lines changed

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@
1313
# warmup causes hangs due to workspace allocation with CPU synchronization
1414
_allreduce_cache = {}
1515

16+
# Global allreduce strategy configuration
17+
# Can be set via set_allreduce_strategy() to override the default AUTO strategy
18+
_global_allreduce_strategy = AllReduceStrategy.AUTO
19+
20+
def set_allreduce_strategy(strategy: AllReduceStrategy):
21+
"""Set the global allreduce strategy for distributed operations.
22+
23+
Args:
24+
strategy: AllReduceStrategy enum value (AUTO, NCCL, ONESHOT, TWOSHOT, etc.)
25+
"""
26+
global _global_allreduce_strategy
27+
_global_allreduce_strategy = strategy
28+
# Clear cache when strategy changes to force recreation with new strategy
29+
_allreduce_cache.clear()
30+
1631
def trtllm_allgather(tensor, dim, sizes=None):
1732
rank, world_size = get_rank_world_size()
1833
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
@@ -22,13 +37,13 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
2237
rank, world_size = get_rank_world_size()
2338
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
2439

25-
# Cache key includes rank, world_size, and dtype to handle different configurations
26-
cache_key = (rank, world_size, tensor.dtype)
40+
# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
41+
cache_key = (rank, world_size, tensor.dtype, _global_allreduce_strategy)
2742
if cache_key not in _allreduce_cache:
2843
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
29-
# Use Strategy.AUTO for optimal performance
44+
# Use the configured global strategy
3045
_allreduce_cache[cache_key] = AllReduce(
31-
mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype
46+
mapping=p_config, strategy=_global_allreduce_strategy, dtype=tensor.dtype
3247
)
3348

3449
torch_op = _allreduce_cache[cache_key]
@@ -59,6 +74,9 @@ def fused_allreduce_residual_rmsnorm_fake(
5974
TRTLLM_OP_AVAILABLE = True
6075
except ImportError:
6176

77+
def set_allreduce_strategy(strategy):
78+
raise ImportError("TRT-LLM is not available.")
79+
6280
def trtllm_allgather(tensor, dim, sizes=None):
6381
raise ImportError("TRT-LLM is not available.")
6482

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,25 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
123123

124124
device: str = Field(default="cuda", description="The device to use for the model.", frozen=True)
125125

126+
allreduce_strategy: Literal[
127+
"AUTO",
128+
"NCCL",
129+
"ONESHOT",
130+
"TWOSHOT",
131+
"MIN_LATENCY",
132+
"LOWPRECISION",
133+
"UB",
134+
"MNNVL",
135+
"NCCL_SYMMETRIC",
136+
] = Field(
137+
default="AUTO",
138+
description="AllReduce strategy for distributed inference. Options: AUTO (automatic selection), "
139+
"NCCL (NCCL-based), ONESHOT (single-phase fusion kernel), TWOSHOT (two-phase fusion kernel), "
140+
"MIN_LATENCY (minimum latency heuristic), LOWPRECISION (low precision allreduce), "
141+
"UB (unified buffer), MNNVL (multi-node NVLINK), NCCL_SYMMETRIC (NCCL symmetric). "
142+
"AUTO is recommended for most use cases.",
143+
)
144+
126145
# TODO: see if we can just remove this field and use kv_cache_config.dtype instead?
127146
kv_cache_dtype: str = Field(
128147
default="auto",

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,17 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
325325
port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port
326326
dist.initialize_or_skip(rank, world_size, port)
327327

328+
# Configure allreduce strategy if specified
329+
if hasattr(ad_config, "allreduce_strategy") and ad_config.allreduce_strategy != "AUTO":
330+
from tensorrt_llm.functional import AllReduceStrategy
331+
332+
from ..distributed.trtllm import TRTLLM_OP_AVAILABLE, set_allreduce_strategy
333+
334+
if TRTLLM_OP_AVAILABLE:
335+
strategy = getattr(AllReduceStrategy, ad_config.allreduce_strategy)
336+
set_allreduce_strategy(strategy)
337+
ad_logger.info(f"Using allreduce strategy: {ad_config.allreduce_strategy}")
338+
328339
# some config
329340
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
330341

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import signal
2+
import subprocess
3+
import tempfile
4+
from contextlib import contextmanager
5+
from pathlib import Path
6+
7+
import pytest
8+
import yaml
9+
from click.testing import CliRunner
10+
11+
from tensorrt_llm.commands.bench import main
12+
13+
14+
class TimeoutError(Exception):
15+
"""Exception raised when a test times out."""
16+
17+
pass
18+
19+
20+
@contextmanager
21+
def timeout(seconds):
22+
"""Context manager that raises TimeoutError if code block exceeds time limit.
23+
24+
Args:
25+
seconds: Maximum time in seconds to allow the code block to run
26+
27+
Raises:
28+
TimeoutError: If the code block execution exceeds the time limit
29+
"""
30+
31+
def timeout_handler(signum, frame):
32+
raise TimeoutError(f"Test execution exceeded {seconds} seconds timeout")
33+
34+
# Set the signal handler and alarm
35+
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
36+
signal.alarm(seconds)
37+
try:
38+
yield
39+
finally:
40+
# Restore the old signal handler and cancel the alarm
41+
signal.alarm(0)
42+
signal.signal(signal.SIGALRM, old_handler)
43+
44+
45+
@pytest.fixture(scope="module")
46+
def shared_dataset(llm_root):
47+
"""Prepare dataset once for all tests in this module."""
48+
model_name = "meta-llama/Llama-3.1-8B"
49+
with tempfile.TemporaryDirectory() as temp_dir:
50+
dataset_path = _prepare_dataset(llm_root, temp_dir, model_name, num_requests=10)
51+
# Read dataset content to return it (temp_dir will be deleted)
52+
with open(dataset_path, "r") as f:
53+
dataset_content = f.read()
54+
yield dataset_content
55+
56+
57+
def _prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str, num_requests: int = 10):
58+
"""Prepare a synthetic dataset for benchmarking."""
59+
_DATASET_NAME = "synthetic_128_128.txt"
60+
dataset_path = Path(temp_dir, _DATASET_NAME)
61+
dataset_tool = Path(root_dir, "benchmarks", "cpp", "prepare_dataset.py")
62+
script_dir = Path(root_dir, "benchmarks", "cpp")
63+
64+
# Generate a small dataset to run a test - matching workload configuration
65+
command = [
66+
"python3",
67+
f"{dataset_tool}",
68+
"--stdout",
69+
"--tokenizer",
70+
model_path_or_name,
71+
"token-norm-dist",
72+
"--input-mean",
73+
"128",
74+
"--output-mean",
75+
"128",
76+
"--input-stdev",
77+
"0",
78+
"--output-stdev",
79+
"0",
80+
"--num-requests",
81+
str(num_requests),
82+
]
83+
print(f"Running command: {' '.join(command)}")
84+
result = subprocess.run(
85+
command, cwd=str(script_dir), capture_output=True, text=True, timeout=300
86+
)
87+
if result.returncode != 0:
88+
raise RuntimeError(f"Failed to prepare dataset: {result.stderr}")
89+
# Grab the stdout and write it to a dataset file for passing to suite.
90+
with open(dataset_path, "w") as dataset:
91+
dataset.write(result.stdout)
92+
return dataset_path
93+
94+
95+
@pytest.mark.parametrize(
96+
"allreduce_strategy",
97+
[
98+
"AUTO",
99+
"ONESHOT",
100+
"TWOSHOT",
101+
"MIN_LATENCY",
102+
"NCCL",
103+
],
104+
)
105+
def test_allreduce_strategies(llm_root, shared_dataset, allreduce_strategy):
106+
"""Test different allreduce strategies with multi-GPU configuration.
107+
108+
This test validates that all allreduce strategies work correctly with TP=4.
109+
Note: TWOSHOT strategy will automatically fall back to ONESHOT when sequence
110+
length is smaller than TP size during initialization.
111+
112+
Test has a 300 second timeout to prevent indefinite hangs.
113+
114+
Args:
115+
llm_root: Root directory fixture
116+
shared_dataset: Shared dataset fixture (prepared once for all test runs)
117+
allreduce_strategy: Strategy to test (AUTO, ONESHOT, TWOSHOT, MIN_LATENCY, NCCL)
118+
"""
119+
# Fixed timeout for all strategies (5 minutes should be enough)
120+
TEST_TIMEOUT_SECONDS = 300
121+
122+
model_name = "meta-llama/Llama-3.1-8B"
123+
tp_size = 4
124+
max_batch_size = 256
125+
max_num_tokens = 8192
126+
127+
with tempfile.TemporaryDirectory() as temp_dir:
128+
# Write shared dataset to temp location
129+
dataset_path = Path(temp_dir, "synthetic_128_128.txt")
130+
with open(dataset_path, "w") as f:
131+
f.write(shared_dataset)
132+
133+
# Create configuration with specified allreduce strategy
134+
extra_llm_api_options_path = f"{temp_dir}/extra_llm_api_options.yaml"
135+
with open(extra_llm_api_options_path, "w") as f:
136+
yaml.dump(
137+
{
138+
"model": model_name,
139+
"allreduce_strategy": allreduce_strategy,
140+
"max_batch_size": max_batch_size,
141+
"max_num_tokens": max_num_tokens,
142+
"max_seq_len": 256,
143+
"transforms": {
144+
"compile_model": {
145+
"stage": "compile",
146+
"backend": "torch-cudagraph",
147+
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256],
148+
}
149+
},
150+
},
151+
f,
152+
)
153+
154+
# Run benchmark with specified allreduce strategy with timeout protection
155+
runner = CliRunner()
156+
args = [
157+
"--model",
158+
model_name,
159+
"throughput",
160+
"--backend",
161+
"_autodeploy",
162+
"--dataset",
163+
str(dataset_path),
164+
"--extra_llm_api_options",
165+
extra_llm_api_options_path,
166+
"--tp",
167+
str(tp_size),
168+
"--max_batch_size",
169+
str(max_batch_size),
170+
"--max_num_tokens",
171+
str(max_num_tokens),
172+
]
173+
174+
try:
175+
with timeout(TEST_TIMEOUT_SECONDS):
176+
result = runner.invoke(main, args, catch_exceptions=False)
177+
assert result.exit_code == 0, f"Benchmark failed with output: {result.output}"
178+
except TimeoutError as e:
179+
pytest.fail(
180+
f"Test timed out after {TEST_TIMEOUT_SECONDS}s for strategy {allreduce_strategy}. "
181+
f"This might indicate a hang (e.g., TWOSHOT without C++ fix). Error: {e}"
182+
)

0 commit comments

Comments
 (0)