Skip to content

Commit d3efe1b

Browse files
authored
Merge pull request #15 from foundation-model-stack/multi_aiu
Multi-AIU support to shape testing
2 parents 6bc631c + 27b67c2 commit d3efe1b

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

tests/models/test_decoders.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from fms.utils.generation import pad_input_ids
55
import itertools
66
import torch
7+
from torch import distributed as dist
78
from aiu_fms_testing_utils.testing.validation import extract_validation_information, LogitsExtractorHook, GoldenTokenHook, capture_level_1_metrics, filter_failed_level_1_cases, load_validation_information, validate_level_0, top_k_loss_calculator
89
from aiu_fms_testing_utils.utils import warmup_model, sample_sharegpt_requests, ids_for_prompt
9-
from aiu_fms_testing_utils.utils.aiu_setup import dprint
10+
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
1011
import os
1112

1213
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
@@ -17,6 +18,7 @@
1718

1819
SHARE_GPT_DATASET_PATH = os.environ.get("SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json"))
1920
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
21+
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
2022
validation_info_dir = os.environ.get("FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info")
2123
common_model_paths = os.environ.get("FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT])
2224
# for validation level 1, the default is a failure rate of 1%
@@ -28,6 +30,10 @@
2830
common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 2048])
2931
common_max_new_tokens = os.environ.get("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [128])
3032

33+
if USE_DISTRIBUTED:
34+
dist.init_process_group()
35+
aiu_dist_setup(dist.get_rank(), dist.get_world_size())
36+
3137
if USE_MICRO_MODELS:
3238
validation_info_dir = os.path.join(validation_info_dir, "tiny_models")
3339

@@ -119,6 +125,7 @@ def __load_validation_info(model_path, batch_size, seq_length, max_new_tokens, t
119125

120126
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_shapes)
121127
def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
128+
torch.manual_seed(42)
122129
os.environ["COMPILATION_MODE"] = "offline_decoder"
123130

124131
if "HF_HOME" not in os.environ:
@@ -135,8 +142,13 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
135142
model_path_kwargs = {"model_path": model_path}
136143
else:
137144
model_path_kwargs = {"variant": model_path}
145+
146+
distributed_kwargs = {}
147+
if USE_DISTRIBUTED:
148+
distributed_kwargs["distr_param"] = "tp"
149+
distributed_kwargs["group"] = dist.group.WORLD
138150

139-
get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs}
151+
get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs}
140152

141153
tokenizer = tokenizers.get_tokenizer(model_path)
142154

0 commit comments

Comments
 (0)