3535 GPTQ_ENABLED = False
3636
3737ORIGINAL_HF_HOME = os .environ .get ("HF_HOME" , None )
38+ MODELS_HOME = os .environ .get ("FMS_TEST_SHAPES_MODELS_HOME" , "/home/senuser/models" )
3839
3940# Add models to test here
4041LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
4142GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
4243GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k"
4344LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4445
46+ micro_model_mapping = {
47+ LLAMA_3p1_8B_INSTRUCT : os .path .join (MODELS_HOME , "llama-8b-layers-3-step-24000" ),
48+ }
49+
4550SHARE_GPT_DATASET_PATH = os .environ .get (
4651 "SHARE_GPT_DATASET_PATH" , os .path .expanduser ("~/share_gpt.json" )
4752)
4853USE_MICRO_MODELS = os .environ .get ("FMS_TEST_SHAPES_USE_MICRO_MODELS" , "1" ) == "1"
4954USE_DISTRIBUTED = os .environ .get ("FMS_TEST_SHAPES_DISTRIBUTED" , "0" ) == "1"
50- FORCE_VALIDATION_LEVEL_1 = os .environ .get ("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1" , "0" ) == "1"
55+ FORCE_VALIDATION_LEVEL_1 = (
56+ os .environ .get ("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1" , "0" ) == "1"
57+ )
5158skip_assertions = os .environ .get ("FMS_TEST_SHAPES_SKIP_ASSERTIONS" , {})
5259validation_info_dir = os .environ .get (
53- "FMS_TEST_SHAPES_VALIDATION_INFO_DIR" , "/tmp /models/validation_info"
60+ "FMS_TEST_SHAPES_VALIDATION_INFO_DIR" , "/home/senuser /models/validation_info"
5461)
5562common_model_paths = os .environ .get (
5663 "FMS_TEST_SHAPES_COMMON_MODEL_PATHS" ,
57- [LLAMA_3p1_8B_INSTRUCT , GRANITE_3p2_8B_INSTRUCT , GRANITE_20B_CODE_INSTRUCT_8K , LLAMA_3p1_70B_INSTRUCT ],
64+ [
65+ LLAMA_3p1_8B_INSTRUCT ,
66+ GRANITE_3p2_8B_INSTRUCT ,
67+ GRANITE_20B_CODE_INSTRUCT_8K ,
68+ LLAMA_3p1_70B_INSTRUCT ,
69+ ],
5870)
5971# for validation level 1, the default is a failure rate of 1%
6072# set this environment variable if you would like to relax that threshold
6173failure_rate_threshold = os .environ .get ("FMS_TEST_SHAPES_FAILURE_THRESHOLD" , 0.01 )
6274default_metrics_threshold = os .environ .get (
63- "FMS_TEST_SHAPES_METRICS_THRESHOLD" , (3.0 , .001 )
75+ "FMS_TEST_SHAPES_METRICS_THRESHOLD" , (3.0 , 0 .001 )
6476)
6577save_validation_info_outputs = (
6678 os .environ .get ("FMS_TEST_SHAPES_SAVE_VALIDATION_INFO_OUTPUTS" , "0" ) == "1"
8698
8799# pass custom default metrics threshold as a comma separated str of floats <cross-entropy threshold>,<mean diff threshold>
88100if isinstance (default_metrics_threshold , str ):
89- default_metrics_threshold = tuple ([float (m ) for m in default_metrics_threshold .split ("," )])
101+ default_metrics_threshold = tuple (
102+ [float (m ) for m in default_metrics_threshold .split ("," )]
103+ )
90104
91105# pass custom common batch sizes as a comma separated str of ints
92106if isinstance (common_batch_sizes , str ):
126140fail_thresholds = {
127141 (LLAMA_3p1_8B_INSTRUCT , True ): (
128142 3.7392955756187423 ,
129- .001 , # FIXME: compute
143+ 0 .001 , # FIXME: compute
130144 ),
131145 (GRANITE_3p2_8B_INSTRUCT , True ): (
132146 2.996668996810913 ,
133- .001 , # FIXME: compute
147+ 0 .001 , # FIXME: compute
134148 ),
135149 (GRANITE_20B_CODE_INSTRUCT_8K , True ): (
136- 3.7392955756187423 , # FIXME: compute -- setting to micro llama 3.1 8b instruct
137- .001 , # FIXME: compute
150+ 3.7392955756187423 , # FIXME: compute -- setting to micro llama 3.1 8b instruct
151+ 0 .001 , # FIXME: compute
138152 ),
139153 (LLAMA_3p1_70B_INSTRUCT , True ): (
140154 3.8235735702514626 ,
141- .001 , # FIXME: compute
155+ 0 .001 , # FIXME: compute
142156 ),
143157 (LLAMA_3p1_8B_INSTRUCT , False ): (
144158 2.6994638133048965 ,
@@ -316,7 +330,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
316330 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
317331
318332 if "HF_HOME" not in os .environ :
319- os .environ ["HF_HOME" ] = "/tmp /models/hf_cache"
333+ os .environ ["HF_HOME" ] = "/home/senuser /models/hf_cache"
320334
321335 dprint (
322336 f"testing model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
@@ -326,13 +340,18 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
326340 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
327341 is_gptq = len (gptq_kwargs_aiu ) != 0
328342
329- if USE_MICRO_MODELS :
343+ micro_model_path = micro_model_mapping .get (model_path , None )
344+ if USE_MICRO_MODELS and micro_model_path is None :
345+ dprint ("using randomly initialized model" )
330346 micro_model_kwargs = {"architecture" : "hf_configured" , "nlayers" : 3 }
331347 else :
348+ dprint ("using trained model" )
332349 micro_model_kwargs = {"architecture" : "hf_pretrained" }
333350
334351 if not USE_MICRO_MODELS and os .path .exists (model_path ):
335352 model_path_kwargs = {"model_path" : model_path }
353+ elif USE_MICRO_MODELS and micro_model_path is not None :
354+ model_path_kwargs = {"model_path" : micro_model_path }
336355 else :
337356 model_path_kwargs = {"variant" : model_path }
338357
@@ -428,7 +447,6 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
428447
429448 # if level 0 fails validation, validate level 1
430449 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
431-
432450 if failed_validation_level_0 :
433451 dprint ("failed validation level 0, testing validation level 1" )
434452 else :
@@ -439,10 +457,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
439457 cross_entropy = torch .nn .CrossEntropyLoss ()(
440458 r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
441459 )
442- diff = torch .mean (torch .abs (
443- r .softmax (dim = 1 ).to (dtype = torch .float32 )
444- - t .softmax (dim = 1 ).to (dtype = torch .float32 )
445- ))
460+ diff = torch .mean (
461+ torch .abs (
462+ r .softmax (dim = 1 ).to (dtype = torch .float32 )
463+ - t .softmax (dim = 1 ).to (dtype = torch .float32 )
464+ )
465+ )
446466 return (cross_entropy , diff )
447467
448468 iters = 1024 // max_new_tokens
0 commit comments