Skip to content

Commit 6c5e5ef

Browse files
committed
✨ Add compilation time test
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent bd1090e commit 6c5e5ef

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

tests/testing/test_compilation.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""This module contains test related to compilation operation"""
2+
3+
# Standard
4+
import itertools
5+
import os
6+
import pytest
7+
import time
8+
9+
# Third Party
10+
from torch import distributed as dist
11+
import torch
12+
13+
# First Party
14+
from fms.models import get_model
15+
from fms.utils import generation, tokenizers
16+
from fms.utils.generation import pad_input_ids
17+
18+
# Local
19+
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests, warmup_model
20+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
21+
22+
GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
23+
SHARE_GPT_DATASET_PATH = os.environ.get(
24+
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
25+
)
26+
27+
ATTN_NAME = "spyre_paged_attn"
28+
29+
compile_dynamic_sendnn = True
30+
31+
common_model_paths = [GRANITE_3p3_8B_INSTRUCT]
32+
common_batch_sizes = [1]
33+
common_seq_lengths = [256]
34+
common_shape_types = ["dynamic"]
35+
common_max_new_tokens = [128]
36+
common_expected_comp_time = [10] # In minutes
37+
38+
if compile_dynamic_sendnn:
39+
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
40+
(((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64
41+
)
42+
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2))
43+
44+
common_shapes = list(
45+
zip(
46+
common_model_paths,
47+
common_shape_types,
48+
common_batch_sizes,
49+
common_seq_lengths,
50+
common_max_new_tokens,
51+
common_expected_comp_time,
52+
)
53+
)
54+
55+
56+
# TODO: This is copied from test_decoders.py would be good to consolidate
57+
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
58+
prompts_and_sizes = sample_sharegpt_requests(
59+
SHARE_GPT_DATASET_PATH,
60+
batch_size,
61+
tokenizer,
62+
int(seq_length / 2),
63+
seq_length,
64+
seed,
65+
)
66+
prompt_list = []
67+
for prompt, _ in prompts_and_sizes:
68+
prompt_list.append(ids_for_prompt(prompt, tokenizer))
69+
70+
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
71+
return input_ids, extra_kwargs
72+
73+
74+
@pytest.fixture(autouse=True)
75+
def reset_compiler():
76+
yield # run the test
77+
if not compile_dynamic_sendnn:
78+
torch.compiler.reset()
79+
torch._dynamo.reset()
80+
os.environ.pop("COMPILATION_MODE", None)
81+
82+
83+
@pytest.mark.parametrize(
84+
"model_path,shape_type,batch_size,seq_length,max_new_tokens,expected_comp_time", common_shapes
85+
)
86+
def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_new_tokens, expected_comp_time):
87+
"""Test to validate time taken for model compilation."""
88+
torch.manual_seed(42)
89+
torch.set_default_dtype(torch.float16)
90+
os.environ["COMPILATION_MODE"] = "offline_decoder"
91+
92+
dprint(
93+
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}"
94+
)
95+
96+
if os.path.exists(model_path):
97+
model_path_kwargs = {"model_path": model_path}
98+
else:
99+
model_path_kwargs = {"variant": model_path}
100+
101+
tokenizer = tokenizers.get_tokenizer(model_path)
102+
103+
# prepare the AIU model
104+
model = get_model(
105+
architecture="hf_pretrained",
106+
device_type="cpu",
107+
data_type= torch.float16,
108+
fused_weights=False,
109+
**model_path_kwargs,
110+
)
111+
112+
model.eval()
113+
torch.set_grad_enabled(False)
114+
115+
# prepare input_ids
116+
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
117+
extra_kwargs["attn_name"] = ATTN_NAME
118+
119+
start_time = time.perf_counter()
120+
if shape_type == "dynamic":
121+
compile_dynamic_sendnn = True
122+
else:
123+
compile_dynamic_sendnn = False
124+
125+
model.compile(
126+
backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn}
127+
)
128+
warmup_model(
129+
model,
130+
input_ids,
131+
max_new_tokens,
132+
compile_dynamic_sendnn,
133+
use_cache=False,
134+
**extra_kwargs
135+
)
136+
end_time = time.perf_counter()
137+
138+
assert (end_time - start_time) < expected_comp_time * 60

0 commit comments

Comments
 (0)