1+ """This module contains test related to compilation operation"""
2+
3+ # Standard
4+ import itertools
5+ import os
6+ import pytest
7+ import time
8+
9+ # Third Party
10+ from torch import distributed as dist
11+ import torch
12+
13+ # First Party
14+ from fms .models import get_model
15+ from fms .utils import generation , tokenizers
16+ from fms .utils .generation import pad_input_ids
17+
18+ # Local
19+ from aiu_fms_testing_utils .utils import ids_for_prompt , sample_sharegpt_requests , warmup_model
20+ from aiu_fms_testing_utils .utils .aiu_setup import dprint
21+
22+ GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
23+ SHARE_GPT_DATASET_PATH = os .environ .get (
24+ "SHARE_GPT_DATASET_PATH" , os .path .expanduser ("~/share_gpt.json" )
25+ )
26+
27+ ATTN_NAME = "spyre_paged_attn"
28+
29+ compile_dynamic_sendnn = True
30+
31+ common_model_paths = [GRANITE_3p3_8B_INSTRUCT ]
32+ common_batch_sizes = [1 ]
33+ common_seq_lengths = [256 ]
34+ common_shape_types = ["dynamic" ]
35+ common_max_new_tokens = [128 ]
36+ common_expected_comp_time = [10 ] # In minutes
37+
38+ if compile_dynamic_sendnn :
39+ os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ] = str (
40+ (((max (common_seq_lengths ) + max (common_max_new_tokens )) // 64 ) + 1 ) * 64
41+ )
42+ os .environ ["VLLM_DT_MAX_BATCH_SIZE" ] = str (max (max (common_batch_sizes ), 2 ))
43+
44+ common_shapes = list (
45+ zip (
46+ common_model_paths ,
47+ common_shape_types ,
48+ common_batch_sizes ,
49+ common_seq_lengths ,
50+ common_max_new_tokens ,
51+ common_expected_comp_time ,
52+ )
53+ )
54+
55+
56+ # TODO: This is copied from test_decoders.py would be good to consolidate
57+ def __prepare_inputs (batch_size , seq_length , tokenizer , seed = 0 ):
58+ prompts_and_sizes = sample_sharegpt_requests (
59+ SHARE_GPT_DATASET_PATH ,
60+ batch_size ,
61+ tokenizer ,
62+ int (seq_length / 2 ),
63+ seq_length ,
64+ seed ,
65+ )
66+ prompt_list = []
67+ for prompt , _ in prompts_and_sizes :
68+ prompt_list .append (ids_for_prompt (prompt , tokenizer ))
69+
70+ input_ids , extra_kwargs = pad_input_ids (prompt_list , min_pad_length = seq_length )
71+ return input_ids , extra_kwargs
72+
73+
74+ @pytest .fixture (autouse = True )
75+ def reset_compiler ():
76+ yield # run the test
77+ if not compile_dynamic_sendnn :
78+ torch .compiler .reset ()
79+ torch ._dynamo .reset ()
80+ os .environ .pop ("COMPILATION_MODE" , None )
81+
82+
83+ @pytest .mark .parametrize (
84+ "model_path,shape_type,batch_size,seq_length,max_new_tokens,expected_comp_time" , common_shapes
85+ )
86+ def test_compilation_time (model_path , shape_type , batch_size , seq_length , max_new_tokens , expected_comp_time ):
87+ """Test to validate time taken for model compilation."""
88+ torch .manual_seed (42 )
89+ torch .set_default_dtype (torch .float16 )
90+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
91+
92+ dprint (
93+ f"testing model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } "
94+ )
95+
96+ if os .path .exists (model_path ):
97+ model_path_kwargs = {"model_path" : model_path }
98+ else :
99+ model_path_kwargs = {"variant" : model_path }
100+
101+ tokenizer = tokenizers .get_tokenizer (model_path )
102+
103+ # prepare the AIU model
104+ model = get_model (
105+ architecture = "hf_pretrained" ,
106+ device_type = "cpu" ,
107+ data_type = torch .float16 ,
108+ fused_weights = False ,
109+ ** model_path_kwargs ,
110+ )
111+
112+ model .eval ()
113+ torch .set_grad_enabled (False )
114+
115+ # prepare input_ids
116+ input_ids , extra_kwargs = __prepare_inputs (batch_size , seq_length , tokenizer )
117+ extra_kwargs ["attn_name" ] = ATTN_NAME
118+
119+ start_time = time .perf_counter ()
120+ if shape_type == "dynamic" :
121+ compile_dynamic_sendnn = True
122+ else :
123+ compile_dynamic_sendnn = False
124+
125+ model .compile (
126+ backend = "sendnn" , options = {"sendnn.dynamic" : compile_dynamic_sendnn }
127+ )
128+ warmup_model (
129+ model ,
130+ input_ids ,
131+ max_new_tokens ,
132+ compile_dynamic_sendnn ,
133+ use_cache = False ,
134+ ** extra_kwargs
135+ )
136+ end_time = time .perf_counter ()
137+
138+ assert (end_time - start_time ) < expected_comp_time * 60
0 commit comments