Skip to content

Commit 9f0cc6c

Browse files
QiJunesuyoggupta
authored andcommitted
[TRTLLM-8521][chore] remove circular dependency between model engine and cuda graph runner (NVIDIA#7572)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
1 parent 2f08aa3 commit 9f0cc6c

16 files changed

+222
-221
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 120 additions & 101 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
set_per_request_piecewise_cuda_graph_flag,
5656
set_torch_compiling, with_model_extra_attrs)
5757
from .config_utils import is_mla
58-
from .cuda_graph_runner import CUDAGraphRunner
58+
from .cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig
5959
from .guided_decoder import CapturableGuidedDecoder
6060
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
6161
from .llm_request import get_draft_token_length
@@ -370,9 +370,31 @@ def __init__(
370370
# We look up this key in resource_manager during forward to find the
371371
# kv cache manager. Can be changed to support multiple model engines
372372
# with different KV cache managers.
373-
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
373+
self.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER if is_draft_model else ResourceManagerType.KV_CACHE_MANAGER
374374
self.lora_model_config: Optional[LoraModelConfig] = None
375-
self.cuda_graph_runner = CUDAGraphRunner(self)
375+
376+
# Create config and runner
377+
cuda_graph_runner_config = CUDAGraphRunnerConfig(
378+
use_cuda_graph=self.cuda_graph_config is not None,
379+
cuda_graph_padding_enabled=self._cuda_graph_padding_enabled,
380+
cuda_graph_batch_sizes=self._cuda_graph_batch_sizes,
381+
max_cuda_graph_batch_size=self._max_cuda_graph_batch_size,
382+
max_beam_width=self.max_beam_width,
383+
spec_config=self.spec_config,
384+
cuda_graph_mem_pool=self._cuda_graph_mem_pool,
385+
max_num_tokens=self.max_num_tokens,
386+
use_mrope=self.use_mrope,
387+
original_max_draft_len=self.original_max_draft_len,
388+
original_max_total_draft_tokens=self.
389+
original_max_total_draft_tokens,
390+
is_draft_model=self.is_draft_model,
391+
enable_attention_dp=self.enable_attention_dp,
392+
batch_size=self.batch_size,
393+
mapping=self.mapping,
394+
dist=self.dist,
395+
kv_cache_manager_key=self.kv_cache_manager_key,
396+
)
397+
self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config)
376398

377399
# Setup the local cache indirection buffer only once and reuse it.
378400
# This way it can also be used for CUDA graphs.
@@ -2319,11 +2341,21 @@ def forward(
23192341
return self._forward_step(inputs, gather_ids,
23202342
gather_context_logits)
23212343
with self.cuda_graph_runner.pad_batch(
2322-
scheduled_requests, resource_manager) as padded_requests:
2323-
2324-
maybe_graph, maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
2325-
padded_requests, spec_resource_manager)
2326-
if maybe_graph:
2344+
scheduled_requests, resource_manager,
2345+
self.runtime_draft_len) as padded_requests:
2346+
2347+
maybe_attn_metadata, maybe_spec_metadata, key = self.cuda_graph_runner.maybe_get_cuda_graph(
2348+
padded_requests,
2349+
iter_counter=self.iter_counter,
2350+
enable_spec_decode=self.enable_spec_decode,
2351+
attn_metadata=attn_metadata,
2352+
spec_metadata=spec_metadata,
2353+
draft_tokens_cuda=self.draft_tokens_cuda
2354+
if self.is_spec_decode else None,
2355+
spec_resource_manager=spec_resource_manager,
2356+
)
2357+
can_run_graph = key is not None
2358+
if can_run_graph:
23272359
attn_metadata = maybe_attn_metadata
23282360
spec_metadata = maybe_spec_metadata
23292361
else:
@@ -2339,7 +2371,7 @@ def forward(
23392371

23402372
self.iter_counter += 1
23412373
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
2342-
if not maybe_graph:
2374+
if not can_run_graph:
23432375
# Fallback to eager execution if graph was not used
23442376
with MoeLoadBalancerIterContext(moe_load_balancer):
23452377
outputs = self._forward_step(inputs, gather_ids,
@@ -2357,9 +2389,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
23572389
def capture_postprocess_fn(inputs: Dict[str, Any]):
23582390
self._postprocess_inputs(inputs)
23592391

2360-
self.cuda_graph_runner.capture(key, capture_forward_fn,
2361-
inputs,
2362-
capture_postprocess_fn)
2392+
self.cuda_graph_runner.capture(
2393+
key,
2394+
capture_forward_fn,
2395+
inputs,
2396+
enable_spec_decode=self.enable_spec_decode,
2397+
postprocess_fn=capture_postprocess_fn)
23632398

23642399
# here we don't need to use context since cuda graph capture didn't run kernel.
23652400
# maybe we need a cleaner way to do this.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ def drafting_loop_wrapper(model):
384384
# For DeepseekV3 MTP, we need to set the num_hidden_layers to 1 for the draft model
385385
if spec_config.spec_dec_mode.is_mtp_eagle():
386386
draft_model_engine.model.model_config.pretrained_config.num_hidden_layers = 1
387-
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
388387
draft_model_engine.load_weights_from_target_model(
389388
model_engine.model)
390389
else:

tests/unittest/_torch/helpers.py

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import torch
44
import torch.nn.functional as F
55

6-
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
6+
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import (
7+
CUDAGraphRunner, CUDAGraphRunnerConfig)
8+
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
9+
from tensorrt_llm.mapping import Mapping
710

811

912
def ceil_div(x: int, y: int) -> int:
@@ -166,42 +169,23 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor,
166169
return results.view_as(x)
167170

168171

169-
class MockPytorchBackendConfig:
170-
171-
def __init__(self, use_cuda_graph, cuda_graph_padding_enabled):
172-
self.use_cuda_graph = use_cuda_graph
173-
self.cuda_graph_padding_enabled = cuda_graph_padding_enabled
174-
175-
176-
class MockEngine:
177-
"""A replacement for SimpleNamespace that supports weak references."""
178-
179-
def __init__(self, **kwargs):
180-
self.__dict__.update(kwargs)
181-
182-
183-
def create_mock_engine(batch_size: int):
184-
185-
class MockSpecConfig:
186-
187-
class SpecDecMode:
188-
189-
def needs_kv_cache_recompute(self):
190-
return False
191-
192-
spec_dec_mode = SpecDecMode()
193-
194-
return MockEngine(
195-
llm_args=TorchLlmArgs(model="dummy"),
196-
_cuda_graph_padding_enabled=True,
197-
_cuda_graph_batch_sizes=[batch_size],
198-
_max_cuda_graph_batch_size=batch_size,
172+
def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False):
173+
config = CUDAGraphRunnerConfig(
174+
use_cuda_graph=True,
175+
cuda_graph_padding_enabled=False,
176+
cuda_graph_batch_sizes=[batch_size],
177+
max_cuda_graph_batch_size=batch_size,
178+
batch_size=batch_size,
199179
max_beam_width=1,
200-
max_num_tokens=8192,
201-
is_spec_decode=False,
202-
enable_spec_decode=False,
203-
spec_config=MockSpecConfig(),
180+
max_num_tokens=1,
181+
use_mrope=use_mrope,
182+
spec_config=None,
183+
cuda_graph_mem_pool=None,
184+
enable_attention_dp=False,
185+
original_max_draft_len=0,
186+
original_max_total_draft_tokens=0,
204187
is_draft_model=False,
205-
_cuda_graph_mem_pool=None,
206-
use_mrope=False,
207-
)
188+
mapping=Mapping(),
189+
dist=None,
190+
kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER)
191+
return CUDAGraphRunner(config)

tests/unittest/_torch/modeling/test_modeling_exaone4.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Exaone4Config(PretrainedConfig):
2222
# TODO: Remove this once we have a proper config for Exaone4
2323
SKIP_EXAONE4_HF_ACCURACY_TEST = True
2424

25-
from _torch.helpers import create_mock_engine
25+
from _torch.helpers import create_mock_cuda_graph_runner
2626
from transformers.cache_utils import HybridCache
2727
from utils.util import getSMVersion
2828

@@ -31,7 +31,6 @@ class Exaone4Config(PretrainedConfig):
3131
from tensorrt_llm._torch.metadata import KVCacheParams
3232
from tensorrt_llm._torch.model_config import ModelConfig
3333
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
34-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
3534
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
3635
from tensorrt_llm.bindings.executor import KvCacheConfig
3736
from tensorrt_llm.mapping import Mapping
@@ -338,10 +337,8 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None:
338337
]
339338
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
340339

341-
graph_runner = None
342-
if scenario.use_cuda_graph:
343-
mock_engine = create_mock_engine(1)
344-
graph_runner = CUDAGraphRunner(mock_engine)
340+
graph_runner = create_mock_cuda_graph_runner(
341+
1) if scenario.use_cuda_graph else None
345342

346343
def run_forward(input_ids, position_ids, attn_metadata):
347344
attn_metadata.prepare()

tests/unittest/_torch/modeling/test_modeling_llama.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
import torch
7-
from _torch.helpers import create_mock_engine
7+
from _torch.helpers import create_mock_cuda_graph_runner
88
from parameterized import parameterized
99
from transformers import LlamaConfig
1010
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
@@ -16,7 +16,6 @@
1616
from tensorrt_llm._torch.metadata import KVCacheParams
1717
from tensorrt_llm._torch.model_config import ModelConfig
1818
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
19-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2019
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
2120
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2221
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
@@ -331,10 +330,8 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None:
331330
]
332331
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
333332

334-
graph_runner = None
335-
if scenario.use_cuda_graph:
336-
mock_engine = create_mock_engine(1)
337-
graph_runner = CUDAGraphRunner(mock_engine)
333+
graph_runner = create_mock_cuda_graph_runner(
334+
1) if scenario.use_cuda_graph else None
338335

339336
def run_forward(input_ids, position_ids, attn_metadata):
340337
attn_metadata.prepare()

tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
import transformers
7-
from _torch.helpers import create_mock_engine
7+
from _torch.helpers import create_mock_cuda_graph_runner
88
from parameterized import parameterized
99
from transformers import Llama4Config
1010
from transformers import \
@@ -20,7 +20,6 @@
2020
Llama4HfWeightMapper
2121
from tensorrt_llm._torch.models.modeling_llama import \
2222
Llama4ForConditionalGeneration
23-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2423
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2524
from tensorrt_llm.bindings.executor import KvCacheConfig
2625
from tensorrt_llm.mapping import Mapping
@@ -406,10 +405,8 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
406405
input_ids.size(-1) + gen_input_ids.size(-1))
407406
]
408407
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
409-
graph_runner = None
410-
if scenario.use_cuda_graph:
411-
mock_engine = create_mock_engine(1)
412-
graph_runner = CUDAGraphRunner(mock_engine)
408+
graph_runner = create_mock_cuda_graph_runner(
409+
1) if scenario.use_cuda_graph else None
413410

414411
def run_forward(input_ids, position_ids, attn_metadata):
415412
attn_metadata.prepare()

tests/unittest/_torch/modeling/test_modeling_mistral.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import transformers
1010
import transformers.models.mistral3
11-
from _torch.helpers import create_mock_engine
11+
from _torch.helpers import create_mock_cuda_graph_runner
1212
from PIL import Image
1313
from utils.util import getSMVersion
1414

@@ -19,7 +19,6 @@
1919
from tensorrt_llm._torch.attention_backend import utils as attention_utils
2020
from tensorrt_llm._torch.models import modeling_mistral
2121
from tensorrt_llm._torch.pyexecutor import resource_manager
22-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2322
from tensorrt_llm.bindings import executor as executor_lib
2423
from tensorrt_llm.models import modeling_utils
2524

@@ -404,10 +403,7 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use
404403
]
405404
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
406405

407-
graph_runner = None
408-
if use_cuda_graph:
409-
mock_engine = create_mock_engine(1)
410-
graph_runner = CUDAGraphRunner(mock_engine)
406+
graph_runner = create_mock_cuda_graph_runner(1) if use_cuda_graph else None
411407

412408
def run_forward(input_ids, position_ids, attn_metadata):
413409
attn_metadata.prepare()

tests/unittest/_torch/modeling/test_modeling_mixtral.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44

55
import torch
6-
from _torch.helpers import create_mock_engine
6+
from _torch.helpers import create_mock_cuda_graph_runner
77
from parameterized import parameterized
88
from transformers import MixtralConfig
99
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
@@ -16,7 +16,6 @@
1616
from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \
1717
MixtralHfWeightMapper
1818
from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM
19-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2019
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2120
from tensorrt_llm.bindings.executor import KvCacheConfig
2221
from tensorrt_llm.mapping import Mapping
@@ -310,10 +309,8 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario):
310309
]
311310
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
312311

313-
graph_runner = None
314-
if scenario.use_cuda_graph:
315-
mock_engine = create_mock_engine(1)
316-
graph_runner = CUDAGraphRunner(mock_engine)
312+
graph_runner = create_mock_cuda_graph_runner(
313+
1) if scenario.use_cuda_graph else None
317314

318315
def run_forward(input_ids, position_ids, attn_metadata):
319316
attn_metadata.prepare()

tests/unittest/_torch/modeling/test_modeling_mllama.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66
import torch
7-
from _torch.helpers import create_mock_engine
7+
from _torch.helpers import create_mock_cuda_graph_runner
88
from parameterized import parameterized
99
from test_modeling_llama import Scenario, reduce_llama_config
1010
from transformers import MllamaConfig
@@ -17,7 +17,6 @@
1717
from tensorrt_llm._torch.model_config import ModelConfig
1818
from tensorrt_llm._torch.models.modeling_mllama import \
1919
MllamaForConditionalGeneration
20-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
2120
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
2221
from tensorrt_llm.bindings.executor import KvCacheConfig
2322
from tensorrt_llm.mapping import Mapping
@@ -420,10 +419,8 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None:
420419
]
421420
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
422421

423-
graph_runner = None
424-
if scenario.use_cuda_graph:
425-
mock_engine = create_mock_engine(1)
426-
graph_runner = CUDAGraphRunner(mock_engine)
422+
graph_runner = create_mock_cuda_graph_runner(
423+
1) if scenario.use_cuda_graph else None
427424

428425
def run_forward(input_ids, position_ids, attn_metadata):
429426
attn_metadata.prepare()

0 commit comments

Comments
 (0)