44from fms .models import get_model
55from fms .utils .generation import pad_input_ids
66import itertools
7- import warnings
8- import re
97import torch
108from torch import distributed as dist
119from aiu_fms_testing_utils .testing .validation import (
1412 GoldenTokenHook ,
1513 capture_level_1_metrics ,
1614 filter_failed_level_1_cases ,
17- get_default_validation_prefix ,
15+ get_validation_info_path ,
1816 load_validation_information ,
1917 validate_level_0 ,
2018 top_k_loss_calculator ,
19+ find_validation_info_path ,
2120)
2221from aiu_fms_testing_utils .utils import (
2322 warmup_model ,
8079}
8180ATTN_NAME = attention_map [ATTN_TYPE ]
8281
82+ CPU_DTYPE = "fp8" if "fp8" in ATTN_TYPE else "fp32"
83+
8384FORCE_VALIDATION_LEVEL_1 = (
8485 os .environ .get ("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1" , "0" ) == "1"
8586)
@@ -356,45 +357,26 @@ def __filter_before_eos(metrics, filter_indexes):
356357 return [item for sublist in filtered_results for item in sublist ]
357358
358359
359- def __get_validation_info_full_path (
360- model_path ,
361- batch_size ,
362- seq_length ,
363- max_new_tokens ,
364- seed ,
365- attn_type : str ,
366- device_type = "cpu" ,
367- ):
368- validation_file_name = f"{ get_default_validation_prefix (model_path , max_new_tokens , batch_size , seq_length , 'fp16' , attn_type )} .{ device_type } _validation_info.{ seed } .out"
369- full_path = os .path .join (validation_info_dir , validation_file_name )
370- return full_path
371-
372-
373360def __load_validation_info (
374361 model_path , batch_size , seq_length , max_new_tokens , tokenizer , seed , attn_type : str
375362):
376363 # if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type'
377- full_path = __get_validation_info_full_path (
378- model_path , batch_size , seq_length , max_new_tokens , seed , attn_type
364+ full_path = find_validation_info_path (
365+ validation_info_dir ,
366+ model_path ,
367+ batch_size ,
368+ seq_length ,
369+ max_new_tokens ,
370+ seed ,
371+ attn_type ,
372+ version_allow_decrement = True ,
373+ dtype = CPU_DTYPE ,
379374 )
380-
381- if os .path .exists (full_path ):
375+ if full_path is not None :
382376 dprint (f"cpu validation info found for seed={ seed } -- loading it" )
383377 return load_validation_information (full_path , "logits" , batch_size , tokenizer )
384- elif "paged" not in attn_type :
385- # This regex applies to a very specific file name format
386- modified_full_path = re .sub (r"_attn-type[^.]*" , "" , full_path )
387-
388- if os .path .exists (modified_full_path ):
389- warnings .warn (
390- f"All future paths should contain attn_type prefix information in path name, please modify { full_path = } to { modified_full_path = } " ,
391- stacklevel = 2 ,
392- )
393- dprint (f"cpu validation info found for seed={ seed } -- loading it" )
394- return load_validation_information (
395- modified_full_path , "logits" , batch_size , tokenizer
396- )
397- return None
378+ else :
379+ return None
398380
399381
400382class PersistentModel :
@@ -568,8 +550,15 @@ def test_common_shapes(
568550
569551 if save_validation_info_outputs :
570552 cpu_validation_info .save (
571- __get_validation_info_full_path (
572- model_path , batch_size , seq_length , max_new_tokens , 0 , ATTN_NAME
553+ get_validation_info_path (
554+ validation_info_dir ,
555+ model_path ,
556+ batch_size ,
557+ seq_length ,
558+ max_new_tokens ,
559+ 0 ,
560+ ATTN_NAME ,
561+ dtype = CPU_DTYPE ,
573562 )
574563 )
575564 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
@@ -654,13 +643,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
654643 )
655644 if save_validation_info_outputs :
656645 cpu_validation_info .save (
657- __get_validation_info_full_path (
646+ get_validation_info_path (
647+ validation_info_dir ,
658648 model_path ,
659649 batch_size ,
660650 seq_length ,
661651 max_new_tokens ,
662652 i ,
663653 ATTN_NAME ,
654+ dtype = CPU_DTYPE ,
664655 )
665656 )
666657 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
@@ -684,14 +675,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
684675 dprint (f"aiu validation info extracted for validation level 1 - iter={ i } " )
685676 if save_validation_info_outputs :
686677 aiu_validation_info .save (
687- __get_validation_info_full_path (
678+ get_validation_info_path (
679+ validation_info_dir ,
688680 model_path ,
689681 batch_size ,
690682 seq_length ,
691683 max_new_tokens ,
692684 i ,
693685 ATTN_NAME ,
694- "aiu" ,
686+ device_type = "aiu" ,
695687 )
696688 )
697689
0 commit comments