From b3252654f9650ab8b7260ec209329fdc42b61021 Mon Sep 17 00:00:00 2001 From: Clement25 <935974082@qq.com> Date: Sat, 9 Aug 2025 09:41:47 +0000 Subject: [PATCH] add snapKV for sdpa implementation with flash_attn availability checkign --- snapkv/monkeypatch/llama_hijack_4_37.py | 83 +++++++++++++++++++++++++ snapkv/monkeypatch/monkeypatch.py | 8 ++- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/snapkv/monkeypatch/llama_hijack_4_37.py b/snapkv/monkeypatch/llama_hijack_4_37.py index 586a21d..56e226f 100644 --- a/snapkv/monkeypatch/llama_hijack_4_37.py +++ b/snapkv/monkeypatch/llama_hijack_4_37.py @@ -135,6 +135,89 @@ def llama_flash_attn2_forward( return attn_output, attn_weights, past_key_value +def llama_sdpa_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # [SnapKV] register kv_cluster + init_snapkv(self) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len + if self.kv_seq_len != 0: + kv_seq_len += self.kv_seq_len + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [SnapKV] move to ahead + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Update SG + if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster + self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len + key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) + past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) + else: + self.kv_seq_len += q_len + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + def prepare_inputs_for_generation_llama( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): diff --git a/snapkv/monkeypatch/monkeypatch.py b/snapkv/monkeypatch/monkeypatch.py index 0dfcfe4..426fed2 100644 --- a/snapkv/monkeypatch/monkeypatch.py +++ b/snapkv/monkeypatch/monkeypatch.py @@ -1,9 +1,11 @@ from importlib.metadata import version +from transformers.utils import is_flash_attn_2_available import warnings import transformers from snapkv.monkeypatch.llama_hijack_4_37 import llama_flash_attn2_forward as llama_flash_attn2_forward_4_37, prepare_inputs_for_generation_llama as prepare_inputs_for_generation_llama_4_37 from snapkv.monkeypatch.mistral_hijack_4_37 import mistral_flash_attn2_forward as mistral_flash_attn2_forward_4_37, prepare_inputs_for_generation_mistral as prepare_inputs_for_generation_mistral_4_37 from snapkv.monkeypatch.mixtral_hijack_4_37 import mixtral_flash_attn2_forward as mixtral_flash_attn2_forward_4_37, prepare_inputs_for_generation_mixtral as prepare_inputs_for_generation_mixtral_4_37 +from snapkv.monkeypatch.llama_hijack_4_37 import llama_sdpa_attn_forward as llama_sdpa_attn_forward_4_37 def check_version(): try: @@ -23,7 +25,11 @@ def replace_llama(): if warning_flag: warnings.warn(f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.") transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = prepare_inputs_for_generation_llama_4_37 - transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_4_37 + + if is_flash_attn_2_available(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = llama_flash_attn2_forward_4_37 + else: # pytorch implementation + transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward = llama_sdpa_attn_forward_4_37 def replace_mistral(): transformers_version = check_version()