@@ -117,6 +117,7 @@ def forward(
117117 attention_mask : torch .Tensor ,
118118 position_ids : Optional [torch .LongTensor ] = None ,
119119 past_key_value : Optional [Cache ] = None ,
120+ comp_ctx_lengths : Optional [torch .LongTensor ] = None ,
120121 batch_index : Optional [torch .LongTensor ] = None ,
121122 layer_past : Optional [Cache ] = None ,
122123 head_mask : Optional [torch .Tensor ] = None ,
@@ -140,7 +141,9 @@ def forward(
140141 query_layer , key_layer = qeff_apply_rotary_pos_emb (query_layer , key_layer , cos , sin , position_ids )
141142
142143 if layer_past is not None :
143- cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids }
144+ if comp_ctx_lengths is not None :
145+ attention_mask = attention_mask [:, :, :, : comp_ctx_lengths .shape [- 1 ]]
146+ cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids , "CCL" : attention_mask .shape [- 1 ]}
144147 key_layer , value_layer = layer_past .update (key_layer , value_layer , self .layer_idx , cache_kwargs )
145148
146149 if attention_mask is not None :
@@ -172,6 +175,7 @@ def forward(
172175 attention_mask : torch .Tensor ,
173176 position_ids : Optional [torch .LongTensor ] = None ,
174177 past_key_value : Optional [Cache ] = None ,
178+ comp_ctx_lengths : Optional [torch .LongTensor ] = None ,
175179 batch_index : Optional [torch .LongTensor ] = None ,
176180 layer_past : Optional [Union [Cache , Tuple [torch .Tensor , torch .Tensor ]]] = None ,
177181 head_mask : Optional [torch .Tensor ] = None ,
@@ -195,6 +199,7 @@ def forward(
195199 attention_mask = attention_mask ,
196200 position_ids = position_ids ,
197201 past_key_value = past_key_value ,
202+ comp_ctx_lengths = comp_ctx_lengths ,
198203 batch_index = batch_index ,
199204 alibi = alibi ,
200205 head_mask = head_mask ,
@@ -245,6 +250,7 @@ def forward(
245250 position_ids : Optional [torch .LongTensor ] = None ,
246251 batch_index : Optional [torch .LongTensor ] = None ,
247252 past_key_values : Optional [Union [Cache , Tuple [Tuple [torch .Tensor , torch .Tensor ], ...]]] = None ,
253+ comp_ctx_lengths : Optional [torch .LongTensor ] = None ,
248254 head_mask : Optional [torch .LongTensor ] = None ,
249255 inputs_embeds : Optional [torch .LongTensor ] = None ,
250256 use_cache : Optional [bool ] = None ,
@@ -307,6 +313,7 @@ def forward(
307313 attention_mask = causal_mask ,
308314 position_ids = position_ids ,
309315 past_key_value = past_key_values ,
316+ comp_ctx_lengths = comp_ctx_lengths ,
310317 batch_index = batch_index ,
311318 head_mask = head_mask [i ],
312319 use_cache = use_cache ,
@@ -352,6 +359,7 @@ def forward(
352359 position_ids : Optional [torch .LongTensor ] = None ,
353360 batch_index : Optional [torch .LongTensor ] = None ,
354361 past_key_values : Optional [Tuple [Tuple [torch .Tensor , torch .Tensor ], ...]] = None ,
362+ comp_ctx_lengths : Optional [torch .LongTensor ] = None ,
355363 head_mask : Optional [torch .Tensor ] = None ,
356364 inputs_embeds : Optional [torch .Tensor ] = None ,
357365 use_cache : Optional [bool ] = None ,
@@ -368,6 +376,7 @@ def forward(
368376 attention_mask = attention_mask ,
369377 position_ids = position_ids ,
370378 past_key_values = past_key_values ,
379+ comp_ctx_lengths = comp_ctx_lengths ,
371380 batch_index = batch_index ,
372381 head_mask = head_mask ,
373382 inputs_embeds = inputs_embeds ,
0 commit comments