@@ -235,39 +235,6 @@ def generate(
235235 for i in range (max_new_tokens ):
236236 input_ids = next_input [:, - max_possible_context_length :]
237237
238- # prepare any padding keyword arguments
239- # iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
240- if i > 0 :
241- kwargs ["mask" ] = None
242- kwargs ["position_ids" ] = kwargs ["position_ids" ][:, - 1 :] + 1
243-
244- # we no longer have a global pos_i, each sequence has its own pos_i
245- slot_mapping = []
246- for seq_i , pos_i in enumerate (current_tkv_mask ):
247- if pos_i % BLOCK_SIZE == 0 :
248- block_number = block_numbers .pop (0 )
249- block_table [seq_i ].append (block_number )
250-
251- block_offset = pos_i % BLOCK_SIZE
252- slot = block_table [seq_i ][- 1 ] * BLOCK_SIZE + block_offset
253- slot_mapping .append ([slot ])
254-
255- kwargs ["block_table" ] = torch .tensor (
256- [
257- (
258- [b_seq [0 ]]
259- * (max (2 , max ([len (b ) for b in block_table ])) - len (b_seq ))
260- )
261- + b_seq
262- for b_seq in block_table
263- ],
264- dtype = torch .int64 ,
265- )
266- kwargs ["left_padded_prompt_mask" ] = left_padded_prompt_mask
267- current_tkv_mask = current_tkv_mask + 1
268- kwargs ["current_tkv_mask" ] = current_tkv_mask
269- kwargs ["slot_mapping" ] = torch .tensor (slot_mapping , dtype = torch .int64 )
270-
271238 # prefill
272239 if i == 0 :
273240 kwargs ["mask" ] = kwargs ["mask" ].unsqueeze (1 )
@@ -354,10 +321,41 @@ def generate(
354321 outputs_list .append (output [0 ].squeeze (0 ))
355322
356323 output = (torch .stack (outputs_list ), current_kv_cache )
357-
358324 # decode
359325 else :
326+ # prepare any padding keyword arguments
327+ # iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
328+
360329 # mask is no longer used here
330+ kwargs ["mask" ] = None
331+ kwargs ["position_ids" ] = kwargs ["position_ids" ][:, - 1 :] + 1
332+
333+ # we no longer have a global pos_i, each sequence has its own pos_i
334+ slot_mapping = []
335+ for seq_i , pos_i in enumerate (current_tkv_mask ):
336+ if pos_i % BLOCK_SIZE == 0 :
337+ block_number = block_numbers .pop (0 )
338+ block_table [seq_i ].append (block_number )
339+
340+ block_offset = pos_i % BLOCK_SIZE
341+ slot = block_table [seq_i ][- 1 ] * BLOCK_SIZE + block_offset
342+ slot_mapping .append ([slot ])
343+
344+ kwargs ["block_table" ] = torch .tensor (
345+ [
346+ (
347+ [b_seq [0 ]]
348+ * (max (2 , max ([len (b ) for b in block_table ])) - len (b_seq ))
349+ )
350+ + b_seq
351+ for b_seq in block_table
352+ ],
353+ dtype = torch .int64 ,
354+ )
355+ kwargs ["left_padded_prompt_mask" ] = left_padded_prompt_mask
356+ current_tkv_mask = current_tkv_mask + 1
357+ kwargs ["current_tkv_mask" ] = current_tkv_mask
358+ kwargs ["slot_mapping" ] = torch .tensor (slot_mapping , dtype = torch .int64 )
361359
362360 # batch
363361 torch ._dynamo .mark_dynamic (input_ids , 0 )
0 commit comments