Skip to content

Commit 8b67444

Browse files
committed
🎨 Fix formatting
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent a3ed81c commit 8b67444

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def warmup_model(
6767
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
6868

6969

70-
7170
def get_env_to_int_list(env_var_name, default):
7271
"""Utility function to convert list of strings passed as given environment variable to
7372
list of integers
@@ -78,7 +77,7 @@ def get_env_to_int_list(env_var_name, default):
7877
if isinstance(env_var_string, list):
7978
return env_var_string
8079

81-
return [int(v) for v in env_var_string.split(",") if not isinstance(v, int)]
80+
return [int(v) for v in env_var_string.split(",") if not isinstance(v, int)]
8281

8382

8483
def ids_for_prompt(prompt, tokenizer):

tests/testing/test_compilation.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""This module contains test related to compilation operation"""
22

33
# Standard
4-
import itertools
54
import os
65
import pytest
76
import time
87

98
# Third Party
10-
from torch import distributed as dist
119
import torch
1210

1311
# First Party
@@ -20,7 +18,7 @@
2018
ids_for_prompt,
2119
get_env_to_int_list,
2220
sample_sharegpt_requests,
23-
warmup_model
21+
warmup_model,
2422
)
2523
from aiu_fms_testing_utils.utils.aiu_setup import dprint
2624

@@ -37,8 +35,12 @@
3735
common_model_paths = get_env_to_int_list("COMMON_MODEL_NAME", [GRANITE_3p3_8B_INSTRUCT])
3836
common_batch_sizes = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1])
3937
common_seq_lengths = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64])
40-
common_max_new_tokens = get_env_to_int_list("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [64])
41-
common_expected_comp_time = get_env_to_int_list("COMMON_COMPILATION_EXPECTED_TIME", [10]) # In minutes
38+
common_max_new_tokens = get_env_to_int_list(
39+
"FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [64]
40+
)
41+
common_expected_comp_time = get_env_to_int_list(
42+
"COMMON_COMPILATION_EXPECTED_TIME", [10]
43+
) # In minutes
4244

4345
COMMON_SHAPE_TYPE = "dynamic"
4446

@@ -98,7 +100,9 @@ def reset_compiler():
98100
@pytest.mark.parametrize(
99101
"model_path,batch_size,seq_length,max_new_tokens,expected_comp_time", COMMON_SHAPES
100102
)
101-
def test_compilation_time(model_path, batch_size, seq_length, max_new_tokens, expected_comp_time):
103+
def test_compilation_time(
104+
model_path, batch_size, seq_length, max_new_tokens, expected_comp_time
105+
):
102106
"""Test to validate time taken for model compilation."""
103107
torch.manual_seed(42)
104108
torch.set_default_dtype(torch.float16)
@@ -119,7 +123,7 @@ def test_compilation_time(model_path, batch_size, seq_length, max_new_tokens, ex
119123
model = get_model(
120124
architecture="hf_pretrained",
121125
device_type="cpu",
122-
data_type= torch.float16,
126+
data_type=torch.float16,
123127
fused_weights=False,
124128
**model_path_kwargs,
125129
)
@@ -137,16 +141,10 @@ def test_compilation_time(model_path, batch_size, seq_length, max_new_tokens, ex
137141
else:
138142
COMPILE_DYNAMIC_SHAPE = False
139143

140-
model.compile(
141-
backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SHAPE}
142-
)
144+
model.compile(backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SHAPE})
143145
warmup_model(
144-
model,
145-
input_ids,
146-
max_new_tokens,
147-
COMPILE_DYNAMIC_SHAPE,
148-
**extra_kwargs
146+
model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SHAPE, **extra_kwargs
149147
)
150148
end_time = time.perf_counter()
151149

152-
assert (end_time - start_time) < expected_comp_time * 60
150+
assert (end_time - start_time) < expected_comp_time * 60

0 commit comments

Comments
 (0)