From cd3b13ab657a9ae272ef7b5acf192ed9438a4221 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Sat, 18 Oct 2025 06:42:07 +0000 Subject: [PATCH 1/2] [WIP]: Add early support for KV replication in VLMs Signed-off-by: vbaddi --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 +- scripts/replicate_kv_head/README.md | 5 + .../replicate_kv_heads_vlm.py | 197 ++++++++++++++++++ 3 files changed, 205 insertions(+), 3 deletions(-) create mode 100644 scripts/replicate_kv_head/replicate_kv_heads_vlm.py diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 030dd7a56..4013d6d0c 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -35,7 +35,7 @@ # from transformers import Qw from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants -from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils._utils import IOInfo, get_padding_shape_vlm from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE from QEfficient.utils.logging_utils import logger @@ -746,10 +746,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) # Add data for KV - kv_cache_shape = get_padding_shape_from_config( + kv_cache_shape = get_padding_shape_vlm( config=self.model.config, + ctx_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.num_hidden_layers)] diff --git a/scripts/replicate_kv_head/README.md b/scripts/replicate_kv_head/README.md index 9a1ac9c1e..c9ead989d 100644 --- a/scripts/replicate_kv_head/README.md +++ b/scripts/replicate_kv_head/README.md @@ -25,6 +25,11 @@ You can run the script with different parameters using the command line. Below i python script.py --model_name "meta-llama/Meta-Llama-3-8B-Instruct" --prompt "Hello, world!" --repeat 3 ``` +3. **Run the script** for Vision Language Models (VLM) (Still WIP): + ```sh + python -W ignore replicate_kv_heads_vlm.py --model_name "Qwen/Qwen2.5-VL-32B-Instruct" --prompt "Hello, world" --repeat 5 + ``` + Replace `` with your actual token. ### Arguments diff --git a/scripts/replicate_kv_head/replicate_kv_heads_vlm.py b/scripts/replicate_kv_head/replicate_kv_heads_vlm.py new file mode 100644 index 000000000..27d92f680 --- /dev/null +++ b/scripts/replicate_kv_head/replicate_kv_heads_vlm.py @@ -0,0 +1,197 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import argparse +from typing import Optional + +import torch +from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration + +from QEfficient.transformers.models.modeling_auto import ( + _QEffAutoModelForImageTextToTextDualQPC, +) +from QEfficient.utils._utils import login_and_download_hf_lm + + +def duplicate_weights_for_linear_layer( + layer: torch.nn.Module, orig_kv_heads: int, repeat: int, head_dim: int, hidden_size: int +): + new_kv_heads = repeat * orig_kv_heads + layer.weight.data = torch.repeat_interleave( + layer.weight.data.view(orig_kv_heads, head_dim, hidden_size), repeat, 0 + ).view(new_kv_heads * head_dim, hidden_size) + if layer.bias is not None: + layer.bias.data = torch.repeat_interleave(layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0).view( + new_kv_heads * head_dim + ) + + +def replicate_kv_heads_vlm( + model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct", + prompt: str = "Describe this image.", + repeat: int = 2, + num_hidden_layers: Optional[int] = None, + num_attention_heads: Optional[int] = None, + hidden_size: Optional[int] = None, +): + """ + Replicate the KV heads for Vision Language Models (language component only). + + The script performs the following steps: + 1. Loads the VLM model using both transformers and QEFFAutoModelForImageTextToText. + 2. Extracts the language model component from both. + 3. Replicates the KV heads in the QEfficient language model. + 4. Validates the changes by comparing outputs with the original transformers model. + 5. Exports the modified model to ONNX format. + + ``Mandatory`` Args: + :model_name (str): Model card name to use (e.g., "Qwen/Qwen2.5-VL-7B-Instruct"). + :prompt (str): Prompt to use for validation. + :repeat (int): Factor to repeat key-value heads. + + ``Optional`` Args: + :num_hidden_layers (int): Number of hidden layers to use, default is None. + :num_attention_heads (int): Number of attention heads, if not passed explicitly then will be picked from config. + :hidden_size (int): Hidden size to use, if not passed explicitly then will be picked from config. + """ + # Load the model configuration + model_base_name = model_name.split("/")[-1] + + # Prepare kwargs for model loading + model_kwargs = {"attn_implementation": "eager"} + + # Load config + config = AutoConfig.from_pretrained(model_name) + if num_hidden_layers: + config.text_config.num_hidden_layers = num_hidden_layers + model_kwargs["config"] = config + + pretrained_model_name_or_path = login_and_download_hf_lm(model_name) + + # Load the original transformers model for validation + print("Loading original transformers model...") + orig_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, **model_kwargs) + + # Load processor for tokenization + processor = AutoProcessor.from_pretrained(model_name) + orig_lang_model = orig_model.language_model + lang_config = config.text_config + inputs = processor(text=prompt, return_tensors="pt", add_special_tokens=True) + + # Generate original outputs and tokens from transformers model + print("\nGenerating with original transformers model...") + with torch.inference_mode(): + orig_tokens = orig_lang_model(**inputs).last_hidden_state + + # Modify the number of key-value heads in QEfficient model + orig_kv_heads = lang_config.num_key_value_heads + new_kv_heads = repeat * orig_kv_heads + orig_lang_model.config.num_key_value_heads = new_kv_heads + + print(f"\nOriginal KV heads: {orig_kv_heads}") + print(f"Modified KV heads: {new_kv_heads}") + + # Check if hidden size and number of attention heads are explicitly passed + if num_attention_heads is None: + num_attention_heads = lang_config.num_attention_heads + if hidden_size is None: + hidden_size = lang_config.hidden_size + + # Update the model's attention layers with new key-value heads + print(f"\nReplicating KV heads in {len(orig_lang_model.layers)} layers...") + for block in orig_lang_model.layers: + attn = block.self_attn + attn.num_key_value_heads = new_kv_heads + attn.num_key_value_groups = num_attention_heads // new_kv_heads + + duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) + duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) + + # Generate modified outputs and tokens from QEfficient model + print("\nGenerating with modified QEfficient model...") + with torch.inference_mode(): + mod_tokens = orig_lang_model(**inputs).last_hidden_state + + # Print the original and modified token outputs + print("\n" + "=" * 80) + print("VALIDATION RESULTS:") + print("=" * 80) + # print(f"Original (transformers): {processor.batch_decode(orig_tokens, skip_special_tokens=True)}") + # print(f"Modified (QEfficient): {processor.batch_decode(mod_tokens, skip_special_tokens=True)}") + print("=" * 80) + + if not torch.all(orig_tokens == mod_tokens): + raise RuntimeError( + "Something went wrong while duplicating KV heads weights, output tokens don't match after modification" + ) + + print("\n✓ Validation successful! Output tokens match.") + + # Export the modified model + export_dir = f"{model_base_name}-{new_kv_heads}kvheads" + print(f"\nExporting modified model to {export_dir}...") + + # Export using the qeff_model's export method + # qeff_model.export(export_dir=export_dir) + # qeff_model = QEffCausalLMForTextImageToTextModel(orig_model) + qeff_model = _QEffAutoModelForImageTextToTextDualQPC(orig_model) + inputs = qeff_model.model.get_dummy_inputs(kv_offload=True) + dynamic_axes = qeff_model.model.get_onnx_dynamic_axes(kv_offload=True) + output_names = qeff_model.model.get_output_names(kv_offload=True) + qeff_model.lang_model.export( + inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True + ) + + # print(f"\n✓ Export completed successfully!") + # print(f"ONNX paths: {qeff_model.onnx_path}") + + +if __name__ == "__main__": + # Set up argument parser + parser = argparse.ArgumentParser( + description="Modify and export KV heads for Vision Language Models (language component)." + ) + parser.add_argument( + "--model_name", + "--model-name", + type=str, + default="Qwen/Qwen2.5-VL-3B-Instruct", + help="Name of the VLM model to use.", + ) + parser.add_argument("--prompt", type=str, default="Describe this image.", help="Prompt to use for validation.") + parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.") + parser.add_argument( + "--num_hidden_layers", + "--num-hidden-layers", + type=int, + default=None, + help="Number of hidden layers to use, default is None", + ) + parser.add_argument( + "--num_attention_heads", + "--num-attention-heads", + type=int, + default=None, + help="Number of attention heads, if not passed explicitly then will be picked from config", + ) + parser.add_argument( + "--hidden_size", + "--hidden-size", + type=int, + default=None, + help="Hidden size to use, if not passed explicitly then will be picked from config", + ) + + args = parser.parse_args() + + replicate_kv_heads_vlm( + model_name=args.model_name, + prompt=args.prompt, + repeat=args.repeat, + num_hidden_layers=args.num_hidden_layers, + num_attention_heads=args.num_attention_heads, + hidden_size=args.hidden_size, + ) From de72e1777f3e11ddb1a8d0803dc9552bdb58cf94 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Sat, 18 Oct 2025 06:44:42 +0000 Subject: [PATCH 2/2] nit: update replicate kv vlm Signed-off-by: vbaddi --- .../replicate_kv_heads_vlm.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/scripts/replicate_kv_head/replicate_kv_heads_vlm.py b/scripts/replicate_kv_head/replicate_kv_heads_vlm.py index 27d92f680..bd14d96da 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads_vlm.py +++ b/scripts/replicate_kv_head/replicate_kv_heads_vlm.py @@ -40,13 +40,6 @@ def replicate_kv_heads_vlm( """ Replicate the KV heads for Vision Language Models (language component only). - The script performs the following steps: - 1. Loads the VLM model using both transformers and QEFFAutoModelForImageTextToText. - 2. Extracts the language model component from both. - 3. Replicates the KV heads in the QEfficient language model. - 4. Validates the changes by comparing outputs with the original transformers model. - 5. Exports the modified model to ONNX format. - ``Mandatory`` Args: :model_name (str): Model card name to use (e.g., "Qwen/Qwen2.5-VL-7B-Instruct"). :prompt (str): Prompt to use for validation. @@ -115,28 +108,18 @@ def replicate_kv_heads_vlm( with torch.inference_mode(): mod_tokens = orig_lang_model(**inputs).last_hidden_state - # Print the original and modified token outputs - print("\n" + "=" * 80) - print("VALIDATION RESULTS:") - print("=" * 80) - # print(f"Original (transformers): {processor.batch_decode(orig_tokens, skip_special_tokens=True)}") - # print(f"Modified (QEfficient): {processor.batch_decode(mod_tokens, skip_special_tokens=True)}") - print("=" * 80) - if not torch.all(orig_tokens == mod_tokens): raise RuntimeError( "Something went wrong while duplicating KV heads weights, output tokens don't match after modification" ) - print("\n✓ Validation successful! Output tokens match.") + print("\n Validation successful! Output tokens match.") # Export the modified model export_dir = f"{model_base_name}-{new_kv_heads}kvheads" print(f"\nExporting modified model to {export_dir}...") # Export using the qeff_model's export method - # qeff_model.export(export_dir=export_dir) - # qeff_model = QEffCausalLMForTextImageToTextModel(orig_model) qeff_model = _QEffAutoModelForImageTextToTextDualQPC(orig_model) inputs = qeff_model.model.get_dummy_inputs(kv_offload=True) dynamic_axes = qeff_model.model.get_onnx_dynamic_axes(kv_offload=True) @@ -145,9 +128,6 @@ def replicate_kv_heads_vlm( inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True ) - # print(f"\n✓ Export completed successfully!") - # print(f"ONNX paths: {qeff_model.onnx_path}") - if __name__ == "__main__": # Set up argument parser