Skip to content

Commit 6cedad2

Browse files
vjanfazaquic-rishinr
authored andcommitted
Adding Compute-Context-Length(CCL)
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 0e9c851 commit 6cedad2

File tree

18 files changed

+365
-104
lines changed

18 files changed

+365
-104
lines changed

QEfficient/transformers/models/codegen/modeling_codegen.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def forward(
7272
self,
7373
hidden_states: Optional[torch.FloatTensor],
7474
layer_past: Optional[Tuple[torch.Tensor]] = None,
75+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
7576
attention_mask: Optional[torch.FloatTensor] = None,
7677
position_ids: Optional[torch.LongTensor] = None,
7778
head_mask: Optional[torch.FloatTensor] = None,
@@ -123,7 +124,9 @@ def forward(
123124
query = query.permute(0, 2, 1, 3)
124125

125126
if layer_past is not None:
126-
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
127+
if comp_ctx_lengths is not None:
128+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
129+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]}
127130
key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
128131

129132
# compute self-attention: V x Softmax(QK^T)
@@ -147,6 +150,7 @@ def forward(
147150
self,
148151
input_ids: Optional[torch.LongTensor] = None,
149152
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
153+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
150154
attention_mask: Optional[torch.FloatTensor] = None,
151155
token_type_ids: Optional[torch.LongTensor] = None,
152156
batch_index: Optional[torch.LongTensor] = None,
@@ -245,6 +249,7 @@ def forward(
245249
outputs = block(
246250
hidden_states,
247251
layer_past=past_key_values,
252+
comp_ctx_lengths=comp_ctx_lengths,
248253
batch_index=batch_index,
249254
attention_mask=attention_mask,
250255
position_ids=position_ids,
@@ -294,6 +299,7 @@ def forward(
294299
self,
295300
input_ids: Optional[torch.LongTensor] = None,
296301
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
302+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
297303
attention_mask: Optional[torch.FloatTensor] = None,
298304
token_type_ids: Optional[torch.LongTensor] = None,
299305
position_ids: Optional[torch.LongTensor] = None,
@@ -312,6 +318,7 @@ def forward(
312318
transformer_outputs = self.transformer(
313319
input_ids,
314320
past_key_values=past_key_values,
321+
comp_ctx_lengths=comp_ctx_lengths,
315322
attention_mask=attention_mask,
316323
token_type_ids=token_type_ids,
317324
batch_index=batch_index,
@@ -348,6 +355,7 @@ def forward(
348355
self,
349356
hidden_states: Optional[torch.FloatTensor],
350357
layer_past: Optional[Cache] = None,
358+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
351359
attention_mask: Optional[torch.FloatTensor] = None,
352360
position_ids: Optional[torch.LongTensor] = None,
353361
batch_index: Optional[torch.LongTensor] = None,
@@ -361,6 +369,7 @@ def forward(
361369
attn_outputs, attn_weights = self.attn(
362370
hidden_states=hidden_states,
363371
layer_past=layer_past,
372+
comp_ctx_lengths=comp_ctx_lengths,
364373
attention_mask=attention_mask,
365374
position_ids=position_ids,
366375
batch_index=batch_index,

QEfficient/transformers/models/falcon/modeling_falcon.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def forward(
117117
attention_mask: torch.Tensor,
118118
position_ids: Optional[torch.LongTensor] = None,
119119
past_key_value: Optional[Cache] = None,
120+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
120121
batch_index: Optional[torch.LongTensor] = None,
121122
layer_past: Optional[Cache] = None,
122123
head_mask: Optional[torch.Tensor] = None,
@@ -140,7 +141,9 @@ def forward(
140141
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
141142

142143
if layer_past is not None:
143-
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
144+
if comp_ctx_lengths is not None:
145+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
146+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]}
144147
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
145148

146149
if attention_mask is not None:
@@ -172,6 +175,7 @@ def forward(
172175
attention_mask: torch.Tensor,
173176
position_ids: Optional[torch.LongTensor] = None,
174177
past_key_value: Optional[Cache] = None,
178+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
175179
batch_index: Optional[torch.LongTensor] = None,
176180
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
177181
head_mask: Optional[torch.Tensor] = None,
@@ -195,6 +199,7 @@ def forward(
195199
attention_mask=attention_mask,
196200
position_ids=position_ids,
197201
past_key_value=past_key_value,
202+
comp_ctx_lengths=comp_ctx_lengths,
198203
batch_index=batch_index,
199204
alibi=alibi,
200205
head_mask=head_mask,
@@ -245,6 +250,7 @@ def forward(
245250
position_ids: Optional[torch.LongTensor] = None,
246251
batch_index: Optional[torch.LongTensor] = None,
247252
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
253+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
248254
head_mask: Optional[torch.LongTensor] = None,
249255
inputs_embeds: Optional[torch.LongTensor] = None,
250256
use_cache: Optional[bool] = None,
@@ -307,6 +313,7 @@ def forward(
307313
attention_mask=causal_mask,
308314
position_ids=position_ids,
309315
past_key_value=past_key_values,
316+
comp_ctx_lengths=comp_ctx_lengths,
310317
batch_index=batch_index,
311318
head_mask=head_mask[i],
312319
use_cache=use_cache,
@@ -352,6 +359,7 @@ def forward(
352359
position_ids: Optional[torch.LongTensor] = None,
353360
batch_index: Optional[torch.LongTensor] = None,
354361
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
362+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
355363
head_mask: Optional[torch.Tensor] = None,
356364
inputs_embeds: Optional[torch.Tensor] = None,
357365
use_cache: Optional[bool] = None,
@@ -368,6 +376,7 @@ def forward(
368376
attention_mask=attention_mask,
369377
position_ids=position_ids,
370378
past_key_values=past_key_values,
379+
comp_ctx_lengths=comp_ctx_lengths,
371380
batch_index=batch_index,
372381
head_mask=head_mask,
373382
inputs_embeds=inputs_embeds,

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,13 @@ def __init__(self, model):
603603
self.lm_head = self.model.lm_head
604604

605605
def forward(
606-
self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None
606+
self,
607+
input_ids,
608+
vision_embeds,
609+
position_ids,
610+
image_idx,
611+
past_key_values,
612+
comp_ctx_lengths: Optional[List[int]] = None,
607613
):
608614
inputs_embeds = self.model.get_input_embeddings()(input_ids)
609615
B, N, C = inputs_embeds.shape
@@ -637,7 +643,13 @@ def get_qeff_language_decoder(self):
637643
return QEffGemma3DecoderWrapper(self)
638644

639645
def forward(
640-
self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None
646+
self,
647+
input_ids,
648+
position_ids,
649+
pixel_values,
650+
image_idx,
651+
past_key_values,
652+
comp_ctx_lengths: Optional[List[int]] = None,
641653
):
642654
image_features = self.get_image_features(pixel_values=pixel_values)
643655
inputs_embeds = self.get_input_embeddings()(input_ids)
@@ -669,8 +681,8 @@ def get_specializations(
669681
prefill_seq_len: int,
670682
ctx_len: int,
671683
img_size: int,
672-
comp_ctx_lengths_prefill: List[int] = None,
673-
comp_ctx_lengths_decode: List[int] = None,
684+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
685+
comp_ctx_lengths_decode: Optional[List[int]] = None,
674686
kv_offload: bool = False,
675687
**compiler_options,
676688
):
@@ -749,7 +761,7 @@ def get_specializations(
749761
else:
750762
return lang, compiler_options
751763

752-
def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
764+
def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False):
753765
# Define dynamic axes
754766
vision_dynamic_axes = {}
755767
lang_dynamic_axes = {}
@@ -825,7 +837,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
825837
past_key_values.append(pkv)
826838
return past_key_values
827839

828-
def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
840+
def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False):
829841
if vis_cfg := getattr(self.config, "vision_config", None):
830842
img_size = getattr(vis_cfg, "image_size", 896)
831843
else:

QEfficient/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def forward(
6565
self,
6666
hidden_states: Optional[Tuple[torch.FloatTensor]],
6767
past_key_value: Optional[Cache] = None,
68+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
6869
attention_mask: Optional[torch.FloatTensor] = None,
6970
position_ids: Optional[torch.LongTensor] = None,
7071
batch_index: Optional[torch.LongTensor] = None,
@@ -118,9 +119,11 @@ def forward(
118119
if (past_key_value is not None and not is_cross_attention) or (
119120
past_key_value is not None and is_cross_attention and not is_updated
120121
):
122+
if comp_ctx_lengths is not None:
123+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
121124
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
122125
# Update the cache_kwargs with position_ids for Cloud AI 100
123-
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
126+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]}
124127
key_states, value_states = curr_past_key_value.update(
125128
key_states, value_states, self.layer_idx, cache_kwargs
126129
)
@@ -156,6 +159,7 @@ def forward(
156159
self,
157160
hidden_states: Optional[Tuple[torch.FloatTensor]],
158161
past_key_value: Optional[Cache] = None,
162+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
159163
attention_mask: Optional[torch.FloatTensor] = None,
160164
position_ids: Optional[torch.LongTensor] = None,
161165
batch_index: Optional[torch.LongTensor] = None,
@@ -174,6 +178,7 @@ def forward(
174178
hidden_states,
175179
past_key_value=past_key_value,
176180
attention_mask=attention_mask,
181+
comp_ctx_lengths=comp_ctx_lengths,
177182
position_ids=position_ids,
178183
batch_index=batch_index,
179184
head_mask=head_mask,
@@ -232,6 +237,7 @@ def forward(
232237
self,
233238
input_ids: Optional[torch.LongTensor] = None,
234239
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
240+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
235241
attention_mask: Optional[torch.FloatTensor] = None,
236242
token_type_ids: Optional[torch.LongTensor] = None,
237243
position_ids: Optional[torch.LongTensor] = None,
@@ -341,6 +347,7 @@ def forward(
341347
outputs = block(
342348
hidden_states,
343349
past_key_value=past_key_values,
350+
comp_ctx_lengths=comp_ctx_lengths,
344351
attention_mask=attention_mask,
345352
position_ids=position_ids,
346353
batch_index=batch_index,
@@ -392,6 +399,7 @@ def forward(
392399
self,
393400
input_ids: Optional[torch.LongTensor] = None,
394401
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
402+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
395403
attention_mask: Optional[torch.FloatTensor] = None,
396404
token_type_ids: Optional[torch.LongTensor] = None,
397405
position_ids: Optional[torch.LongTensor] = None,
@@ -418,6 +426,7 @@ def forward(
418426
transformer_outputs = self.transformer(
419427
input_ids,
420428
past_key_values=past_key_values,
429+
comp_ctx_lengths=comp_ctx_lengths,
421430
attention_mask=attention_mask,
422431
token_type_ids=token_type_ids,
423432
position_ids=position_ids,

QEfficient/transformers/models/gptj/modeling_gptj.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def forward(
8383
self,
8484
hidden_states: torch.FloatTensor,
8585
layer_past: Optional[Tuple[torch.Tensor]] = None,
86+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
8687
attention_mask: Optional[torch.FloatTensor] = None,
8788
position_ids: Optional[torch.LongTensor] = None,
8889
batch_index: Optional[torch.LongTensor] = None,
@@ -134,7 +135,9 @@ def forward(
134135
query = query.permute(0, 2, 1, 3)
135136

136137
if layer_past is not None:
137-
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
138+
if comp_ctx_lengths is not None:
139+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
140+
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]}
138141
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
139142

140143
# compute self-attention: V x Softmax(QK^T)
@@ -151,6 +154,7 @@ def forward(
151154
self,
152155
hidden_states: Optional[torch.FloatTensor],
153156
layer_past: Optional[Cache] = None,
157+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
154158
attention_mask: Optional[torch.FloatTensor] = None,
155159
position_ids: Optional[torch.LongTensor] = None,
156160
batch_index: Optional[torch.LongTensor] = None,
@@ -164,6 +168,7 @@ def forward(
164168
attn_outputs, attn_weights = self.attn(
165169
hidden_states=hidden_states,
166170
layer_past=layer_past,
171+
comp_ctx_lengths=comp_ctx_lengths,
167172
attention_mask=attention_mask,
168173
position_ids=position_ids,
169174
batch_index=batch_index,
@@ -191,6 +196,7 @@ def forward(
191196
self,
192197
input_ids: Optional[torch.LongTensor] = None,
193198
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
199+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
194200
attention_mask: Optional[torch.FloatTensor] = None,
195201
token_type_ids: Optional[torch.LongTensor] = None,
196202
position_ids: Optional[torch.LongTensor] = None,
@@ -270,6 +276,7 @@ def forward(
270276
outputs = block(
271277
hidden_states=hidden_states,
272278
layer_past=past_key_values,
279+
comp_ctx_lengths=comp_ctx_lengths,
273280
attention_mask=causal_mask,
274281
position_ids=position_ids,
275282
batch_index=batch_index,
@@ -314,6 +321,7 @@ def forward(
314321
self,
315322
input_ids: Optional[torch.LongTensor] = None,
316323
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
324+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
317325
attention_mask: Optional[torch.FloatTensor] = None,
318326
token_type_ids: Optional[torch.LongTensor] = None,
319327
position_ids: Optional[torch.LongTensor] = None,
@@ -339,6 +347,7 @@ def forward(
339347
transformer_outputs = self.transformer(
340348
input_ids,
341349
past_key_values=past_key_values,
350+
comp_ctx_lengths=comp_ctx_lengths,
342351
attention_mask=attention_mask,
343352
token_type_ids=token_type_ids,
344353
position_ids=position_ids,

QEfficient/transformers/models/grok_1/modeling_grok1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def forward(
5555
attention_mask: Optional[torch.Tensor] = None,
5656
position_ids: Optional[torch.LongTensor] = None,
5757
past_key_value: Optional[Tuple[torch.Tensor]] = None,
58+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
5859
batch_index: Optional[torch.LongTensor] = None,
5960
output_attentions: bool = False,
6061
use_cache: bool = False,
@@ -93,7 +94,9 @@ def forward(
9394
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
9495

9596
if past_key_value is not None:
96-
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
97+
if comp_ctx_lengths is not None:
98+
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
99+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]}
97100
key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs)
98101

99102
# repeat k/v heads if n_kv_heads < n_heads
@@ -205,6 +208,7 @@ def forward(
205208
attention_mask: Optional[torch.Tensor] = None,
206209
position_ids: Optional[torch.LongTensor] = None,
207210
past_key_value: Optional[Tuple[torch.Tensor]] = None,
211+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
208212
batch_index: Optional[torch.LongTensor] = None,
209213
output_attentions: Optional[bool] = False,
210214
output_router_logits: Optional[bool] = False,
@@ -235,6 +239,7 @@ def forward(
235239
attention_mask=attention_mask,
236240
position_ids=position_ids,
237241
past_key_value=past_key_value,
242+
comp_ctx_lengths=comp_ctx_lengths,
238243
batch_index=batch_index,
239244
output_attentions=output_attentions,
240245
use_cache=use_cache,
@@ -277,6 +282,7 @@ def forward(
277282
attention_mask: Optional[torch.Tensor] = None,
278283
position_ids: Optional[torch.LongTensor] = None,
279284
past_key_values: Optional[List[torch.FloatTensor]] = None,
285+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
280286
batch_index: Optional[torch.LongTensor] = None,
281287
inputs_embeds: Optional[torch.FloatTensor] = None,
282288
use_cache: Optional[bool] = None,
@@ -351,6 +357,7 @@ def forward(
351357
attention_mask=attention_mask,
352358
position_ids=position_ids,
353359
past_key_value=past_key_values,
360+
comp_ctx_lengths=comp_ctx_lengths,
354361
batch_index=batch_index,
355362
output_attentions=output_attentions,
356363
use_cache=use_cache,
@@ -395,6 +402,7 @@ def forward(
395402
attention_mask: Optional[torch.Tensor] = None,
396403
position_ids: Optional[torch.LongTensor] = None,
397404
past_key_values: Optional[List[torch.FloatTensor]] = None,
405+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
398406
batch_index: Optional[torch.LongTensor] = None,
399407
inputs_embeds: Optional[torch.FloatTensor] = None,
400408
labels: Optional[torch.LongTensor] = None,
@@ -441,6 +449,7 @@ def forward(
441449
attention_mask=attention_mask,
442450
position_ids=position_ids,
443451
past_key_values=past_key_values,
452+
comp_ctx_lengths=comp_ctx_lengths,
444453
batch_index=batch_index,
445454
inputs_embeds=inputs_embeds,
446455
use_cache=use_cache,

0 commit comments

Comments
 (0)