Skip to content
Draft

Qwen3VL #3385

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 122 files
1 change: 1 addition & 0 deletions python/mlc_llm/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
orion,
phi,
qwen2,
qwen3_vl,
redpajama,
rwkv,
stablelm,
Expand Down
20 changes: 20 additions & 0 deletions python/mlc_llm/conversation_template/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Qwen3-VL default templates"""

from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders

from .registry import ConvTemplateRegistry

# Qwen3-VL
ConvTemplateRegistry.register_conv_template(
Conversation(
name="qwen3_vl",
system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n",
system_message="You are a helpful assistant.",
roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"},
seps=["<|im_end|>\n"],
role_content_sep="\n",
role_empty_sep="\n",
stop_str=["<|endoftext|>", "<|im_end|>"],
stop_token_ids=[151643, 151645],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"chatml",
"chatml_nosystem",
"qwen2",
"qwen3_vl",
"open_hermes_mistral",
"neural_hermes_mistral",
"llama_default",
Expand Down
13 changes: 13 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization
from .qwen3 import qwen3_loader, qwen3_model, qwen3_quantization
from .qwen3_moe import qwen3_moe_loader, qwen3_moe_model, qwen3_moe_quantization
from .qwen3_vl import qwen3_vl_loader, qwen3_vl_model, qwen3_vl_quantization
from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization
from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
Expand Down Expand Up @@ -374,6 +375,18 @@ class Model:
"block-scale-quant": qwen3_moe_quantization.block_scale_quant,
},
),
"qwen3_vl": Model(
name="qwen3_vl",
model=qwen3_vl_model.Qwen3VLForConditionalGeneration,
config=qwen3_vl_model.Qwen3VLConfig,
source={
"huggingface-torch": qwen3_vl_loader.huggingface,
"huggingface-safetensor": qwen3_vl_loader.huggingface,
},
quantize={
"no-quant": qwen3_vl_quantization.no_quant,
},
),
"deepseek_v2": Model(
name="deepseek_v2",
model=deepseek_v2_model.DeepseekV2ForCausalLM,
Expand Down
Empty file.
109 changes: 109 additions & 0 deletions python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Configuration for Qwen3-VL model.
"""

import dataclasses
from typing import Any, Dict, Optional, Tuple

from mlc_llm.model.qwen3.qwen3_model import Qwen3Config
from mlc_llm.support import logging
from mlc_llm.support.config import ConfigBase

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class Qwen3VLVisionConfig(ConfigBase):
"""Configuration for the vision module of Qwen3-VL."""

depth: int
hidden_size: int
hidden_act: str
intermediate_size: int
num_heads: int
in_channels: int
patch_size: int
spatial_merge_size: int
temporal_patch_size: int
out_hidden_size: int
num_position_embeddings: int
deepstack_visual_indexes: list[int]
initializer_range: float = 0.02
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class Qwen3VLConfig(ConfigBase):
"""Configuration for Qwen3-VL model."""

text_config: Qwen3Config
vision_config: Qwen3VLVisionConfig
image_token_id: int = 151655
video_token_id: int = 151656
vision_start_token_id: int = 151652
vision_end_token_id: int = 151653
tie_word_embeddings: bool = False
max_batch_size: int = 128
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

@property
def vocab_size(self) -> int:
return self.text_config.vocab_size

@property
def prefill_chunk_size(self) -> int:
return self.text_config.prefill_chunk_size

@property
def context_window_size(self) -> int:
return self.text_config.context_window_size

@property
def tensor_parallel_shards(self) -> int:
return self.text_config.tensor_parallel_shards

def __post_init__(self):
if isinstance(self.text_config, dict):
self.text_config = Qwen3Config.from_dict(self.text_config)
if isinstance(self.vision_config, dict):
self.vision_config = Qwen3VLVisionConfig.from_dict(self.vision_config)

@classmethod
def from_huggingface(cls, config_json: Dict[str, Any]) -> "Qwen3VLConfig":
"""Create Qwen3VLConfig from HuggingFace config."""
# Extract text config
text_config_dict = config_json.get("text_config", {})
# Ensure model_type is set correctly for Qwen3Config if needed, or just pass as is
# Qwen3Config might expect certain fields.

# Extract vision config
vision_config_dict = config_json.get("vision_config", {})

# Extract top-level fields
image_token_id = config_json.get("image_token_id", 151655)
video_token_id = config_json.get("video_token_id", 151656)
vision_start_token_id = config_json.get("vision_start_token_id", 151652)
vision_end_token_id = config_json.get("vision_end_token_id", 151653)
tie_word_embeddings = config_json.get("tie_word_embeddings", False)

return cls(
text_config=Qwen3Config.from_dict(text_config_dict),
vision_config=Qwen3VLVisionConfig.from_dict(vision_config_dict),
image_token_id=image_token_id,
video_token_id=video_token_id,
vision_start_token_id=vision_start_token_id,
vision_end_token_id=vision_end_token_id,
tie_word_embeddings=tie_word_embeddings,
kwargs=config_json,
)

# Testing command
# conda activate tvm-dev
# export LOCAL_MODEL_PATH=../mlc-models/Qwen3-VL-2B-Instruct/
# export MLC_MODEL_PATH=../mlc-models/mlc-qwen/
# export QUANTIZATION=q0f16
# export CONV_TEMPLATE=qwen3_vl
# python -m mlc_llm gen_config $LOCAL_MODEL_PATH \
# --quantization $QUANTIZATION \
# --conv-template $CONV_TEMPLATE \
# -o $MLC_MODEL_PATH
9 changes: 9 additions & 0 deletions python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Minimal loader for Qwen3-VL.
"""
from typing import Any, Dict

from mlc_llm.loader import Loader

def huggingface(model_config, quantization):
return None
10 changes: 10 additions & 0 deletions python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Minimal model for Qwen3-VL.
"""
from tvm.relax.frontend import nn
from .qwen3_vl_config import Qwen3VLConfig

class Qwen3VLForConditionalGeneration(nn.Module):
def __init__(self, config: Qwen3VLConfig):
super().__init__()
self.config = config
9 changes: 9 additions & 0 deletions python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Minimal quantization for Qwen3-VL.
"""
from typing import Any, Dict, Tuple
from tvm.relax.frontend import nn
from mlc_llm.loader import QuantizeMapping

def no_quant(model_config, quantization) -> Tuple[nn.Module, QuantizeMapping]:
return None, None
4 changes: 3 additions & 1 deletion python/mlc_llm/support/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:
An instance of the config object.
"""
field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type]
fields = {k: v for k, v in source.items() if k in field_names}
fields = {k: v for k, v in source.items() if k in field_names and k != "kwargs"}
kwargs = {k: v for k, v in source.items() if k not in field_names}
if "kwargs" in source and isinstance(source["kwargs"], dict):
kwargs.update(source["kwargs"])
return cls(**fields, kwargs=kwargs) # type: ignore[call-arg]

@classmethod
Expand Down
Loading