Skip to content

Commit dff0aa2

Browse files
authored
Merge pull request #119 from foundation-model-stack/log_version_modified
Include version in validation info outputs
2 parents aca31a9 + b5a5c31 commit dff0aa2

File tree

3 files changed

+279
-42
lines changed

3 files changed

+279
-42
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from aiu_fms_testing_utils.utils.aiu_setup import dprint
6+
from aiu_fms_testing_utils._version import version_tuple
67
import os
78

89

@@ -130,8 +131,22 @@ def get_default_validation_prefix(
130131
seq_length: int,
131132
dtype: str,
132133
attn_type: str,
134+
aftu_version: str,
133135
):
134-
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}"
136+
"""
137+
Args:
138+
model_id (str): model name used
139+
max_new_tokens (int): number of max new tokens to generate
140+
batch_size (int): batch size used
141+
seq_length (int):sequence length used
142+
dtype (str): data type
143+
attn_type (str): type of attention
144+
aftu_version (str): introduced in v0.3.0 to track changed in log
145+
146+
Returns:
147+
str: A prefix that will be prepended to the file name
148+
"""
149+
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}"
135150

136151

137152
def load_validation_information(
@@ -246,7 +261,7 @@ def extract_validation_information(
246261
**extra_kwargs,
247262
):
248263
attention_specific_kwargs = {}
249-
if "paged" in extra_kwargs["attn_name"]:
264+
if "paged" in extra_kwargs.get("attn_name", "sdpa"):
250265
from aiu_fms_testing_utils.utils.paged import generate
251266
else:
252267
# TODO: Add a unified generation dependent on attn_type
@@ -388,3 +403,87 @@ def print_failed_cases(failed_cases, aiu_tokens, validation_tokens, tokenizer):
388403
print(
389404
f"In sentence {sentence_index + 1}/{len(aiu_tokens)}, token {token_index}, AIU outputs {aiu_token} instead of {validation_token} -- AIU val={aiu_str} -- CPU val={validation_str}"
390405
)
406+
407+
408+
def get_validation_info_path(
409+
validation_info_dir: str,
410+
model_variant: str,
411+
batch_size: int,
412+
seq_length: int,
413+
max_new_tokens: int,
414+
seed: int,
415+
attn_type: str,
416+
aftu_version: Optional[Tuple[int, int, int]] = None,
417+
device_type: str = "cpu",
418+
dtype: str = "fp16",
419+
):
420+
if aftu_version is None:
421+
aftu_version = version_tuple
422+
423+
validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
424+
full_path = os.path.join(validation_info_dir, validation_file_name)
425+
return full_path
426+
427+
428+
def __decrement_version(version: Tuple[int, int, int]):
429+
"""
430+
Function designed to prevent triple nested for loop while decrementing version
431+
"""
432+
major, minor, patch = version
433+
if patch > 0:
434+
return (major, minor, patch - 1)
435+
elif minor > 0:
436+
return (major, minor - 1, 0)
437+
elif major > 0:
438+
return (major - 1, 0, 0)
439+
else:
440+
return None
441+
442+
443+
def find_validation_info_path(
444+
validation_info_dir: str,
445+
model_variant: str,
446+
batch_size: int,
447+
seq_length: int,
448+
max_new_tokens: int,
449+
seed: int,
450+
attn_type: str,
451+
aftu_version: Optional[Tuple[int, int, int]] = None,
452+
version_allow_decrement: bool = False,
453+
device_type: str = "cpu",
454+
dtype: str = "fp16",
455+
):
456+
"""
457+
Find the validation info path if it exists, otherwise return None
458+
"""
459+
460+
if aftu_version is None:
461+
loc_version_tuple = version_tuple[:3]
462+
else:
463+
loc_version_tuple = aftu_version
464+
465+
result_path: Optional[str] = None
466+
467+
while result_path is None and loc_version_tuple is not None:
468+
full_path = get_validation_info_path(
469+
validation_info_dir,
470+
model_variant,
471+
batch_size,
472+
seq_length,
473+
max_new_tokens,
474+
seed,
475+
attn_type,
476+
loc_version_tuple,
477+
device_type,
478+
dtype,
479+
)
480+
# if the path is found, we are done searching and can return
481+
if os.path.exists(full_path):
482+
result_path = full_path
483+
# if allow version decrements, decrement the version and continue
484+
elif version_allow_decrement:
485+
loc_version_tuple = __decrement_version(loc_version_tuple)
486+
# if path is not found and we are not allowing decrementing of version, finish with no result
487+
else:
488+
loc_version_tuple = None
489+
return result_path

tests/models/test_decoders.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from fms.models import get_model
55
from fms.utils.generation import pad_input_ids
66
import itertools
7-
import warnings
8-
import re
97
import torch
108
from torch import distributed as dist
119
from aiu_fms_testing_utils.testing.validation import (
@@ -14,10 +12,11 @@
1412
GoldenTokenHook,
1513
capture_level_1_metrics,
1614
filter_failed_level_1_cases,
17-
get_default_validation_prefix,
15+
get_validation_info_path,
1816
load_validation_information,
1917
validate_level_0,
2018
top_k_loss_calculator,
19+
find_validation_info_path,
2120
)
2221
from aiu_fms_testing_utils.utils import (
2322
warmup_model,
@@ -80,6 +79,8 @@
8079
}
8180
ATTN_NAME = attention_map[ATTN_TYPE]
8281

82+
CPU_DTYPE = "fp8" if "fp8" in ATTN_TYPE else "fp32"
83+
8384
FORCE_VALIDATION_LEVEL_1 = (
8485
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
8586
)
@@ -356,45 +357,26 @@ def __filter_before_eos(metrics, filter_indexes):
356357
return [item for sublist in filtered_results for item in sublist]
357358

358359

359-
def __get_validation_info_full_path(
360-
model_path,
361-
batch_size,
362-
seq_length,
363-
max_new_tokens,
364-
seed,
365-
attn_type: str,
366-
device_type="cpu",
367-
):
368-
validation_file_name = f"{get_default_validation_prefix(model_path, max_new_tokens, batch_size, seq_length, 'fp16', attn_type)}.{device_type}_validation_info.{seed}.out"
369-
full_path = os.path.join(validation_info_dir, validation_file_name)
370-
return full_path
371-
372-
373360
def __load_validation_info(
374361
model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str
375362
):
376363
# if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type'
377-
full_path = __get_validation_info_full_path(
378-
model_path, batch_size, seq_length, max_new_tokens, seed, attn_type
364+
full_path = find_validation_info_path(
365+
validation_info_dir,
366+
model_path,
367+
batch_size,
368+
seq_length,
369+
max_new_tokens,
370+
seed,
371+
attn_type,
372+
version_allow_decrement=True,
373+
dtype=CPU_DTYPE,
379374
)
380-
381-
if os.path.exists(full_path):
375+
if full_path is not None:
382376
dprint(f"cpu validation info found for seed={seed} -- loading it")
383377
return load_validation_information(full_path, "logits", batch_size, tokenizer)
384-
elif "paged" not in attn_type:
385-
# This regex applies to a very specific file name format
386-
modified_full_path = re.sub(r"_attn-type[^.]*", "", full_path)
387-
388-
if os.path.exists(modified_full_path):
389-
warnings.warn(
390-
f"All future paths should contain attn_type prefix information in path name, please modify {full_path=} to {modified_full_path=}",
391-
stacklevel=2,
392-
)
393-
dprint(f"cpu validation info found for seed={seed} -- loading it")
394-
return load_validation_information(
395-
modified_full_path, "logits", batch_size, tokenizer
396-
)
397-
return None
378+
else:
379+
return None
398380

399381

400382
class PersistentModel:
@@ -568,8 +550,15 @@ def test_common_shapes(
568550

569551
if save_validation_info_outputs:
570552
cpu_validation_info.save(
571-
__get_validation_info_full_path(
572-
model_path, batch_size, seq_length, max_new_tokens, 0, ATTN_NAME
553+
get_validation_info_path(
554+
validation_info_dir,
555+
model_path,
556+
batch_size,
557+
seq_length,
558+
max_new_tokens,
559+
0,
560+
ATTN_NAME,
561+
dtype=CPU_DTYPE,
573562
)
574563
)
575564
cpu_static_tokens = cpu_validation_info.get_info("tokens")
@@ -654,13 +643,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
654643
)
655644
if save_validation_info_outputs:
656645
cpu_validation_info.save(
657-
__get_validation_info_full_path(
646+
get_validation_info_path(
647+
validation_info_dir,
658648
model_path,
659649
batch_size,
660650
seq_length,
661651
max_new_tokens,
662652
i,
663653
ATTN_NAME,
654+
dtype=CPU_DTYPE,
664655
)
665656
)
666657
cpu_static_tokens = cpu_validation_info.get_info("tokens")
@@ -684,14 +675,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
684675
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
685676
if save_validation_info_outputs:
686677
aiu_validation_info.save(
687-
__get_validation_info_full_path(
678+
get_validation_info_path(
679+
validation_info_dir,
688680
model_path,
689681
batch_size,
690682
seq_length,
691683
max_new_tokens,
692684
i,
693685
ATTN_NAME,
694-
"aiu",
686+
device_type="aiu",
695687
)
696688
)
697689

0 commit comments

Comments
 (0)