Skip to content

Commit bd1090e

Browse files
authored
Merge pull request #91 from foundation-model-stack/homogeneous_tkv_left_padded_optimization_only_last_token
Homogeneous tkv left padded optimization - prefill compute and kv-cache block id re-use
2 parents 59333c2 + 9e40f4c commit bd1090e

File tree

6 files changed

+152
-66
lines changed

6 files changed

+152
-66
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def warmup_model(
2525
import torch_sendnn
2626

2727
attention_specific_kwargs = {}
28-
attn_name = extra_kwargs["attn_name"]
28+
attn_name = extra_kwargs.get("attn_name", "sdpa")
2929
if "paged" in attn_name:
3030
from aiu_fms_testing_utils.utils.paged import generate, adjust_inputs_to_batch
3131
else:

aiu_fms_testing_utils/utils/paged.py

Lines changed: 101 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,17 @@ def generate(
106106

107107
result = input_ids
108108
next_input = input_ids
109+
# this includes empty pages and max_new_tokens
110+
max_possible_context_length = input_ids.size(1) + max_new_tokens
111+
109112
BLOCK_SIZE = 64
110-
_MAX_BATCH = int(
111-
os.environ.setdefault("VLLM_DT_MAX_BATCH_SIZE", str(input_ids.size(0)))
112-
)
113-
_MAX_CONTEXT_LENGTH = int(
114-
os.environ.setdefault(
115-
"VLLM_DT_MAX_CONTEXT_LEN",
116-
str(
117-
(((input_ids.size(1) + max_new_tokens - 1) // BLOCK_SIZE) + 1)
118-
* BLOCK_SIZE
119-
),
120-
)
121-
)
113+
114+
# these variables are guaranteed to be set in another location (inference.py, test_decoders.py, etc.)
115+
# if we set these variables here, we run the risk of warming up and generating with different sizes
116+
_MAX_BATCH = int(os.environ["VLLM_DT_MAX_BATCH_SIZE"])
117+
_MAX_CONTEXT_LENGTH = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])
122118
NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE
123-
max_seq_len = input_ids.size(1) + max_new_tokens
119+
124120
if hasattr(model, "head"):
125121
model_dtype = model.head.weight.dtype
126122
elif hasattr(model, "shared"):
@@ -194,27 +190,41 @@ def generate(
194190
block_numbers = [i for i in range(NUM_BLOCKS)]
195191
# this will ensure we don't have contiguous blocks
196192
random.shuffle(block_numbers)
193+
194+
# this is the true number of left pads when computing paged attention using a paged kv-cache
195+
# it may include whole empty pages
197196
left_padded_prompt_mask = (kwargs["position_ids"] == 0).sum(dim=1) - 1
198-
current_context_lengths = (kwargs["position_ids"] != 0).sum(dim=1) + 1
199-
current_tkv_mask = left_padded_prompt_mask + current_context_lengths
197+
198+
# this is the context length for each sequence without pads
199+
context_lengths_without_pads = (kwargs["position_ids"] != 0).sum(dim=1) + 1
200+
201+
# this is the context length for each sequence with no empty pages (padded to multiple of 64)
202+
context_lengths = BLOCK_SIZE * (
203+
(context_lengths_without_pads + BLOCK_SIZE - 1) // BLOCK_SIZE
204+
)
205+
206+
# left_padded_prompt_mask - empty_slots + context_lengths
207+
current_tkv_mask = torch.fill(context_lengths, torch.max(context_lengths))
208+
200209
slot_mapping = []
201210
block_table = []
202-
for seq_i in input_ids:
203-
block_table_i = []
211+
# each sequence has the possibility of a different tkv, so loop over that
212+
for seq_tkv in context_lengths:
213+
block_table_i = [block_numbers.pop(0) for _ in range(seq_tkv // BLOCK_SIZE)]
204214
slot_mapping_i = []
205-
for pos_i in range(seq_i.size(0)):
206-
if pos_i % BLOCK_SIZE == 0:
207-
block_number = block_numbers.pop(0)
208-
block_table_i.append(block_number)
215+
for pos_i in range(seq_tkv):
216+
# we may have already popped a block, so index to the proper block
217+
block_number = block_table_i[pos_i // BLOCK_SIZE]
218+
209219
block_offset = pos_i % BLOCK_SIZE
210220
slot = block_number * BLOCK_SIZE + block_offset
211221
slot_mapping_i.append(slot)
212222
slot_mapping.append(slot_mapping_i)
213223
block_table.append(block_table_i)
214-
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
215224
kwargs["current_tkv_mask"] = None
216225
kwargs["left_padded_prompt_mask"] = None
217226
kwargs["use_cache"] = use_cache
227+
only_last_token = kwargs.get("only_last_token", False)
218228

219229
prompt_length = input_ids.shape[1]
220230

@@ -223,45 +233,40 @@ def generate(
223233
start_time = time.time()
224234

225235
for i in range(max_new_tokens):
226-
input_ids = next_input[:, -max_seq_len:]
227-
228-
# prepare any padding keyword arguments
229-
# iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
230-
if i > 0:
231-
kwargs["mask"] = None
232-
kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1
233-
pos_i = result.size(1) - 1
234-
if pos_i % BLOCK_SIZE == 0:
235-
for block_table_i in block_table:
236-
block_number = block_numbers.pop(0)
237-
block_table_i.append(block_number)
238-
block_offset = pos_i % BLOCK_SIZE
239-
240-
slot_mapping = []
241-
for block_table_i in block_table:
242-
slot = block_table_i[-1] * BLOCK_SIZE + block_offset
243-
slot_mapping.append([slot])
244-
kwargs["block_table"] = torch.tensor(block_table, dtype=torch.int64)
245-
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
246-
current_tkv_mask = current_tkv_mask + 1
247-
kwargs["current_tkv_mask"] = current_tkv_mask
248-
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
236+
input_ids = next_input[:, -max_possible_context_length:]
249237

250238
# prefill
251239
if i == 0:
252240
kwargs["mask"] = kwargs["mask"].unsqueeze(1)
253241

254242
outputs_list = []
255243
current_kv_cache = kwargs["past_key_value_states"]
244+
256245
if "fp8" in kwargs["attn_name"]:
257246
current_kv_scales = [
258247
(t1._scale, t2._scale) for t1, t2 in kwargs["past_key_value_states"]
259248
]
260-
for seq_i in range(input_ids.size(0)):
261-
input_ids_i = input_ids[seq_i].unsqueeze(0)
262-
slot_mapping_i = kwargs["slot_mapping"][seq_i].unsqueeze(0)
263-
position_ids_i = kwargs["position_ids"][seq_i].unsqueeze(0)
264-
mask_i = kwargs["mask"][seq_i].unsqueeze(0)
249+
for seq_i, current_tkv in enumerate(context_lengths):
250+
# remove extra pads from the input_ids, slot_mapping, position_ids, mask to account for empty pages
251+
# each input should be padded to its smallest multiple of BLOCK_SIZE (64)
252+
# we need to clone these tensors to ensure the pointer offset is 0
253+
input_ids_i = input_ids[seq_i][-current_tkv:].unsqueeze(0).clone()
254+
slot_mapping_i = (
255+
torch.tensor(slot_mapping[seq_i][-current_tkv:], dtype=torch.int64)
256+
.unsqueeze(0)
257+
.clone()
258+
)
259+
position_ids_i = (
260+
kwargs["position_ids"][seq_i][-current_tkv:].unsqueeze(0).clone()
261+
)
262+
263+
# This view will result in a discontiguous tensor (creates a new graph during compile)
264+
# For this reason, we must explicitly make contiguous
265+
mask_i = (
266+
kwargs["mask"][seq_i][:, -current_tkv:, -current_tkv:]
267+
.unsqueeze(0)
268+
.contiguous()
269+
)
265270

266271
# batch dynamic
267272
torch._dynamo.mark_static(input_ids_i, 0)
@@ -283,7 +288,6 @@ def generate(
283288
t2._scale = current_kv_scales[layer_idx][1][seq_i].reshape(-1)
284289

285290
only_last_token = kwargs.get("only_last_token", False)
286-
287291
output, current_kv_cache = model(
288292
input_ids_i,
289293
slot_mapping=slot_mapping_i,
@@ -295,6 +299,10 @@ def generate(
295299
attn_name=kwargs["attn_name"],
296300
)
297301

302+
# only last token must be handled here to properly stack the tensors
303+
if not only_last_token:
304+
output = output[:, -1, :]
305+
298306
# TODO: Figure out how to do this cleanly
299307
if "fp8" in kwargs["attn_name"]:
300308
for layer_idx, (t1, t2) in enumerate(current_kv_cache):
@@ -313,10 +321,41 @@ def generate(
313321
outputs_list.append(output[0].squeeze(0))
314322

315323
output = (torch.stack(outputs_list), current_kv_cache)
316-
317324
# decode
318325
else:
326+
# prepare any padding keyword arguments
327+
# iteration 0 is the prefill step (cache has not been filled yet), so no need to extend the mask/position_ids
328+
319329
# mask is no longer used here
330+
kwargs["mask"] = None
331+
kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1
332+
333+
# we no longer have a global pos_i, each sequence has its own pos_i
334+
slot_mapping = []
335+
for seq_i, pos_i in enumerate(current_tkv_mask):
336+
if pos_i % BLOCK_SIZE == 0:
337+
block_number = block_numbers.pop(0)
338+
block_table[seq_i].append(block_number)
339+
340+
block_offset = pos_i % BLOCK_SIZE
341+
slot = block_table[seq_i][-1] * BLOCK_SIZE + block_offset
342+
slot_mapping.append([slot])
343+
344+
kwargs["block_table"] = torch.tensor(
345+
[
346+
(
347+
[b_seq[0]]
348+
* (max(2, max([len(b) for b in block_table])) - len(b_seq))
349+
)
350+
+ b_seq
351+
for b_seq in block_table
352+
],
353+
dtype=torch.int64,
354+
)
355+
kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask
356+
current_tkv_mask = current_tkv_mask + 1
357+
kwargs["current_tkv_mask"] = current_tkv_mask
358+
kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64)
320359

321360
# batch
322361
torch._dynamo.mark_dynamic(input_ids, 0)
@@ -336,7 +375,16 @@ def generate(
336375
torch._dynamo.mark_static(kwargs["slot_mapping"], 1) # always 1
337376
torch._dynamo.mark_static(kwargs["position_ids"], 1) # always 1
338377

339-
output = model(input_ids, **kwargs)
378+
logits, past_key_value_states = model(input_ids, **kwargs)
379+
380+
# typically this is done outside of prefill/decode logic, but since this logic already exists as part of the
381+
# conditional for prefill (since prefill does this within a loop for each batch size 1 prefill), we also provide
382+
# this same logic as part of the decode conditional
383+
if not only_last_token:
384+
logits = logits[:, -1, :]
385+
386+
output = (logits, past_key_value_states)
387+
340388
if use_cache:
341389
logits, past_key_value_states = output
342390
# TODO: this should go away when reduce-overhead issues are fixed, or
@@ -345,9 +393,6 @@ def generate(
345393
else:
346394
logits = output
347395

348-
if not kwargs.get("only_last_token", False):
349-
logits = logits[:, -1, :]
350-
351396
if do_sample:
352397
# get logits from last value in sequence nad scale
353398
logits = logits / temperature

scripts/inference.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,24 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length):
690690

691691
extra_generation_kwargs["attn_name"] = attn_name
692692

693+
if "paged" in attn_name:
694+
import bisect
695+
696+
# the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
697+
# this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
698+
# if the user provides their own VLLM_DT_MAX_CONTEXT_LEN, use this value instead
699+
__largest_context = ids.shape[1] + args.max_new_tokens
700+
__supported_context_lengths = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
701+
os.environ.setdefault(
702+
"VLLM_DT_MAX_CONTEXT_LEN",
703+
str(
704+
__supported_context_lengths[
705+
bisect.bisect_left(__supported_context_lengths, __largest_context)
706+
]
707+
),
708+
)
709+
os.environ.setdefault("VLLM_DT_MAX_BATCH_SIZE", str(max(ids.shape[0], 2)))
710+
693711

694712
def print_result(result, result_idx: int):
695713
if local_rank != 0:

tests/models/test_decoders.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,16 @@
162162
compile_dynamic_sendnn = ATTN_TYPE == "paged"
163163

164164
if compile_dynamic_sendnn:
165+
import bisect
166+
167+
# the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
168+
# this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
169+
__largest_context = max(common_seq_lengths) + max(common_max_new_tokens)
170+
__supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192]
165171
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
166-
(((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64
172+
__supported_context_lengths[
173+
bisect.bisect_left(__supported_context_lengths, __largest_context)
174+
]
167175
)
168176
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2))
169177

@@ -290,7 +298,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
290298
SHARE_GPT_DATASET_PATH,
291299
batch_size,
292300
tokenizer,
293-
int(seq_length / 2),
301+
seq_length // 2,
294302
seq_length,
295303
seed,
296304
)
@@ -417,7 +425,7 @@ def test_common_shapes(
417425
os.environ["COMPILATION_MODE"] = "offline_decoder"
418426

419427
dprint(
420-
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
428+
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, attn_type={ATTN_TYPE}"
421429
)
422430

423431
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured

tests/models/test_scripts.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
common_batch_sizes = [1, 8]
2121
common_seq_lengths = [64]
2222
common_max_new_tokens = [12]
23+
common_attn_types = ["sdpa", "paged"]
2324

2425
common_params = list(
2526
itertools.product(
2627
common_model_paths,
2728
common_batch_sizes,
2829
common_seq_lengths,
2930
common_max_new_tokens,
31+
common_attn_types,
3032
)
3133
)
3234

@@ -51,7 +53,12 @@ def execute_script(execute_cmd):
5153
raise Exception(error)
5254

5355

54-
def execute_inference(model_path, max_new_tokens, batch_size, seq_length):
56+
def execute_inference(model_path, batch_size, seq_length, max_new_tokens, attn_type):
57+
extra_args = []
58+
if attn_type == "paged":
59+
extra_args.append("--compile_dynamic_sendnn")
60+
extra_args.append("--attention_type=paged")
61+
5562
execute_cmd = [
5663
"python3",
5764
INFERENCE_FILE_PATH,
@@ -68,7 +75,7 @@ def execute_inference(model_path, max_new_tokens, batch_size, seq_length):
6875
"--device_type=aiu",
6976
"--default_dtype=fp16",
7077
]
71-
return execute_script(execute_cmd)
78+
return execute_script(execute_cmd + extra_args)
7279

7380

7481
common_asserts = [
@@ -93,10 +100,15 @@ def __repeat_batch_asserts(bs: int) -> list[str]:
93100

94101

95102
@pytest.mark.parametrize(
96-
"model_path,batch_size,seq_length,max_new_tokens,asserts", common_inference_params
103+
"model_path,batch_size,seq_length,max_new_tokens,attn_type,asserts",
104+
common_inference_params,
97105
)
98-
def test_inference_script(model_path, max_new_tokens, seq_length, batch_size, asserts):
99-
result_text = execute_inference(model_path, max_new_tokens, batch_size, seq_length)
106+
def test_inference_script(
107+
model_path, batch_size, seq_length, max_new_tokens, attn_type, asserts
108+
):
109+
result_text = execute_inference(
110+
model_path, batch_size, seq_length, max_new_tokens, attn_type
111+
)
100112

101113
for common_assert in asserts:
102114
assert common_assert in result_text

tests/utils/test_paged.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fms.utils.generation import pad_input_ids, generate
44
from aiu_fms_testing_utils.utils.paged import generate as paged_generate
55
from fms.utils.tokenizers import get_tokenizer
6+
import os
67

78

89
def test_paged_equivalence():
@@ -32,13 +33,15 @@ def test_paged_equivalence():
3233
use_cache=True,
3334
extra_kwargs=padding_kwargs,
3435
)
36+
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = "2"
37+
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = "256"
3538

3639
result_paged = paged_generate(
3740
_model_mock,
3841
ids,
3942
max_new_tokens=5,
4043
do_sample=False,
4144
use_cache=True,
42-
extra_kwargs=padding_kwargs,
45+
extra_kwargs={"attn_name": "spyre_paged_attn", **padding_kwargs},
4346
)
4447
torch.testing.assert_close(result, result_paged)

0 commit comments

Comments
 (0)