@@ -106,21 +106,17 @@ def generate(
106106
107107 result = input_ids
108108 next_input = input_ids
109+ # this includes empty pages and max_new_tokens
110+ max_possible_context_length = input_ids .size (1 ) + max_new_tokens
111+
109112 BLOCK_SIZE = 64
110- _MAX_BATCH = int (
111- os .environ .setdefault ("VLLM_DT_MAX_BATCH_SIZE" , str (input_ids .size (0 )))
112- )
113- _MAX_CONTEXT_LENGTH = int (
114- os .environ .setdefault (
115- "VLLM_DT_MAX_CONTEXT_LEN" ,
116- str (
117- (((input_ids .size (1 ) + max_new_tokens - 1 ) // BLOCK_SIZE ) + 1 )
118- * BLOCK_SIZE
119- ),
120- )
121- )
113+
114+ # these variables are guaranteed to be set in another location (inference.py, test_decoders.py, etc.)
115+ # if we set these variables here, we run the risk of warming up and generating with different sizes
116+ _MAX_BATCH = int (os .environ ["VLLM_DT_MAX_BATCH_SIZE" ])
117+ _MAX_CONTEXT_LENGTH = int (os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ])
122118 NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH ) // BLOCK_SIZE
123- max_seq_len = input_ids . size ( 1 ) + max_new_tokens
119+
124120 if hasattr (model , "head" ):
125121 model_dtype = model .head .weight .dtype
126122 elif hasattr (model , "shared" ):
@@ -194,27 +190,41 @@ def generate(
194190 block_numbers = [i for i in range (NUM_BLOCKS )]
195191 # this will ensure we don't have contiguous blocks
196192 random .shuffle (block_numbers )
193+
194+ # this is the true number of left pads when computing paged attention using a paged kv-cache
195+ # it may include whole empty pages
197196 left_padded_prompt_mask = (kwargs ["position_ids" ] == 0 ).sum (dim = 1 ) - 1
198- current_context_lengths = (kwargs ["position_ids" ] != 0 ).sum (dim = 1 ) + 1
199- current_tkv_mask = left_padded_prompt_mask + current_context_lengths
197+
198+ # this is the context length for each sequence without pads
199+ context_lengths_without_pads = (kwargs ["position_ids" ] != 0 ).sum (dim = 1 ) + 1
200+
201+ # this is the context length for each sequence with no empty pages (padded to multiple of 64)
202+ context_lengths = BLOCK_SIZE * (
203+ (context_lengths_without_pads + BLOCK_SIZE - 1 ) // BLOCK_SIZE
204+ )
205+
206+ # left_padded_prompt_mask - empty_slots + context_lengths
207+ current_tkv_mask = torch .fill (context_lengths , torch .max (context_lengths ))
208+
200209 slot_mapping = []
201210 block_table = []
202- for seq_i in input_ids :
203- block_table_i = []
211+ # each sequence has the possibility of a different tkv, so loop over that
212+ for seq_tkv in context_lengths :
213+ block_table_i = [block_numbers .pop (0 ) for _ in range (seq_tkv // BLOCK_SIZE )]
204214 slot_mapping_i = []
205- for pos_i in range (seq_i . size ( 0 ) ):
206- if pos_i % BLOCK_SIZE == 0 :
207- block_number = block_numbers . pop ( 0 )
208- block_table_i . append ( block_number )
215+ for pos_i in range (seq_tkv ):
216+ # we may have already popped a block, so index to the proper block
217+ block_number = block_table_i [ pos_i // BLOCK_SIZE ]
218+
209219 block_offset = pos_i % BLOCK_SIZE
210220 slot = block_number * BLOCK_SIZE + block_offset
211221 slot_mapping_i .append (slot )
212222 slot_mapping .append (slot_mapping_i )
213223 block_table .append (block_table_i )
214- kwargs ["slot_mapping" ] = torch .tensor (slot_mapping , dtype = torch .int64 )
215224 kwargs ["current_tkv_mask" ] = None
216225 kwargs ["left_padded_prompt_mask" ] = None
217226 kwargs ["use_cache" ] = use_cache
227+ only_last_token = kwargs .get ("only_last_token" , False )
218228
219229 prompt_length = input_ids .shape [1 ]
220230
@@ -223,45 +233,40 @@ def generate(
223233 start_time = time .time ()
224234
225235 for i in range (max_new_tokens ):
226- input_ids = next_input [:, - max_seq_len :]
227-
228- # prepare any padding keyword arguments
229- # iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
230- if i > 0 :
231- kwargs ["mask" ] = None
232- kwargs ["position_ids" ] = kwargs ["position_ids" ][:, - 1 :] + 1
233- pos_i = result .size (1 ) - 1
234- if pos_i % BLOCK_SIZE == 0 :
235- for block_table_i in block_table :
236- block_number = block_numbers .pop (0 )
237- block_table_i .append (block_number )
238- block_offset = pos_i % BLOCK_SIZE
239-
240- slot_mapping = []
241- for block_table_i in block_table :
242- slot = block_table_i [- 1 ] * BLOCK_SIZE + block_offset
243- slot_mapping .append ([slot ])
244- kwargs ["block_table" ] = torch .tensor (block_table , dtype = torch .int64 )
245- kwargs ["slot_mapping" ] = torch .tensor (slot_mapping , dtype = torch .int64 )
246- current_tkv_mask = current_tkv_mask + 1
247- kwargs ["current_tkv_mask" ] = current_tkv_mask
248- kwargs ["left_padded_prompt_mask" ] = left_padded_prompt_mask
236+ input_ids = next_input [:, - max_possible_context_length :]
249237
250238 # prefill
251239 if i == 0 :
252240 kwargs ["mask" ] = kwargs ["mask" ].unsqueeze (1 )
253241
254242 outputs_list = []
255243 current_kv_cache = kwargs ["past_key_value_states" ]
244+
256245 if "fp8" in kwargs ["attn_name" ]:
257246 current_kv_scales = [
258247 (t1 ._scale , t2 ._scale ) for t1 , t2 in kwargs ["past_key_value_states" ]
259248 ]
260- for seq_i in range (input_ids .size (0 )):
261- input_ids_i = input_ids [seq_i ].unsqueeze (0 )
262- slot_mapping_i = kwargs ["slot_mapping" ][seq_i ].unsqueeze (0 )
263- position_ids_i = kwargs ["position_ids" ][seq_i ].unsqueeze (0 )
264- mask_i = kwargs ["mask" ][seq_i ].unsqueeze (0 )
249+ for seq_i , current_tkv in enumerate (context_lengths ):
250+ # remove extra pads from the input_ids, slot_mapping, position_ids, mask to account for empty pages
251+ # each input should be padded to its smallest multiple of BLOCK_SIZE (64)
252+ # we need to clone these tensors to ensure the pointer offset is 0
253+ input_ids_i = input_ids [seq_i ][- current_tkv :].unsqueeze (0 ).clone ()
254+ slot_mapping_i = (
255+ torch .tensor (slot_mapping [seq_i ][- current_tkv :], dtype = torch .int64 )
256+ .unsqueeze (0 )
257+ .clone ()
258+ )
259+ position_ids_i = (
260+ kwargs ["position_ids" ][seq_i ][- current_tkv :].unsqueeze (0 ).clone ()
261+ )
262+
263+ # This view will result in a discontiguous tensor (creates a new graph during compile)
264+ # For this reason, we must explicitly make contiguous
265+ mask_i = (
266+ kwargs ["mask" ][seq_i ][:, - current_tkv :, - current_tkv :]
267+ .unsqueeze (0 )
268+ .contiguous ()
269+ )
265270
266271 # batch dynamic
267272 torch ._dynamo .mark_static (input_ids_i , 0 )
@@ -283,7 +288,6 @@ def generate(
283288 t2 ._scale = current_kv_scales [layer_idx ][1 ][seq_i ].reshape (- 1 )
284289
285290 only_last_token = kwargs .get ("only_last_token" , False )
286-
287291 output , current_kv_cache = model (
288292 input_ids_i ,
289293 slot_mapping = slot_mapping_i ,
@@ -295,6 +299,10 @@ def generate(
295299 attn_name = kwargs ["attn_name" ],
296300 )
297301
302+ # only last token must be handled here to properly stack the tensors
303+ if not only_last_token :
304+ output = output [:, - 1 , :]
305+
298306 # TODO: Figure out how to do this cleanly
299307 if "fp8" in kwargs ["attn_name" ]:
300308 for layer_idx , (t1 , t2 ) in enumerate (current_kv_cache ):
@@ -313,10 +321,41 @@ def generate(
313321 outputs_list .append (output [0 ].squeeze (0 ))
314322
315323 output = (torch .stack (outputs_list ), current_kv_cache )
316-
317324 # decode
318325 else :
326+ # prepare any padding keyword arguments
327+ # iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
328+
319329 # mask is no longer used here
330+ kwargs ["mask" ] = None
331+ kwargs ["position_ids" ] = kwargs ["position_ids" ][:, - 1 :] + 1
332+
333+ # we no longer have a global pos_i, each sequence has its own pos_i
334+ slot_mapping = []
335+ for seq_i , pos_i in enumerate (current_tkv_mask ):
336+ if pos_i % BLOCK_SIZE == 0 :
337+ block_number = block_numbers .pop (0 )
338+ block_table [seq_i ].append (block_number )
339+
340+ block_offset = pos_i % BLOCK_SIZE
341+ slot = block_table [seq_i ][- 1 ] * BLOCK_SIZE + block_offset
342+ slot_mapping .append ([slot ])
343+
344+ kwargs ["block_table" ] = torch .tensor (
345+ [
346+ (
347+ [b_seq [0 ]]
348+ * (max (2 , max ([len (b ) for b in block_table ])) - len (b_seq ))
349+ )
350+ + b_seq
351+ for b_seq in block_table
352+ ],
353+ dtype = torch .int64 ,
354+ )
355+ kwargs ["left_padded_prompt_mask" ] = left_padded_prompt_mask
356+ current_tkv_mask = current_tkv_mask + 1
357+ kwargs ["current_tkv_mask" ] = current_tkv_mask
358+ kwargs ["slot_mapping" ] = torch .tensor (slot_mapping , dtype = torch .int64 )
320359
321360 # batch
322361 torch ._dynamo .mark_dynamic (input_ids , 0 )
@@ -336,7 +375,16 @@ def generate(
336375 torch ._dynamo .mark_static (kwargs ["slot_mapping" ], 1 ) # always 1
337376 torch ._dynamo .mark_static (kwargs ["position_ids" ], 1 ) # always 1
338377
339- output = model (input_ids , ** kwargs )
378+ logits , past_key_value_states = model (input_ids , ** kwargs )
379+
380+ # typically this is done outside of prefill/decode logic, but since this logic already exists as part of the
381+ # conditional for prefill (since prefill does this within a loop for each batch size 1 prefill), we also provide
382+ # this same logic as part of the decode conditional
383+ if not only_last_token :
384+ logits = logits [:, - 1 , :]
385+
386+ output = (logits , past_key_value_states )
387+
340388 if use_cache :
341389 logits , past_key_value_states = output
342390 # TODO: this should go away when reduce-overhead issues are fixed, or
@@ -345,9 +393,6 @@ def generate(
345393 else :
346394 logits = output
347395
348- if not kwargs .get ("only_last_token" , False ):
349- logits = logits [:, - 1 , :]
350-
351396 if do_sample :
352397 # get logits from last value in sequence nad scale
353398 logits = logits / temperature
0 commit comments