Skip to content

Commit 463910e

Browse files
authored
[Bugfix] use module-level import for patched function in Qwen3Next (#4354)
### What this PR does / why we need it? **Problem**: The Qwen3Next model implementation currently imports chunk_gated_delta_rule directly using `from ... import ...` In frameworks like `verl`, the model file is often imported before `vllm-ascend` initializes and applies its patches. This causes the model to permanently hold a reference to the original (unpatched) vLLM kernel, resulting in execution errors on Ascend devices even if the patch runs later. **Solution**: Changed the import style to `from vllm...ops import chunk` and call `chunk.chunk_gated_delta_rule().` This ensures that the function lookup happens at runtime (dynamic dispatch), allowing the model to correctly pick up the patched function regardless of import order. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: zjchenn <zjchenn@gmail.com>
1 parent 941d54a commit 463910e

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1717
get_tensor_model_parallel_world_size)
1818
from vllm.forward_context import get_forward_context
19-
from vllm.model_executor.layers.fla.ops import RMSNormGated
20-
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
19+
from vllm.model_executor.layers.fla.ops import RMSNormGated, chunk
2120
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
2221
fused_recurrent_gated_delta_rule
2322
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -35,8 +34,7 @@
3534
mamba_v2_sharded_weight_loader
3635
from vllm.model_executor.layers.mamba.mamba_utils import (
3736
MambaStateDtypeCalculator, MambaStateShapeCalculator)
38-
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
39-
causal_conv1d_fn, causal_conv1d_update)
37+
from vllm.model_executor.layers.mamba.ops import causal_conv1d
4038
from vllm.model_executor.layers.quantization import QuantizationConfig
4139
from vllm.model_executor.layers.vocab_parallel_embedding import (
4240
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -252,7 +250,7 @@ def _forward(
252250
mixed_qkv_spec = mixed_qkv_spec.view(
253251
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
254252
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
255-
mixed_qkv_spec = causal_conv1d_update(
253+
mixed_qkv_spec = causal_conv1d.causal_conv1d_update(
256254
mixed_qkv_spec,
257255
conv_state,
258256
conv_weights,
@@ -269,7 +267,7 @@ def _forward(
269267
if attn_metadata.num_prefills > 0:
270268
# - "cache_indices" updates the conv_state cache in positions
271269
# pointed to by "mamba_cache_params.state_indices_tensor"
272-
mixed_qkv_non_spec = causal_conv1d_fn(
270+
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_fn(
273271
mixed_qkv_non_spec.transpose(0, 1),
274272
conv_weights,
275273
self.conv1d.bias,
@@ -280,7 +278,7 @@ def _forward(
280278
query_start_loc=non_spec_query_start_loc,
281279
).transpose(0, 1)
282280
elif attn_metadata.num_decodes > 0:
283-
mixed_qkv_non_spec = causal_conv1d_update(
281+
mixed_qkv_non_spec = causal_conv1d.causal_conv1d_update(
284282
mixed_qkv_non_spec,
285283
conv_state,
286284
conv_weights,
@@ -364,7 +362,7 @@ def _forward(
364362
(
365363
cur_core_attn_out_non_spec,
366364
cur_last_recurrent_state,
367-
) = chunk_gated_delta_rule(
365+
) = chunk.chunk_gated_delta_rule(
368366
query=cur_q,
369367
key=cur_k,
370368
value=cur_v,

0 commit comments

Comments
 (0)