@@ -125,31 +125,33 @@ def build( # type: ignore[override]
125125 common_prefix_len : int ,
126126 common_attn_metadata : CommonAttentionMetadata ,
127127 num_accepted_tokens : Optional [torch .Tensor ] = None ,
128- num_draft_tokens : Optional [torch .Tensor ] = None ,
128+ num_decode_draft_tokens_cpu : Optional [torch .Tensor ] = None ,
129129 fast_build : bool = False ,
130130 ) -> GDNAttentionMetadata :
131131 m = common_attn_metadata
132132
133133 query_start_loc = m .query_start_loc
134134 context_lens = m .num_computed_tokens_cpu
135135 context_lens_tensor = context_lens .to (query_start_loc .device )
136- seq_lens_tensor = m .seq_lens
137136 nums_dict , batch_ptr , token_chunk_offset_ptr = None , None , None
138137
139- if (not self .use_spec_decode or num_draft_tokens is None
140- or num_draft_tokens .sum ().item () == 0 ):
138+ if (not self .use_spec_decode or num_decode_draft_tokens_cpu is None
139+ or num_decode_draft_tokens_cpu [num_decode_draft_tokens_cpu >=
140+ 0 ].sum ().item () == 0 ):
141141 spec_sequence_masks = None
142+ num_spec_decodes = 0
142143 else :
143- spec_sequence_masks = (num_draft_tokens > 0 ) & (
144- context_lens_tensor +
145- (num_draft_tokens + 1 ) == seq_lens_tensor )
146- if spec_sequence_masks .sum ().item () == 0 :
144+ spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
145+ num_spec_decodes = spec_sequence_masks .sum ().item ()
146+ if num_spec_decodes == 0 :
147147 spec_sequence_masks = None
148+ else :
149+ spec_sequence_masks = spec_sequence_masks .to (
150+ query_start_loc .device , non_blocking = True )
148151
149152 if spec_sequence_masks is None :
150153 num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = (
151154 split_decodes_and_prefills (m , decode_threshold = 1 ))
152- num_spec_decodes = 0
153155 num_spec_decode_tokens = 0
154156 spec_token_masks = None
155157 spec_state_indices_tensor = None
@@ -158,7 +160,6 @@ def build( # type: ignore[override]
158160 non_spec_query_start_loc = query_start_loc
159161 num_accepted_tokens = None
160162 else :
161- num_spec_decodes = spec_sequence_masks .sum ().item ()
162163 query_lens = query_start_loc [1 :] - query_start_loc [:- 1 ]
163164
164165 non_spec_query_lens = query_lens [~ spec_sequence_masks ]
@@ -314,28 +315,18 @@ def build_for_cudagraph_capture(
314315 """
315316 m = common_attn_metadata
316317
317- assert (m .num_reqs * (self .num_spec + 1 ) <= m .num_actual_tokens
318- and ((m .num_reqs + 1 ) * (self .num_spec + 1 )
319- >= m .num_actual_tokens )), \
320- "GDN only supports decode-only full CUDAGraph capture. " \
321- "Make sure all cudagraph capture sizes <= max_num_seq."
322-
323- num_accepted_tokens = torch .full ((m .num_reqs , ),
324- m .max_query_len ,
325- dtype = torch .int32 ,
326- device = m .query_start_loc .device )
327- num_drafted_tokens = torch .full ((m .num_reqs , ),
328- self .num_spec ,
329- dtype = torch .int32 ,
330- device = m .query_start_loc .device )
331-
332- # Fixes query-start loc for spec-sequence-indices.
333- m .query_start_loc = torch .arange (0 ,
334- m .num_actual_tokens + 1 ,
335- step = m .max_query_len ,
336- device = m .query_start_loc .device ,
337- dtype = torch .int32 )
338- m .num_computed_tokens_cpu = (m .seq_lens_cpu - torch .full (
339- (m .num_reqs , ), m .max_query_len , dtype = torch .int32 , device = 'cpu' ))
340-
341- return self .build (0 , m , num_accepted_tokens , num_drafted_tokens )
318+ assert (
319+ m .num_reqs <= self .decode_cudagraph_max_bs
320+ and m .num_actual_tokens <= self .decode_cudagraph_max_bs ), (
321+ f"GDN only supports decode-only full CUDAGraph capture. "
322+ f"Make sure batch size ({ m .num_reqs } ) <= "
323+ f"cudagraph capture sizes ({ self .decode_cudagraph_max_bs } ), "
324+ f"and number of tokens ({ m .num_actual_tokens } ) <= "
325+ f"cudagraph capture sizes ({ self .decode_cudagraph_max_bs } )." )
326+
327+ num_accepted_tokens = torch .diff (m .query_start_loc )
328+ num_decode_draft_tokens_cpu = (num_accepted_tokens - 1 ).cpu ()
329+ m .num_computed_tokens_cpu = m .seq_lens_cpu - num_accepted_tokens .cpu ()
330+
331+ return self .build (0 , m , num_accepted_tokens ,
332+ num_decode_draft_tokens_cpu )
0 commit comments