@@ -98,52 +98,64 @@ def __init__(self) -> None:
9898def generate_dummy_past_key_values (config , input_bs ):
9999 """Generate the dummy past_key_values."""
100100 from optimum .utils import NormalizedConfigManager
101-
102- normalized_config = NormalizedConfigManager .get_normalized_config_class (
103- config .model_type
104- )(config )
105- nb_pkv = 2
106- num_layers = normalized_config .num_layers
107- num_attention_heads = normalized_config .num_attention_heads
108- hidden_size = normalized_config .hidden_size
109- d_k = hidden_size // num_attention_heads
110- num_key_value_heads = num_attention_heads
111- if hasattr (normalized_config , "num_key_value_heads" ):
112- num_key_value_heads = normalized_config .num_key_value_heads
113- if hasattr (normalized_config , "multi_query_group_num" ):
114- num_key_value_heads = normalized_config .multi_query_group_num
115-
116- if config .model_type == "bloom" :
117- shape_key = (input_bs * num_attention_heads , d_k , 1 )
118- shape_value = (input_bs * num_attention_heads , 1 , d_k )
119- key = torch .ones (size = shape_key )
120- value = torch .ones (size = shape_value )
121- past_key_values = tuple (
122- tuple (key if idx % 2 == 0 else value for idx in range (nb_pkv ))
123- for _ in range (num_layers )
124- )
125- return past_key_values
126- elif config .model_type == "gpt_bigcode" :
127- new_shape = [input_bs , 0 , d_k * 2 ]
128- dummy_tensor = torch .zeros (size = new_shape )
129- past_key_values = tuple ([dummy_tensor ] * num_layers )
130- return past_key_values
131- elif config .model_type == "qwen" :
132- new_shape = [input_bs , 1 , num_key_value_heads , d_k ]
133- past_key_values = [
134- (
135- torch .ones (size = new_shape ).contiguous (),
136- torch .ones (size = new_shape ).contiguous (),
137- )
138- for _ in range (num_layers )
101+ if config .model_type == "qwen" :
102+ new_shape = [
103+ input_bs ,
104+ 0 ,
105+ config .num_attention_heads ,
106+ config .hidden_size // config .num_attention_heads ,
139107 ]
140- return tuple (past_key_values )
108+ num_layers = config .num_hidden_layers
109+ elif config .model_type == "baichuan" :
110+ new_shape = [
111+ input_bs ,
112+ config .num_attention_heads ,
113+ 0 ,
114+ config .hidden_size // config .num_attention_heads ,
115+ ]
116+ num_layers = config .num_hidden_layers
141117 elif config .model_type == "chatglm" :
142- new_shape = [0 , input_bs , num_key_value_heads , d_k ]
143- elif config .model_type == "falcon" :
144- new_shape = [input_bs , 1 , 0 , d_k ]
118+ new_shape = [
119+ 0 ,
120+ input_bs ,
121+ config .num_attention_heads ,
122+ config .hidden_size // config .num_attention_heads ,
123+ ]
124+ num_layers = config .num_layers
145125 else :
146- new_shape = [input_bs , num_key_value_heads , 0 , d_k ]
126+ normalized_config = NormalizedConfigManager .get_normalized_config_class (
127+ config .model_type
128+ )(config )
129+ nb_pkv = 2
130+ num_layers = normalized_config .num_layers
131+ num_attention_heads = normalized_config .num_attention_heads
132+ hidden_size = normalized_config .hidden_size
133+ d_k = hidden_size // num_attention_heads
134+ num_key_value_heads = num_attention_heads
135+ if hasattr (normalized_config , "num_key_value_heads" ):
136+ num_key_value_heads = normalized_config .num_key_value_heads
137+ if hasattr (normalized_config , "multi_query_group_num" ):
138+ num_key_value_heads = normalized_config .multi_query_group_num
139+
140+ if config .model_type == "bloom" :
141+ shape_key = (input_bs * num_attention_heads , d_k , 1 )
142+ shape_value = (input_bs * num_attention_heads , 1 , d_k )
143+ key = torch .ones (size = shape_key )
144+ value = torch .ones (size = shape_value )
145+ past_key_values = tuple (
146+ tuple (key if idx % 2 == 0 else value for idx in range (nb_pkv ))
147+ for _ in range (num_layers )
148+ )
149+ return past_key_values
150+ elif config .model_type == "gpt_bigcode" :
151+ new_shape = [input_bs , 0 , d_k * 2 ]
152+ dummy_tensor = torch .zeros (size = new_shape )
153+ past_key_values = tuple ([dummy_tensor ] * num_layers )
154+ return past_key_values
155+ elif config .model_type == "falcon" :
156+ new_shape = [input_bs , 1 , 0 , d_k ]
157+ else :
158+ new_shape = [input_bs , num_key_value_heads , 0 , d_k ]
147159 past_key_values = [
148160 (
149161 torch .zeros (size = new_shape ).contiguous (),
@@ -156,44 +168,64 @@ def generate_dummy_past_key_values(config, input_bs):
156168def generate_dummy_past_key_values_for_inference (config , input_bs ):
157169 """Generate the dummy past_key_values."""
158170 from optimum .utils import NormalizedConfigManager
159-
160- normalized_config = NormalizedConfigManager .get_normalized_config_class (
161- config .model_type
162- )(config )
163- nb_pkv = 2
164- num_layers = normalized_config .num_layers
165- num_attention_heads = normalized_config .num_attention_heads
166- hidden_size = normalized_config .hidden_size
167- d_k = hidden_size // num_attention_heads
168- num_key_value_heads = num_attention_heads
169- if hasattr (normalized_config , "num_key_value_heads" ):
170- num_key_value_heads = normalized_config .num_key_value_heads
171- if hasattr (normalized_config , "multi_query_group_num" ):
172- num_key_value_heads = normalized_config .multi_query_group_num
173-
174- if config .model_type == "bloom" :
175- shape_key = (input_bs * num_attention_heads , d_k , 0 )
176- shape_value = (input_bs * num_attention_heads , 0 , d_k )
177- key = torch .empty (size = shape_key )
178- value = torch .empty (size = shape_value )
179- past_key_values = tuple (
180- tuple (key if idx % 2 == 0 else value for idx in range (nb_pkv ))
181- for _ in range (num_layers )
182- )
183- return past_key_values
184- elif config .model_type == "gpt_bigcode" :
185- new_shape = [input_bs , 0 , d_k * 2 ]
186- dummy_tensor = torch .zeros (size = new_shape )
187- past_key_values = tuple ([dummy_tensor ] * num_layers )
188- return past_key_values
189- elif config .model_type == "qwen" :
190- new_shape = [input_bs , 0 , num_key_value_heads , d_k ]
171+ if config .model_type == "qwen" :
172+ new_shape = [
173+ input_bs ,
174+ 0 ,
175+ config .num_attention_heads ,
176+ config .hidden_size // config .num_attention_heads ,
177+ ]
178+ num_layers = config .num_hidden_layers
179+ elif config .model_type == "baichuan" :
180+ new_shape = [
181+ input_bs ,
182+ config .num_attention_heads ,
183+ 0 ,
184+ config .hidden_size // config .num_attention_heads ,
185+ ]
186+ num_layers = config .num_hidden_layers
191187 elif config .model_type == "chatglm" :
192- new_shape = [0 , input_bs , num_key_value_heads , d_k ]
193- elif config .model_type == "falcon" :
194- new_shape = [input_bs , 1 , 0 , d_k ]
188+ new_shape = [
189+ 0 ,
190+ input_bs ,
191+ config .num_attention_heads ,
192+ config .hidden_size // config .num_attention_heads ,
193+ ]
194+ num_layers = config .num_layers
195195 else :
196- new_shape = [input_bs , num_key_value_heads , 0 , d_k ]
196+ normalized_config = NormalizedConfigManager .get_normalized_config_class (
197+ config .model_type
198+ )(config )
199+ nb_pkv = 2
200+ num_layers = normalized_config .num_layers
201+ num_attention_heads = normalized_config .num_attention_heads
202+ hidden_size = normalized_config .hidden_size
203+ d_k = hidden_size // num_attention_heads
204+ num_key_value_heads = num_attention_heads
205+ if hasattr (normalized_config , "num_key_value_heads" ):
206+ num_key_value_heads = normalized_config .num_key_value_heads
207+ if hasattr (normalized_config , "multi_query_group_num" ):
208+ num_key_value_heads = normalized_config .multi_query_group_num
209+
210+ if config .model_type == "bloom" :
211+ shape_key = (input_bs * num_attention_heads , d_k , 0 )
212+ shape_value = (input_bs * num_attention_heads , 0 , d_k )
213+ key = torch .empty (size = shape_key )
214+ value = torch .empty (size = shape_value )
215+ past_key_values = tuple (
216+ tuple (key if idx % 2 == 0 else value for idx in range (nb_pkv ))
217+ for _ in range (num_layers )
218+ )
219+ return past_key_values
220+ elif config .model_type == "gpt_bigcode" :
221+ new_shape = [input_bs , 0 , d_k * 2 ]
222+ dummy_tensor = torch .zeros (size = new_shape )
223+ past_key_values = tuple ([dummy_tensor ] * num_layers )
224+ return past_key_values
225+ elif config .model_type == "falcon" :
226+ new_shape = [input_bs , 1 , 0 , d_k ]
227+ else :
228+ new_shape = [input_bs , num_key_value_heads , 0 , d_k ]
197229 past_key_values = [
198230 (
199231 torch .zeros (size = new_shape ).contiguous (),
@@ -206,32 +238,53 @@ def generate_dummy_past_key_values_for_inference(config, input_bs):
206238def generate_dummy_past_key_values_for_opt_llm (config , input_bs , num_beams = 1 ):
207239 """Generate the dummy past_key_values."""
208240 from optimum .utils import NormalizedConfigManager
209-
210- normalized_config = NormalizedConfigManager .get_normalized_config_class (
211- config .model_type
212- )(config )
213- num_layers = normalized_config .num_layers
214- num_attention_heads = normalized_config .num_attention_heads
215- hidden_size = normalized_config .hidden_size
216- d_k = hidden_size // num_attention_heads
217- num_key_value_heads = num_attention_heads
218- nb_pkv = 2
219- if hasattr (normalized_config , "num_key_value_heads" ):
220- num_key_value_heads = normalized_config .num_key_value_heads
221- if hasattr (normalized_config , "multi_query_group_num" ):
222- num_key_value_heads = normalized_config .multi_query_group_num
223- if config .model_type == "bloom" :
224- for nb_pkv in range (nb_pkv ):
225- if nb_pkv % 2 == 0 :
226- new_shape = [input_bs * num_key_value_heads , d_k , 1 ]
227- else :
228- new_shape = [input_bs * num_key_value_heads , 1 , d_k ]
229- elif config .model_type == "qwen" :
230- new_shape = [input_bs , 1 , num_key_value_heads , d_k ]
241+ if config .model_type == "qwen" :
242+ new_shape = [
243+ input_bs ,
244+ 1 ,
245+ config .num_attention_heads ,
246+ config .hidden_size // config .num_attention_heads ,
247+ ]
248+ num_layers = config .num_hidden_layers
249+ elif config .model_type == "baichuan" :
250+ new_shape = [
251+ input_bs ,
252+ config .num_attention_heads ,
253+ 1 ,
254+ config .hidden_size // config .num_attention_heads ,
255+ ]
256+ num_layers = config .num_hidden_layers
231257 elif config .model_type == "chatglm" :
232- new_shape = [1 , input_bs , num_key_value_heads , d_k ]
258+ new_shape = [
259+ 1 ,
260+ input_bs ,
261+ config .num_attention_heads ,
262+ config .hidden_size // config .num_attention_heads ,
263+ ]
264+ num_layers = config .num_layers
233265 else :
234- new_shape = [input_bs , num_key_value_heads , 1 , d_k ]
266+ normalized_config = NormalizedConfigManager .get_normalized_config_class (
267+ config .model_type
268+ )(config )
269+ num_layers = normalized_config .num_layers
270+ num_attention_heads = normalized_config .num_attention_heads
271+ hidden_size = normalized_config .hidden_size
272+ d_k = hidden_size // num_attention_heads
273+ num_key_value_heads = num_attention_heads
274+ nb_pkv = 2
275+ if hasattr (normalized_config , "num_key_value_heads" ):
276+ num_key_value_heads = normalized_config .num_key_value_heads
277+ if hasattr (normalized_config , "multi_query_group_num" ):
278+ num_key_value_heads = normalized_config .multi_query_group_num
279+ if config .model_type == "bloom" :
280+ for nb_pkv in range (nb_pkv ):
281+ if nb_pkv % 2 == 0 :
282+ new_shape = [input_bs * num_key_value_heads , d_k , 1 ]
283+ else :
284+ new_shape = [input_bs * num_key_value_heads , 1 , d_k ]
285+
286+ else :
287+ new_shape = [input_bs , num_key_value_heads , 1 , d_k ]
235288
236289 beam_idx_tmp = torch .zeros (
237290 (2048 , int (input_bs * num_beams )), dtype = torch .long
0 commit comments