44from fms .utils .generation import pad_input_ids
55import itertools
66import torch
7+ from torch import distributed as dist
78from aiu_fms_testing_utils .testing .validation import extract_validation_information , LogitsExtractorHook , GoldenTokenHook , capture_level_1_metrics , filter_failed_level_1_cases , load_validation_information , validate_level_0 , top_k_loss_calculator
89from aiu_fms_testing_utils .utils import warmup_model , sample_sharegpt_requests , ids_for_prompt
9- from aiu_fms_testing_utils .utils .aiu_setup import dprint
10+ from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
1011import os
1112
1213ORIGINAL_HF_HOME = os .environ .get ("HF_HOME" , None )
1718
1819SHARE_GPT_DATASET_PATH = os .environ .get ("SHARE_GPT_DATASET_PATH" , os .path .expanduser ("~/share_gpt.json" ))
1920USE_MICRO_MODELS = os .environ .get ("FMS_TEST_SHAPES_USE_MICRO_MODELS" , "1" ) == "1"
21+ USE_DISTRIBUTED = os .environ .get ("FMS_TEST_SHAPES_DISTRIBUTED" , "0" ) == "1"
2022validation_info_dir = os .environ .get ("FMS_TEST_SHAPES_VALIDATION_INFO_DIR" , "/tmp/models/validation_info" )
2123common_model_paths = os .environ .get ("FMS_TEST_SHAPES_COMMON_MODEL_PATHS" , [LLAMA_3p1_8B_INSTRUCT , GRANITE_3p2_8B_INSTRUCT ])
2224# for validation level 1, the default is a failure rate of 1%
2830common_seq_lengths = os .environ .get ("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS" , [64 , 2048 ])
2931common_max_new_tokens = os .environ .get ("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS" , [128 ])
3032
33+ if USE_DISTRIBUTED :
34+ dist .init_process_group ()
35+ aiu_dist_setup (dist .get_rank (), dist .get_world_size ())
36+
3137if USE_MICRO_MODELS :
3238 validation_info_dir = os .path .join (validation_info_dir , "tiny_models" )
3339
@@ -119,6 +125,7 @@ def __load_validation_info(model_path, batch_size, seq_length, max_new_tokens, t
119125
120126@pytest .mark .parametrize ("model_path,batch_size,seq_length,max_new_tokens" , common_shapes )
121127def test_common_shapes (model_path , batch_size , seq_length , max_new_tokens ):
128+ torch .manual_seed (42 )
122129 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
123130
124131 if "HF_HOME" not in os .environ :
@@ -135,8 +142,13 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
135142 model_path_kwargs = {"model_path" : model_path }
136143 else :
137144 model_path_kwargs = {"variant" : model_path }
145+
146+ distributed_kwargs = {}
147+ if USE_DISTRIBUTED :
148+ distributed_kwargs ["distr_param" ] = "tp"
149+ distributed_kwargs ["group" ] = dist .group .WORLD
138150
139- get_model_kwargs = {** model_path_kwargs , ** micro_model_kwargs }
151+ get_model_kwargs = {** model_path_kwargs , ** micro_model_kwargs , ** distributed_kwargs }
140152
141153 tokenizer = tokenizers .get_tokenizer (model_path )
142154
0 commit comments