Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ class PositionalEmbeddingParams:

# mRoPE params (currently, Qwen2/2.5-VL uses it)
mrope_section: Optional[List[int]] = None
mrope_interleaved: bool = False

def __post_init__(self) -> None:
if self.type.is_deferred():
Expand Down
48 changes: 13 additions & 35 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_qwen3_next import Qwen3NextForCausalLM
from .modeling_qwen3vl import Qwen3VLModelTRT
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_seedoss import SeedOssForCausalLM
from .modeling_siglip import SiglipVisionModel
Expand All @@ -35,41 +36,18 @@

# Note: for better readiblity, this should have same order as imports above
__all__ = [
"AutoModelForCausalLM",
"BertForSequenceClassification",
"CLIPVisionModel",
"DeepseekV3ForCausalLM",
"Exaone4ForCausalLM",
"Gemma3ForCausalLM",
"Gemma3VLM",
"HCXVisionForCausalLM",
"HunYuanDenseV1ForCausalLM",
"HunYuanMoEV1ForCausalLM",
"LlamaForCausalLM",
"LlavaNextModel",
"Mistral3VLM",
"MistralForCausalLM",
"MixtralForCausalLM",
"NemotronH_Nano_VL_V2",
"NemotronForCausalLM",
"NemotronHForCausalLM",
"NemotronNASForCausalLM",
"Phi3ForCausalLM",
"Phi4MMForCausalLM",
"Qwen2ForCausalLM",
"Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel",
"Qwen2MoeForCausalLM",
"SiglipVisionModel",
"get_model_architecture",
"VilaModel",
"Qwen2VLModel",
"Qwen2_5_VLModel",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"GptOssForCausalLM",
"SeedOssForCausalLM",
"AutoModelForCausalLM", "BertForSequenceClassification", "CLIPVisionModel",
"DeepseekV3ForCausalLM", "Exaone4ForCausalLM", "Gemma3ForCausalLM",
"Gemma3VLM", "HCXVisionForCausalLM", "HunYuanDenseV1ForCausalLM",
"HunYuanMoEV1ForCausalLM", "LlamaForCausalLM", "LlavaNextModel",
"Mistral3VLM", "MistralForCausalLM", "MixtralForCausalLM",
"NemotronH_Nano_VL_V2", "NemotronForCausalLM", "NemotronHForCausalLM",
"NemotronNASForCausalLM", "Phi3ForCausalLM", "Phi4MMForCausalLM",
"Qwen2ForCausalLM", "Qwen2ForProcessRewardModel", "Qwen2ForRewardModel",
"Qwen2MoeForCausalLM", "SiglipVisionModel", "get_model_architecture",
"VilaModel", "Qwen2VLModel", "Qwen2_5_VLModel", "Qwen3ForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3NextForCausalLM", "GptOssForCausalLM",
"SeedOssForCausalLM", "Qwen3VLModelTRT"
]

if transformers.__version__ >= "4.45.1":
Expand Down
19 changes: 17 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def __init__(
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(pos_type),
rope=RopeParams.from_config(config),
)
mrope_section=config.rope_scaling.get("mrope_section", None),
mrope_interleaved=config.rope_scaling.get(
"mrope_interleaved", False))
if config.rope_scaling.get("mrope_interleaved", False):
fuse_qk_norm_rope = False
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
Expand Down Expand Up @@ -114,6 +118,7 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
mrope_config: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
Expand All @@ -130,6 +135,7 @@ def forward(
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not self.disable_allreduce),
mrope_config=mrope_config,
**kwargs,
)

Expand Down Expand Up @@ -184,6 +190,9 @@ def forward(
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
mrope_config: Optional[dict] = None,
# args for deepstack
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -197,14 +206,20 @@ def forward(
hidden_states = inputs_embeds

residual = None
for decoder_layer in self.layers:
for layer_idx, decoder_layer in enumerate(self.layers):
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
mrope_config=mrope_config,
)
# add visual features to the hidden states of first several layers
if deepstack_visual_embeds is not None and layer_idx in range(
len(deepstack_visual_embeds)):
hidden_states = hidden_states + deepstack_visual_embeds[
layer_idx]
Comment on lines +209 to +222
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Deepstack visual embeds must be applied before speculative capture.

Qwen3DecoderLayer pushes its hidden/residual pair into SpecMetadata before it returns. Because the deepstack addition happens out here after the layer call, the tensors cached for speculative decoding miss the visual contribution, so replaying the layer during speculation diverges from the runtime path as soon as deepstack is enabled. Please pass the per-layer visual tensor down into the decoder layer and add it before maybe_capture_hidden_states() so both execution paths see the same activations.

Apply this diff:

@@
-        spec_metadata: Optional[SpecMetadata] = None,
-        mrope_config: Optional[dict] = None,
+        spec_metadata: Optional[SpecMetadata] = None,
+        mrope_config: Optional[dict] = None,
+        deepstack_visual_embed: Optional[torch.Tensor] = None,
         **kwargs,
     ) -> torch.Tensor:
@@
-        hidden_states = self.mlp(
+        hidden_states = self.mlp(
             hidden_states,
             all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
             final_all_reduce_params=AllReduceParams(
                 enable_allreduce=not self.disable_allreduce),
             cutlass_min_latency_mode=False,
         )
 
+        if deepstack_visual_embed is not None:
+            hidden_states = hidden_states + deepstack_visual_embed
+
         if spec_metadata is not None:
             spec_metadata.maybe_capture_hidden_states(self.layer_idx,
                                                       hidden_states, residual)
@@
-        for layer_idx, decoder_layer in enumerate(self.layers):
-            hidden_states, residual = decoder_layer(
+        for layer_idx, decoder_layer in enumerate(self.layers):
+            visual_embed = None
+            if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
+                visual_embed = deepstack_visual_embeds[layer_idx]
+
+            hidden_states, residual = decoder_layer(
                 position_ids=position_ids,
                 hidden_states=hidden_states,
                 attn_metadata=attn_metadata,
                 residual=residual,
                 spec_metadata=spec_metadata,
-                mrope_config=mrope_config,
+                mrope_config=mrope_config,
+                deepstack_visual_embed=visual_embed,
             )
-            # add visual features to the hidden states of first several layers
-            if deepstack_visual_embeds is not None and layer_idx in range(
-                    len(deepstack_visual_embeds)):
-                hidden_states = hidden_states + deepstack_visual_embeds[
-                    layer_idx]

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen3.py around lines 210 to 223, the
deepstack visual embeddings are added to hidden_states after the decoder_layer
returns, so SpecMetadata captured inside Qwen3DecoderLayer misses the visual
contribution; modify the call to pass the per-layer visual tensor (e.g.,
deepstack_visual_embeds[layer_idx] or None) into decoder_layer and change
Qwen3DecoderLayer to add that visual tensor to its hidden_states before calling
maybe_capture_hidden_states()/pushing into SpecMetadata so both normal and
speculative paths see identical activations.


hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import triton
import triton.language as tl
from torch import nn
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation

Expand Down Expand Up @@ -319,7 +318,8 @@ def __init__(
self.mlp_only_layers = mlp_only_layers


AutoConfig.register("qwen3_next", Qwen3NextConfig)
# since update transformers to 4.57.0, we do not need register it for autoconfig
# AutoConfig.register("qwen3_next", Qwen3NextConfig)


class Qwen3NextGate(nn.Module):
Expand Down
Loading