5555 set_per_request_piecewise_cuda_graph_flag ,
5656 set_torch_compiling , with_model_extra_attrs )
5757from .config_utils import is_mla
58- from .cuda_graph_runner import CUDAGraphRunner
58+ from .cuda_graph_runner import CUDAGraphRunner , CUDAGraphRunnerConfig
5959from .guided_decoder import CapturableGuidedDecoder
6060from .layerwise_nvtx_marker import LayerwiseNvtxMarker
6161from .llm_request import get_draft_token_length
@@ -370,9 +370,31 @@ def __init__(
370370 # We look up this key in resource_manager during forward to find the
371371 # kv cache manager. Can be changed to support multiple model engines
372372 # with different KV cache managers.
373- self .kv_cache_manager_key = ResourceManagerType .KV_CACHE_MANAGER
373+ self .kv_cache_manager_key = ResourceManagerType .DRAFT_KV_CACHE_MANAGER if is_draft_model else ResourceManagerType . KV_CACHE_MANAGER
374374 self .lora_model_config : Optional [LoraModelConfig ] = None
375- self .cuda_graph_runner = CUDAGraphRunner (self )
375+
376+ # Create config and runner
377+ cuda_graph_runner_config = CUDAGraphRunnerConfig (
378+ use_cuda_graph = self .cuda_graph_config is not None ,
379+ cuda_graph_padding_enabled = self ._cuda_graph_padding_enabled ,
380+ cuda_graph_batch_sizes = self ._cuda_graph_batch_sizes ,
381+ max_cuda_graph_batch_size = self ._max_cuda_graph_batch_size ,
382+ max_beam_width = self .max_beam_width ,
383+ spec_config = self .spec_config ,
384+ cuda_graph_mem_pool = self ._cuda_graph_mem_pool ,
385+ max_num_tokens = self .max_num_tokens ,
386+ use_mrope = self .use_mrope ,
387+ original_max_draft_len = self .original_max_draft_len ,
388+ original_max_total_draft_tokens = self .
389+ original_max_total_draft_tokens ,
390+ is_draft_model = self .is_draft_model ,
391+ enable_attention_dp = self .enable_attention_dp ,
392+ batch_size = self .batch_size ,
393+ mapping = self .mapping ,
394+ dist = self .dist ,
395+ kv_cache_manager_key = self .kv_cache_manager_key ,
396+ )
397+ self .cuda_graph_runner = CUDAGraphRunner (cuda_graph_runner_config )
376398
377399 # Setup the local cache indirection buffer only once and reuse it.
378400 # This way it can also be used for CUDA graphs.
@@ -2319,11 +2341,21 @@ def forward(
23192341 return self ._forward_step (inputs , gather_ids ,
23202342 gather_context_logits )
23212343 with self .cuda_graph_runner .pad_batch (
2322- scheduled_requests , resource_manager ) as padded_requests :
2323-
2324- maybe_graph , maybe_attn_metadata , maybe_spec_metadata , key = self .cuda_graph_runner .maybe_get_cuda_graph (
2325- padded_requests , spec_resource_manager )
2326- if maybe_graph :
2344+ scheduled_requests , resource_manager ,
2345+ self .runtime_draft_len ) as padded_requests :
2346+
2347+ maybe_attn_metadata , maybe_spec_metadata , key = self .cuda_graph_runner .maybe_get_cuda_graph (
2348+ padded_requests ,
2349+ iter_counter = self .iter_counter ,
2350+ enable_spec_decode = self .enable_spec_decode ,
2351+ attn_metadata = attn_metadata ,
2352+ spec_metadata = spec_metadata ,
2353+ draft_tokens_cuda = self .draft_tokens_cuda
2354+ if self .is_spec_decode else None ,
2355+ spec_resource_manager = spec_resource_manager ,
2356+ )
2357+ can_run_graph = key is not None
2358+ if can_run_graph :
23272359 attn_metadata = maybe_attn_metadata
23282360 spec_metadata = maybe_spec_metadata
23292361 else :
@@ -2339,7 +2371,7 @@ def forward(
23392371
23402372 self .iter_counter += 1
23412373 with with_shared_pool (self .cuda_graph_runner .get_graph_pool ()):
2342- if not maybe_graph :
2374+ if not can_run_graph :
23432375 # Fallback to eager execution if graph was not used
23442376 with MoeLoadBalancerIterContext (moe_load_balancer ):
23452377 outputs = self ._forward_step (inputs , gather_ids ,
@@ -2357,9 +2389,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
23572389 def capture_postprocess_fn (inputs : Dict [str , Any ]):
23582390 self ._postprocess_inputs (inputs )
23592391
2360- self .cuda_graph_runner .capture (key , capture_forward_fn ,
2361- inputs ,
2362- capture_postprocess_fn )
2392+ self .cuda_graph_runner .capture (
2393+ key ,
2394+ capture_forward_fn ,
2395+ inputs ,
2396+ enable_spec_decode = self .enable_spec_decode ,
2397+ postprocess_fn = capture_postprocess_fn )
23632398
23642399 # here we don't need to use context since cuda graph capture didn't run kernel.
23652400 # maybe we need a cleaner way to do this.
0 commit comments