Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
051c18f
[dpp] store enforce_sizes in log name and added generic kwargs to get…
kcirred Sep 26, 2025
f3613ba
[utils] added doc string, refactor sample_key, added return_key to sq…
kcirred Sep 29, 2025
2f69fa3
[dpp/validation] restore sample_key in logic after rebase of main
kcirred Sep 30, 2025
34587fe
[validation] Modified final file string to hash due to OSError name t…
kcirred Oct 3, 2025
a4a89b9
[test_validation] remove unused line
kcirred Oct 3, 2025
525b006
[validation] removed enforce_sizes from find_validation_info_path
kcirred Oct 3, 2025
b3d0f9b
paged head_size getattr(model.config, "head_dim", model.config.emb_di…
Oct 6, 2025
e56c686
Revert "paged head_size getattr(model.config, "head_dim", model.confi…
Oct 6, 2025
1eac366
[dpp] added handling of return_key for __custom_line_sampler
kcirred Oct 7, 2025
992e612
Merge pull request #136 from kcirred/enforce_log
JRosenkranz Oct 7, 2025
2b3028f
updated llama model expectation tests using v1.0.0 aiu software stack…
JRosenkranz Oct 7, 2025
abe35d3
Merge pull request #150 from foundation-model-stack/update_llama_mode…
ani300 Oct 7, 2025
fcf950f
[testing] changed get_default_validation_prefix to generic kwargs, fi…
kcirred Sep 30, 2025
99e6bd1
Merge pull request #148 from kcirred/prefix_rewrite
JRosenkranz Oct 8, 2025
adc276e
fixed test_scripts program assertion
JRosenkranz Oct 8, 2025
f6c9a8b
Merge pull request #151 from foundation-model-stack/fix_test_scripts_…
JRosenkranz Oct 8, 2025
281ff22
Merge pull request #93 from alex-jw-brooks/test_cache_refactor
JRosenkranz Oct 9, 2025
18bbf01
Add cache test
alex-jw-brooks Oct 13, 2025
42305bb
use tmp_path fixture for cache test
alex-jw-brooks Oct 13, 2025
fe8c61f
fix cache_dir in cache checks
alex-jw-brooks Oct 13, 2025
630fe39
only warmup on cache tests
alex-jw-brooks Oct 13, 2025
fd1c20f
parametrize use_cache
alex-jw-brooks Oct 13, 2025
1566583
use request param for setting up use_cache
alex-jw-brooks Oct 13, 2025
c33072e
reuse aiu/cpu models from cache miss fixture
alex-jw-brooks Oct 26, 2025
b8181ac
remove duplicate code
alex-jw-brooks Nov 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions aiu_fms_testing_utils/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import Iterable


def format_kwargs_to_string(**kwargs):
"""
Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,`
"""
formatted_pairs = []
for key, value in sorted(kwargs.items()):
formatted_value = None
if isinstance(value, str):
formatted_value = value
elif isinstance(value, Iterable):
formatted_value = ",".join(map(str, value))
elif value:
formatted_value = str(value)
# only append if formatted_value exists
if formatted_value:
# Keep previous convention of variable names with `-` instead of `_`
formatted_pairs.append(
f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}"
)

return "_".join(formatted_pairs)
31 changes: 21 additions & 10 deletions aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from aiu_fms_testing_utils._version import version_tuple
import os
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

import hashlib


class LogitsExtractorHook(
Expand Down Expand Up @@ -125,13 +128,7 @@ def __len__(self):


def get_default_validation_prefix(
model_id: str,
max_new_tokens: int,
batch_size: int,
seq_length: int,
dtype: str,
attn_type: str,
aftu_version: str,
**kwargs,
):
"""
Args:
Expand All @@ -144,9 +141,17 @@ def get_default_validation_prefix(
aftu_version (str): introduced in v0.3.0 to track changed in log

Returns:
str: A prefix that will be prepended to the file name
str: A hashed prefix that will be prepended to the file name
"""
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}"
aftu_version = kwargs.pop(
"aftu_version", ".".join([str(_) for _ in version_tuple[:3]])
)
kwargs_str = format_kwargs_to_string(**kwargs)

filename = f"{kwargs_str}"
hash_object = hashlib.sha256(filename.encode("utf-8"))
hex_digest = hash_object.hexdigest()
return f"{hex_digest}_{aftu_version}"


def load_validation_information(
Expand Down Expand Up @@ -416,11 +421,14 @@ def get_validation_info_path(
aftu_version: Optional[Tuple[int, int, int]] = None,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
if aftu_version is None:
aftu_version = version_tuple

validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
sample_key = kwargs.get("sample_key", None)

validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
full_path = os.path.join(validation_info_dir, validation_file_name)
return full_path

Expand Down Expand Up @@ -452,10 +460,12 @@ def find_validation_info_path(
version_allow_decrement: bool = False,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
"""
Find the validation info path if it exists, otherwise return None
"""
sample_key = kwargs.get("sample_key", None)

if aftu_version is None:
loc_version_tuple = version_tuple[:3]
Expand All @@ -476,6 +486,7 @@ def find_validation_info_path(
loc_version_tuple,
device_type,
dtype,
sample_key=sample_key,
)
# if the path is found, we are done searching and can return
if os.path.exists(full_path):
Expand Down
62 changes: 59 additions & 3 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

from fms.utils.generation import pad_input_ids
import torch
Expand Down Expand Up @@ -482,6 +483,7 @@ def sample_rag_factoid_requests(
enforce_sizes: List[int] = [],
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("error dataset does not exist")
Expand All @@ -492,7 +494,7 @@ def sample_rag_factoid_requests(
for line in f:
dataset.append(line)

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -506,6 +508,24 @@ def sample_rag_factoid_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="rag_factoid",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)

return sample_request, sample_key
else:
return sample_request


def sample_sharegpt_requests(
dataset_path: str,
Expand All @@ -518,6 +538,7 @@ def sample_sharegpt_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("downloading share-gpt dataset as it does not exist")
Expand All @@ -543,7 +564,7 @@ def sample_sharegpt_requests(
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset: List[str] = [data["conversations"][0]["value"] for data in dataset]

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -557,6 +578,23 @@ def sample_sharegpt_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="sharegpt",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def sample_squad_v2_qa_requests(
dataset_path: str,
Expand All @@ -569,6 +607,7 @@ def sample_squad_v2_qa_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
from datasets import load_dataset

Expand All @@ -582,7 +621,7 @@ def sample_squad_v2_qa_requests(

ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(
sample_request = __sample_requests(
ds,
num_requests,
tokenizer,
Expand All @@ -595,6 +634,23 @@ def sample_squad_v2_qa_requests(
pad_multiple,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="squad_v2",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def prepare_inputs(
batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"
Expand Down
26 changes: 20 additions & 6 deletions scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_programs_prompts,
KVCACHE_NUM_BLOCKS_HINT,
)
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

parser = argparse.ArgumentParser(
description="Script which will drive paged programs for debugging"
Expand Down Expand Up @@ -195,6 +196,10 @@
custom_shape = (len(result), max([_[1] for _ in result]))

def __custom_line_sampler(*args, **kwargs):
return_key = kwargs.get("return_key", False)
sample_key = format_kwargs_to_string(**kwargs)
if return_key:
return result, sample_key
return result

sampler = __custom_line_sampler
Expand Down Expand Up @@ -245,7 +250,7 @@ def __custom_line_sampler(*args, **kwargs):

def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
start = time.time()
prompts_and_sizes = sampler(
prompts_and_sizes, sample_key = sampler(
DATASET_PATH,
batch_size,
tokenizer,
Expand All @@ -254,6 +259,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
seed,
enforce_sizes=enforce_sizes,
truncation=allow_truncation,
return_key=True,
)
end = time.time()
if local_rank == 0:
Expand All @@ -274,7 +280,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0

input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
return input_ids, extra_kwargs
return input_ids, extra_kwargs, sample_key


def __maybe_prepare_fp8_weights(model_in, is_fp8):
Expand All @@ -296,7 +302,9 @@ def __load_validation_info(
tokenizer,
seed,
attn_type: str,
**kwargs,
):
sample_key = kwargs.get("sample_key", None)
full_path = find_validation_info_path(
args.validation_info_outputs_dir,
model_variant,
Expand All @@ -307,6 +315,7 @@ def __load_validation_info(
attn_type,
version_allow_decrement=True,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
if full_path is not None:
dprint(f"cpu validation info found for seed={seed} -- loading it")
Expand Down Expand Up @@ -367,13 +376,14 @@ def __load_validation_info(

# warmup with any input so compiler produces criteria json
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer)
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
if is_fp8:
prompt_list = prompt_list * 2
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)

extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand Down Expand Up @@ -494,7 +504,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -506,6 +516,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
custom_shape,
input_ids,
extra_kwargs,
sample_key,
)
]
break
Expand Down Expand Up @@ -566,7 +577,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
)
)
try:
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -578,6 +589,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
valid_prompt_shape,
input_ids,
extra_kwargs,
sample_key,
)
)
used_keys.add(program_seq_key[0])
Expand Down Expand Up @@ -609,7 +621,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):

failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts:
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand All @@ -634,6 +646,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
tokenizer,
seed=0,
attn_type=ATTN_NAME,
sample_key=sample_key,
)
# if the cpu validation info is not yet computed, compute it
if cpu_validation_info is None:
Expand All @@ -657,6 +670,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
0,
ATTN_NAME,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
)

Expand Down
6 changes: 5 additions & 1 deletion scripts/generate_layers_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,11 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output)

prefix = get_default_validation_prefix(
model_path, max_new_token, batch_size, seq_length, "float16"
model_id=model_path,
max_new_tokens=max_new_token,
batch_size=batch_size,
seq_length=seq_length,
dtype="float16",
)
layer_name = str(layer_key).replace("[", "").replace("]", "")

Expand Down
Loading