11"""This module contains test related to compilation operation"""
22
33# Standard
4- import itertools
54import os
65import pytest
76import time
87
98# Third Party
10- from torch import distributed as dist
119import torch
1210
1311# First Party
2018 ids_for_prompt ,
2119 get_env_to_int_list ,
2220 sample_sharegpt_requests ,
23- warmup_model
21+ warmup_model ,
2422)
2523from aiu_fms_testing_utils .utils .aiu_setup import dprint
2624
3735common_model_paths = get_env_to_int_list ("COMMON_MODEL_NAME" , [GRANITE_3p3_8B_INSTRUCT ])
3836common_batch_sizes = get_env_to_int_list ("FMS_TEST_SHAPES_COMMON_BATCH_SIZES" , [1 ])
3937common_seq_lengths = get_env_to_int_list ("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS" , [64 ])
40- common_max_new_tokens = get_env_to_int_list ("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS" , [64 ])
41- common_expected_comp_time = get_env_to_int_list ("COMMON_COMPILATION_EXPECTED_TIME" , [10 ]) # In minutes
38+ common_max_new_tokens = get_env_to_int_list (
39+ "FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS" , [64 ]
40+ )
41+ common_expected_comp_time = get_env_to_int_list (
42+ "COMMON_COMPILATION_EXPECTED_TIME" , [10 ]
43+ ) # In minutes
4244
4345COMMON_SHAPE_TYPE = "dynamic"
4446
@@ -98,7 +100,9 @@ def reset_compiler():
98100@pytest .mark .parametrize (
99101 "model_path,batch_size,seq_length,max_new_tokens,expected_comp_time" , COMMON_SHAPES
100102)
101- def test_compilation_time (model_path , batch_size , seq_length , max_new_tokens , expected_comp_time ):
103+ def test_compilation_time (
104+ model_path , batch_size , seq_length , max_new_tokens , expected_comp_time
105+ ):
102106 """Test to validate time taken for model compilation."""
103107 torch .manual_seed (42 )
104108 torch .set_default_dtype (torch .float16 )
@@ -119,7 +123,7 @@ def test_compilation_time(model_path, batch_size, seq_length, max_new_tokens, ex
119123 model = get_model (
120124 architecture = "hf_pretrained" ,
121125 device_type = "cpu" ,
122- data_type = torch .float16 ,
126+ data_type = torch .float16 ,
123127 fused_weights = False ,
124128 ** model_path_kwargs ,
125129 )
@@ -137,16 +141,10 @@ def test_compilation_time(model_path, batch_size, seq_length, max_new_tokens, ex
137141 else :
138142 COMPILE_DYNAMIC_SHAPE = False
139143
140- model .compile (
141- backend = "sendnn" , options = {"sendnn.dynamic" : COMPILE_DYNAMIC_SHAPE }
142- )
144+ model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : COMPILE_DYNAMIC_SHAPE })
143145 warmup_model (
144- model ,
145- input_ids ,
146- max_new_tokens ,
147- COMPILE_DYNAMIC_SHAPE ,
148- ** extra_kwargs
146+ model , input_ids , max_new_tokens , COMPILE_DYNAMIC_SHAPE , ** extra_kwargs
149147 )
150148 end_time = time .perf_counter ()
151149
152- assert (end_time - start_time ) < expected_comp_time * 60
150+ assert (end_time - start_time ) < expected_comp_time * 60
0 commit comments