11"""This module contains test related to compilation operation"""
22
33# Standard
4- import itertools
54import os
65import pytest
76import time
87
98# Third Party
10- from torch import distributed as dist
119import torch
1210
1311# First Party
1412from fms .models import get_model
15- from fms .utils import generation , tokenizers
13+ from fms .utils import tokenizers
1614from fms .utils .generation import pad_input_ids
1715
1816# Local
19- from aiu_fms_testing_utils .utils import ids_for_prompt , sample_sharegpt_requests , warmup_model
17+ from aiu_fms_testing_utils .utils import (
18+ ids_for_prompt ,
19+ get_env_to_int_list ,
20+ sample_sharegpt_requests ,
21+ warmup_model ,
22+ )
2023from aiu_fms_testing_utils .utils .aiu_setup import dprint
2124
2225GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
2629
2730ATTN_NAME = "spyre_paged_attn"
2831
29- compile_dynamic_sendnn = True
32+ COMPILE_DYNAMIC_SHAPE = True
33+
34+
35+ common_model_paths = get_env_to_int_list ("COMMON_MODEL_NAME" , [GRANITE_3p3_8B_INSTRUCT ])
36+ common_batch_sizes = get_env_to_int_list ("FMS_TEST_SHAPES_COMMON_BATCH_SIZES" , [1 ])
37+ common_seq_lengths = get_env_to_int_list ("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS" , [64 ])
38+ common_max_new_tokens = get_env_to_int_list (
39+ "FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS" , [64 ]
40+ )
41+ common_expected_comp_time = get_env_to_int_list (
42+ "COMMON_COMPILATION_EXPECTED_TIME" , [10 ]
43+ ) # In minutes
3044
31- common_model_paths = [GRANITE_3p3_8B_INSTRUCT ]
32- common_batch_sizes = [1 ]
33- common_seq_lengths = [256 ]
34- common_shape_types = ["dynamic" ]
35- common_max_new_tokens = [128 ]
36- common_expected_comp_time = [10 ] # In minutes
45+ COMMON_SHAPE_TYPE = "dynamic"
3746
38- if compile_dynamic_sendnn :
47+
48+ if COMPILE_DYNAMIC_SHAPE :
49+ import bisect
50+
51+ # the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
52+ # this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
53+ __largest_context = max (common_seq_lengths ) + max (common_max_new_tokens )
54+ __supported_context_lengths = [256 , 512 , 1024 , 2048 , 4096 , 8192 ]
3955 os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ] = str (
40- (((max (common_seq_lengths ) + max (common_max_new_tokens )) // 64 ) + 1 ) * 64
56+ __supported_context_lengths [
57+ bisect .bisect_left (__supported_context_lengths , __largest_context )
58+ ]
4159 )
4260 os .environ ["VLLM_DT_MAX_BATCH_SIZE" ] = str (max (max (common_batch_sizes ), 2 ))
4361
44- common_shapes = list (
62+ COMMON_SHAPES = list (
4563 zip (
4664 common_model_paths ,
47- common_shape_types ,
4865 common_batch_sizes ,
4966 common_seq_lengths ,
5067 common_max_new_tokens ,
@@ -59,7 +76,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
5976 SHARE_GPT_DATASET_PATH ,
6077 batch_size ,
6178 tokenizer ,
62- int ( seq_length / 2 ) ,
79+ seq_length // 2 ,
6380 seq_length ,
6481 seed ,
6582 )
@@ -74,16 +91,18 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
7491@pytest .fixture (autouse = True )
7592def reset_compiler ():
7693 yield # run the test
77- if not compile_dynamic_sendnn :
94+ if not COMPILE_DYNAMIC_SHAPE :
7895 torch .compiler .reset ()
7996 torch ._dynamo .reset ()
8097 os .environ .pop ("COMPILATION_MODE" , None )
8198
8299
83100@pytest .mark .parametrize (
84- "model_path,shape_type, batch_size,seq_length,max_new_tokens,expected_comp_time" , common_shapes
101+ "model_path,batch_size,seq_length,max_new_tokens,expected_comp_time" , COMMON_SHAPES
85102)
86- def test_compilation_time (model_path , shape_type , batch_size , seq_length , max_new_tokens , expected_comp_time ):
103+ def test_compilation_time (
104+ model_path , batch_size , seq_length , max_new_tokens , expected_comp_time
105+ ):
87106 """Test to validate time taken for model compilation."""
88107 torch .manual_seed (42 )
89108 torch .set_default_dtype (torch .float16 )
@@ -104,7 +123,7 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne
104123 model = get_model (
105124 architecture = "hf_pretrained" ,
106125 device_type = "cpu" ,
107- data_type = torch .float16 ,
126+ data_type = torch .float16 ,
108127 fused_weights = False ,
109128 ** model_path_kwargs ,
110129 )
@@ -117,21 +136,14 @@ def test_compilation_time(model_path, shape_type, batch_size, seq_length, max_ne
117136 extra_kwargs ["attn_name" ] = ATTN_NAME
118137
119138 start_time = time .perf_counter ()
120- if shape_type == "dynamic" :
121- compile_dynamic_sendnn = True
139+ if COMMON_SHAPE_TYPE == "dynamic" :
140+ COMPILE_DYNAMIC_SHAPE = True
122141 else :
123- compile_dynamic_sendnn = False
142+ COMPILE_DYNAMIC_SHAPE = False
124143
125- model .compile (
126- backend = "sendnn" , options = {"sendnn.dynamic" : compile_dynamic_sendnn }
127- )
144+ model .compile (backend = "sendnn" , options = {"sendnn.dynamic" : COMPILE_DYNAMIC_SHAPE })
128145 warmup_model (
129- model ,
130- input_ids ,
131- max_new_tokens ,
132- compile_dynamic_sendnn ,
133- use_cache = False ,
134- ** extra_kwargs
146+ model , input_ids , max_new_tokens , COMPILE_DYNAMIC_SHAPE , ** extra_kwargs
135147 )
136148 end_time = time .perf_counter ()
137149
0 commit comments