Skip to content

Commit 14a9a0e

Browse files
committed
Try adding a flag on CPU
1 parent cded66e commit 14a9a0e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

optimum/executorch/attentions/whisper_attention.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(
7777
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
7878
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
7979

80+
# Force this boolean to be on CPU
81+
self.is_cache_initialized = torch.tensor(False, device="cpu")
82+
8083
def forward(
8184
self,
8285
hidden_states: torch.Tensor,
@@ -138,19 +141,17 @@ def recompute_kv(
138141
cached_keys = past_key_values.layers[self.layer_idx].keys
139142
cached_values = past_key_values.layers[self.layer_idx].values
140143

141-
# Tensor predicate: True if any element is non-zero
142-
# Result is a 0-dim bool tensor suitable for torch.cond
143-
cache_is_initialized = (cached_keys != 0).any()
144-
145144
# Use torch.cond to select branch in a traceable way.
146145
# All operands must be (nested) tensors or simple Python values.
147146
key_states, value_states = torch.cond(
148-
cache_is_initialized,
147+
self.is_cache_initialized,
149148
use_cached_kv,
150149
recompute_kv,
151150
operands=(cached_keys, cached_values, key_value_states),
152151
)
153152

153+
self.is_cache_initialized = torch.tensor(True, device="cpu")
154+
154155
attention_interface: Callable = eager_attention_forward
155156
if self.config._attn_implementation != "eager":
156157
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

0 commit comments

Comments
 (0)