Skip to content

Commit cded66e

Browse files
committed
Address comments
1 parent 6ca7dd0 commit cded66e

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

optimum/exporters/executorch/whisper_attention.py renamed to optimum/executorch/attentions/whisper_attention.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Export friendly cross attention implementation for Whisper. Adopted
16-
# from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L241
16+
# from https://github.com/huggingface/transformers/blob/454c0a7ccf33f7fc13e3e2eb9b188a5c09ab708b/src/transformers/models/whisper/modeling_whisper.py#L241
1717
# Rewritten to replace if branches with torch.cond. Note that unlike
1818
# the original WhisperAttention, this implementation only works for
1919
# cross attention (where `key_value_states` is not None).
@@ -22,7 +22,7 @@
2222

2323
import torch
2424
from torch import Tensor, nn
25-
from transformers.cache_utils import Cache, EncoderDecoderCache
25+
from transformers.cache_utils import EncoderDecoderCache
2626
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
2727
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
2828
from transformers.models.whisper.configuration_whisper import WhisperConfig
@@ -81,7 +81,7 @@ def forward(
8181
self,
8282
hidden_states: torch.Tensor,
8383
key_value_states: torch.Tensor,
84-
past_key_values: Optional[Cache] = None,
84+
past_key_values: EncoderDecoderCache,
8585
attention_mask: Optional[torch.Tensor] = None,
8686
output_attentions: bool = False,
8787
cache_position: Optional[torch.Tensor] = None,
@@ -90,7 +90,10 @@ def forward(
9090
**kwargs: Unpack[FlashAttentionKwargs],
9191
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
9292
"""Input shape: Batch x Time x Channel"""
93-
93+
torch._assert(
94+
isinstance(past_key_values, EncoderDecoderCache),
95+
f"past_key_values must be an EncoderDecoderCache, got {type(past_key_values)}",
96+
)
9497
# determine input shapes
9598
bsz, tgt_len = hidden_states.shape[:-1]
9699
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
@@ -131,30 +134,22 @@ def recompute_kv(
131134
v = torch.ops.executorch.update_cross_attn_cache(value_states, cached_values)
132135
return k, v
133136

134-
if past_key_values is not None and self.layer_idx is not None:
135-
# Grab cached tensors (these are Tensors, so they are OK for export)
136-
cached_keys = past_key_values.layers[self.layer_idx].keys
137-
cached_values = past_key_values.layers[self.layer_idx].values
138-
139-
# Tensor predicate: True if any element is non-zero
140-
# Result is a 0-dim bool tensor suitable for torch.cond
141-
cache_is_initialized = (cached_keys != 0).any()
142-
143-
# Use torch.cond to select branch in a traceable way.
144-
# All operands must be (nested) tensors or simple Python values.
145-
key_states, value_states = torch.cond(
146-
cache_is_initialized,
147-
use_cached_kv,
148-
recompute_kv,
149-
operands=(cached_keys, cached_values, key_value_states),
150-
)
151-
152-
else:
153-
# No cache available: always compute fresh K/V
154-
key_states = self.k_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
155-
value_states = self.v_proj(key_value_states).view(bsz, -1, self.num_heads, self.head_dim)
156-
key_states = key_states.transpose(1, 2).contiguous()
157-
value_states = value_states.transpose(1, 2).contiguous()
137+
# Grab cached tensors (these are Tensors, so they are OK for export)
138+
cached_keys = past_key_values.layers[self.layer_idx].keys
139+
cached_values = past_key_values.layers[self.layer_idx].values
140+
141+
# Tensor predicate: True if any element is non-zero
142+
# Result is a 0-dim bool tensor suitable for torch.cond
143+
cache_is_initialized = (cached_keys != 0).any()
144+
145+
# Use torch.cond to select branch in a traceable way.
146+
# All operands must be (nested) tensors or simple Python values.
147+
key_states, value_states = torch.cond(
148+
cache_is_initialized,
149+
use_cached_kv,
150+
recompute_kv,
151+
operands=(cached_keys, cached_values, key_value_states),
152+
)
158153

159154
attention_interface: Callable = eager_attention_forward
160155
if self.config._attn_implementation != "eager":

optimum/exporters/executorch/integrations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from transformers.modeling_utils import AttentionInterface
3737

3838
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
39+
from optimum.executorch.attentions.whisper_attention import WhisperCrossAttention
3940

4041
from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods
41-
from .whisper_attention import WhisperCrossAttention
4242

4343

4444
class VisionExportableModule(torch.nn.Module):
@@ -723,7 +723,7 @@ def __init__(self, model, max_static_cache_length, batch_size):
723723
f"cross_attention_value_cache_{i}", self.cross_attention_cache.layers[i].values, persistent=False
724724
)
725725

726-
# Massage decoder to use cross attention.
726+
# Use custom cross attention for Whisper.
727727
if isinstance(model, WhisperForConditionalGeneration):
728728
for layer in self.decoder.layers:
729729
cross_attn = WhisperCrossAttention(
@@ -734,6 +734,7 @@ def __init__(self, model, max_static_cache_length, batch_size):
734734
layer_idx=layer.encoder_attn.layer_idx,
735735
config=layer.encoder_attn.config,
736736
).to(dtype=model.dtype, device=model.device)
737+
cross_attn.q_proj = layer.encoder_attn.q_proj
737738
cross_attn.k_proj = layer.encoder_attn.k_proj
738739
cross_attn.v_proj = layer.encoder_attn.v_proj
739740
cross_attn.out_proj = layer.encoder_attn.out_proj

0 commit comments

Comments
 (0)