Skip to content

Commit b912264

Browse files
authored
Merge pull request #11 from foundation-model-stack/encoder_shape_testing
added a shapes test for encoders
2 parents b9068c6 + 642e49d commit b912264

File tree

6 files changed

+211
-30
lines changed

6 files changed

+211
-30
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,39 +46,26 @@ def __download_file(url, filename):
4646
except requests.exceptions.RequestException as e:
4747
print(f"An error occurred: {e}")
4848

49-
def sample_sharegpt_requests(
50-
dataset_path: str,
49+
def __sample_requests(
50+
prompt_list: List[str],
5151
num_requests: int,
5252
tokenizer: BaseTokenizer,
5353
prompt_length_min: int = 32,
5454
prompt_length_max: int = 64,
5555
seed: Optional[int] = None
56-
) -> List[Tuple[str, int]]:
57-
if not os.path.exists(dataset_path):
58-
print("downloading share-gpt dataset as it does not exist")
59-
__download_file("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", dataset_path)
60-
61-
# Load the dataset.
62-
with open(dataset_path, encoding='utf-8') as f:
63-
dataset = json.load(f)
64-
# Filter out the conversations with less than 2 turns.
65-
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
66-
# Only keep the first two turns of each conversation.
67-
dataset = [(data["conversations"][0]["value"],
68-
data["conversations"][1]["value"]) for data in dataset]
69-
56+
):
7057
# Shuffle the dataset.
7158
if seed is not None:
72-
random.Random(seed).shuffle(dataset)
59+
random.Random(seed).shuffle(prompt_list)
7360

7461
# Filter out sequences that are too long or too short
7562
filtered_dataset: List[Tuple[str, int, int]] = []
76-
for i in range(len(dataset)):
63+
for i in range(len(prompt_list)):
7764
if len(filtered_dataset) == num_requests:
7865
break
7966

8067
# Tokenize the prompts and completions.
81-
prompt = dataset[i][0]
68+
prompt = prompt_list[i]
8269
prompt_token_ids = ids_for_prompt(prompt, tokenizer)
8370

8471
prompt_len = len(prompt_token_ids)
@@ -87,4 +74,49 @@ def sample_sharegpt_requests(
8774
continue
8875
filtered_dataset.append((prompt, prompt_len))
8976

90-
return filtered_dataset
77+
return filtered_dataset
78+
79+
80+
81+
def sample_sharegpt_requests(
82+
dataset_path: str,
83+
num_requests: int,
84+
tokenizer: BaseTokenizer,
85+
prompt_length_min: int = 32,
86+
prompt_length_max: int = 64,
87+
seed: Optional[int] = None
88+
) -> List[Tuple[str, int]]:
89+
if not os.path.exists(dataset_path):
90+
print("downloading share-gpt dataset as it does not exist")
91+
__download_file("https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json", dataset_path)
92+
93+
# Load the dataset.
94+
with open(dataset_path, encoding='utf-8') as f:
95+
dataset = json.load(f)
96+
# Filter out the conversations with less than 2 turns.
97+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
98+
dataset = [data["conversations"][0]["value"] for data in dataset]
99+
100+
return __sample_requests(dataset, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
101+
102+
def sample_squad_v2_qa_requests(
103+
dataset_path: str,
104+
num_requests: int,
105+
tokenizer: BaseTokenizer,
106+
prompt_length_min: int = 32,
107+
prompt_length_max: int = 64,
108+
seed: Optional[int] = None
109+
) -> List[Tuple[str, int]]:
110+
from datasets import load_dataset
111+
112+
if os.path.exists(dataset_path):
113+
ds = load_dataset(dataset_path)['train']
114+
else:
115+
ds = load_dataset("rajpurkar/squad_v2", cache_dir=dataset_path)['train']
116+
117+
118+
ds = [f"{data['context']}\n{data['question']}" for data in ds]
119+
120+
return __sample_requests(ds, num_requests, tokenizer, prompt_length_min, prompt_length_max, seed)
121+
122+

tests/models/test_shapes.py renamed to tests/models/test_decoders.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1010
import os
1111

12-
if "HF_HOME" not in os.environ:
13-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
12+
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
1413

1514
# Add models to test here
1615
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
@@ -72,6 +71,11 @@ def reset_compiler():
7271
yield # run the test
7372
torch.compiler.reset()
7473
torch._dynamo.reset()
74+
os.environ.pop('COMPILATION_MODE', None)
75+
if ORIGINAL_HF_HOME is None:
76+
os.environ.pop('HF_HOME', None)
77+
else:
78+
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
7579

7680
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
7781
prompts_and_sizes = sample_sharegpt_requests(SHARE_GPT_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed)
@@ -113,9 +117,13 @@ def __load_validation_info(model_path, batch_size, seq_length, max_new_tokens, t
113117
else:
114118
return None
115119

116-
117120
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_shapes)
118121
def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
122+
os.environ["COMPILATION_MODE"] = "offline_decoder"
123+
124+
if "HF_HOME" not in os.environ:
125+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
126+
119127
dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}")
120128

121129
if USE_MICRO_MODELS:

tests/models/test_encoders.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from fms.testing.comparison import ModelSignatureParams, compare_model_signatures, get_signature
2+
from fms.utils import tokenizers
3+
import pytest
4+
from fms.models import get_model
5+
from fms.utils.generation import pad_input_ids
6+
import itertools
7+
import torch
8+
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_squad_v2_qa_requests
9+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
10+
import os
11+
12+
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
13+
14+
# Add models to test here
15+
ROBERTA_SQUAD_V2 = "deepset/roberta-base-squad2"
16+
17+
SQUAD_V2_DATASET_PATH = os.environ.get("SQUAD_V2_DATASET_PATH", os.path.expanduser("~/squad_v2"))
18+
common_model_paths = os.environ.get("FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [ROBERTA_SQUAD_V2])
19+
common_batch_sizes = os.environ.get("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1, 2, 4, 8])
20+
common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 512])
21+
22+
# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/roberta,/tmp/models/roberta-base-squad2"
23+
if isinstance(common_model_paths, str):
24+
common_model_paths = common_model_paths.split(",")
25+
26+
# pass custom common batch sizes as a comma separated str of ints
27+
if isinstance(common_batch_sizes, str):
28+
common_batch_sizes = [int(bs) for bs in common_batch_sizes.split(",")]
29+
30+
# pass custom common seq lengths as a comma separated str of ints
31+
if isinstance(common_seq_lengths, str):
32+
common_seq_lengths = [int(sl) for sl in common_seq_lengths.split(",")]
33+
34+
common_shapes = list(itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths))
35+
36+
37+
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
38+
prompts_and_sizes = sample_squad_v2_qa_requests(SQUAD_V2_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed)
39+
prompt_list = []
40+
for prompt, _ in prompts_and_sizes:
41+
prompt_list.append(ids_for_prompt(prompt, tokenizer))
42+
43+
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length, is_causal_mask=False)
44+
return input_ids, padding_kwargs
45+
46+
@pytest.fixture(autouse=True)
47+
def reset_compiler():
48+
yield # run the test
49+
torch.compiler.reset()
50+
torch._dynamo.reset()
51+
os.environ.pop('COMPILATION_MODE', None)
52+
if ORIGINAL_HF_HOME is None:
53+
os.environ.pop('HF_HOME', None)
54+
else:
55+
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
56+
57+
encoder_paths = ["deepset/roberta-base-squad2"]
58+
common_encoder_shapes = list(itertools.product(encoder_paths, common_batch_sizes, common_seq_lengths))
59+
60+
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_encoder_shapes)
61+
def test_common_shapes(model_path, batch_size, seq_length):
62+
os.environ["COMPILATION_MODE"] = "offline"
63+
64+
if "HF_HOME" not in os.environ:
65+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
66+
67+
dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}")
68+
69+
tokenizer = tokenizers.get_tokenizer(model_path)
70+
71+
if os.path.exists(model_path):
72+
model_path_kwargs = {"model_path": model_path}
73+
else:
74+
model_path_kwargs = {"variant": model_path}
75+
76+
# prepare the AIU model
77+
model = get_model(
78+
architecture="hf_pretrained",
79+
device_type="cpu",
80+
fused_weights=False,
81+
**model_path_kwargs
82+
)
83+
84+
model.eval()
85+
torch.set_grad_enabled(False)
86+
model.compile(backend="sendnn")
87+
88+
# prepare the cpu model
89+
validation_model = get_model(
90+
architecture="hf_pretrained",
91+
device_type="cpu",
92+
data_type=torch.float32,
93+
fused_weights=False,
94+
**model_path_kwargs
95+
)
96+
97+
# prepare input_ids
98+
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
99+
100+
# warmup model
101+
logits_getter_fn = lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1)
102+
aiu_msp = ModelSignatureParams(model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs)
103+
get_signature(aiu_msp.model, aiu_msp.params, aiu_msp.inp, aiu_msp.other_params, aiu_msp.logits_getter_fn)
104+
105+
cpu_msp = ModelSignatureParams(validation_model, ["x"], logits_getter_fn=logits_getter_fn, inp=input_ids, other_params=padding_kwargs)
106+
# FIXME: Compute GPU atol/rtol
107+
compare_model_signatures(cpu_msp, aiu_msp, atol=0.1, rtol=.05)

tests/models/test_model_expectations.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,71 @@
88
)
99
import os
1010

11+
if "HF_HOME" not in os.environ:
12+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
13+
1114
model_dir = os.environ.get("FMS_TESTING_MODEL_DIR", "/tmp/models")
1215
LLAMA_194M = f"{model_dir}/llama-194m"
1316
GRANITE_7B_BASE = f"{model_dir}/granite-7b-base"
1417
GRANITE_8B_CODE_BASE = f"{model_dir}/granite-8b-code-base"
1518
GRANITE_3_8B_CODE_BASE = f"{model_dir}/granite-3-8b-base"
1619

1720
models = [LLAMA_194M, GRANITE_7B_BASE, GRANITE_8B_CODE_BASE, GRANITE_3_8B_CODE_BASE]
21+
mini_models = {LLAMA_194M, GRANITE_7B_BASE, GRANITE_8B_CODE_BASE, GRANITE_3_8B_CODE_BASE}
1822

1923
class AIUModelFixtureMixin(ModelFixtureMixin):
2024

2125
@pytest.fixture(scope="class", autouse=True)
2226
def uninitialized_model(self, model_id):
27+
if model_id in mini_models:
28+
get_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
29+
else:
30+
get_model_kwargs = {"architecture": "hf_pretrained"}
31+
2332
aiu_model = get_model(
24-
"hf_configured",
25-
model_id,
33+
variant=model_id,
2634
device_type="cpu",
2735
unfuse_weights=True,
28-
nlayers=3
36+
**get_model_kwargs
2937
)
3038
torch.compile(aiu_model, backend="sendnn")
3139
return aiu_model
32-
40+
41+
class TestAIUModels(
42+
ModelConsistencyTestSuite,
43+
AIUModelFixtureMixin,
44+
):
45+
46+
# x is the main parameter for this model which is the input tensor
47+
_get_signature_params = ["x"]
48+
3349
@pytest.fixture(scope="class", autouse=True, params=models)
3450
def model_id(self, request):
3551
return request.param
3652

37-
class TestAIUModels(
53+
def test_model_unfused(self, model, signature):
54+
pytest.skip("All AIU models are already unfused")
55+
56+
57+
ROBERTA_SQUAD_v2 = "deepset/roberta-base-squad2"
58+
tuple_output_models = [ROBERTA_SQUAD_v2]
59+
60+
class TestAIUModelsTupleOutput(
3861
ModelConsistencyTestSuite,
3962
AIUModelFixtureMixin,
4063
):
41-
64+
4265
# x is the main parameter for this model which is the input tensor
4366
_get_signature_params = ["x"]
4467

68+
@pytest.fixture(scope="class", autouse=True, params=tuple_output_models)
69+
def model_id(self, request):
70+
return request.param
71+
72+
@staticmethod
73+
def _get_signature_logits_getter_fn(f_out) -> torch.Tensor:
74+
return torch.cat([f_out[0], f_out[1]], dim=-1)
75+
4576
def test_model_unfused(self, model, signature):
46-
pytest.skip("All AIU models are already unfused")
77+
pytest.skip("All AIU models are already unfused")
78+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9.834766387939453e-07,3.5762786865234375e-07,8.940696716308594e-07,6.258487701416016e-07,8.344650268554688e-07,1.1324882507324219e-06,6.556510925292969e-07,1.2516975402832031e-06,1.6391277313232422e-06,0.0,2.384185791015625e-07,1.1324882507324219e-06,4.172325134277344e-07,9.238719940185547e-07,4.76837158203125e-07,1.1622905731201172e-06,0.2104383111000061,0.2104375958442688,0.21043795347213745,0.21043813228607178,0.21043753623962402,0.21043819189071655,0.2104378342628479,0.21043813228607178,0.21043860912322998,0.21043741703033447,0.21043741703033447,0.21043819189071655,0.21043717861175537,0.21043848991394043,0.21043795347213745,0.21043837070465088
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
base_model.embedding.weight,base_model.enc_norm.bias,base_model.enc_norm.weight,base_model.layers.0.attn.dense.bias,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.qkv_fused.bias,base_model.layers.0.attn.in_proj.qkv_fused.weight,base_model.layers.0.ff_ln.bias,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w1.bias,base_model.layers.0.ff_sub_layer.w1.weight,base_model.layers.0.ff_sub_layer.w2.bias,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ln.bias,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.bias,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.qkv_fused.bias,base_model.layers.1.attn.in_proj.qkv_fused.weight,base_model.layers.1.ff_ln.bias,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w1.bias,base_model.layers.1.ff_sub_layer.w1.weight,base_model.layers.1.ff_sub_layer.w2.bias,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ln.bias,base_model.layers.1.ln.weight,base_model.layers.10.attn.dense.bias,base_model.layers.10.attn.dense.weight,base_model.layers.10.attn.in_proj.qkv_fused.bias,base_model.layers.10.attn.in_proj.qkv_fused.weight,base_model.layers.10.ff_ln.bias,base_model.layers.10.ff_ln.weight,base_model.layers.10.ff_sub_layer.w1.bias,base_model.layers.10.ff_sub_layer.w1.weight,base_model.layers.10.ff_sub_layer.w2.bias,base_model.layers.10.ff_sub_layer.w2.weight,base_model.layers.10.ln.bias,base_model.layers.10.ln.weight,base_model.layers.11.attn.dense.bias,base_model.layers.11.attn.dense.weight,base_model.layers.11.attn.in_proj.qkv_fused.bias,base_model.layers.11.attn.in_proj.qkv_fused.weight,base_model.layers.11.ff_ln.bias,base_model.layers.11.ff_ln.weight,base_model.layers.11.ff_sub_layer.w1.bias,base_model.layers.11.ff_sub_layer.w1.weight,base_model.layers.11.ff_sub_layer.w2.bias,base_model.layers.11.ff_sub_layer.w2.weight,base_model.layers.11.ln.bias,base_model.layers.11.ln.weight,base_model.layers.2.attn.dense.bias,base_model.layers.2.attn.dense.weight,base_model.layers.2.attn.in_proj.qkv_fused.bias,base_model.layers.2.attn.in_proj.qkv_fused.weight,base_model.layers.2.ff_ln.bias,base_model.layers.2.ff_ln.weight,base_model.layers.2.ff_sub_layer.w1.bias,base_model.layers.2.ff_sub_layer.w1.weight,base_model.layers.2.ff_sub_layer.w2.bias,base_model.layers.2.ff_sub_layer.w2.weight,base_model.layers.2.ln.bias,base_model.layers.2.ln.weight,base_model.layers.3.attn.dense.bias,base_model.layers.3.attn.dense.weight,base_model.layers.3.attn.in_proj.qkv_fused.bias,base_model.layers.3.attn.in_proj.qkv_fused.weight,base_model.layers.3.ff_ln.bias,base_model.layers.3.ff_ln.weight,base_model.layers.3.ff_sub_layer.w1.bias,base_model.layers.3.ff_sub_layer.w1.weight,base_model.layers.3.ff_sub_layer.w2.bias,base_model.layers.3.ff_sub_layer.w2.weight,base_model.layers.3.ln.bias,base_model.layers.3.ln.weight,base_model.layers.4.attn.dense.bias,base_model.layers.4.attn.dense.weight,base_model.layers.4.attn.in_proj.qkv_fused.bias,base_model.layers.4.attn.in_proj.qkv_fused.weight,base_model.layers.4.ff_ln.bias,base_model.layers.4.ff_ln.weight,base_model.layers.4.ff_sub_layer.w1.bias,base_model.layers.4.ff_sub_layer.w1.weight,base_model.layers.4.ff_sub_layer.w2.bias,base_model.layers.4.ff_sub_layer.w2.weight,base_model.layers.4.ln.bias,base_model.layers.4.ln.weight,base_model.layers.5.attn.dense.bias,base_model.layers.5.attn.dense.weight,base_model.layers.5.attn.in_proj.qkv_fused.bias,base_model.layers.5.attn.in_proj.qkv_fused.weight,base_model.layers.5.ff_ln.bias,base_model.layers.5.ff_ln.weight,base_model.layers.5.ff_sub_layer.w1.bias,base_model.layers.5.ff_sub_layer.w1.weight,base_model.layers.5.ff_sub_layer.w2.bias,base_model.layers.5.ff_sub_layer.w2.weight,base_model.layers.5.ln.bias,base_model.layers.5.ln.weight,base_model.layers.6.attn.dense.bias,base_model.layers.6.attn.dense.weight,base_model.layers.6.attn.in_proj.qkv_fused.bias,base_model.layers.6.attn.in_proj.qkv_fused.weight,base_model.layers.6.ff_ln.bias,base_model.layers.6.ff_ln.weight,base_model.layers.6.ff_sub_layer.w1.bias,base_model.layers.6.ff_sub_layer.w1.weight,base_model.layers.6.ff_sub_layer.w2.bias,base_model.layers.6.ff_sub_layer.w2.weight,base_model.layers.6.ln.bias,base_model.layers.6.ln.weight,base_model.layers.7.attn.dense.bias,base_model.layers.7.attn.dense.weight,base_model.layers.7.attn.in_proj.qkv_fused.bias,base_model.layers.7.attn.in_proj.qkv_fused.weight,base_model.layers.7.ff_ln.bias,base_model.layers.7.ff_ln.weight,base_model.layers.7.ff_sub_layer.w1.bias,base_model.layers.7.ff_sub_layer.w1.weight,base_model.layers.7.ff_sub_layer.w2.bias,base_model.layers.7.ff_sub_layer.w2.weight,base_model.layers.7.ln.bias,base_model.layers.7.ln.weight,base_model.layers.8.attn.dense.bias,base_model.layers.8.attn.dense.weight,base_model.layers.8.attn.in_proj.qkv_fused.bias,base_model.layers.8.attn.in_proj.qkv_fused.weight,base_model.layers.8.ff_ln.bias,base_model.layers.8.ff_ln.weight,base_model.layers.8.ff_sub_layer.w1.bias,base_model.layers.8.ff_sub_layer.w1.weight,base_model.layers.8.ff_sub_layer.w2.bias,base_model.layers.8.ff_sub_layer.w2.weight,base_model.layers.8.ln.bias,base_model.layers.8.ln.weight,base_model.layers.9.attn.dense.bias,base_model.layers.9.attn.dense.weight,base_model.layers.9.attn.in_proj.qkv_fused.bias,base_model.layers.9.attn.in_proj.qkv_fused.weight,base_model.layers.9.ff_ln.bias,base_model.layers.9.ff_ln.weight,base_model.layers.9.ff_sub_layer.w1.bias,base_model.layers.9.ff_sub_layer.w1.weight,base_model.layers.9.ff_sub_layer.w2.bias,base_model.layers.9.ff_sub_layer.w2.weight,base_model.layers.9.ln.bias,base_model.layers.9.ln.weight,base_model.position_embedding.weight,qa_head.bias,qa_head.weight

0 commit comments

Comments
 (0)