@@ -77,6 +77,9 @@ def __init__(
7777 self .q_proj = nn .Linear (embed_dim , embed_dim , bias = bias )
7878 self .out_proj = nn .Linear (embed_dim , embed_dim , bias = bias )
7979
80+ # Force this boolean to be on CPU
81+ self .is_cache_initialized = torch .tensor (False , device = "cpu" )
82+
8083 def forward (
8184 self ,
8285 hidden_states : torch .Tensor ,
@@ -138,19 +141,17 @@ def recompute_kv(
138141 cached_keys = past_key_values .layers [self .layer_idx ].keys
139142 cached_values = past_key_values .layers [self .layer_idx ].values
140143
141- # Tensor predicate: True if any element is non-zero
142- # Result is a 0-dim bool tensor suitable for torch.cond
143- cache_is_initialized = (cached_keys != 0 ).any ()
144-
145144 # Use torch.cond to select branch in a traceable way.
146145 # All operands must be (nested) tensors or simple Python values.
147146 key_states , value_states = torch .cond (
148- cache_is_initialized ,
147+ self . is_cache_initialized ,
149148 use_cached_kv ,
150149 recompute_kv ,
151150 operands = (cached_keys , cached_values , key_value_states ),
152151 )
153152
153+ self .is_cache_initialized = torch .tensor (True , device = "cpu" )
154+
154155 attention_interface : Callable = eager_attention_forward
155156 if self .config ._attn_implementation != "eager" :
156157 attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
0 commit comments