Skip to content

Commit 9e40f4c

Browse files
committed
cleaned up conditionals in paged generate
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 4ccacc3 commit 9e40f4c

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

aiu_fms_testing_utils/utils/paged.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -235,39 +235,6 @@ def generate(
235235
for i in range(max_new_tokens):
236236
input_ids = next_input[:, -max_possible_context_length:]
237237

238-
# prepare any padding keyword arguments
239-
# iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
240-
if i > 0:
241-
kwargs["mask"] = None
242-
kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1
243-
244-
# we no longer have a global pos_i, each sequence has its own pos_i
245-
slot_mapping = []
246-
for seq_i, pos_i in enumerate(current_tkv_mask):
247-
if pos_i % BLOCK_SIZE == 0:
248-
block_number = block_numbers.pop(0)
249-
block_table[seq_i].append(block_number)
250-
251-
block_offset = pos_i % BLOCK_SIZE
252-
slot = block_table[seq_i][-1] * BLOCK_SIZE + block_offset
253-
slot_mapping.append([slot])
254-
255-
kwargs["block_table"] = torch.tensor(
256-
[
257-
(
258-
[b_seq[0]]
259-
* (max(2, max([len(b) for b in block_table])) - len(b_seq))
260-
)
261-
+ b_seq
262-
for b_seq in block_table
263-
],
264-
dtype=torch.int64,
265-
)
266-
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
267-
current_tkv_mask = current_tkv_mask + 1
268-
kwargs["current_tkv_mask"] = current_tkv_mask
269-
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
270-
271238
# prefill
272239
if i == 0:
273240
kwargs["mask"] = kwargs["mask"].unsqueeze(1)
@@ -354,10 +321,41 @@ def generate(
354321
outputs_list.append(output[0].squeeze(0))
355322

356323
output = (torch.stack(outputs_list), current_kv_cache)
357-
358324
# decode
359325
else:
326+
# prepare any padding keyword arguments
327+
# iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
328+
360329
# mask is no longer used here
330+
kwargs["mask"] = None
331+
kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1
332+
333+
# we no longer have a global pos_i, each sequence has its own pos_i
334+
slot_mapping = []
335+
for seq_i, pos_i in enumerate(current_tkv_mask):
336+
if pos_i % BLOCK_SIZE == 0:
337+
block_number = block_numbers.pop(0)
338+
block_table[seq_i].append(block_number)
339+
340+
block_offset = pos_i % BLOCK_SIZE
341+
slot = block_table[seq_i][-1] * BLOCK_SIZE + block_offset
342+
slot_mapping.append([slot])
343+
344+
kwargs["block_table"] = torch.tensor(
345+
[
346+
(
347+
[b_seq[0]]
348+
* (max(2, max([len(b) for b in block_table])) - len(b_seq))
349+
)
350+
+ b_seq
351+
for b_seq in block_table
352+
],
353+
dtype=torch.int64,
354+
)
355+
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
356+
current_tkv_mask = current_tkv_mask + 1
357+
kwargs["current_tkv_mask"] = current_tkv_mask
358+
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
361359

362360
# batch
363361
torch._dynamo.mark_dynamic(input_ids, 0)

0 commit comments

Comments
 (0)