Skip to content

Commit 3ba5ea0

Browse files
committed
fix test decoders
Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
1 parent 4e2bddb commit 3ba5ea0

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def load_validation_information(validation_path, validation_files_type, batch_si
187187

188188
return ValidationInfo(validation_info)
189189

190-
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **padding_kwargs):
190+
def extract_validation_information(model, input_ids, max_new_tokens, post_iteration_hook, attn_algorithm=None, eos_token_id = None, only_last_token=False, timing="", **extra_kwargs):
191191
max_seq_len = model.config.max_expected_seq_len
192192
attention_specific_kwargs = {}
193-
if "paged" in padding_kwargs["attn_name"]:
193+
if "paged" in extra_kwargs["attn_name"]:
194194
from aiu_fms_testing_utils.utils.paged import generate
195195
else:
196196
# TODO: Add a unified generation dependent on attn_type
@@ -199,7 +199,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
199199
attention_specific_kwargs["max_seq_len"] = max_seq_len
200200

201201
# Add only_last_token optimization
202-
extra_generation_kwargs = {**padding_kwargs}
202+
extra_generation_kwargs = {**extra_kwargs}
203203
if only_last_token:
204204
extra_generation_kwargs["only_last_token"] = only_last_token
205205
if attn_algorithm is not None:

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import json
1010
import random
1111

12-
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, use_cache: bool = True, **padding_kwargs):
12+
def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int, compile_dynamic_sendnn = False, use_cache: bool = True, **extra_kwargs):
1313
import torch_sendnn
1414
attention_specific_kwargs = {}
15-
attn_name = padding_kwargs["attn_name"]
15+
attn_name = extra_kwargs["attn_name"]
1616
if "paged" in attn_name:
1717
from aiu_fms_testing_utils.utils.paged import generate, adjust_inputs_to_batch
1818
else:
@@ -25,15 +25,15 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
2525

2626
# adjust inputs depending on attn_type and dynamic shapes
2727
_warmup_input_ids = input_ids
28-
_padding_kwargs = padding_kwargs
28+
_extra_kwargs = extra_kwargs
2929
_max_new_tokens = max_new_tokens
3030
if compile_dynamic_sendnn:
3131
_max_new_tokens = 2
3232
# always warmup with batch size 2 when using attn_type=paged
3333
if "paged" in attn_name:
34-
_warmup_input_ids, _padding_kwargs = adjust_inputs_to_batch(input_ids, **padding_kwargs)
34+
_warmup_input_ids, _extra_kwargs = adjust_inputs_to_batch(input_ids, **extra_kwargs)
3535

36-
extra_kwargs = {**_padding_kwargs, "only_last_token": "paged" not in attn_name}
36+
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
3737

3838
with torch_sendnn.warmup_mode():
3939
generate(model, _warmup_input_ids, max_new_tokens=_max_new_tokens, do_sample=False, use_cache=use_cache, extra_kwargs=extra_kwargs, **attention_specific_kwargs)

aiu_fms_testing_utils/utils/paged.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import fms.utils.spyre.paged
77

8-
def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):
8+
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
99
"""
1010
Adjusts the inputs to a batch. Batch size 1 cannot be handled since we want a symbolic shape for the batch
1111
and pytorch automatically sets size 1 dimensions as static
@@ -14,11 +14,11 @@ def adjust_inputs_to_batch(input_ids: torch.Tensor, **padding_kwargs):
1414
"""
1515
input_ids = input_ids[0].repeat(2, 1)
1616
# ensure we pass along other kwargs
17-
kwargs = {**padding_kwargs}
18-
mask = padding_kwargs.get("mask", None)
17+
kwargs = {**extra_kwargs}
18+
mask = extra_kwargs.get("mask", None)
1919
if mask is not None:
2020
kwargs["mask"] = torch.stack((mask[0], mask[0]))
21-
position_ids = padding_kwargs.get("position_ids", None)
21+
position_ids = extra_kwargs.get("position_ids", None)
2222
if position_ids is not None:
2323
kwargs["position_ids"] = position_ids[0].repeat(2, 1)
2424
return input_ids, kwargs

tests/models/test_decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
499499
input_ids, extra_kwargs = __prepare_inputs(
500500
batch_size, seq_length, tokenizer, seed=i
501501
)
502+
extra_kwargs["attn_name"] = ATTN_NAME
502503
cpu_validation_info = __load_validation_info(
503504
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
504505
)

0 commit comments

Comments
 (0)