Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `Zamba2ForCausalLM`
* Zamba2
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.
*
*
:::

:::{note}
Expand Down
51 changes: 29 additions & 22 deletions tests/models/decoder_only/language/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...utils import check_outputs_equal

# This test is for the hybrid models
MODELS = ["ai21labs/Jamba-tiny-dev"]
MODELS = ["ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct"]
# Bamba at Fp32 is too big for the CI (L4 GPU).
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]

Expand All @@ -27,17 +27,19 @@ def test_models(
) -> None:

# numeric error produces different generation
if 'Bamba' in model:
if "Bamba" in model:
example_prompts.pop(3)

with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
Expand Down Expand Up @@ -112,26 +114,31 @@ def test_mamba_prefill_chunking_with_parallel_sampling(
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int) -> None:
# numeric error during prefill chucking produces different generation
# numeric error during prefill chunking produces different generation
# compared to w/o prefill chunking for those examples, removed them for now
if 'Jamba' in model:
if "Jamba" in model:
example_prompts.pop(7)
example_prompts.pop(2)
example_prompts.pop(1)
elif 'Bamba' in model:
elif "Bamba" in model:
example_prompts.pop(6)
example_prompts.pop(3)
example_prompts.pop(2)
dtype = "half" # use a different dtype for Bamba

with hf_runner(
model,
dtype=dtype,
model_kwargs={
"use_mamba_kernels":
False, # mamba kernels are not installed so HF
# don't use them
}) as hf_model:
elif "Zamba2" in model:
example_prompts.pop(7)
dtype = "half"

model_kwargs = {
"use_mamba_kernels": False, # mamba kernels are not installed so HF
# don't use them
}
if "Zamba2" in model:
# Zamba2 HF implementation automatically checks if mamba kernels are
# installed
model_kwargs = {}

with hf_runner(model, dtype=dtype, model_kwargs=model_kwargs) as hf_model:
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def check_available_online(
"XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat",
is_available_online=False,
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct",
min_transformers_version="4.49"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
Expand Down
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,11 @@ def get_head_size(self) -> int:
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim

if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
return self.hf_text_config.attention_head_dim

if self.is_attention_free:
return 0

Expand Down Expand Up @@ -942,6 +947,15 @@ def get_num_layers_by_block_type(
"cannot determine the num of "
f"{block_type.value} layers")

if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
== "zamba2"):
if attn_block_type:
return sum(t == "hybrid"
for t in layers_block_type_value[start:end])
else:
return self.get_num_layers(parallel_config)

return sum(t == block_type.value
for t in layers_block_type_value[start:end])

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def __init__(self,
assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."


assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
(
"If tensor parallel world size does not divide num_heads, "
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]


class BambaMLP(nn.Module):

Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]


class JambaMoE(nn.Module):

Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
Expand Down
Loading