Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Qwen2_5_VLTextModel,
Qwen2_5_VLVisionAttention,
apply_rotary_pos_emb_vision,
repeat_kv,
rotate_half,
)

Expand Down Expand Up @@ -360,6 +359,44 @@ def forward(self, x, seq_len=None):
)


def repeat_kv_text(
hidden_states: torch.Tensor, n_rep: int = 2, num_key_value_heads=16, num_attention_heads=40, orig_kv_heads=8
) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
rows_to_fill = num_attention_heads % (num_key_value_heads * n_rep) # -> 8
if rows_to_fill == 0:
return hidden_states
if rows_to_fill != 0:
old_repeats = num_key_value_heads // orig_kv_heads # -> 2
required_repeats = rows_to_fill // orig_kv_heads # -> 1
remaining_expansion_data = hidden_states[
:, [i for i in range(0, num_key_value_heads, old_repeats)], :, :
] # 1, 8
remaining_expansion_data = torch.repeat_interleave(
remaining_expansion_data, repeats=required_repeats, dim=1
) # 1x8
chunk_fill_size = n_rep * old_repeats # -> 4

hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
hidden_states = hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

if rows_to_fill != 0:
tensors_to_cat = []
for k in range(orig_kv_heads):
tensors_to_cat.extend(
[
hidden_states[:, k * chunk_fill_size : (k + 1) * chunk_fill_size, :, :],
remaining_expansion_data[:, k * required_repeats : (k + 1) * required_repeats, :, :],
]
)
hidden_states = torch.cat(tensors_to_cat, dim=1)
return hidden_states


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand All @@ -368,8 +405,22 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
# key_states = repeat_kv(key, module.num_key_value_groups)
# value_states = repeat_kv(value, module.num_key_value_groups)
key_states = repeat_kv_text(
key,
module.num_key_value_groups,
num_key_value_heads=module.num_key_value_heads,
num_attention_heads=module.num_heads,
orig_kv_heads=module.orig_kv_heads,
)
value_states = repeat_kv_text(
value,
module.num_key_value_groups,
num_key_value_heads=module.num_key_value_heads,
num_attention_heads=module.num_heads,
orig_kv_heads=module.orig_kv_heads,
)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
if attention_mask is not None:
Expand Down
53 changes: 30 additions & 23 deletions scripts/replicate_kv_head/replicate_kv_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForImageTextToText, AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM, export
from QEfficient import QEFFAutoModelForImageTextToText
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
Expand Down Expand Up @@ -99,17 +99,17 @@ def replicate_kv_heads(

"""
# Load the model and tokenizer
model_base_name = model_name.split("/")[-1]
# model_base_name = model_name.split("/")[-1]
# Replace quantizers for loading Quantized AWQ/GPTQ models on CPU.
replace_transformers_quantizers()
# Prepare kwargs for model loading
model_kwargs = {"attn_implementation": "eager"}

if num_hidden_layers:
model_kwargs["num_hidden_layers"] = num_hidden_layers

# breakpoint()
pretrained_model_name_or_path = login_and_download_hf_lm(model_name)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_kwargs)
model = AutoModelForImageTextToText.from_pretrained(pretrained_model_name_or_path, **model_kwargs)

# Undo the effect of replace_transformers_quantizers
undo_transformers_quantizers()
Expand Down Expand Up @@ -137,36 +137,43 @@ def replicate_kv_heads(
hidden_size = model.config.hidden_size

# Update the model's attention layers with new key-value heads
for block in model.model.layers:
for block in model.model.language_model.layers:
attn = block.self_attn
setattr(attn, "orig_kv_heads", orig_kv_heads)
attn.num_key_value_heads = new_kv_heads
attn.num_key_value_groups = num_attention_heads // new_kv_heads
duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size)
duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size)

# Generate modified outputs and tokens
with torch.inference_mode():
_ = model(**inputs) # Modified output
mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False)
## This won't work as attention_heads isn't divisible by num_kv_heads for repeat > 1, so we skip this inference run.
# # Generate modified outputs and tokens
# with torch.inference_mode():
# _ = model(**inputs) # Modified output
# mod_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False)

# Print the original and modified token outputs
# # Print the original and modified token outputs
print("Original:", tokenizer.batch_decode(orig_tokens))
print("Modified:", tokenizer.batch_decode(mod_tokens))
# print("Modified:", tokenizer.batch_decode(mod_tokens))

if not torch.all(orig_tokens == mod_tokens):
raise RuntimeError(
"Something went wrong while duplicating KV heads weights, output token don't match after modification"
)
# if not torch.all(orig_tokens == mod_tokens):
# raise RuntimeError(
# "Something went wrong while duplicating KV heads weights, output token don't match after modification"
# )

# Export the modified model
q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False))
export(
model_name,
q_model,
tokenizer=tokenizer,
onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
full_batch_size=(full_batch_size if full_batch_size else None),
q_model = QEFFAutoModelForImageTextToText(
model,
continuous_batching=(True if full_batch_size else False),
)
# export(
# model_name,
# q_model,
# tokenizer=tokenizer,
# onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
# full_batch_size=(full_batch_size if full_batch_size else None),
# )
# Exported in cache by default.
q_model.export()


if __name__ == "__main__":
Expand Down
Loading