Skip to content

Commit adafb73

Browse files
committed
Resolve conflicts
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
2 parents dd0d879 + 281ff22 commit adafb73

17 files changed

+1205
-409
lines changed

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 204 additions & 51 deletions
Large diffs are not rendered by default.

aiu_fms_testing_utils/scripts/generate_layers_metrics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
473473
cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output)
474474

475475
prefix = get_default_validation_prefix(
476-
model_path, max_new_token, batch_size, seq_length, "float16"
476+
model_id=model_path,
477+
max_new_tokens=max_new_token,
478+
batch_size=batch_size,
479+
seq_length=seq_length,
480+
dtype="float16",
477481
)
478482
layer_name = str(layer_key).replace("[", "").replace("]", "")
479483

aiu_fms_testing_utils/scripts/generate_metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@
134134

135135
# this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
136136
prefix = get_default_validation_prefix(
137-
args.variant,
138-
args.max_new_tokens,
139-
args.batch_size,
140-
args.min_pad_length,
141-
args.default_dtype,
137+
model_id=args.variant,
138+
max_new_tokens=args.max_new_tokens,
139+
batch_size=args.batch_size,
140+
seq_len=args.min_pad_length,
141+
dtype=args.default_dtype,
142142
)
143143
if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")):
144144
print("skipping metric generation as it has already been done")
@@ -259,7 +259,7 @@ def write_csv(metrics, path, metric_name):
259259
ids.to("cuda"),
260260
args.max_new_tokens,
261261
None,
262-
only_last_token=True,
262+
last_n_tokens=1,
263263
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
264264
)
265265
cuda_static_tokens = cuda_validation_info.get_info("tokens")
@@ -334,7 +334,7 @@ def write_csv(metrics, path, metric_name):
334334
ids.to("cuda"),
335335
args.max_new_tokens,
336336
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
337-
only_last_token=True,
337+
last_n_tokens=1,
338338
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
339339
)
340340

aiu_fms_testing_utils/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def infer(use_cache, do_sample, warmup):
771771
global extra_generation_kwargs
772772
if extra_generation_kwargs is None:
773773
extra_generation_kwargs = {}
774-
extra_generation_kwargs["only_last_token"] = "paged" not in attn_name
774+
extra_generation_kwargs["last_n_tokens"] = 64 if "paged" in attn_name else 1
775775

776776
if not args.no_early_termination and not warmup:
777777
eos_token_id = tokenizer.eos_token_id

aiu_fms_testing_utils/scripts/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
710710
args.max_new_tokens,
711711
post_iteration_hook,
712712
eos_token_id=None if args.no_early_termination else tokenizer.eos_token_id,
713-
only_last_token=True,
713+
last_n_tokens=1,
714714
timing=args.timing,
715715
**padding_kwargs,
716716
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from collections.abc import Iterable
2+
3+
4+
def format_kwargs_to_string(**kwargs):
5+
"""
6+
Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,`
7+
"""
8+
formatted_pairs = []
9+
for key, value in sorted(kwargs.items()):
10+
formatted_value = None
11+
if isinstance(value, str):
12+
formatted_value = value
13+
elif isinstance(value, Iterable):
14+
formatted_value = ",".join(map(str, value))
15+
elif value:
16+
formatted_value = str(value)
17+
# only append if formatted_value exists
18+
if formatted_value:
19+
# Keep previous convention of variable names with `-` instead of `_`
20+
formatted_pairs.append(
21+
f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}"
22+
)
23+
24+
return "_".join(formatted_pairs)

aiu_fms_testing_utils/testing/validation.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from aiu_fms_testing_utils.utils.aiu_setup import dprint
66
from aiu_fms_testing_utils._version import version_tuple
77
import os
8+
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
9+
10+
import hashlib
811

912

1013
class LogitsExtractorHook(
@@ -125,13 +128,7 @@ def __len__(self):
125128

126129

127130
def get_default_validation_prefix(
128-
model_id: str,
129-
max_new_tokens: int,
130-
batch_size: int,
131-
seq_length: int,
132-
dtype: str,
133-
attn_type: str,
134-
aftu_version: str,
131+
**kwargs,
135132
):
136133
"""
137134
Args:
@@ -144,9 +141,17 @@ def get_default_validation_prefix(
144141
aftu_version (str): introduced in v0.3.0 to track changed in log
145142
146143
Returns:
147-
str: A prefix that will be prepended to the file name
144+
str: A hashed prefix that will be prepended to the file name
148145
"""
149-
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}"
146+
aftu_version = kwargs.pop(
147+
"aftu_version", ".".join([str(_) for _ in version_tuple[:3]])
148+
)
149+
kwargs_str = format_kwargs_to_string(**kwargs)
150+
151+
filename = f"{kwargs_str}"
152+
hash_object = hashlib.sha256(filename.encode("utf-8"))
153+
hex_digest = hash_object.hexdigest()
154+
return f"{hex_digest}_{aftu_version}"
150155

151156

152157
def load_validation_information(
@@ -256,7 +261,7 @@ def extract_validation_information(
256261
post_iteration_hook,
257262
attn_algorithm=None,
258263
eos_token_id=None,
259-
only_last_token=False,
264+
last_n_tokens=0,
260265
timing="",
261266
**extra_kwargs,
262267
):
@@ -270,10 +275,10 @@ def extract_validation_information(
270275
attention_specific_kwargs["contiguous_cache"] = True
271276
attention_specific_kwargs["max_seq_len"] = input_ids.shape[1] + max_new_tokens
272277

273-
# Add only_last_token optimization
278+
# Add last_n_tokens optimization
274279
extra_generation_kwargs = {**extra_kwargs}
275-
if only_last_token:
276-
extra_generation_kwargs["only_last_token"] = only_last_token
280+
if last_n_tokens != 0:
281+
extra_generation_kwargs["last_n_tokens"] = last_n_tokens
277282
if attn_algorithm is not None:
278283
extra_generation_kwargs["attn_algorithm"] = attn_algorithm
279284

@@ -416,26 +421,29 @@ def get_validation_info_path(
416421
aftu_version: Optional[Tuple[int, int, int]] = None,
417422
device_type: str = "cpu",
418423
dtype: str = "fp16",
424+
**kwargs,
419425
):
420426
if aftu_version is None:
421427
aftu_version = version_tuple
422428

423-
validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
429+
sample_key = kwargs.get("sample_key", None)
430+
431+
validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
424432
full_path = os.path.join(validation_info_dir, validation_file_name)
425433
return full_path
426434

427435

428-
def __decrement_version(version: Tuple[int, int, int]):
436+
def __decrement_version(version: Tuple[int, int, int], max_minor=25, max_patch=25):
429437
"""
430438
Function designed to prevent triple nested for loop while decrementing version
431439
"""
432440
major, minor, patch = version
433441
if patch > 0:
434442
return (major, minor, patch - 1)
435443
elif minor > 0:
436-
return (major, minor - 1, 0)
444+
return (major, minor - 1, max_patch)
437445
elif major > 0:
438-
return (major - 1, 0, 0)
446+
return (major - 1, max_minor, max_patch)
439447
else:
440448
return None
441449

@@ -452,10 +460,12 @@ def find_validation_info_path(
452460
version_allow_decrement: bool = False,
453461
device_type: str = "cpu",
454462
dtype: str = "fp16",
463+
**kwargs,
455464
):
456465
"""
457466
Find the validation info path if it exists, otherwise return None
458467
"""
468+
sample_key = kwargs.get("sample_key", None)
459469

460470
if aftu_version is None:
461471
loc_version_tuple = version_tuple[:3]
@@ -476,6 +486,7 @@ def find_validation_info_path(
476486
loc_version_tuple,
477487
device_type,
478488
dtype,
489+
sample_key=sample_key,
479490
)
480491
# if the path is found, we are done searching and can return
481492
if os.path.exists(full_path):

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
1313
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
14+
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
1415

1516
from fms.utils.generation import pad_input_ids
1617
import torch
@@ -85,7 +86,7 @@ def warmup_model(
8586
**extra_kwargs,
8687
)
8788

88-
extra_kwargs = {**_extra_kwargs, "only_last_token": "paged" not in attn_name}
89+
extra_kwargs = {**_extra_kwargs, "last_n_tokens": 64 if "paged" in attn_name else 1}
8990

9091
with stagger_region(stagger_update_lazyhandle):
9192
with torch_sendnn.warmup_mode():
@@ -421,8 +422,11 @@ def __sample_requests(
421422
prompt_token_ids = tokenizer.encode(
422423
prompt, add_special_tokens=False
423424
)
425+
# If we don't set clean_up_tokenization_spaces=False, encoding then decoding text might result in different lengths which would break expected results from the sampler
424426
truncated_prompt = tokenizer.decode(
425-
prompt_token_ids[:truncate_to_size], skip_special_tokens=True
427+
prompt_token_ids[:truncate_to_size],
428+
skip_special_tokens=True,
429+
clean_up_tokenization_spaces=False,
426430
)
427431
enforced_dataset.append((truncated_prompt, truncate_to_size))
428432
enforce_sizes_with_truncation.remove(truncation_found)
@@ -479,6 +483,7 @@ def sample_rag_factoid_requests(
479483
enforce_sizes: List[int] = [],
480484
truncation: bool = False,
481485
pad_multiple: int = 64,
486+
return_key: bool = False,
482487
) -> List[Tuple[str, int]]:
483488
if not os.path.exists(dataset_path):
484489
print("error dataset does not exist")
@@ -489,7 +494,7 @@ def sample_rag_factoid_requests(
489494
for line in f:
490495
dataset.append(line)
491496

492-
return __sample_requests(
497+
sample_request = __sample_requests(
493498
dataset,
494499
num_requests,
495500
tokenizer,
@@ -503,6 +508,24 @@ def sample_rag_factoid_requests(
503508
_cached_dataset_key=dataset_path,
504509
)
505510

511+
if return_key:
512+
sample_key: str = format_kwargs_to_string(
513+
dataset="rag_factoid",
514+
num_requests=num_requests,
515+
tokenizer=tokenizer.name_or_path.replace("/", "--"),
516+
prompt_length_min=prompt_length_min,
517+
prompt_length_max=prompt_length_max,
518+
seed=seed,
519+
enforce_heterogeneous=enforce_heterogeneous,
520+
enforce_sizes=enforce_sizes,
521+
truncate=truncation,
522+
pad_multiple=pad_multiple,
523+
)
524+
525+
return sample_request, sample_key
526+
else:
527+
return sample_request
528+
506529

507530
def sample_sharegpt_requests(
508531
dataset_path: str,
@@ -515,6 +538,7 @@ def sample_sharegpt_requests(
515538
enforce_sizes: List[int] | None = None,
516539
truncation: bool = False,
517540
pad_multiple: int = 64,
541+
return_key: bool = False,
518542
) -> List[Tuple[str, int]]:
519543
if not os.path.exists(dataset_path):
520544
print("downloading share-gpt dataset as it does not exist")
@@ -540,7 +564,7 @@ def sample_sharegpt_requests(
540564
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
541565
dataset: List[str] = [data["conversations"][0]["value"] for data in dataset]
542566

543-
return __sample_requests(
567+
sample_request = __sample_requests(
544568
dataset,
545569
num_requests,
546570
tokenizer,
@@ -554,6 +578,23 @@ def sample_sharegpt_requests(
554578
_cached_dataset_key=dataset_path,
555579
)
556580

581+
if return_key:
582+
sample_key: str = format_kwargs_to_string(
583+
dataset="sharegpt",
584+
num_requests=num_requests,
585+
tokenizer=tokenizer.name_or_path.replace("/", "--"),
586+
prompt_length_min=prompt_length_min,
587+
prompt_length_max=prompt_length_max,
588+
seed=seed,
589+
enforce_heterogeneous=enforce_heterogeneous,
590+
enforce_sizes=enforce_sizes,
591+
truncate=truncation,
592+
pad_multiple=pad_multiple,
593+
)
594+
return sample_request, sample_key
595+
else:
596+
return sample_request
597+
557598

558599
def sample_squad_v2_qa_requests(
559600
dataset_path: str,
@@ -566,6 +607,7 @@ def sample_squad_v2_qa_requests(
566607
enforce_sizes: List[int] | None = None,
567608
truncation: bool = False,
568609
pad_multiple: int = 64,
610+
return_key: bool = False,
569611
) -> List[Tuple[str, int]]:
570612
from datasets import load_dataset
571613

@@ -579,7 +621,7 @@ def sample_squad_v2_qa_requests(
579621

580622
ds = [f"{data['context']}\n{data['question']}" for data in ds]
581623

582-
return __sample_requests(
624+
sample_request = __sample_requests(
583625
ds,
584626
num_requests,
585627
tokenizer,
@@ -592,6 +634,23 @@ def sample_squad_v2_qa_requests(
592634
pad_multiple,
593635
)
594636

637+
if return_key:
638+
sample_key: str = format_kwargs_to_string(
639+
dataset="squad_v2",
640+
num_requests=num_requests,
641+
tokenizer=tokenizer.name_or_path.replace("/", "--"),
642+
prompt_length_min=prompt_length_min,
643+
prompt_length_max=prompt_length_max,
644+
seed=seed,
645+
enforce_heterogeneous=enforce_heterogeneous,
646+
enforce_sizes=enforce_sizes,
647+
truncate=truncation,
648+
pad_multiple=pad_multiple,
649+
)
650+
return sample_request, sample_key
651+
else:
652+
return sample_request
653+
595654

596655
def prepare_inputs(
597656
batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"

0 commit comments

Comments
 (0)