File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -164,12 +164,12 @@ def rejection_sample(
164164 assert target_probs .shape == (num_tokens , vocab_size )
165165
166166 # Create output buffer.
167- output_token_ids = torch .empty (
167+ output_token_ids = torch .full (
168168 (batch_size , max_spec_len + 1 ),
169+ PLACEHOLDER_TOKEN_ID ,
169170 dtype = torch .int32 , # Consistent with SamplerOutput.sampled_token_ids.
170171 device = device ,
171172 )
172- output_token_ids .fill_ (PLACEHOLDER_TOKEN_ID )
173173
174174 if sampling_metadata .all_greedy :
175175 is_greedy = None
@@ -186,7 +186,6 @@ def rejection_sample(
186186 bonus_token_ids ,
187187 is_greedy ,
188188 max_spec_len ,
189- num_warps = 1 ,
190189 )
191190 if sampling_metadata .all_greedy :
192191 return output_token_ids
@@ -227,7 +226,6 @@ def rejection_sample(
227226 max_spec_len ,
228227 vocab_size ,
229228 NO_DRAFT_PROBS = draft_probs is None ,
230- num_warps = 1 ,
231229 )
232230 return output_token_ids
233231
@@ -329,7 +327,6 @@ def expand_batch_to_tokens(
329327 replace_from ,
330328 replace_to ,
331329 MAX_NUM_TOKENS = MAX_SPEC_LEN , # To avoid recompilation.
332- num_warps = 1 ,
333330 )
334331 return expanded_x
335332
You can’t perform that action at this time.
0 commit comments