From 1716efb39866e9680d80184ca8d78d1e089dd3eb Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Mon, 3 Mar 2025 22:51:18 +0000 Subject: [PATCH 01/10] Zamba2 initial commit Signed-off-by: Yury Tokpanov --- vllm/config.py | 14 + vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/zamba2.py | 740 +++++++++++++++++++++++++ 3 files changed, 755 insertions(+) create mode 100644 vllm/model_executor/models/zamba2.py diff --git a/vllm/config.py b/vllm/config.py index 70cc0affe998..2d64ec4b9108 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -821,6 +821,11 @@ def get_head_size(self) -> int: if qk_rope_head_dim and qk_nope_head_dim: return qk_rope_head_dim + qk_nope_head_dim + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + return self.hf_text_config.attention_head_dim + if self.is_attention_free: return 0 @@ -942,6 +947,15 @@ def get_num_layers_by_block_type( "cannot determine the num of " f"{block_type.value} layers") + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + if attn_block_type: + return sum(t == "hybrid" + for t in layers_block_type_value[start:end]) + else: + return self.get_num_layers(parallel_config) + return sum(t == block_type.value for t in layers_block_type_value[start:end]) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dd3aa2973cd..8b469132da6d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -105,6 +105,7 @@ "SolarForCausalLM": ("solar", "SolarForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"), + "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py new file mode 100644 index 000000000000..3926187ebbb7 --- /dev/null +++ b/vllm/model_executor/models/zamba2.py @@ -0,0 +1,740 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only Zamba2 model.""" +# Added by the Zyphra Technologies, 2025 +from itertools import cycle +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Zamba2Config + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasInnerState, IsHybrid +from .utils import make_empty_intermediate_tensors_factory, maybe_prefix + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Zamba2Attention(nn.Module): + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + layer2block_map: Dict[int, int], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.layer2block_map = layer2block_map + self.num_fwd_mem_blocks = len(layer2block_map) + self.rope_theta = config.rope_theta + + self.attention_hidden_size = config.attention_hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.scale = (self.attention_head_dim / 2)**-0.5 + + if (self.attention_head_dim * + self.num_attention_heads) != self.attention_hidden_size: + raise ValueError( + f"attention_hidden_size must be divisible by" + f" num_attention_heads" + f" (got `attention_hidden_size`: {self.attention_hidden_size}" + f" and `num_heads`: {self.num_attention_heads}).") + + self.qkv_proj = QKVParallelLinear( + self.attention_hidden_size, + self.attention_head_dim, + self.num_attention_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.attention_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + + # Need to define separate Attention objects, because in recent vLLM + # KV cache tensors are tied to specific Attention objects. + self.dpa_list = nn.ModuleList([]) + j = bare_block_idx * (self.num_fwd_mem_blocks + config.num_mem_blocks - + 1) // config.num_mem_blocks + for block_idx in range(self.num_fwd_mem_blocks): + if block_idx % config.num_mem_blocks == bare_block_idx: + dpa = Attention( + self.num_attention_heads, + self.attention_head_dim, + self.scale, + cache_config=cache_config, + prefix=f"{prefix}.attn.{j}", + ) + j += 1 + else: + dpa = nn.Identity() + self.dpa_list.append(dpa) + + if config.use_shared_attention_adapter: + self.linear_q_adapter_list = nn.ModuleList([]) + self.linear_k_adapter_list = nn.ModuleList([]) + self.linear_v_adapter_list = nn.ModuleList([]) + + for block_idx in range(self.num_fwd_mem_blocks): + if block_idx % config.num_mem_blocks == bare_block_idx: + linear_q_adapter = nn.ModuleList([ + ColumnParallelLinear(self.attention_hidden_size, + config.adapter_rank, + bias=False, + quant_config=quant_config), + ColumnParallelLinear(config.adapter_rank, + self.attention_hidden_size, + bias=False, + quant_config=quant_config), + ]) + linear_k_adapter = nn.ModuleList([ + ColumnParallelLinear(self.attention_hidden_size, + config.adapter_rank, + bias=False, + quant_config=quant_config), + ColumnParallelLinear(config.adapter_rank, + self.attention_hidden_size, + bias=False, + quant_config=quant_config), + ]) + linear_v_adapter = nn.ModuleList([ + ColumnParallelLinear(self.attention_hidden_size, + config.adapter_rank, + bias=False, + quant_config=quant_config), + ColumnParallelLinear(config.adapter_rank, + self.attention_hidden_size, + bias=False, + quant_config=quant_config), + ]) + else: + linear_q_adapter = nn.Identity() + linear_k_adapter = nn.Identity() + linear_v_adapter = nn.Identity() + self.linear_q_adapter_list.append(linear_q_adapter) + self.linear_k_adapter_list.append(linear_k_adapter) + self.linear_v_adapter_list.append(linear_v_adapter) + + if config.use_mem_rope: + self.rotary_emb = get_rope( + head_size=self.attention_head_dim, + rotary_dim=self.attention_head_dim, + max_position=config.max_position_embeddings, + base=self.rope_theta, + rope_scaling=None, + is_neox_style=True, + ) + + def forward( + self, + hidden_states, + layer_idx: int, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + ): + qkv, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split([ + self.attention_hidden_size, self.attention_hidden_size, + self.attention_hidden_size + ], + dim=-1) + + block_idx = self.layer2block_map[layer_idx] + if self.config.use_shared_attention_adapter: + q_lora_output = self.linear_q_adapter_list[block_idx][0]( + hidden_states)[0] + q_lora_output = self.linear_q_adapter_list[block_idx][1]( + q_lora_output)[0] + query_states = query_states + q_lora_output + + k_lora_output = self.linear_k_adapter_list[block_idx][0]( + hidden_states)[0] + k_lora_output = self.linear_k_adapter_list[block_idx][1]( + k_lora_output)[0] + key_states = key_states + k_lora_output + + v_lora_output = self.linear_v_adapter_list[block_idx][0]( + hidden_states)[0] + v_lora_output = self.linear_v_adapter_list[block_idx][1]( + v_lora_output)[0] + value_states = value_states + v_lora_output + + if self.config.use_mem_rope: + query_states, key_states = self.rotary_emb(position_ids, + query_states, + key_states) + + # NOTE: No need anymore to pass specific kv_cache tensor, + # but keeping it for API compatibility + y = self.dpa_list[block_idx](query_states, key_states, value_states, + kv_caches[block_idx], attn_metadata) + y, _ = self.o_proj(y) + return y + + +class Zamba2MLP(nn.Module): + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + layer2block_map: Dict[int, int], + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer2block_map = layer2block_map + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_fwd_mem_blocks = len(layer2block_map) + + self.gate_up_proj = ColumnParallelLinear( + self.hidden_size, + 2 * self.intermediate_size, + bias=self.config.add_bias_linear, + quant_config=quant_config) + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config) + if config.hidden_act != "gelu": + raise ValueError(f"Only gelu activation is supported" + f" (got `hidden_act`: {config.hidden_act})") + self.act_fn = F.gelu + + self.gate_up_proj_adapter_list = nn.ModuleList([]) + for block_idx in range(self.num_fwd_mem_blocks): + if block_idx % config.num_mem_blocks == bare_block_idx: + gate_up_proj_adapter = nn.ModuleList([ + ColumnParallelLinear(config.hidden_size, + self.config.adapter_rank, + bias=False, + quant_config=quant_config), + ColumnParallelLinear(config.adapter_rank, + 2 * self.intermediate_size, + bias=False, + quant_config=quant_config), + ]) + else: + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) + + def forward(self, hidden_states, layer_idx): + gate_up_state, _ = self.gate_up_proj(hidden_states) + block_idx = self.layer2block_map[layer_idx] + lora_output = self.gate_up_proj_adapter_list[block_idx][0]( + hidden_states)[0] + lora_output = self.gate_up_proj_adapter_list[block_idx][1]( + lora_output)[0] + gate_up_state = gate_up_state + lora_output + + gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) + hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] + output, _ = self.down_proj(hidden_state) + return output + + +class Zamba2AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Zamba2Config, + bare_block_idx: int, + layer2block_map: Dict[int, int], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Zamba2Attention( + config, + bare_block_idx=bare_block_idx, + layer2block_map=layer2block_map, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + self.feed_forward = Zamba2MLP( + config, + bare_block_idx=bare_block_idx, + layer2block_map=layer2block_map, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(2 * config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + + # The argument original_hidden_states is concatenated with hidden_states + # (which is the output of the previous (mamba) layer). + # The concatenated tensor is then used as input of the pre-attention + # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). + hidden_states = torch.concatenate( + [hidden_states, original_hidden_states], dim=-1) + + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + hidden_states, + position_ids=positions, + layer_idx=layer_idx, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + # feed-forward (MLP) + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx=layer_idx) + + return hidden_states + + +class Zamba2MambaDecoderLayer(nn.Module): + + def __init__( + self, + config: Zamba2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + intermediate_size = config.mamba_expand * config.hidden_size + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + chunk_size=config.chunk_size, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + sequence_idx: Optional[torch.Tensor] = None, + transformer_hidden_states: Optional[torch.Tensor] = None, + positions: Optional[torch.Tensor] = None, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: Optional[int] = None, + kv_caches: Optional[List[KVCache]] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + + residual = hidden_states + + # `transformer_hidden_states` is the output from shared + # transformer + linear layer (see fig. 2 in + # https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba + # layer below (as described in eq. (6) of + # https://arxiv.org/pdf/2405.16712). + if transformer_hidden_states is not None: + hidden_states = hidden_states + transformer_hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.mamba( + hidden_states, + attn_metadata=attn_metadata, + mamba_cache_params=mamba_cache_params, + sequence_idx=sequence_idx, + ) + + # residual connection after mamba + hidden_states = residual + hidden_states + + return hidden_states + + +class Zamba2HybridLayer(nn.Module): + + def __init__( + self, + shared_transformer: Zamba2AttentionDecoderLayer, + linear: ColumnParallelLinear, + mamba: Zamba2MambaDecoderLayer, + ): + super().__init__() + self.shared_transformer = shared_transformer + self.linear = linear + self.mamba_decoder = mamba + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + mamba_cache_params: Optional[MambaCacheParams] = None, + sequence_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + + transformer_hidden_states = self.shared_transformer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + + transformer_hidden_states, _ = self.linear(transformer_hidden_states) + + layer_outputs = self.mamba_decoder( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + attn_metadata=attn_metadata, + mamba_cache_params=mamba_cache_params, + sequence_idx=sequence_idx, + ) + + return layer_outputs + + +class Zamba2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Implement PP, need to use make_layers() + + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + is_lora_enabled = bool(lora_config) + assert not is_lora_enabled + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + layer2block_map = { + layer_idx: block_idx + for block_idx, layer_idx in enumerate(config.hybrid_layer_ids) + } + blocks = cycle([ + Zamba2AttentionDecoderLayer(config, + bare_block_idx=idx, + layer2block_map=layer2block_map, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}") + for idx in range(config.num_mem_blocks) + ]) + layers = [] + for layer_type in config.layers_block_type: + mamba_layer = Zamba2MambaDecoderLayer(config, + quant_config=quant_config) + if layer_type == "hybrid": + block = next(blocks) + linear_layer = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + layers.append( + Zamba2HybridLayer(block, linear_layer, mamba_layer)) + else: + layers.append(mamba_layer) + self.layers = nn.ModuleList(layers) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + # TODO: decide whether we want to implement PP support + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + seq_idx = None + if attn_metadata.num_prefills > 0: + seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) + for i, (srt, end) in enumerate( + zip( + attn_metadata.query_start_loc, + attn_metadata.query_start_loc[1:], + )): + seq_idx[srt:end] = i + seq_idx.unsqueeze_(0) + + original_hidden_states = torch.clone(hidden_states) + for layer_idx, layer in enumerate(self.layers): + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), + sequence_idx=seq_idx, + ) + hidden_states = layer_outputs + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.vllm_config = vllm_config + self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.model = Zamba2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + if self.scheduler_config is not None and \ + not self.model_config.enforce_eager: + if self.scheduler_config.max_num_seqs > \ + vllm_config.compilation_config.max_capture_size: + self.max_batch_size = \ + vllm_config.compilation_config.max_capture_size + else: + self.max_batch_size = vllm_config.pad_for_cudagraph( + self.scheduler_config.max_num_seqs) + elif self.scheduler_config is not None: + # For eager just take the scheduler_config if avail + self.max_batch_size = self.scheduler_config.max_num_seqs + else: + self.max_batch_size = 8192 + 2 + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, + self.max_batch_size, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + mamba_cache_params, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + + intermediate_size = self.config.mamba_expand * self.config.hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards( + self.config.mamba_ngroups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(divide(intermediate_size, self.config.mamba_headdim), + world_size), + self.config.mamba_headdim, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + weights_dict = {} + for key, loaded_weight in weights: + if "A_log" in key: + key = key.replace("A_log", "A") + weights_dict[key] = loaded_weight + + params_dict = dict(self.named_parameters()) + for chkpt_weight_name, loaded_weight in weights_dict.items(): + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in chkpt_weight_name: + continue + chkpt_weight_name = chkpt_weight_name.replace( + weight_name, param_name) + param = params_dict[chkpt_weight_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if chkpt_weight_name not in params_dict: + continue + param = params_dict[chkpt_weight_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From c91c9c4d0b5dff7e1759286797503b111d16c8a4 Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Mon, 3 Mar 2025 22:51:32 +0000 Subject: [PATCH 02/10] TP support, unit test, docsrtrings Signed-off-by: Yury Tokpanov --- requirements/common.txt | 2 +- .../decoder_only/language/test_hybrid.py | 4 +- tests/models/registry.py | 1 + .../layers/mamba/mamba_mixer2.py | 1 - vllm/model_executor/models/zamba2.py | 612 ++++++++++++++---- 5 files changed, 479 insertions(+), 141 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index d08ef253828b..616382843848 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -6,7 +6,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.48.2 # Required for Bamba model and Transformers backend. +transformers >= 4.49.0 # Required for Bamba and Zamba2 models and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 1a78b30930e3..9115ef23f4a7 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -9,7 +9,7 @@ from ...utils import check_outputs_equal # This test is for the hybrid models -MODELS = ["ai21labs/Jamba-tiny-dev"] +MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"] # Bamba at Fp32 is too big for the CI (L4 GPU). # MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"] @@ -112,7 +112,7 @@ def test_mamba_prefill_chunking_with_parallel_sampling( def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, model: str, dtype: str, max_tokens: int) -> None: - # numeric error during prefill chucking produces different generation + # numeric error during prefill chunking produces different generation # compared to w/o prefill chunking for those examples, removed them for now if 'Jamba' in model: example_prompts.pop(7) diff --git a/tests/models/registry.py b/tests/models/registry.py index 6b0ac46b0c36..c0796579b4f2 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -195,6 +195,7 @@ def check_available_online( "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), + "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b53a540ed662..53d68b60f2fd 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -245,7 +245,6 @@ def __init__(self, assert num_heads % self.tp_size == 0, \ "Tensor parallel world size must divide num heads." - assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ ( "If tensor parallel world size does not divide num_heads, " diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 3926187ebbb7..37b9233de6e0 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -1,8 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -"""Inference-only Zamba2 model.""" -# Added by the Zyphra Technologies, 2025 +"""PyTorch Zamba2 model implementation for vLLM. + +This module implements the Zamba2 architecture from +https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer +architectures in a hybrid model optimized for efficient sequence modeling. The +model alternates between state space model layers and attention-based layers. +""" from itertools import cycle -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn.functional as F @@ -12,11 +17,12 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed import (divide, get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import ( @@ -33,12 +39,18 @@ from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid -from .utils import make_empty_intermediate_tensors_factory, maybe_prefix +from .utils import maybe_prefix KVCache = Tuple[torch.Tensor, torch.Tensor] class Zamba2Attention(nn.Module): + """Multi-head attention mechanism for the Zamba2 model. + + Implements attention with parallel computation, QKV projections, optional + adapters and rotary position embeddings. The attention is computed across + distributed blocks for efficient processing. + """ def __init__( self, @@ -48,20 +60,34 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> None: + """Initialize the attention layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare attention block + layer2block_map: Mapping from layer indices to block indices + cache_config: Configuration for key-value caching + quant_config: Configuration for model quantization + prefix: Optional prefix for parameter names + """ super().__init__() + tp_size = get_tensor_model_parallel_world_size() self.config = config self.layer2block_map = layer2block_map self.num_fwd_mem_blocks = len(layer2block_map) self.rope_theta = config.rope_theta self.attention_hidden_size = config.attention_hidden_size - self.num_attention_heads = config.num_attention_heads + self.total_num_attention_heads = config.num_attention_heads + assert self.total_num_attention_heads % tp_size == 0 + self.num_attention_heads = config.num_attention_heads // tp_size self.attention_head_dim = config.attention_head_dim + self.qkv_size = self.attention_hidden_size // tp_size self.scale = (self.attention_head_dim / 2)**-0.5 if (self.attention_head_dim * - self.num_attention_heads) != self.attention_hidden_size: + self.total_num_attention_heads) != self.attention_hidden_size: raise ValueError( f"attention_hidden_size must be divisible by" f" num_attention_heads" @@ -71,7 +97,7 @@ def __init__( self.qkv_proj = QKVParallelLinear( self.attention_hidden_size, self.attention_head_dim, - self.num_attention_heads, + self.total_num_attention_heads, bias=False, quant_config=quant_config, ) @@ -82,6 +108,8 @@ def __init__( # Need to define separate Attention objects, because in recent vLLM # KV cache tensors are tied to specific Attention objects. + + # Initialize attention blocks with proper indexing self.dpa_list = nn.ModuleList([]) j = bare_block_idx * (self.num_fwd_mem_blocks + config.num_mem_blocks - 1) // config.num_mem_blocks @@ -99,6 +127,7 @@ def __init__( dpa = nn.Identity() self.dpa_list.append(dpa) + # Initialize adapter layers if enabled if config.use_shared_attention_adapter: self.linear_q_adapter_list = nn.ModuleList([]) self.linear_k_adapter_list = nn.ModuleList([]) @@ -110,7 +139,8 @@ def __init__( ColumnParallelLinear(self.attention_hidden_size, config.adapter_rank, bias=False, - quant_config=quant_config), + quant_config=quant_config, + gather_output=True), ColumnParallelLinear(config.adapter_rank, self.attention_hidden_size, bias=False, @@ -120,7 +150,8 @@ def __init__( ColumnParallelLinear(self.attention_hidden_size, config.adapter_rank, bias=False, - quant_config=quant_config), + quant_config=quant_config, + gather_output=True), ColumnParallelLinear(config.adapter_rank, self.attention_hidden_size, bias=False, @@ -130,7 +161,8 @@ def __init__( ColumnParallelLinear(self.attention_hidden_size, config.adapter_rank, bias=False, - quant_config=quant_config), + quant_config=quant_config, + gather_output=True), ColumnParallelLinear(config.adapter_rank, self.attention_hidden_size, bias=False, @@ -140,6 +172,7 @@ def __init__( linear_q_adapter = nn.Identity() linear_k_adapter = nn.Identity() linear_v_adapter = nn.Identity() + self.linear_q_adapter_list.append(linear_q_adapter) self.linear_k_adapter_list.append(linear_k_adapter) self.linear_v_adapter_list.append(linear_v_adapter) @@ -155,34 +188,60 @@ def __init__( ) def forward( - self, - hidden_states, - layer_idx: int, - position_ids: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - ): + self, + hidden_states: torch.Tensor, + layer_idx: int, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details + ) -> torch.Tensor: + """Forward pass through the attention layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + position_ids: Position IDs for positional embeddings + layer_idx: Current layer index + kv_caches: List of key-value cache tuples + attn_metadata: Metadata required for attention computation, + including: + - block_tables: Mapping of sequence blocks to physical storage + - context_lens: Length of context for each sequence + - max_context_len: Maximum context length in the batch + - query_start_loc: Starting positions of queries in the batch + - num_queries: Number of query tokens to process + - num_prefills: Number of tokens being prefilled (vs generated) + Used to handle variable sequence lengths and enable efficient + batched attention computation across multiple sequences. + + Returns: + Output tensor [batch_size, seq_len, hidden_size] + """ qkv, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv.split([ - self.attention_hidden_size, self.attention_hidden_size, - self.attention_hidden_size - ], + query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1) block_idx = self.layer2block_map[layer_idx] if self.config.use_shared_attention_adapter: + # Apply adapter transformations to Q, K, V if enabled + assert not isinstance(self.linear_q_adapter_list[block_idx], + nn.Identity) q_lora_output = self.linear_q_adapter_list[block_idx][0]( hidden_states)[0] q_lora_output = self.linear_q_adapter_list[block_idx][1]( q_lora_output)[0] query_states = query_states + q_lora_output + assert not isinstance(self.linear_k_adapter_list[block_idx], + nn.Identity) k_lora_output = self.linear_k_adapter_list[block_idx][0]( hidden_states)[0] k_lora_output = self.linear_k_adapter_list[block_idx][1]( k_lora_output)[0] key_states = key_states + k_lora_output + assert not isinstance(self.linear_v_adapter_list[block_idx], + nn.Identity) v_lora_output = self.linear_v_adapter_list[block_idx][0]( hidden_states)[0] v_lora_output = self.linear_v_adapter_list[block_idx][1]( @@ -203,6 +262,12 @@ def forward( class Zamba2MLP(nn.Module): + """Feed-forward MLP layer for the Zamba2 model. + + Implements a gated feed-forward network that projects inputs to a larger + intermediate size, applies GELU activation with gating, then projects back + to the original size. Includes optional adapter layers for model adaptation. + """ def __init__( self, @@ -210,36 +275,51 @@ def __init__( bare_block_idx: int, layer2block_map: Dict[int, int], quant_config: Optional[QuantizationConfig] = None, - ): + ) -> None: + """Initialize the MLP layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare block in the model + layer2block_map: Mapping from layer indices to block indices + quant_config: Configuration for model quantization + """ super().__init__() self.config = config + self.tp_size = get_tensor_model_parallel_world_size() self.layer2block_map = layer2block_map self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.num_fwd_mem_blocks = len(layer2block_map) + # Main projection layers with gating self.gate_up_proj = ColumnParallelLinear( self.hidden_size, - 2 * self.intermediate_size, + 2 * self.intermediate_size, # 2x for gate and input projections bias=self.config.add_bias_linear, quant_config=quant_config) - self.down_proj = RowParallelLinear(self.intermediate_size, - self.hidden_size, - bias=self.config.add_bias_linear, - quant_config=quant_config) + + self.down_proj = ReplicatedLinear(self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config) + + # Only allow GELU activations if config.hidden_act != "gelu": - raise ValueError(f"Only gelu activation is supported" - f" (got `hidden_act`: {config.hidden_act})") + raise ValueError(f"Only gelu activation is supported " + f"(got `hidden_act`: {config.hidden_act})") self.act_fn = F.gelu + # Initialize adapter layers if enabled self.gate_up_proj_adapter_list = nn.ModuleList([]) for block_idx in range(self.num_fwd_mem_blocks): if block_idx % config.num_mem_blocks == bare_block_idx: gate_up_proj_adapter = nn.ModuleList([ ColumnParallelLinear(config.hidden_size, - self.config.adapter_rank, + config.adapter_rank, bias=False, - quant_config=quant_config), + quant_config=quant_config, + gather_output=True), ColumnParallelLinear(config.adapter_rank, 2 * self.intermediate_size, bias=False, @@ -249,22 +329,52 @@ def __init__( gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) - def forward(self, hidden_states, layer_idx): - gate_up_state, _ = self.gate_up_proj(hidden_states) + def forward(self, hidden_states: torch.Tensor, + layer_idx: int) -> torch.Tensor: + """Forward pass through the MLP layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + layer_idx: Current layer index + + Returns: + Output tensor [batch_size, seq_len, hidden_size] after applying + gated feed-forward transformation + """ + # Project input to intermediate size with gating + gate_up_states, _ = self.gate_up_proj(hidden_states) + + # Apply adapter transformation if present block_idx = self.layer2block_map[layer_idx] - lora_output = self.gate_up_proj_adapter_list[block_idx][0]( - hidden_states)[0] - lora_output = self.gate_up_proj_adapter_list[block_idx][1]( - lora_output)[0] - gate_up_state = gate_up_state + lora_output - - gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) - hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] - output, _ = self.down_proj(hidden_state) + assert not isinstance(self.gate_up_proj_adapter_list[block_idx], + nn.Identity) + adapter = self.gate_up_proj_adapter_list[block_idx] + lora_output = adapter[0](hidden_states)[0] + lora_output = adapter[1](lora_output)[0] + gate_up_states = gate_up_states + lora_output + if self.tp_size > 1: + gate_up_states = tensor_model_parallel_all_gather(gate_up_states) + + # Split into gate and input projections + gate_up_states = torch.chunk(gate_up_states, 2, dim=-1) + + # Apply GELU activation with gating + hidden_states = self.act_fn(gate_up_states[0]) * gate_up_states[1] + + # Project back to hidden size + output, _ = self.down_proj(hidden_states) return output class Zamba2AttentionDecoderLayer(nn.Module): + """Single decoder layer combining attention and feed-forward networks. + + This layer implements a standard transformer block with: + - Input layer normalization + - Multi-head self-attention + - Pre-feed-forward layer normalization + - Feed-forward network (MLP) + """ def __init__( self, @@ -274,8 +384,20 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - ): + ) -> None: + """Initialize the decoder layer. + + Args: + config: The Zamba2 model configuration + bare_block_idx: Index of the bare block + layer2block_map: Mapping from layer indices to block indices + cache_config: Configuration for key-value caching + quant_config: Configuration for model quantization + prefix: Optional prefix for parameter names + """ super().__init__() + + # Initialize attention sublayer self.self_attn = Zamba2Attention( config, bare_block_idx=bare_block_idx, @@ -284,27 +406,49 @@ def __init__( quant_config=quant_config, prefix=prefix, ) + + # Initialize feed-forward sublayer self.feed_forward = Zamba2MLP( config, bare_block_idx=bare_block_idx, layer2block_map=layer2block_map, quant_config=quant_config, ) + + # Initialize layer normalizations + # Input normalization operates on concatenated states self.input_layernorm = RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps) + # Pre-FF normalization operates on attention output self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - original_hidden_states: torch.Tensor, - layer_idx: int, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details + ) -> torch.Tensor: + """Forward pass through the decoder layer. + + Args: + hidden_states: Input tensor from previous layer + original_hidden_states: Original input tensor for residual + connection + layer_idx: Current layer index + positions: IDs for positional embeddings + kv_caches: List of key-value cache tuples + attn_metadata: Metadata for sequence processing and attention + computation + + + Returns: + Transformed hidden states after attention and feed-forward + """ # The argument original_hidden_states is concatenated with hidden_states # (which is the output of the previous (mamba) layer). @@ -313,8 +457,10 @@ def forward( hidden_states = torch.concatenate( [hidden_states, original_hidden_states], dim=-1) + # Layer norm before attention hidden_states = self.input_layernorm(hidden_states) + # Self attention hidden_states = self.self_attn( hidden_states, position_ids=positions, @@ -323,21 +469,37 @@ def forward( attn_metadata=attn_metadata, ) - # feed-forward (MLP) + # Layer norm before feed-forward hidden_states = self.pre_ff_layernorm(hidden_states) + + # Feed-forward network hidden_states = self.feed_forward(hidden_states, layer_idx=layer_idx) return hidden_states class Zamba2MambaDecoderLayer(nn.Module): + """Single Mamba decoder layer with normalization. + + This implements a Mamba block. It includes input normalization + and can process sequences using either chunked or full + computation depending on configuration. + """ def __init__( self, config: Zamba2Config, quant_config: Optional[QuantizationConfig] = None, - ): + ) -> None: + """Initialize the Mamba decoder layer. + + Args: + config: The Zamba2 model configuration + quant_config: Configuration for model quantization + """ super().__init__() + + # Initialize Mamba mixer with expanded intermediate size intermediate_size = config.mamba_expand * config.hidden_size self.mamba = MambaMixer2( hidden_size=config.hidden_size, @@ -354,13 +516,16 @@ def __init__( chunk_size=config.chunk_size, quant_config=quant_config, ) + + # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, @@ -368,9 +533,28 @@ def forward( original_hidden_states: Optional[torch.Tensor] = None, layer_idx: Optional[int] = None, kv_caches: Optional[List[KVCache]] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - + ) -> torch.Tensor: + """Forward pass through the Mamba decoder layer. + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + attn_metadata: Metadata for sequence processing and attention + computation + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + sequence_idx: Index tensor for identifying sequences in batch + Required for proper chunked processing in prefill + transformer_hidden_states: Optional output from transformer path + Added to input if provided (used in hybrid architecture) + positions: Optional position IDs (unused in Mamba) + original_hidden_states: Optional original inputs (unused in Mamba) + layer_idx: Optional layer index (unused in Mamba) + kv_caches: Optional KV caches (unused in Mamba) + + Returns: + Transformed hidden states with residual connection applied + """ + # Store input for residual connection residual = hidden_states # `transformer_hidden_states` is the output from shared @@ -382,7 +566,10 @@ def forward( if transformer_hidden_states is not None: hidden_states = hidden_states + transformer_hidden_states + # Apply input normalization hidden_states = self.input_layernorm(hidden_states) + + # Process through Mamba mixer hidden_states = self.mamba( hidden_states, attn_metadata=attn_metadata, @@ -397,17 +584,35 @@ def forward( class Zamba2HybridLayer(nn.Module): + """Hybrid layer combining Transformer and Mamba architectures. + + This layer implements the hybrid architecture described in the Zamba paper, + where a shared transformer pathway processes input in parallel with a Mamba + pathway. The transformer output is projected and added to the Mamba input + for enhanced representation learning. + """ def __init__( self, shared_transformer: Zamba2AttentionDecoderLayer, - linear: ColumnParallelLinear, - mamba: Zamba2MambaDecoderLayer, - ): + config: Zamba2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + """Initialize the hybrid layer. + + Args: + shared_transformer: Transformer decoder layer for attention pathway + linear: Linear projection for transformer output before Mamba + mamba: Mamba decoder layer for state space pathway + """ super().__init__() self.shared_transformer = shared_transformer - self.linear = linear - self.mamba_decoder = mamba + self.linear = ReplicatedLinear(config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + self.mamba_decoder = Zamba2MambaDecoderLayer(config, + quant_config=quant_config) def forward( self, @@ -416,12 +621,37 @@ def forward( layer_idx: int, positions: torch.Tensor, kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: Optional[MambaCacheParams] = None, sequence_idx: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - + ) -> torch.Tensor: + """Forward pass through the hybrid layer. + + Processes input through parallel transformer and Mamba paths: + 1. Transformer path processes input with attention + 2. Transformer output is projected to match hidden size + 3. Projected output is added to Mamba path input + 4. Final output combines both paths' representations + + Args: + hidden_states: Input tensor [batch_size, seq_len, hidden_size] + original_hidden_states: Original input for transformer residual + connection + layer_idx: Current layer index for block mapping + positions: Position IDs for positional embeddings + kv_caches: Key-value caches for attention + attn_metadata: Metadata for sequence processing and attention + computation + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + sequence_idx: Indices for identifying sequences in batch, + required for proper chunked processing in prefill + + Returns: + Output tensor combining transformer and Mamba representations + """ + # Process through transformer pathway transformer_hidden_states = self.shared_transformer( hidden_states, original_hidden_states=original_hidden_states, @@ -431,8 +661,10 @@ def forward( attn_metadata=attn_metadata, ) + # Project transformer output transformer_hidden_states, _ = self.linear(transformer_hidden_states) + # Process through Mamba pathway with transformer injection layer_outputs = self.mamba_decoder( hidden_states, transformer_hidden_states=transformer_hidden_states, @@ -445,10 +677,20 @@ def forward( class Zamba2Model(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Implement PP, need to use make_layers() - + """Core Zamba2 model combining transformer and Mamba architectures. + + The model processes input through a sequence of hybrid and Mamba-only + layers, using token embeddings and final layer normalization. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + """Initialize the Zamba2 model. + + Args: + vllm_config: Configuration object containing model, cache, + quantization and LoRA settings + prefix: Optional prefix for parameter names in state dict + """ super().__init__() config = vllm_config.model_config.hf_config @@ -464,16 +706,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size + # Initialize token embeddings self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) + # Map hybrid layer indices to block indices layer2block_map = { layer_idx: block_idx for block_idx, layer_idx in enumerate(config.hybrid_layer_ids) } + + # Create cyclic iterator of transformer blocks blocks = cycle([ Zamba2AttentionDecoderLayer(config, bare_block_idx=idx, @@ -483,29 +729,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}") for idx in range(config.num_mem_blocks) ]) + + # Initialize layers according to block type configuration layers = [] for layer_type in config.layers_block_type: - mamba_layer = Zamba2MambaDecoderLayer(config, - quant_config=quant_config) if layer_type == "hybrid": block = next(blocks) - linear_layer = ColumnParallelLinear(config.hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) - layers.append( - Zamba2HybridLayer(block, linear_layer, mamba_layer)) + layers.append(Zamba2HybridLayer(block, config, quant_config)) else: - layers.append(mamba_layer) + layers.append( + Zamba2MambaDecoderLayer(config, quant_config=quant_config)) self.layers = nn.ModuleList(layers) + + # Final layer normalization self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. + + Args: + input_ids: Tensor of input token IDs + + Returns: + Embedded representation of the input tokens + """ return self.embed_tokens(input_ids) def forward( @@ -513,19 +761,30 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: MambaCacheParams, - intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ): - # TODO: decide whether we want to implement PP support - if get_pp_group().is_first_rank: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - hidden_states = inputs_embeds - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] + ) -> Union[torch.Tensor, IntermediateTensors]: + """Forward pass through the model. + + Args: + input_ids: Input token IDs + positions: Position IDs for embeddings + kv_caches: List of key-value cache tuples + attn_metadata: Metadata for attention computation + mamba_cache_params: Parameters for Mamba's state caches + (one for conv, one for ssm) + inputs_embeds: Optional pre-computed input embeddings + + Returns: + Either final hidden states or intermediate tensors for pipeline + parallelism + """ + # Handle pipeline parallelism for first rank + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds # pass a sequence index tensor, that is required for # proper continuous batching computation including @@ -541,6 +800,7 @@ def forward( seq_idx[srt:end] = i seq_idx.unsqueeze_(0) + # Process through layers original_hidden_states = torch.clone(hidden_states) for layer_idx, layer in enumerate(self.layers): layer_outputs = layer( @@ -555,17 +815,32 @@ def forward( ) hidden_states = layer_outputs - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) hidden_states = self.final_layernorm(hidden_states) return hidden_states class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + """Zamba2 model with causal language modeling head. + + This class wraps the core Zamba2 model and adds: + - A language modeling head for next token prediction + - Mamba state caching functionality + - Support for model parallelism and quantization + - Sampling capabilities for text generation + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + """Initialize the Zamba2 model for causal language modeling. + + Args: + vllm_config: Configuration containing model, cache, quantization, + LoRA and scheduler settings + prefix: Optional prefix for parameter names + + Raises: + AssertionError: If prefix caching is enabled (not supported by + Mamba) + """ config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config @@ -582,9 +857,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + # Initialize core model self.model = Zamba2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # Initialize language modeling head self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, @@ -594,82 +871,119 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Used to track and store by the Mamba cache between steps. self.mamba_cache: Optional[MambaCacheManager] = None + # Initialize logits processing and sampling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ - vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) - elif self.scheduler_config is not None: - # For eager just take the scheduler_config if avail - self.max_batch_size = self.scheduler_config.max_num_seqs - else: - self.max_batch_size = 8192 + 2 - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings. + Args: + input_ids: Tensor of input token IDs + Returns: + Embedded representation of the input tokens + """ return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: + AttentionMetadata, # See Zamba2Attention.forward for details + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """Forward pass through the model. + + Args: + input_ids: Input token IDs + positions: Position IDs for embeddings + kv_caches: List of key-value cache tuples + attn_metadata: Metadata for attention computation + inputs_embeds: Optional pre-computed input embeddings + **kwargs: Additional arguments passed to cache manager + + Returns: + Output hidden states + """ + # Initialize Mamba cache if needed if self.mamba_cache is None: num_mamba_layers = self.config.num_hidden_layers self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, num_mamba_layers, - self.max_batch_size, *self._get_mamba_cache_shape()) + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + # Forward pass through model hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, - intermediate_tensors, inputs_embeds, ) + return hidden_states - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str, + torch.Tensor], + **kwargs) -> Dict[str, torch.Tensor]: + """Copy inputs before CUDA graph capture. + + Args: + input_buffers: Dictionary of input tensors + **kwargs: Additional arguments passed to cache manager + + Returns: + Updated input buffers + """ return self.mamba_cache.copy_inputs_before_cuda_graphs( input_buffers, **kwargs) - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + def get_seqlen_agnostic_capture_inputs( + self, batch_size: int) -> Dict[str, torch.Tensor]: + """Get inputs for sequence-length-agnostic graph capture. + + Args: + batch_size: Size of batch to capture + Returns: + Dictionary of capture inputs + """ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) def _get_mamba_cache_shape( self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ world_size = get_tensor_model_parallel_world_size() intermediate_size = self.config.mamba_expand * self.config.hidden_size + # Extend groups if needed to ensure all groups needed by a head + # are sharded together + # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards( self.config.mamba_ngroups, world_size)) + # Calculate conv state shape (includes groups) # - heads and n_groups are TP-ed conv_dim = (intermediate_size + 2 * n_groups * self.config.mamba_d_state) @@ -678,6 +992,7 @@ def _get_mamba_cache_shape( self.config.mamba_d_conv - 1, ) + # Calculate temporal state shape (per-head states) # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small # e.g., (h_heads, d_head, d_state) = (128, 64, 128) @@ -687,6 +1002,7 @@ def _get_mamba_cache_shape( self.config.mamba_headdim, self.config.mamba_d_state, ) + return conv_state_shape, temporal_state_shape def compute_logits( @@ -694,6 +1010,15 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: + """Compute logits for next token prediction. + + Args: + hidden_states: Hidden states from model forward pass + sampling_metadata: Metadata for sampling process + + Returns: + Logits for next token prediction + """ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits @@ -703,10 +1028,20 @@ def sample( logits: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: + """Sample next tokens from computed logits. + + Args: + logits: Computed logits for next token prediction + sampling_metadata: Metadata for sampling process + + Returns: + Sampled tokens and related sampling information + """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -721,6 +1056,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weights_dict[key] = loaded_weight params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for chkpt_weight_name, loaded_weight in weights_dict.items(): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in chkpt_weight_name: @@ -738,3 +1074,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(chkpt_weight_name) + return loaded_params From 89dab816e73de3a97ca407dc04270e5e47cdc965 Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Mon, 3 Mar 2025 22:51:46 +0000 Subject: [PATCH 03/10] Fix unit tests Signed-off-by: Yury Tokpanov --- .../decoder_only/language/test_hybrid.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/decoder_only/language/test_hybrid.py index 9115ef23f4a7..60eb3830c6d8 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/decoder_only/language/test_hybrid.py @@ -27,17 +27,19 @@ def test_models( ) -> None: # numeric error produces different generation - if 'Bamba' in model: + if "Bamba" in model: example_prompts.pop(3) - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: + model_kwargs = { + "use_mamba_kernels": False, # mamba kernels are not installed so HF + # don't use them + } + if "Zamba2" in model: + # Zamba2 HF implementation automatically checks if mamba kernels are + # installed + model_kwargs = {} + + with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: @@ -114,24 +116,29 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, max_tokens: int) -> None: # numeric error during prefill chunking produces different generation # compared to w/o prefill chunking for those examples, removed them for now - if 'Jamba' in model: + if "Jamba" in model: example_prompts.pop(7) example_prompts.pop(2) example_prompts.pop(1) - elif 'Bamba' in model: + elif "Bamba" in model: example_prompts.pop(6) example_prompts.pop(3) example_prompts.pop(2) dtype = "half" # use a different dtype for Bamba - - with hf_runner( - model, - dtype=dtype, - model_kwargs={ - "use_mamba_kernels": - False, # mamba kernels are not installed so HF - # don't use them - }) as hf_model: + elif "Zamba2" in model: + example_prompts.pop(7) + dtype = "half" + + model_kwargs = { + "use_mamba_kernels": False, # mamba kernels are not installed so HF + # don't use them + } + if "Zamba2" in model: + # Zamba2 HF implementation automatically checks if mamba kernels are + # installed + model_kwargs = {} + + with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model: non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, From 39105a9f7a0d28a40f5145ee7155e8be5b7711f6 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Mon, 3 Mar 2025 13:43:42 -0800 Subject: [PATCH 04/10] label transformers req as just zamba2 Co-authored-by: Tyler Michael Smith Signed-off-by: Quentin Anthony --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 616382843848..9d3249e20dfe 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -6,7 +6,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.49.0 # Required for Bamba and Zamba2 models and Transformers backend. +transformers >= 4.49.0 # Required for Zamba2 models and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. From 0867d1b0d42707433f32bbcb8dfd35f251b546a9 Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Tue, 4 Mar 2025 00:04:16 +0000 Subject: [PATCH 05/10] Rebase + remove kv_cache/attn_metadata args Signed-off-by: Yury Tokpanov --- requirements/test.in | 2 +- requirements/test.txt | 2 +- vllm/model_executor/models/bamba.py | 2 - vllm/model_executor/models/jamba.py | 2 - vllm/model_executor/models/zamba2.py | 101 ++++++--------------------- 5 files changed, 24 insertions(+), 85 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index faa4564eaa39..eb86f6d1a43f 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -30,7 +30,7 @@ matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.4 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -transformers==4.48.2 +transformers>=4.49.0 # quantization bitsandbytes>=0.45.3 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index c733364fd871..df25b2bb97c0 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -633,7 +633,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.48.2 +transformers==4.49.0 # via # -r requirements/test.in # genai-perf diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 61b68125e07e..de0209d0b43b 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -38,8 +38,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -KVCache = Tuple[torch.Tensor, torch.Tensor] - class BambaMLP(nn.Module): diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 11b863ded45d..6fabc8228e18 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -36,8 +36,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -KVCache = Tuple[torch.Tensor, torch.Tensor] - class JambaMoE(nn.Module): diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 37b9233de6e0..f9c366035c5a 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -7,18 +7,18 @@ model alternates between state space model layers and attention-based layers. """ from itertools import cycle -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, Optional, Set, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers import Zamba2Config -from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -41,8 +41,6 @@ from .interfaces import HasInnerState, IsHybrid from .utils import maybe_prefix -KVCache = Tuple[torch.Tensor, torch.Tensor] - class Zamba2Attention(nn.Module): """Multi-head attention mechanism for the Zamba2 model. @@ -106,8 +104,10 @@ def __init__( bias=False, quant_config=quant_config) - # Need to define separate Attention objects, because in recent vLLM - # KV cache tensors are tied to specific Attention objects. + # Even though in Zamba2 weights are shared between attention layers, KV + # cache is unique for every attention layer. Hence, we need to define + # separate Attention objects, because in recent vLLM KV cache tensors + # are tied to specific Attention objects. # Initialize attention blocks with proper indexing self.dpa_list = nn.ModuleList([]) @@ -188,13 +188,10 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - position_ids: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details + self, + hidden_states: torch.Tensor, + layer_idx: int, + position_ids: torch.Tensor, ) -> torch.Tensor: """Forward pass through the attention layer. @@ -202,17 +199,6 @@ def forward( hidden_states: Input tensor [batch_size, seq_len, hidden_size] position_ids: Position IDs for positional embeddings layer_idx: Current layer index - kv_caches: List of key-value cache tuples - attn_metadata: Metadata required for attention computation, - including: - - block_tables: Mapping of sequence blocks to physical storage - - context_lens: Length of context for each sequence - - max_context_len: Maximum context length in the batch - - query_start_loc: Starting positions of queries in the batch - - num_queries: Number of query tokens to process - - num_prefills: Number of tokens being prefilled (vs generated) - Used to handle variable sequence lengths and enable efficient - batched attention computation across multiple sequences. Returns: Output tensor [batch_size, seq_len, hidden_size] @@ -253,10 +239,7 @@ def forward( query_states, key_states) - # NOTE: No need anymore to pass specific kv_cache tensor, - # but keeping it for API compatibility - y = self.dpa_list[block_idx](query_states, key_states, value_states, - kv_caches[block_idx], attn_metadata) + y = self.dpa_list[block_idx](query_states, key_states, value_states) y, _ = self.o_proj(y) return y @@ -424,14 +407,11 @@ def __init__( eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - original_hidden_states: torch.Tensor, - layer_idx: int, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + positions: torch.Tensor, ) -> torch.Tensor: """Forward pass through the decoder layer. @@ -441,10 +421,6 @@ def forward( connection layer_idx: Current layer index positions: IDs for positional embeddings - kv_caches: List of key-value cache tuples - attn_metadata: Metadata for sequence processing and attention - computation - Returns: Transformed hidden states after attention and feed-forward @@ -465,8 +441,6 @@ def forward( hidden_states, position_ids=positions, layer_idx=layer_idx, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) # Layer norm before feed-forward @@ -524,22 +498,17 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: MambaCacheParams, sequence_idx: Optional[torch.Tensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, layer_idx: Optional[int] = None, - kv_caches: Optional[List[KVCache]] = None, ) -> torch.Tensor: """Forward pass through the Mamba decoder layer. Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - attn_metadata: Metadata for sequence processing and attention - computation mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) sequence_idx: Index tensor for identifying sequences in batch @@ -549,7 +518,6 @@ def forward( positions: Optional position IDs (unused in Mamba) original_hidden_states: Optional original inputs (unused in Mamba) layer_idx: Optional layer index (unused in Mamba) - kv_caches: Optional KV caches (unused in Mamba) Returns: Transformed hidden states with residual connection applied @@ -572,7 +540,6 @@ def forward( # Process through Mamba mixer hidden_states = self.mamba( hidden_states, - attn_metadata=attn_metadata, mamba_cache_params=mamba_cache_params, sequence_idx=sequence_idx, ) @@ -620,9 +587,6 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: Optional[MambaCacheParams] = None, sequence_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -640,9 +604,6 @@ def forward( connection layer_idx: Current layer index for block mapping positions: Position IDs for positional embeddings - kv_caches: Key-value caches for attention - attn_metadata: Metadata for sequence processing and attention - computation mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) sequence_idx: Indices for identifying sequences in batch, @@ -657,8 +618,6 @@ def forward( original_hidden_states=original_hidden_states, layer_idx=layer_idx, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, ) # Project transformer output @@ -668,7 +627,6 @@ def forward( layer_outputs = self.mamba_decoder( hidden_states, transformer_hidden_states=transformer_hidden_states, - attn_metadata=attn_metadata, mamba_cache_params=mamba_cache_params, sequence_idx=sequence_idx, ) @@ -760,9 +718,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details mamba_cache_params: MambaCacheParams, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -771,8 +726,6 @@ def forward( Args: input_ids: Input token IDs positions: Position IDs for embeddings - kv_caches: List of key-value cache tuples - attn_metadata: Metadata for attention computation mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) inputs_embeds: Optional pre-computed input embeddings @@ -790,6 +743,7 @@ def forward( # proper continuous batching computation including # chunked prefill seq_idx = None + attn_metadata = get_forward_context().attn_metadata if attn_metadata.num_prefills > 0: seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) for i, (srt, end) in enumerate( @@ -808,8 +762,6 @@ def forward( original_hidden_states=original_hidden_states, layer_idx=layer_idx, positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), sequence_idx=seq_idx, ) @@ -891,23 +843,16 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """ return self.model.get_input_embeddings(input_ids) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: - AttentionMetadata, # See Zamba2Attention.forward for details - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """Forward pass through the model. Args: input_ids: Input token IDs positions: Position IDs for embeddings - kv_caches: List of key-value cache tuples - attn_metadata: Metadata for attention computation inputs_embeds: Optional pre-computed input embeddings **kwargs: Additional arguments passed to cache manager @@ -928,8 +873,6 @@ def forward( hidden_states = self.model( input_ids, positions, - kv_caches, - attn_metadata, mamba_cache_params, inputs_embeds, ) From f3ec9ef20eb8605b29bf9876307d85a0e6a3876d Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Wed, 5 Mar 2025 04:04:16 +0000 Subject: [PATCH 06/10] MergedColumnParallel for MLP block Signed-off-by: Yury Tokpanov --- vllm/model_executor/models/zamba2.py | 37 ++++++++++++---------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index f9c366035c5a..97ff7ae5f13f 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -10,17 +10,17 @@ from typing import Dict, Iterable, Optional, Set, Tuple, Union import torch -import torch.nn.functional as F from torch import nn from transformers import Zamba2Config from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -276,22 +276,22 @@ def __init__( self.num_fwd_mem_blocks = len(layer2block_map) # Main projection layers with gating - self.gate_up_proj = ColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( self.hidden_size, - 2 * self.intermediate_size, # 2x for gate and input projections + 2 * [self.intermediate_size], # 2x for gate and input projections bias=self.config.add_bias_linear, quant_config=quant_config) - self.down_proj = ReplicatedLinear(self.intermediate_size, - self.hidden_size, - bias=self.config.add_bias_linear, - quant_config=quant_config) + self.down_proj = RowParallelLinear(self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config) # Only allow GELU activations if config.hidden_act != "gelu": - raise ValueError(f"Only gelu activation is supported " + raise ValueError(f"Only GELU activation is supported " f"(got `hidden_act`: {config.hidden_act})") - self.act_fn = F.gelu + self.act_fn = GeluAndMul() # Initialize adapter layers if enabled self.gate_up_proj_adapter_list = nn.ModuleList([]) @@ -303,10 +303,10 @@ def __init__( bias=False, quant_config=quant_config, gather_output=True), - ColumnParallelLinear(config.adapter_rank, - 2 * self.intermediate_size, - bias=False, - quant_config=quant_config), + MergedColumnParallelLinear(config.adapter_rank, + 2 * [self.intermediate_size], + bias=False, + quant_config=quant_config), ]) else: gate_up_proj_adapter = nn.Identity() @@ -335,14 +335,9 @@ def forward(self, hidden_states: torch.Tensor, lora_output = adapter[0](hidden_states)[0] lora_output = adapter[1](lora_output)[0] gate_up_states = gate_up_states + lora_output - if self.tp_size > 1: - gate_up_states = tensor_model_parallel_all_gather(gate_up_states) - - # Split into gate and input projections - gate_up_states = torch.chunk(gate_up_states, 2, dim=-1) # Apply GELU activation with gating - hidden_states = self.act_fn(gate_up_states[0]) * gate_up_states[1] + hidden_states = self.act_fn(gate_up_states) # Project back to hidden size output, _ = self.down_proj(hidden_states) From 7ad96e081de3bb12942050f3c73a697efdb72f96 Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Thu, 6 Mar 2025 23:37:53 +0000 Subject: [PATCH 07/10] Zamba2LoRA class + block indexing rework Signed-off-by: Yury Tokpanov --- vllm/model_executor/models/zamba2.py | 221 ++++++++++++++------------- 1 file changed, 118 insertions(+), 103 deletions(-) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 97ff7ae5f13f..ec02459985e7 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -7,7 +7,7 @@ model alternates between state space model layers and attention-based layers. """ from itertools import cycle -from typing import Dict, Iterable, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -42,6 +42,54 @@ from .utils import maybe_prefix +class Zamba2LoRA(nn.Module): + """LoRA layer for the Zamba2 model. + + Implements a LoRA layer that is used in shared attention and gated MLP + blocks. + """ + + def __init__( + self, + input_dim: int, + rank: int, + output_dim: Union[int, List[int]], + quant_config: Optional[QuantizationConfig] = None, + ): + """Initialize the attention layer. + + Args: + input_dim: input dimension + rank: LoRA rank + output_dim: output dimension + quant_config: Configuration for model quantization + """ + super().__init__() + + self.A = ColumnParallelLinear(input_dim, + rank, + bias=False, + quant_config=quant_config, + gather_output=True) + + if isinstance(output_dim, list): + B_class = MergedColumnParallelLinear + else: + B_class = ColumnParallelLinear + self.B = B_class(rank, + output_dim, + bias=False, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + ): + lora_output, _ = self.A(hidden_states) + lora_output, _ = self.B(lora_output) + return lora_output + + class Zamba2Attention(nn.Module): """Multi-head attention mechanism for the Zamba2 model. @@ -54,7 +102,7 @@ def __init__( self, config: Zamba2Config, bare_block_idx: int, - layer2block_map: Dict[int, int], + num_hybrid_layers: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -64,7 +112,7 @@ def __init__( Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare attention block - layer2block_map: Mapping from layer indices to block indices + num_hybrid_layers: Total number of hybrid layers cache_config: Configuration for key-value caching quant_config: Configuration for model quantization prefix: Optional prefix for parameter names @@ -72,8 +120,7 @@ def __init__( super().__init__() tp_size = get_tensor_model_parallel_world_size() self.config = config - self.layer2block_map = layer2block_map - self.num_fwd_mem_blocks = len(layer2block_map) + self.num_hybrid_layers = num_hybrid_layers self.rope_theta = config.rope_theta self.attention_hidden_size = config.attention_hidden_size @@ -111,9 +158,9 @@ def __init__( # Initialize attention blocks with proper indexing self.dpa_list = nn.ModuleList([]) - j = bare_block_idx * (self.num_fwd_mem_blocks + config.num_mem_blocks - + j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks - 1) // config.num_mem_blocks - for block_idx in range(self.num_fwd_mem_blocks): + for block_idx in range(self.num_hybrid_layers): if block_idx % config.num_mem_blocks == bare_block_idx: dpa = Attention( self.num_attention_heads, @@ -133,41 +180,26 @@ def __init__( self.linear_k_adapter_list = nn.ModuleList([]) self.linear_v_adapter_list = nn.ModuleList([]) - for block_idx in range(self.num_fwd_mem_blocks): + for block_idx in range(self.num_hybrid_layers): if block_idx % config.num_mem_blocks == bare_block_idx: - linear_q_adapter = nn.ModuleList([ - ColumnParallelLinear(self.attention_hidden_size, - config.adapter_rank, - bias=False, - quant_config=quant_config, - gather_output=True), - ColumnParallelLinear(config.adapter_rank, - self.attention_hidden_size, - bias=False, - quant_config=quant_config), - ]) - linear_k_adapter = nn.ModuleList([ - ColumnParallelLinear(self.attention_hidden_size, - config.adapter_rank, - bias=False, - quant_config=quant_config, - gather_output=True), - ColumnParallelLinear(config.adapter_rank, - self.attention_hidden_size, - bias=False, - quant_config=quant_config), - ]) - linear_v_adapter = nn.ModuleList([ - ColumnParallelLinear(self.attention_hidden_size, - config.adapter_rank, - bias=False, - quant_config=quant_config, - gather_output=True), - ColumnParallelLinear(config.adapter_rank, - self.attention_hidden_size, - bias=False, - quant_config=quant_config), - ]) + linear_q_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) + linear_k_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) + linear_v_adapter = Zamba2LoRA( + self.attention_hidden_size, + config.adapter_rank, + self.attention_hidden_size, + quant_config=quant_config, + ) else: linear_q_adapter = nn.Identity() linear_k_adapter = nn.Identity() @@ -190,7 +222,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - layer_idx: int, + block_idx: int, position_ids: torch.Tensor, ) -> torch.Tensor: """Forward pass through the attention layer. @@ -198,7 +230,7 @@ def forward( Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] position_ids: Position IDs for positional embeddings - layer_idx: Current layer index + block_idx: Current shared transformer block index Returns: Output tensor [batch_size, seq_len, hidden_size] @@ -207,31 +239,21 @@ def forward( query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1) - block_idx = self.layer2block_map[layer_idx] if self.config.use_shared_attention_adapter: # Apply adapter transformations to Q, K, V if enabled - assert not isinstance(self.linear_q_adapter_list[block_idx], - nn.Identity) - q_lora_output = self.linear_q_adapter_list[block_idx][0]( - hidden_states)[0] - q_lora_output = self.linear_q_adapter_list[block_idx][1]( - q_lora_output)[0] + q_adapter = self.linear_q_adapter_list[block_idx] + assert not isinstance(q_adapter, nn.Identity) + q_lora_output = q_adapter(hidden_states) query_states = query_states + q_lora_output - assert not isinstance(self.linear_k_adapter_list[block_idx], - nn.Identity) - k_lora_output = self.linear_k_adapter_list[block_idx][0]( - hidden_states)[0] - k_lora_output = self.linear_k_adapter_list[block_idx][1]( - k_lora_output)[0] + k_adapter = self.linear_k_adapter_list[block_idx] + assert not isinstance(k_adapter, nn.Identity) + k_lora_output = k_adapter(hidden_states) key_states = key_states + k_lora_output - assert not isinstance(self.linear_v_adapter_list[block_idx], - nn.Identity) - v_lora_output = self.linear_v_adapter_list[block_idx][0]( - hidden_states)[0] - v_lora_output = self.linear_v_adapter_list[block_idx][1]( - v_lora_output)[0] + v_adapter = self.linear_v_adapter_list[block_idx] + assert not isinstance(v_adapter, nn.Identity) + v_lora_output = v_adapter(hidden_states) value_states = value_states + v_lora_output if self.config.use_mem_rope: @@ -256,7 +278,7 @@ def __init__( self, config: Zamba2Config, bare_block_idx: int, - layer2block_map: Dict[int, int], + num_hybrid_layers: Dict[int, int], quant_config: Optional[QuantizationConfig] = None, ) -> None: """Initialize the MLP layer. @@ -264,16 +286,15 @@ def __init__( Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block in the model - layer2block_map: Mapping from layer indices to block indices + num_hybrid_layers: Total number of hybrid layers quant_config: Configuration for model quantization """ super().__init__() self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.layer2block_map = layer2block_map + self.num_hybrid_layers = num_hybrid_layers self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.num_fwd_mem_blocks = len(layer2block_map) # Main projection layers with gating self.gate_up_proj = MergedColumnParallelLinear( @@ -293,32 +314,27 @@ def __init__( f"(got `hidden_act`: {config.hidden_act})") self.act_fn = GeluAndMul() - # Initialize adapter layers if enabled + # Initialize adapter layers self.gate_up_proj_adapter_list = nn.ModuleList([]) - for block_idx in range(self.num_fwd_mem_blocks): + for block_idx in range(self.num_hybrid_layers): if block_idx % config.num_mem_blocks == bare_block_idx: - gate_up_proj_adapter = nn.ModuleList([ - ColumnParallelLinear(config.hidden_size, - config.adapter_rank, - bias=False, - quant_config=quant_config, - gather_output=True), - MergedColumnParallelLinear(config.adapter_rank, - 2 * [self.intermediate_size], - bias=False, - quant_config=quant_config), - ]) + gate_up_proj_adapter = Zamba2LoRA( + config.hidden_size, + config.adapter_rank, + 2 * [self.intermediate_size], + quant_config, + ) else: gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) def forward(self, hidden_states: torch.Tensor, - layer_idx: int) -> torch.Tensor: + block_idx: int) -> torch.Tensor: """Forward pass through the MLP layer. Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - layer_idx: Current layer index + block_idx: Current shared transformer block index Returns: Output tensor [batch_size, seq_len, hidden_size] after applying @@ -328,12 +344,9 @@ def forward(self, hidden_states: torch.Tensor, gate_up_states, _ = self.gate_up_proj(hidden_states) # Apply adapter transformation if present - block_idx = self.layer2block_map[layer_idx] - assert not isinstance(self.gate_up_proj_adapter_list[block_idx], - nn.Identity) adapter = self.gate_up_proj_adapter_list[block_idx] - lora_output = adapter[0](hidden_states)[0] - lora_output = adapter[1](lora_output)[0] + assert not isinstance(adapter, nn.Identity) + lora_output = adapter(hidden_states) gate_up_states = gate_up_states + lora_output # Apply GELU activation with gating @@ -358,7 +371,7 @@ def __init__( self, config: Zamba2Config, bare_block_idx: int, - layer2block_map: Dict[int, int], + num_hybrid_layers: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -368,7 +381,7 @@ def __init__( Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block - layer2block_map: Mapping from layer indices to block indices + num_hybrid_layers: Total number of hybrid layers cache_config: Configuration for key-value caching quant_config: Configuration for model quantization prefix: Optional prefix for parameter names @@ -379,7 +392,7 @@ def __init__( self.self_attn = Zamba2Attention( config, bare_block_idx=bare_block_idx, - layer2block_map=layer2block_map, + num_hybrid_layers=num_hybrid_layers, cache_config=cache_config, quant_config=quant_config, prefix=prefix, @@ -389,7 +402,7 @@ def __init__( self.feed_forward = Zamba2MLP( config, bare_block_idx=bare_block_idx, - layer2block_map=layer2block_map, + num_hybrid_layers=num_hybrid_layers, quant_config=quant_config, ) @@ -405,7 +418,7 @@ def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, - layer_idx: int, + block_idx: int, positions: torch.Tensor, ) -> torch.Tensor: """Forward pass through the decoder layer. @@ -414,7 +427,7 @@ def forward( hidden_states: Input tensor from previous layer original_hidden_states: Original input tensor for residual connection - layer_idx: Current layer index + block_idx: Current shared transformer block index positions: IDs for positional embeddings Returns: @@ -435,14 +448,14 @@ def forward( hidden_states = self.self_attn( hidden_states, position_ids=positions, - layer_idx=layer_idx, + block_idx=block_idx, ) # Layer norm before feed-forward hidden_states = self.pre_ff_layernorm(hidden_states) # Feed-forward network - hidden_states = self.feed_forward(hidden_states, layer_idx=layer_idx) + hidden_states = self.feed_forward(hidden_states, block_idx=block_idx) return hidden_states @@ -498,7 +511,6 @@ def forward( transformer_hidden_states: Optional[torch.Tensor] = None, positions: Optional[torch.Tensor] = None, original_hidden_states: Optional[torch.Tensor] = None, - layer_idx: Optional[int] = None, ) -> torch.Tensor: """Forward pass through the Mamba decoder layer. @@ -512,7 +524,6 @@ def forward( Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) original_hidden_states: Optional original inputs (unused in Mamba) - layer_idx: Optional layer index (unused in Mamba) Returns: Transformed hidden states with residual connection applied @@ -558,6 +569,7 @@ def __init__( self, shared_transformer: Zamba2AttentionDecoderLayer, config: Zamba2Config, + block_idx: int, quant_config: Optional[QuantizationConfig] = None, ) -> None: """Initialize the hybrid layer. @@ -568,6 +580,7 @@ def __init__( mamba: Mamba decoder layer for state space pathway """ super().__init__() + self.block_idx = block_idx self.shared_transformer = shared_transformer self.linear = ReplicatedLinear(config.hidden_size, config.hidden_size, @@ -580,7 +593,6 @@ def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, - layer_idx: int, positions: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None, sequence_idx: Optional[torch.Tensor] = None, @@ -597,7 +609,6 @@ def forward( hidden_states: Input tensor [batch_size, seq_len, hidden_size] original_hidden_states: Original input for transformer residual connection - layer_idx: Current layer index for block mapping positions: Position IDs for positional embeddings mamba_cache_params: Parameters for Mamba's state caches (one for conv, one for ssm) @@ -611,7 +622,7 @@ def forward( transformer_hidden_states = self.shared_transformer( hidden_states, original_hidden_states=original_hidden_states, - layer_idx=layer_idx, + block_idx=self.block_idx, positions=positions, ) @@ -676,7 +687,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: blocks = cycle([ Zamba2AttentionDecoderLayer(config, bare_block_idx=idx, - layer2block_map=layer2block_map, + num_hybrid_layers=len(layer2block_map), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}") @@ -685,10 +696,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Initialize layers according to block type configuration layers = [] - for layer_type in config.layers_block_type: + for layer_idx, layer_type in enumerate(config.layers_block_type): if layer_type == "hybrid": block = next(blocks) - layers.append(Zamba2HybridLayer(block, config, quant_config)) + block_idx = layer2block_map[layer_idx] + layers.append( + Zamba2HybridLayer(block, config, block_idx, quant_config)) else: layers.append( Zamba2MambaDecoderLayer(config, quant_config=quant_config)) @@ -755,7 +768,6 @@ def forward( layer_outputs = layer( hidden_states, original_hidden_states=original_hidden_states, - layer_idx=layer_idx, positions=positions, mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx), sequence_idx=seq_idx, @@ -991,6 +1003,9 @@ def load_weights(self, weights: Iterable[Tuple[str, for key, loaded_weight in weights: if "A_log" in key: key = key.replace("A_log", "A") + elif "adapter_list" in key: + key = key.replace("0.weight", "A.weight") + key = key.replace("1.weight", "B.weight") weights_dict[key] = loaded_weight params_dict = dict(self.named_parameters()) From 54f25ddfed1f133b3a9c6b4f059bd7bcb7149978 Mon Sep 17 00:00:00 2001 From: yury-tokpanov Date: Fri, 14 Mar 2025 16:47:36 -0700 Subject: [PATCH 08/10] Update tests/models/registry.py Co-authored-by: Cyrus Leung Signed-off-by: Yury Tokpanov --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index c0796579b4f2..554e28863a7b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -195,7 +195,8 @@ def check_available_online( "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), - "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), + "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct", + min_transformers_version="4.49"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), From 68ffb4aaa390ae6b30a36d320978543eb499330d Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Tue, 18 Mar 2025 06:17:54 +0000 Subject: [PATCH 09/10] revert requirements Signed-off-by: Yury Tokpanov --- requirements/common.txt | 2 +- requirements/test.in | 2 +- requirements/test.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 9d3249e20dfe..d08ef253828b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -6,7 +6,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.49.0 # Required for Zamba2 models and Transformers backend. +transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. diff --git a/requirements/test.in b/requirements/test.in index eb86f6d1a43f..faa4564eaa39 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -30,7 +30,7 @@ matplotlib # required for qwen-vl test mistral_common[opencv] >= 1.5.4 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -transformers>=4.49.0 +transformers==4.48.2 # quantization bitsandbytes>=0.45.3 buildkite-test-collector==0.1.9 diff --git a/requirements/test.txt b/requirements/test.txt index df25b2bb97c0..c733364fd871 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -633,7 +633,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.49.0 +transformers==4.48.2 # via # -r requirements/test.in # genai-perf From b34a384451a4f4a351f711adb2a56fef8dce3e84 Mon Sep 17 00:00:00 2001 From: Yury Tokpanov Date: Tue, 18 Mar 2025 07:21:38 +0000 Subject: [PATCH 10/10] Add SupportsV0Only and update list of supported models Signed-off-by: Yury Tokpanov --- docs/source/models/supported_models.md | 5 +++++ vllm/model_executor/models/zamba2.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 2d7617d9ebab..63372e314f5a 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * ✅︎ * ✅︎ +- * `Zamba2ForCausalLM` + * Zamba2 + * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. + * + * ::: :::{note} diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index ec02459985e7..7e210244f794 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -38,7 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import HasInnerState, IsHybrid +from .interfaces import HasInnerState, IsHybrid, SupportsV0Only from .utils import maybe_prefix @@ -778,7 +778,7 @@ def forward( return hidden_states -class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): +class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): """Zamba2 model with causal language modeling head. This class wraps the core Zamba2 model and adds: