Skip to content

Commit e7c18f8

Browse files
committed
🎨 Fix formatting
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent 6c5e5ef commit e7c18f8

File tree

2 files changed

+57
-32
lines changed

2 files changed

+57
-32
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ def warmup_model(
6767
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
6868

6969

70+
def get_env_to_int_list(env_var_name, default):
71+
"""Utility function to convert list of strings passed as given environment variable to
72+
list of integers
73+
"""
74+
env_var_string = os.environ.get(env_var_name, default=default)
75+
if not env_var_string:
76+
return []
77+
if isinstance(env_var_string, list):
78+
return env_var_string
79+
80+
return [int(v) for v in env_var_string.split(",") if not isinstance(v, int)]
81+
82+
7083
def ids_for_prompt(prompt, tokenizer):
7184
tokens = tokenizer.tokenize(prompt)
7285
ids = tokenizer.convert_tokens_to_ids(tokens)

tests/testing/test_compilation.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
"""This module contains test related to compilation operation"""
22

33
# Standard
4-
import itertools
54
import os
65
import pytest
76
import time
87

98
# Third Party
10-
from torch import distributed as dist
119
import torch
1210

1311
# First Party
1412
from fms.models import get_model
15-
from fms.utils import generation, tokenizers
13+
from fms.utils import tokenizers
1614
from fms.utils.generation import pad_input_ids
1715

1816
# Local
19-
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests, warmup_model
17+
from aiu_fms_testing_utils.utils import (
18+
ids_for_prompt,
19+
get_env_to_int_list,
20+
sample_sharegpt_requests,
21+
warmup_model,
22+
)
2023
from aiu_fms_testing_utils.utils.aiu_setup import dprint
2124

2225
GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
@@ -26,25 +29,39 @@
2629

2730
ATTN_NAME = "spyre_paged_attn"
2831

29-
compile_dynamic_sendnn = True
32+
COMPILE_DYNAMIC_SHAPE = True
33+
34+
35+
common_model_paths = get_env_to_int_list("COMMON_MODEL_NAME", [GRANITE_3p3_8B_INSTRUCT])
36+
common_batch_sizes = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1])
37+
common_seq_lengths = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64])
38+
common_max_new_tokens = get_env_to_int_list(
39+
"FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [64]
40+
)
41+
common_expected_comp_time = get_env_to_int_list(
42+
"COMMON_COMPILATION_EXPECTED_TIME", [10]
43+
) # In minutes
3044

31-
common_model_paths = [GRANITE_3p3_8B_INSTRUCT]
32-
common_batch_sizes = [1]
33-
common_seq_lengths = [256]
34-
common_shape_types = ["dynamic"]
35-
common_max_new_tokens = [128]
36-
common_expected_comp_time = [10] # In minutes
45+
COMMON_SHAPE_TYPE = "dynamic"
3746

38-
if compile_dynamic_sendnn:
47+
48+
if COMPILE_DYNAMIC_SHAPE:
49+
import bisect
50+
51+
# the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
52+
# this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
53+
__largest_context = max(common_seq_lengths) + max(common_max_new_tokens)
54+
__supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192]
3955
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
40-
(((max(common_seq_lengths) + max(common_max_new_tokens)) // 64) + 1) * 64
56+
__supported_context_lengths[
57+
bisect.bisect_left(__supported_context_lengths, __largest_context)
58+
]
4159
)
4260
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2))
4361

44-
common_shapes = list(
62+
COMMON_SHAPES = list(
4563
zip(
4664
common_model_paths,
47-
common_shape_types,
4865
common_batch_sizes,
4966
common_seq_lengths,
5067
common_max_new_tokens,
@@ -59,7 +76,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
5976
SHARE_GPT_DATASET_PATH,
6077
batch_size,
6178
tokenizer,
62-
int(seq_length / 2),
79+
seq_length // 2,
6380
seq_length,
6481
seed,
6582
)
@@ -74,16 +91,18 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
7491
@pytest.fixture(autouse=True)
7592
def reset_compiler():
7693
yield # run the test
77-
if not compile_dynamic_sendnn:
94+
if not COMPILE_DYNAMIC_SHAPE:
7895
torch.compiler.reset()
7996
torch._dynamo.reset()
8097
os.environ.pop("COMPILATION_MODE", None)
8198

8299

83100
@pytest.mark.parametrize(
84-
"model_path,shape_type,batch_size,seq_length,max_new_tokens,expected_comp_time", common_shapes
101+
"model_path,batch_size,seq_length,max_new_tokens,expected_comp_time", COMMON_SHAPES
85102
)
86-
def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_new_tokens, expected_comp_time):
103+
def test_compilation_time(
104+
model_path, batch_size, seq_length, max_new_tokens, expected_comp_time
105+
):
87106
"""Test to validate time taken for model compilation."""
88107
torch.manual_seed(42)
89108
torch.set_default_dtype(torch.float16)
@@ -104,7 +123,7 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne
104123
model = get_model(
105124
architecture="hf_pretrained",
106125
device_type="cpu",
107-
data_type= torch.float16,
126+
data_type=torch.float16,
108127
fused_weights=False,
109128
**model_path_kwargs,
110129
)
@@ -117,21 +136,14 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne
117136
extra_kwargs["attn_name"] = ATTN_NAME
118137

119138
start_time = time.perf_counter()
120-
if shape_type == "dynamic":
121-
compile_dynamic_sendnn = True
139+
if COMMON_SHAPE_TYPE == "dynamic":
140+
COMPILE_DYNAMIC_SHAPE = True
122141
else:
123-
compile_dynamic_sendnn = False
142+
COMPILE_DYNAMIC_SHAPE = False
124143

125-
model.compile(
126-
backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn}
127-
)
144+
model.compile(backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SHAPE})
128145
warmup_model(
129-
model,
130-
input_ids,
131-
max_new_tokens,
132-
compile_dynamic_sendnn,
133-
use_cache=False,
134-
**extra_kwargs
146+
model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SHAPE, **extra_kwargs
135147
)
136148
end_time = time.perf_counter()
137149

0 commit comments

Comments
 (0)