diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 030dd7a56..c76a5b473 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -26,7 +26,6 @@ Qwen2_5_VLTextModel, Qwen2_5_VLVisionAttention, apply_rotary_pos_emb_vision, - repeat_kv, rotate_half, ) @@ -360,6 +359,44 @@ def forward(self, x, seq_len=None): ) +def repeat_kv_text( + hidden_states: torch.Tensor, n_rep: int = 2, num_key_value_heads=16, num_attention_heads=40, orig_kv_heads=8 +) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + rows_to_fill = num_attention_heads % (num_key_value_heads * n_rep) # -> 8 + if rows_to_fill == 0: + return hidden_states + if rows_to_fill != 0: + old_repeats = num_key_value_heads // orig_kv_heads # -> 2 + required_repeats = rows_to_fill // orig_kv_heads # -> 1 + remaining_expansion_data = hidden_states[ + :, [i for i in range(0, num_key_value_heads, old_repeats)], :, : + ] # 1, 8 + remaining_expansion_data = torch.repeat_interleave( + remaining_expansion_data, repeats=required_repeats, dim=1 + ) # 1x8 + chunk_fill_size = n_rep * old_repeats # -> 4 + + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + if rows_to_fill != 0: + tensors_to_cat = [] + for k in range(orig_kv_heads): + tensors_to_cat.extend( + [ + hidden_states[:, k * chunk_fill_size : (k + 1) * chunk_fill_size, :, :], + remaining_expansion_data[:, k * required_repeats : (k + 1) * required_repeats, :, :], + ] + ) + hidden_states = torch.cat(tensors_to_cat, dim=1) + return hidden_states + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -368,8 +405,22 @@ def eager_attention_forward( attention_mask: Optional[torch.Tensor], **kwargs, ): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) + # key_states = repeat_kv(key, module.num_key_value_groups) + # value_states = repeat_kv(value, module.num_key_value_groups) + key_states = repeat_kv_text( + key, + module.num_key_value_groups, + num_key_value_heads=module.num_key_value_heads, + num_attention_heads=module.num_heads, + orig_kv_heads=module.orig_kv_heads, + ) + value_states = repeat_kv_text( + value, + module.num_key_value_groups, + num_key_value_heads=module.num_key_value_heads, + num_attention_heads=module.num_heads, + orig_kv_heads=module.orig_kv_heads, + ) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) if attention_mask is not None: diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index e2e78105a..ef767590d 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -9,9 +9,9 @@ from typing import Optional import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForImageTextToText, AutoTokenizer -from QEfficient import QEFFAutoModelForCausalLM, export +from QEfficient import QEFFAutoModelForImageTextToText from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ @@ -99,7 +99,7 @@ def replicate_kv_heads( """ # Load the model and tokenizer - model_base_name = model_name.split("/")[-1] + # model_base_name = model_name.split("/")[-1] # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU. replace_transformers_quantizers() # Prepare kwargs for model loading @@ -107,9 +107,9 @@ def replicate_kv_heads( if num_hidden_layers: model_kwargs["num_hidden_layers"] = num_hidden_layers - + # breakpoint() pretrained_model_name_or_path = login_and_download_hf_lm(model_name) - model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_kwargs) + model = AutoModelForImageTextToText.from_pretrained(pretrained_model_name_or_path, **model_kwargs) # Undo the effect of replace_transformers_quantizers undo_transformers_quantizers() @@ -137,36 +137,43 @@ def replicate_kv_heads( hidden_size = model.config.hidden_size # Update the model's attention layers with new key-value heads - for block in model.model.layers: + for block in model.model.language_model.layers: attn = block.self_attn + setattr(attn, "orig_kv_heads", orig_kv_heads) attn.num_key_value_heads = new_kv_heads attn.num_key_value_groups = num_attention_heads // new_kv_heads duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) - # Generate modified outputs and tokens - with torch.inference_mode(): - _ = model(**inputs) # Modified output - mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) + ## This won't work as attention_heads isn't divisible by num_kv_heads for repeat > 1, so we skip this inference run. + # # Generate modified outputs and tokens + # with torch.inference_mode(): + # _ = model(**inputs) # Modified output + # mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) - # Print the original and modified token outputs + # # Print the original and modified token outputs print("Original:", tokenizer.batch_decode(orig_tokens)) - print("Modified:", tokenizer.batch_decode(mod_tokens)) + # print("Modified:", tokenizer.batch_decode(mod_tokens)) - if not torch.all(orig_tokens == mod_tokens): - raise RuntimeError( - "Something went wrong while duplicating KV heads weights, output token don't match after modification" - ) + # if not torch.all(orig_tokens == mod_tokens): + # raise RuntimeError( + # "Something went wrong while duplicating KV heads weights, output token don't match after modification" + # ) # Export the modified model - q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False)) - export( - model_name, - q_model, - tokenizer=tokenizer, - onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", - full_batch_size=(full_batch_size if full_batch_size else None), + q_model = QEFFAutoModelForImageTextToText( + model, + continuous_batching=(True if full_batch_size else False), ) + # export( + # model_name, + # q_model, + # tokenizer=tokenizer, + # onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", + # full_batch_size=(full_batch_size if full_batch_size else None), + # ) + # Exported in cache by default. + q_model.export() if __name__ == "__main__":