@@ -432,7 +432,10 @@ def __post_init__(self):
432432 dtype = torch .int32 ,
433433 capture_graph = capture_graph ,
434434 )
435- # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
435+ self .create_expanded_buffers (capture_graph = capture_graph )
436+
437+ # TODO: remove these expanded buffers when fp8_paged_mqa_logits supports MTP > 1.
438+ def create_expanded_buffers (self , capture_graph = False ):
436439 self .kv_lens_expanded_cuda = self .get_empty (
437440 self .cuda_graph_buffers ,
438441 (self .max_num_sequences * (1 + self .max_draft_tokens ), ),
@@ -468,6 +471,25 @@ def __post_init__(self):
468471 capture_graph = capture_graph ,
469472 )
470473
474+ # This function is only used to create the expanded buffers when the max_draft_tokens is changed.
475+ # TODO: remove this function when fp8_paged_mqa_logits can support MTP > 1.
476+ def update_spec_dec_param (
477+ self ,
478+ is_spec_decoding_enabled ,
479+ is_spec_dec_tree ,
480+ is_spec_dec_dynamic_tree ,
481+ max_draft_tokens ,
482+ spec_decoding_tensor : Optional ['SpecDecodingTensor' ] = None ,
483+ ):
484+ super ().update_spec_dec_param (is_spec_decoding_enabled ,
485+ is_spec_dec_tree ,
486+ is_spec_dec_dynamic_tree ,
487+ max_draft_tokens , spec_decoding_tensor )
488+ init_shape = self .kv_lens_expanded_host .shape [0 ]
489+ if self .max_num_sequences * (1 + self .max_draft_tokens ) != init_shape :
490+ capture_graph = torch .cuda .is_current_stream_capturing ()
491+ self .create_expanded_buffers (capture_graph = capture_graph )
492+
471493 def prepare (self ):
472494 super ().prepare ()
473495 if self .kv_cache_manager is not None :
0 commit comments