From 051c18f6749803c93b02a52ab3545d01af64a950 Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Fri, 26 Sep 2025 22:00:28 +0000 Subject: [PATCH 01/20] [dpp] store enforce_sizes in log name and added generic kwargs to get_default_validation_prefix, enable sample_key Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/utils.py | 19 +++++ aiu_fms_testing_utils/testing/validation.py | 17 ++++- aiu_fms_testing_utils/utils/__init__.py | 43 ++++++++++- scripts/drive_paged_programs.py | 9 ++- tests/testing/test_validation.py | 85 +++++++++++++++++++++ 5 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 aiu_fms_testing_utils/testing/utils.py diff --git a/aiu_fms_testing_utils/testing/utils.py b/aiu_fms_testing_utils/testing/utils.py new file mode 100644 index 00000000..79ae564f --- /dev/null +++ b/aiu_fms_testing_utils/testing/utils.py @@ -0,0 +1,19 @@ +from collections.abc import Iterable + + +def format_kwargs_to_string(**kwargs): + formatted_pairs = [] + for key, value in sorted(kwargs.items()): + formatted_value = None + if isinstance(value, str): + formatted_value = value + elif isinstance(value, Iterable): + formatted_value = ",".join(map(str, value)) + elif value: + formatted_value = str(value) + # only append if formatted_value exists + if formatted_value: + # Keep previous convention of variable names with `-` instead of `_` + formatted_pairs.append(f"{key.replace('_', '-')}-{formatted_value}") + + return "_".join(formatted_pairs) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 0c655ff5..5749cb89 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -5,6 +5,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint from aiu_fms_testing_utils._version import version_tuple import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string class LogitsExtractorHook( @@ -132,6 +133,7 @@ def get_default_validation_prefix( dtype: str, attn_type: str, aftu_version: str, + **kwargs, ): """ Args: @@ -146,7 +148,12 @@ def get_default_validation_prefix( Returns: str: A prefix that will be prepended to the file name """ - return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}" + kwargs_str = format_kwargs_to_string(**kwargs) + + if kwargs_str == "": + return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}" + else: + return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}.{aftu_version}" def load_validation_information( @@ -416,11 +423,14 @@ def get_validation_info_path( aftu_version: Optional[Tuple[int, int, int]] = None, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): if aftu_version is None: aftu_version = version_tuple - validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + sample_key = kwargs.get("sample_key", None) + + validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" full_path = os.path.join(validation_info_dir, validation_file_name) return full_path @@ -452,10 +462,12 @@ def find_validation_info_path( version_allow_decrement: bool = False, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): """ Find the validation info path if it exists, otherwise return None """ + enforce_sizes = kwargs.get("enforce_sizes", None) if aftu_version is None: loc_version_tuple = version_tuple[:3] @@ -476,6 +488,7 @@ def find_validation_info_path( loc_version_tuple, device_type, dtype, + enforce_sizes=enforce_sizes, ) # if the path is found, we are done searching and can return if os.path.exists(full_path): diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 65a0f9ab..796567ee 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -11,6 +11,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string from fms.utils.generation import pad_input_ids import torch @@ -482,6 +483,7 @@ def sample_rag_factoid_requests( enforce_sizes: List[int] = [], truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("error dataset does not exist") @@ -492,7 +494,7 @@ def sample_rag_factoid_requests( for line in f: dataset.append(line) - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -506,6 +508,24 @@ def sample_rag_factoid_requests( _cached_dataset_key=dataset_path, ) + sample_key: str = format_kwargs_to_string( + dataset="rag_factoid", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + + if return_key: + return sample_request, sample_key + else: + return sample_request + def sample_sharegpt_requests( dataset_path: str, @@ -518,6 +538,7 @@ def sample_sharegpt_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("downloading share-gpt dataset as it does not exist") @@ -543,7 +564,7 @@ def sample_sharegpt_requests( dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset: List[str] = [data["conversations"][0]["value"] for data in dataset] - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -557,6 +578,24 @@ def sample_sharegpt_requests( _cached_dataset_key=dataset_path, ) + sample_key: str = format_kwargs_to_string( + dataset="sharegpt", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + + if return_key: + return sample_request, sample_key + else: + return sample_request + def sample_squad_v2_qa_requests( dataset_path: str, diff --git a/scripts/drive_paged_programs.py b/scripts/drive_paged_programs.py index ea51bad8..469e3ae0 100644 --- a/scripts/drive_paged_programs.py +++ b/scripts/drive_paged_programs.py @@ -245,7 +245,7 @@ def __custom_line_sampler(*args, **kwargs): def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): start = time.time() - prompts_and_sizes = sampler( + prompts_and_sizes, sample_key = sampler( DATASET_PATH, batch_size, tokenizer, @@ -254,6 +254,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 seed, enforce_sizes=enforce_sizes, truncation=allow_truncation, + return_key=True, ) end = time.time() if local_rank == 0: @@ -274,7 +275,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - return input_ids, extra_kwargs + return input_ids, extra_kwargs, sample_key def __maybe_prepare_fp8_weights(model_in, is_fp8): @@ -367,13 +368,14 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available -# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer) +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) prompt_list = [torch.arange(0, 64, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -657,6 +659,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): 0, ATTN_NAME, dtype=CPU_DTYPE, + sample_key=sample_key, ) ) diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index ac3367ae..b02bd19c 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -8,7 +8,13 @@ get_validation_info_path, find_validation_info_path, __decrement_version, + get_default_validation_prefix, ) +import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +from aiu_fms_testing_utils.utils import sample_sharegpt_requests +from transformers import AutoTokenizer + from aiu_fms_testing_utils._version import version_tuple from fms.models import get_model from fms.utils.generation import pad_input_ids @@ -238,3 +244,82 @@ def test_decrement_version(max_minor, max_patch, current_version): + patch + 1 ) +def test_format_kwargs_to_string(): + kwargs = { + "enforce_sizes": [1, 32, 4, 8], + "batch_size": 1, + "model_id": "granite-3.3-8b", + "seq_len": 64, + } + kwargs_str = format_kwargs_to_string(**kwargs) + assert ( + kwargs_str + == "batch-size-1_enforce-sizes-1,32,4,8_model-id-granite-3.3-8b_seq-len-64" + ) + + +DATASET_PATH = os.getenv( + "DATASET_PATH", "/mnt/home/models/ShareGPT_V3_unfiltered_cleaned_split.json" +) +TOKENIZER = os.getenv("TOKENIZER", "ibm-granite/granite-3.3-8b-Instruct") + + +@pytest.mark.parametrize( + "model_variant,max_new_tokens,batch_size,seq_length,dtype,attn_type,device_type,seed,aftu_version", + [("granite-3.3-8b", 64, 2, 64, "fp16", "spda", "cpu", 0, (1, 2, 3))], +) +def test_get_default_validation_prefix( + model_variant, + max_new_tokens, + batch_size, + seq_length, + dtype, + attn_type, + device_type, + seed, + aftu_version, +): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + sample_key = None + # get_default_validation_prefix with sample_key set to None + prefix_sample_key_none = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + + assert ( + prefix_sample_key_none + == f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.1.2.3.cpu_validation_info.0.out" + ) + + # get_default_validation_prefix with no kwargs using legacy case + legacy_prefix = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + assert prefix_sample_key_none == legacy_prefix + + # retrieve a sample_key with return_key is True + dataset_1, sample_key = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + return_key=True, + ) + prefix_with_sample_key = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + + # Check sample key sorted by parameter name + assert sample_key.split("_") == sorted(sample_key.split("_")) + # Check sample key included in name as expected + assert "sample-key-" + sample_key in prefix_with_sample_key + + dataset_2 = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + ) + + assert dataset_1 == dataset_2 From f3613ba9543abd269df25cdb6baaec2cb7469be0 Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:23:44 +0000 Subject: [PATCH 02/20] [utils] added doc string, refactor sample_key, added return_key to squad_v2 sampler Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/utils.py | 3 ++ aiu_fms_testing_utils/utils/__init__.py | 71 +++++++++++++++---------- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/aiu_fms_testing_utils/testing/utils.py b/aiu_fms_testing_utils/testing/utils.py index 79ae564f..72fd30b2 100644 --- a/aiu_fms_testing_utils/testing/utils.py +++ b/aiu_fms_testing_utils/testing/utils.py @@ -2,6 +2,9 @@ def format_kwargs_to_string(**kwargs): + """ + Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,` + """ formatted_pairs = [] for key, value in sorted(kwargs.items()): formatted_value = None diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 796567ee..6615c5c9 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -508,20 +508,20 @@ def sample_rag_factoid_requests( _cached_dataset_key=dataset_path, ) - sample_key: str = format_kwargs_to_string( - dataset="rag_factoid", - num_requests=num_requests, - tokenizer=tokenizer.name_or_path.replace("/", "--"), - prompt_length_min=prompt_length_min, - prompt_length_max=prompt_length_max, - seed=seed, - enforce_heterogeneous=enforce_heterogeneous, - enforce_sizes=enforce_sizes, - truncate=truncation, - pad_multiple=pad_multiple, - ) - if return_key: + sample_key: str = format_kwargs_to_string( + dataset="rag_factoid", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key else: return sample_request @@ -578,20 +578,19 @@ def sample_sharegpt_requests( _cached_dataset_key=dataset_path, ) - sample_key: str = format_kwargs_to_string( - dataset="sharegpt", - num_requests=num_requests, - tokenizer=tokenizer.name_or_path.replace("/", "--"), - prompt_length_min=prompt_length_min, - prompt_length_max=prompt_length_max, - seed=seed, - enforce_heterogeneous=enforce_heterogeneous, - enforce_sizes=enforce_sizes, - truncate=truncation, - pad_multiple=pad_multiple, - ) - if return_key: + sample_key: str = format_kwargs_to_string( + dataset="sharegpt", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) return sample_request, sample_key else: return sample_request @@ -608,6 +607,7 @@ def sample_squad_v2_qa_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: from datasets import load_dataset @@ -621,7 +621,7 @@ def sample_squad_v2_qa_requests( ds = [f"{data['context']}\n{data['question']}" for data in ds] - return __sample_requests( + sample_request = __sample_requests( ds, num_requests, tokenizer, @@ -634,6 +634,23 @@ def sample_squad_v2_qa_requests( pad_multiple, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="squad_v2", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key + else: + return sample_request + def prepare_inputs( batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt" From 2f69fa3c900d2de5df5798ab32856a8f49dd3b9b Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:12:19 +0000 Subject: [PATCH 03/20] [dpp/validation] restore sample_key in logic after rebase of main Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/validation.py | 2 ++ scripts/drive_paged_programs.py | 12 +++++++++--- tests/models/test_decoders.py | 8 +++++++- tests/testing/test_validation.py | 16 ++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 5749cb89..f62c7def 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -468,6 +468,7 @@ def find_validation_info_path( Find the validation info path if it exists, otherwise return None """ enforce_sizes = kwargs.get("enforce_sizes", None) + sample_key = kwargs.get("sample_key", None) if aftu_version is None: loc_version_tuple = version_tuple[:3] @@ -489,6 +490,7 @@ def find_validation_info_path( device_type, dtype, enforce_sizes=enforce_sizes, + sample_key=sample_key, ) # if the path is found, we are done searching and can return if os.path.exists(full_path): diff --git a/scripts/drive_paged_programs.py b/scripts/drive_paged_programs.py index 469e3ae0..e3cb7ca6 100644 --- a/scripts/drive_paged_programs.py +++ b/scripts/drive_paged_programs.py @@ -297,7 +297,9 @@ def __load_validation_info( tokenizer, seed, attn_type: str, + **kwargs, ): + sample_key = kwargs.get("sample_key", None) full_path = find_validation_info_path( args.validation_info_outputs_dir, model_variant, @@ -308,6 +310,7 @@ def __load_validation_info( attn_type, version_allow_decrement=True, dtype=CPU_DTYPE, + sample_key=sample_key, ) if full_path is not None: dprint(f"cpu validation info found for seed={seed} -- loading it") @@ -496,7 +499,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: for valid_prompt_shape in valid_prompt_shapes: if valid_prompt_shape == custom_shape: enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -508,6 +511,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: custom_shape, input_ids, extra_kwargs, + sample_key, ) ] break @@ -568,7 +572,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: ) ) try: - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -580,6 +584,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape, input_ids, extra_kwargs, + sample_key, ) ) used_keys.add(program_seq_key[0]) @@ -611,7 +616,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts: +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -636,6 +641,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): tokenizer, seed=0, attn_type=ATTN_NAME, + sample_key=sample_key, ) # if the cpu validation info is not yet computed, compute it if cpu_validation_info is None: diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 122c9664..d257c44c 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -364,7 +364,13 @@ def __filter_before_eos(metrics, filter_indexes): def __load_validation_info( - model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + seed, + attn_type: str, ): # if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type' full_path = find_validation_info_path( diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index b02bd19c..19c11b07 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -99,6 +99,20 @@ def test_get_validation_info_path(tmp_path): == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.1.2.3.cpu_validation_info.0.out" ) + # Check that it is accepting kwargs and handling sample_key + dummy_sample_key = "dataset-sharegpt_num-requests-4_pad-multiple-64_prompt-length-max-128_prompt-length-min-32_tokenizer-ibm-granite--granite-3.3-8b-Instruct" + assert "sample_key" and "dataset" in get_validation_info_path( + tmp_path, + "ibm-granite/granite-3.3-8b-instruct", + 4, + 64, + 128, + 0, + "sdpa", + aftu_version=(1, 2, 3), + sample_key=dummy_sample_key, + ) + @pytest.mark.parametrize( "current_version,save_version,expected_version,version_allow_decrement", @@ -244,6 +258,8 @@ def test_decrement_version(max_minor, max_patch, current_version): + patch + 1 ) + + def test_format_kwargs_to_string(): kwargs = { "enforce_sizes": [1, 32, 4, 8], From 34587feed76f9c9ceac75cad41946c2b193e8604 Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Fri, 3 Oct 2025 17:56:49 +0000 Subject: [PATCH 04/20] [validation] Modified final file string to hash due to OSError name too long Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/validation.py | 11 ++++-- tests/testing/test_validation.py | 38 +++++++++------------ 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index f62c7def..072c9c0a 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -7,6 +7,8 @@ import os from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +import hashlib + class LogitsExtractorHook( Callable[ @@ -146,14 +148,17 @@ def get_default_validation_prefix( aftu_version (str): introduced in v0.3.0 to track changed in log Returns: - str: A prefix that will be prepended to the file name + str: A hashed prefix that will be prepended to the file name """ kwargs_str = format_kwargs_to_string(**kwargs) if kwargs_str == "": - return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}" + filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" else: - return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}.{aftu_version}" + filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}" + hash_object = hashlib.sha256(filename.encode("utf-8")) + hex_digest = hash_object.hexdigest() + return f"{hex_digest}_{aftu_version}" def load_validation_information( diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index 19c11b07..74f6403c 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -10,6 +10,7 @@ __decrement_version, get_default_validation_prefix, ) +import hashlib import os from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string from aiu_fms_testing_utils.utils import sample_sharegpt_requests @@ -79,12 +80,21 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): def test_get_validation_info_path(tmp_path): + check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, "ibm-granite/granite-3.3-8b-instruct", 4, 64, 128, 0, "sdpa" ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" + == f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" ) + + check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, @@ -96,21 +106,7 @@ def test_get_validation_info_path(tmp_path): "sdpa", aftu_version=(1, 2, 3), ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.1.2.3.cpu_validation_info.0.out" - ) - - # Check that it is accepting kwargs and handling sample_key - dummy_sample_key = "dataset-sharegpt_num-requests-4_pad-multiple-64_prompt-length-max-128_prompt-length-min-32_tokenizer-ibm-granite--granite-3.3-8b-Instruct" - assert "sample_key" and "dataset" in get_validation_info_path( - tmp_path, - "ibm-granite/granite-3.3-8b-instruct", - 4, - 64, - 128, - 0, - "sdpa", - aftu_version=(1, 2, 3), - sample_key=dummy_sample_key, + == f"{tmp_path}/{hex_digest}_1.2.3.cpu_validation_info.0.out" ) @@ -299,12 +295,12 @@ def test_get_default_validation_prefix( sample_key = None # get_default_validation_prefix with sample_key set to None + check_prefix_sample_key_none = f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" + hash_object = hashlib.sha256(check_prefix_sample_key_none.encode("utf-8")) + hex_digest = hash_object.hexdigest() prefix_sample_key_none = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" - assert ( - prefix_sample_key_none - == f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.1.2.3.cpu_validation_info.0.out" - ) + assert prefix_sample_key_none == f"{hex_digest}_1.2.3.cpu_validation_info.0.out" # get_default_validation_prefix with no kwargs using legacy case legacy_prefix = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" @@ -325,8 +321,6 @@ def test_get_default_validation_prefix( # Check sample key sorted by parameter name assert sample_key.split("_") == sorted(sample_key.split("_")) - # Check sample key included in name as expected - assert "sample-key-" + sample_key in prefix_with_sample_key dataset_2 = sample_sharegpt_requests( DATASET_PATH, From a4a89b91c9867457fbc5effedd422df04d4d2b16 Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:11:01 +0000 Subject: [PATCH 05/20] [test_validation] remove unused line Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- tests/testing/test_validation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index 74f6403c..220f89e9 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -317,7 +317,6 @@ def test_get_default_validation_prefix( enforce_sizes=[], return_key=True, ) - prefix_with_sample_key = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" # Check sample key sorted by parameter name assert sample_key.split("_") == sorted(sample_key.split("_")) From 525b00667b0b1278ea74762fd78775b3bbeae64b Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Fri, 3 Oct 2025 19:11:04 +0000 Subject: [PATCH 06/20] [validation] removed enforce_sizes from find_validation_info_path Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/validation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 072c9c0a..ad1b5906 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -472,7 +472,6 @@ def find_validation_info_path( """ Find the validation info path if it exists, otherwise return None """ - enforce_sizes = kwargs.get("enforce_sizes", None) sample_key = kwargs.get("sample_key", None) if aftu_version is None: @@ -494,7 +493,6 @@ def find_validation_info_path( loc_version_tuple, device_type, dtype, - enforce_sizes=enforce_sizes, sample_key=sample_key, ) # if the path is found, we are done searching and can return From b3d0f9b859aa289d6c2e2534ea802e9b1e0e13dd Mon Sep 17 00:00:00 2001 From: "Rashed Z. Bhatti, PhD" Date: Mon, 6 Oct 2025 19:17:25 +0000 Subject: [PATCH 07/20] paged head_size getattr(model.config, "head_dim", model.config.emb_dim // model.config.nheads ) Signed-off-by: Rashed Z. Bhatti, PhD --- aiu_fms_testing_utils/utils/paged.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index edf0c548..1d6bcbc7 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -153,7 +153,9 @@ def generate( raise ValueError("model must have a distributed_strategy") kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads - head_size = model.config.emb_dim // nheads + head_size = getattr( + model.config, "head_dim", model.config.emb_dim // model.config.nheads + ) if "fp8" in kwargs["attn_name"]: from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor From e56c686919ef9263f3a673d9af8c5ab892e9c709 Mon Sep 17 00:00:00 2001 From: "Rashed Z. Bhatti, PhD" Date: Mon, 6 Oct 2025 19:47:39 +0000 Subject: [PATCH 08/20] Revert "paged head_size getattr(model.config, "head_dim", model.config.emb_dim // model.config.nheads )" This reverts commit b3d0f9b859aa289d6c2e2534ea802e9b1e0e13dd. --- aiu_fms_testing_utils/utils/paged.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index 1d6bcbc7..edf0c548 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -153,9 +153,7 @@ def generate( raise ValueError("model must have a distributed_strategy") kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads - head_size = getattr( - model.config, "head_dim", model.config.emb_dim // model.config.nheads - ) + head_size = model.config.emb_dim // nheads if "fp8" in kwargs["attn_name"]: from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor From 1eac3664db32364c263ae39b50940fabf09a4182 Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Tue, 7 Oct 2025 01:18:06 +0000 Subject: [PATCH 09/20] [dpp] added handling of return_key for __custom_line_sampler Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- scripts/drive_paged_programs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/drive_paged_programs.py b/scripts/drive_paged_programs.py index e3cb7ca6..033a8efe 100644 --- a/scripts/drive_paged_programs.py +++ b/scripts/drive_paged_programs.py @@ -40,6 +40,7 @@ get_programs_prompts, KVCACHE_NUM_BLOCKS_HINT, ) +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string parser = argparse.ArgumentParser( description="Script which will drive paged programs for debugging" @@ -195,6 +196,10 @@ custom_shape = (len(result), max([_[1] for _ in result])) def __custom_line_sampler(*args, **kwargs): + return_key = kwargs.get("return_key", False) + sample_key = format_kwargs_to_string(**kwargs) + if return_key: + return result, sample_key return result sampler = __custom_line_sampler From 2b3028f5fd7a4075189770e39f1bae6918960729 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Tue, 7 Oct 2025 14:45:11 +0000 Subject: [PATCH 10/20] updated llama model expectation tests using v1.0.0 aiu software stack as modeling code changed Signed-off-by: Joshua Rosenkranz --- ...TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output | 2 +- ...IUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output index 098709e3..bfbcd6b1 100644 --- a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output +++ b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output @@ -1 +1 @@ -9.625,9.625,9.6875,9.625,10.53125,37.375,8.65625,14.90625,1.03125,5.875,15.6875,6.0625,9.5,17.5625,37.0,10.34375,6.25,13.125,3.8125,9.21875,21.96875,14.28125,0.0,13.09375,7.6875,6.4375,19.09375,10.6875,23.9375,13.0,11.84375,46.4375,6.59375,0.0,13.0,23.125,16.34375,3.125,12.65625,6.03125,14.375,6.84375,14.9375,20.9375,5.625,37.0,4.875,3.25,7.40625,2.6875,18.9375,4.1875,13.5,8.4375,21.1875,13.21875,35.25,21.78125,8.3125,4.75,12.0625,3.90625,9.34375,4.25 \ No newline at end of file +0.18359375,0.18359375,0.181640625,0.189453125,0.2734375,0.544921875,0.607421875,0.365234375,0.30078125,0.25,0.078125,0.302734375,0.0,0.322265625,0.142578125,0.099609375,0.296875,0.28125,0.673828125,0.44921875,0.13671875,0.42578125,1.072265625,0.18359375,0.388671875,0.177734375,0.193359375,0.296875,0.484375,0.3515625,0.826171875,0.349609375,0.296875,0.720703125,0.634765625,0.607421875,0.14453125,0.29296875,0.154296875,0.287109375,0.482421875,0.2421875,0.48046875,0.203125,0.349609375,0.21484375,0.28515625,0.17578125,0.162109375,0.3203125,0.3125,0.54296875,0.287109375,0.361328125,0.390625,0.08984375,0.2109375,0.5,0.18359375,0.228515625,0.314453125,0.291015625,0.248046875,0.5078125 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys index 6329cb98..3fcc470f 100644 --- a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys +++ b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys @@ -1 +1 @@ -dec_norm.weight,layers.0.attn.dense.weight,layers.0.attn.in_proj.key.weight,layers.0.attn.in_proj.query.weight,layers.0.attn.in_proj.value.weight,layers.0.ff_ln.weight,layers.0.ff_sub_layer.w1.weight,layers.0.ff_sub_layer.w2.weight,layers.0.ff_sub_layer.wg.weight,layers.0.ln.weight,layers.1.attn.dense.weight,layers.1.attn.in_proj.key.weight,layers.1.attn.in_proj.query.weight,layers.1.attn.in_proj.value.weight,layers.1.ff_ln.weight,layers.1.ff_sub_layer.w1.weight,layers.1.ff_sub_layer.w2.weight,layers.1.ff_sub_layer.wg.weight,layers.1.ln.weight,layers.2.attn.dense.weight,layers.2.attn.in_proj.key.weight,layers.2.attn.in_proj.query.weight,layers.2.attn.in_proj.value.weight,layers.2.ff_ln.weight,layers.2.ff_sub_layer.w1.weight,layers.2.ff_sub_layer.w2.weight,layers.2.ff_sub_layer.wg.weight,layers.2.ln.weight,shared.emb.weight,shared.head.weight \ No newline at end of file +base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.key.weight,base_model.layers.0.attn.in_proj.query.weight,base_model.layers.0.attn.in_proj.value.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w1.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.key.weight,base_model.layers.1.attn.in_proj.query.weight,base_model.layers.1.attn.in_proj.value.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w1.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg.weight,base_model.layers.1.ln.weight,base_model.layers.2.attn.dense.weight,base_model.layers.2.attn.in_proj.key.weight,base_model.layers.2.attn.in_proj.query.weight,base_model.layers.2.attn.in_proj.value.weight,base_model.layers.2.ff_ln.weight,base_model.layers.2.ff_sub_layer.w1.weight,base_model.layers.2.ff_sub_layer.w2.weight,base_model.layers.2.ff_sub_layer.wg.weight,base_model.layers.2.ln.weight,head.weight \ No newline at end of file From fcf950f7d6a82a75845ee8b836964f0561a28dee Mon Sep 17 00:00:00 2001 From: kcirred <16872435+kcirred@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:07:14 +0000 Subject: [PATCH 11/20] [testing] changed get_default_validation_prefix to generic kwargs, file names now sorted, testing of file names modified for new order Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com> --- aiu_fms_testing_utils/testing/utils.py | 4 +++- aiu_fms_testing_utils/testing/validation.py | 17 +++++------------ scripts/generate_layers_metrics.py | 6 +++++- scripts/generate_metrics.py | 10 +++++----- tests/testing/test_validation.py | 10 +++++----- 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/aiu_fms_testing_utils/testing/utils.py b/aiu_fms_testing_utils/testing/utils.py index 72fd30b2..cacff899 100644 --- a/aiu_fms_testing_utils/testing/utils.py +++ b/aiu_fms_testing_utils/testing/utils.py @@ -17,6 +17,8 @@ def format_kwargs_to_string(**kwargs): # only append if formatted_value exists if formatted_value: # Keep previous convention of variable names with `-` instead of `_` - formatted_pairs.append(f"{key.replace('_', '-')}-{formatted_value}") + formatted_pairs.append( + f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}" + ) return "_".join(formatted_pairs) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index ad1b5906..5bf120a0 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -128,13 +128,6 @@ def __len__(self): def get_default_validation_prefix( - model_id: str, - max_new_tokens: int, - batch_size: int, - seq_length: int, - dtype: str, - attn_type: str, - aftu_version: str, **kwargs, ): """ @@ -150,12 +143,12 @@ def get_default_validation_prefix( Returns: str: A hashed prefix that will be prepended to the file name """ + aftu_version = kwargs.pop( + "aftu_version", ".".join([str(_) for _ in version_tuple[:3]]) + ) kwargs_str = format_kwargs_to_string(**kwargs) - if kwargs_str == "": - filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" - else: - filename = f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}_{kwargs_str}" + filename = f"{kwargs_str}" hash_object = hashlib.sha256(filename.encode("utf-8")) hex_digest = hash_object.hexdigest() return f"{hex_digest}_{aftu_version}" @@ -435,7 +428,7 @@ def get_validation_info_path( sample_key = kwargs.get("sample_key", None) - validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" full_path = os.path.join(validation_info_dir, validation_file_name) return full_path diff --git a/scripts/generate_layers_metrics.py b/scripts/generate_layers_metrics.py index d3245123..ffc01930 100644 --- a/scripts/generate_layers_metrics.py +++ b/scripts/generate_layers_metrics.py @@ -473,7 +473,11 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens): cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output) prefix = get_default_validation_prefix( - model_path, max_new_token, batch_size, seq_length, "float16" + model_id=model_path, + max_new_tokens=max_new_token, + batch_size=batch_size, + seq_length=seq_length, + dtype="float16", ) layer_name = str(layer_key).replace("[", "").replace("]", "") diff --git a/scripts/generate_metrics.py b/scripts/generate_metrics.py index 8ec3f028..f65149fa 100644 --- a/scripts/generate_metrics.py +++ b/scripts/generate_metrics.py @@ -134,11 +134,11 @@ # this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing. prefix = get_default_validation_prefix( - args.variant, - args.max_new_tokens, - args.batch_size, - args.min_pad_length, - args.default_dtype, + model_id=args.variant, + max_new_tokens=args.max_new_tokens, + batch_size=args.batch_size, + seq_len=args.min_pad_length, + dtype=args.default_dtype, ) if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")): print("skipping metric generation as it has already been done") diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index 220f89e9..95f2ff4e 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -80,7 +80,7 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): def test_get_validation_info_path(tmp_path): - check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64" hash_object = hashlib.sha256(check_pathname.encode("utf-8")) hex_digest = hash_object.hexdigest() @@ -91,7 +91,7 @@ def test_get_validation_info_path(tmp_path): == f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" ) - check_pathname = "ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa" + check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64" hash_object = hashlib.sha256(check_pathname.encode("utf-8")) hex_digest = hash_object.hexdigest() @@ -295,15 +295,15 @@ def test_get_default_validation_prefix( sample_key = None # get_default_validation_prefix with sample_key set to None - check_prefix_sample_key_none = f"{model_variant}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}" + check_prefix_sample_key_none = f"attn-type-{attn_type}_batch-size-{batch_size}_dtype-{dtype}_max-new-tokens-{max_new_tokens}_model-id-{model_variant}_seq-length-{seq_length}" hash_object = hashlib.sha256(check_prefix_sample_key_none.encode("utf-8")) hex_digest = hash_object.hexdigest() - prefix_sample_key_none = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + prefix_sample_key_none = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" assert prefix_sample_key_none == f"{hex_digest}_1.2.3.cpu_validation_info.0.out" # get_default_validation_prefix with no kwargs using legacy case - legacy_prefix = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + legacy_prefix = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" assert prefix_sample_key_none == legacy_prefix # retrieve a sample_key with return_key is True From adc276e7b7b7387a08ce8c10a949ce2e352e5486 Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Wed, 8 Oct 2025 03:28:29 +0000 Subject: [PATCH 12/20] fixed test_scripts program assertion Signed-off-by: Joshua Rosenkranz --- tests/models/test_scripts.py | 45 ++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 79bd9952..21b47fb4 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -6,6 +6,7 @@ from pathlib import Path import itertools import math +from aiu_fms_testing_utils.utils.paged import get_programs_prompts, ProgramCriteria FMS_DIR = Path(__file__).parent AIU_FMS_DIR = os.path.join(FMS_DIR, "../../../aiu-fms-testing-utils/") @@ -291,28 +292,48 @@ def test_dpp_script( ) print(result_text) with open(os.environ["DT_PROG_CRITERIA_FILEPATH"], "r") as f: - program_criteria_list = json.load(f)["programs"] + program_criteria_json_list = json.load(f)["programs"] + program_criteria_list = [] + for i, d in enumerate(program_criteria_json_list): + program_criteria_list.append( + ProgramCriteria( + i, + d["max_batch"], + d["max_tkv"], + d["batch_granularity"], + d["tkv_granularity"], + ) + ) if programs is None: program_assertions = [i for i in range(len(program_criteria_list))] shape_assertions = [">=0", ">=0"] else: + program_map = get_programs_prompts( + program_criteria_list, + multiple=64, + max_batch_size=2, + max_tkv=512, + program_cycles=max_new_tokens, + ) programs_split = programs.split(":") program_ids_str = programs_split[0] shape_assertions = [ f">={_}" if _.isnumeric() else _ for _ in programs_split[1].split(",") ] - match_number = r"\d+" - valid_program_assertions = [ - f">={re.search(match_number, _).group()}" for _ in shape_assertions - ] - # need to add 1 for tkv as that is the first decode - program_assertions = [ - i - for i, p in enumerate(program_criteria_list) - if eval(f"p['max_batch']{valid_program_assertions[0]}") - and eval(f"p['max_tkv']{valid_program_assertions[1]}+1") - ] + + program_assertions = [] + for program_id_seq, shapes in program_map.items(): + if any( + ( + eval( + f"shape[0]{shape_assertions[0]} and shape[1]{shape_assertions[1]}" + ) + for shape in shapes + ) + ): + program_assertions.append(program_id_seq[0].program_id) + if program_ids_str == "?": program_assertions = program_assertions[:1] elif program_ids_str.isnumeric(): From 18bbf015538f8211920ec43f4237914e1ff2144f Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:25:47 +0000 Subject: [PATCH 13/20] Add cache test Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 149 +++++++++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 4f95e61e..5aa64aa5 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -7,6 +7,8 @@ import torch from torch import distributed as dist from torch.fx.experimental import _config as fx_config +from torch_sendnn.backends.sendnn_backend import _get_global_state +from torch_sendnn.utils.graph_cache import SpyreGraphCache from aiu_fms_testing_utils.testing.validation import ( extract_validation_information, @@ -29,6 +31,7 @@ from transformers import AutoTokenizer from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup +import shutil import os try: @@ -132,7 +135,7 @@ if USE_MICRO_MODELS: VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models") -# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" +# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" if isinstance(COMMON_MODEL_PATHS, str): COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",") @@ -593,6 +596,8 @@ def _get_device_validation_information( token_iter, ATTN_NAME, ) + if cpu_validation_info is not None: + return cpu_validation_info if cpu_validation_info is not None: return cpu_validation_info @@ -830,6 +835,7 @@ def _run_cpu_aiu_validation_test( aiu_model, micro_model_path, record_property, + verify_cache_state=None, ): # Get the tokenizer and AIU / CPU models to compare tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -866,6 +872,12 @@ def _run_cpu_aiu_validation_test( aiu_model, ) + # Used only for cache tests; this is a nonparametric closure that + # should assert the cache for torch sendnn is in the correct state + # for this test + if verify_cache_state is not None: + verify_cache_state() + # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: if failed_validation_level_0: @@ -888,6 +900,88 @@ def _run_cpu_aiu_validation_test( ) +def _get_cache_test_params(): + # NOTE - currently we always use granite 3.3 for the cache test, + # TODO make this configurable as tests are refactored + model_path = GRANITE_3p3_8B_INSTRUCT + batch_size = COMMON_BATCH_SIZES[0] + seq_length = COMMON_SEQ_LENGTHS[0] + max_new_tokens = COMMON_MAX_NEW_TOKENS[0] + return [model_path, batch_size, seq_length, max_new_tokens] + + +def _reset_cache_settings(purge_cache_dir): + os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" + os.environ["COMPILATION_MODE"] = "offline_decoder" + cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"] + + # Ensure we start in clean state + if purge_cache_dir and os.path.isdir(cache_dir): + shutil.rmtree(cache_dir) + os.mkdir(cache_dir) + + _get_global_state().use_aiu_cache = True + _get_global_state().spyre_graph_cache = SpyreGraphCache() + + +@pytest.fixture +def use_cached_model(persistent_model, record_property): + """Configures the torchsendnn cache and runs the AIU model prior to test execution; + this is computationally expensive and should only be used in situations like testing + cache hit correctness; + """ + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=True) + + model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_miss(): + cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache directory not populated on cache miss" + ) + + dprint( + f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" + ) + + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) + + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) + + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_miss, + ) + + @pytest.mark.parametrize( "model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES ) @@ -937,3 +1031,56 @@ def test_common_shapes( micro_model_path, record_property, ) + + +def test_cache(use_cached_model, persistent_model, record_property): + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=False) + + model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_hit(): + cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache miss occurred when hit was expected" + ) + + dprint( + f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" + ) + + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) + + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) + + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_hit, + ) From 42305bb7534736bbbf64021bc54ca453b3a6c425 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:41:34 +0000 Subject: [PATCH 14/20] use tmp_path fixture for cache test Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 5aa64aa5..575e5317 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -910,29 +910,35 @@ def _get_cache_test_params(): return [model_path, batch_size, seq_length, max_new_tokens] -def _reset_cache_settings(purge_cache_dir): +def _reset_cache_settings(purge_cache_dir, cache_dir=None): os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" os.environ["COMPILATION_MODE"] = "offline_decoder" - cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"] + if cache_dir is not None: + # Might be a posixpath + cache_dir = str(cache_dir) + os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir # Ensure we start in clean state if purge_cache_dir and os.path.isdir(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) + # NOTE: currently, the cache dir is pulled from + # TORCH_SENDNN_CACHE_DIR at initialization time, + # so this should correctly use the cache_dir _get_global_state().use_aiu_cache = True _get_global_state().spyre_graph_cache = SpyreGraphCache() @pytest.fixture -def use_cached_model(persistent_model, record_property): +def use_cached_model(persistent_model, record_property, tmp_path): """Configures the torchsendnn cache and runs the AIU model prior to test execution; this is computationally expensive and should only be used in situations like testing cache hit correctness; """ torch.manual_seed(42) torch.set_grad_enabled(False) - _reset_cache_settings(purge_cache_dir=True) + _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) @@ -1033,10 +1039,10 @@ def test_common_shapes( ) -def test_cache(use_cached_model, persistent_model, record_property): +def test_cache(use_cached_model, persistent_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) - _reset_cache_settings(purge_cache_dir=False) + _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) From fe8c61f5d865a925b768091ebe874add45a645ac Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:48:35 +0000 Subject: [PATCH 15/20] fix cache_dir in cache checks Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 575e5317..1a5b5f31 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -944,7 +944,7 @@ def use_cached_model(persistent_model, record_property, tmp_path): micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_miss(): - cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + cache_dir = str(tmp_path) updated_cache_len = ( len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 ) @@ -1048,7 +1048,7 @@ def test_cache(use_cached_model, persistent_model, record_property, tmp_path): micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): - cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + cache_dir = str(tmp_path) updated_cache_len = ( len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 ) From 630fe39abc500e72dfcb9b1afb8065e802cb0ffa Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:59:06 +0000 Subject: [PATCH 16/20] only warmup on cache tests Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 1a5b5f31..5075b888 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -836,6 +836,7 @@ def _run_cpu_aiu_validation_test( micro_model_path, record_property, verify_cache_state=None, + warmup_only=False, ): # Get the tokenizer and AIU / CPU models to compare tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -859,6 +860,16 @@ def _run_cpu_aiu_validation_test( aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs ) + # Used only for cache tests; this is a nonparametric closure that + # should assert the cache for torch sendnn is in the correct state + # for this test + if verify_cache_state is not None: + verify_cache_state() + + # For some tests, e.g., cache checks, we only need to run the warmup + if warmup_only: + return + # Run validation level 0 failed_validation_level_0, validation_zero_info = _run_validation_level_0( model_path, @@ -872,12 +883,6 @@ def _run_cpu_aiu_validation_test( aiu_model, ) - # Used only for cache tests; this is a nonparametric closure that - # should assert the cache for torch sendnn is in the correct state - # for this test - if verify_cache_state is not None: - verify_cache_state() - # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: if failed_validation_level_0: From fd1c20f581b9b0255b573884f31c37be8c1001ea Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 09:11:15 +0000 Subject: [PATCH 17/20] parametrize use_cache Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 5075b888..27178aa3 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -936,7 +936,7 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None): @pytest.fixture -def use_cached_model(persistent_model, record_property, tmp_path): +def use_cached_model(request, persistent_model, record_property, tmp_path): """Configures the torchsendnn cache and runs the AIU model prior to test execution; this is computationally expensive and should only be used in situations like testing cache hit correctness; @@ -990,7 +990,9 @@ def verify_cache_miss(): micro_model_path, record_property, verify_cache_state=verify_cache_miss, + warmup_only=True, ) + return request.param @pytest.mark.parametrize( @@ -1044,12 +1046,19 @@ def test_common_shapes( ) +@pytest.mark.parametrize( + "use_cached_model", + COMMON_SHAPES, + indirect=True, +) def test_cache(use_cached_model, persistent_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) - model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + # use_cached_model is an indirectly parametrized fixture, and the returned + # value is an expanded tuple from COMMON_SHAPES, so we unpack it here + model_path, batch_size, seq_length, max_new_tokens = use_cached_model micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): @@ -1094,4 +1103,5 @@ def verify_cache_hit(): micro_model_path, record_property, verify_cache_state=verify_cache_hit, + warmup_only=True, ) From 15665834a038a5633fafdf134e4c0f0b8aa3efd8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 09:33:24 +0000 Subject: [PATCH 18/20] use request param for setting up use_cache Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 27178aa3..ad90dcdb 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -905,16 +905,6 @@ def _run_cpu_aiu_validation_test( ) -def _get_cache_test_params(): - # NOTE - currently we always use granite 3.3 for the cache test, - # TODO make this configurable as tests are refactored - model_path = GRANITE_3p3_8B_INSTRUCT - batch_size = COMMON_BATCH_SIZES[0] - seq_length = COMMON_SEQ_LENGTHS[0] - max_new_tokens = COMMON_MAX_NEW_TOKENS[0] - return [model_path, batch_size, seq_length, max_new_tokens] - - def _reset_cache_settings(purge_cache_dir, cache_dir=None): os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" os.environ["COMPILATION_MODE"] = "offline_decoder" @@ -937,15 +927,15 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None): @pytest.fixture def use_cached_model(request, persistent_model, record_property, tmp_path): - """Configures the torchsendnn cache and runs the AIU model prior to test execution; - this is computationally expensive and should only be used in situations like testing - cache hit correctness; + """Configures the torchsendnn cache and runs the AIU model (warmup) + prior to test execution; this is computationally expensive and should + only be used in situations like testing cache hit correctness. """ torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) - model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + model_path, batch_size, seq_length, max_new_tokens = request.param micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_miss(): From c33072e696bc15f8804b2c1581da849fa85bbce6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 26 Oct 2025 11:57:24 +0000 Subject: [PATCH 19/20] reuse aiu/cpu models from cache miss fixture Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index ad90dcdb..a336836d 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -969,6 +969,8 @@ def verify_cache_miss(): **gptq_kwargs_cpu, **model_kwargs, ) + # We also return the models so that we can reuse them in the cache hit check + models = (model, validation_model) _run_cpu_aiu_validation_test( model_path, @@ -982,7 +984,7 @@ def verify_cache_miss(): verify_cache_state=verify_cache_miss, warmup_only=True, ) - return request.param + return request.param, models @pytest.mark.parametrize( @@ -1041,14 +1043,19 @@ def test_common_shapes( COMMON_SHAPES, indirect=True, ) -def test_cache(use_cached_model, persistent_model, record_property, tmp_path): +def test_cache(use_cached_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) # use_cached_model is an indirectly parametrized fixture, and the returned - # value is an expanded tuple from COMMON_SHAPES, so we unpack it here - model_path, batch_size, seq_length, max_new_tokens = use_cached_model + # value is an expanded tuple from COMMON_SHAPES, so we unpack it here. + # In addition, we also pass the model created on AIU in the fixture to + # avoid recreating it. + test_params, models = use_cached_model + model, validation_model = models + model_path, batch_size, seq_length, max_new_tokens = test_params + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): @@ -1064,25 +1071,6 @@ def verify_cache_hit(): f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" ) - # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured - gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) - is_gptq = len(gptq_kwargs_aiu) != 0 - is_fp8 = "fp8" in ATTN_NAME - model_kwargs = _get_common_model_kwargs(is_gptq, model_path) - - # Get the AIU model w/ the persistent model fixture - model = persistent_model.get_or_create( - is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs - ) - - validation_model = _get_cpu_model( - is_gptq, - is_fp8, - micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, - **gptq_kwargs_cpu, - **model_kwargs, - ) - _run_cpu_aiu_validation_test( model_path, batch_size, From b8181ac08f3ebe3a60beef0466bf68faf5e4571b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 2 Nov 2025 04:09:25 +0000 Subject: [PATCH 20/20] remove duplicate code Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index a336836d..dd20cfa7 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -599,9 +599,6 @@ def _get_device_validation_information( if cpu_validation_info is not None: return cpu_validation_info - if cpu_validation_info is not None: - return cpu_validation_info - # overrides for validation info that are device specific device_dependent_kwargs = {} if device == "cpu":