Skip to content

Conversation

@nv-guomingz
Copy link
Collaborator

@nv-guomingz nv-guomingz commented Nov 11, 2025

For Q3N, there're only 12 attention layers which 1/4 of total layers use kv cache.

Summary by CodeRabbit

  • New Features
    • Added support for Qwen3Next model architecture with specialized attention layer configuration.

@nv-guomingz nv-guomingz requested a review from byshiue November 11, 2025 11:48
@nv-guomingz nv-guomingz force-pushed the user/guomingz/fix_qwen3_next_kv_cache branch from e6ada5b to db4653b Compare November 12, 2025 14:45
@nv-guomingz nv-guomingz marked this pull request as ready for review November 12, 2025 15:19
@nv-guomingz nv-guomingz requested a review from a team as a code owner November 12, 2025 15:19
@nv-guomingz nv-guomingz force-pushed the user/guomingz/fix_qwen3_next_kv_cache branch from db4653b to 98c088c Compare November 12, 2025 15:19
@nv-guomingz
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 12, 2025

📝 Walkthrough

Walkthrough

Adds architecture-specific handling for Qwen3NextForCausalLM in the get_num_attention_layers function. The change introduces a conditional branch that calculates full-attention layers by dividing num_hidden_layers by full_attention_interval when the model architecture matches this type.

Changes

Cohort / File(s) Change Summary
Hybrid attention layer calculation
tensorrt_llm/_torch/model_config.py
Added conditional branch in get_num_attention_layers to handle Qwen3NextForCausalLM architecture, computing full-attention layers as num_hidden_layers / full_attention_interval alongside existing nemotron hybrid logic

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

  • The change follows an existing pattern (nemotron hybrid handling) with minimal added complexity
  • Localized modification to a single function without side effects on other code paths
  • Verify that the full_attention_interval value is appropriate for Qwen3NextForCausalLM and consistent with model documentation

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ⚠️ Warning PR description is extremely minimal and lacks required sections from template including detailed explanation, test coverage information, and checklist items. Complete the PR description following the template: add detailed problem/solution explanation in Description section, list relevant tests in Test Coverage section, and review PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: updating attention layer counting logic for the Qwen3-next architecture.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 264d38e and 98c088c.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/model_config.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/model_config.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/model_config.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/model_config.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.

Applied to files:

  • tensorrt_llm/_torch/model_config.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.

Applied to files:

  • tensorrt_llm/_torch/model_config.py
🔇 Additional comments (1)
tensorrt_llm/_torch/model_config.py (1)

650-650: Integer division is correct—no changes needed.

The calculation num_hidden_layers // full_attention_interval properly counts full attention layers in the Qwen3-Next hybrid pattern. With default interval of 4, full attention occurs at positions where (layer_index + 1) % 4 == 0, yielding exactly num_hidden_layers // 4 layers. Incomplete intervals correctly use linear attention, which is the intended design.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24313 [ run ] triggered by Bot. Commit: 98c088c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24313 [ run ] completed with state SUCCESS. Commit: 98c088c
/LLM/main/L0_MergeRequest_PR pipeline #18344 completed with status: 'FAILURE'

@nv-guomingz nv-guomingz force-pushed the user/guomingz/fix_qwen3_next_kv_cache branch from 98c088c to a1b0b40 Compare November 13, 2025 01:09
@nv-guomingz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24351 [ run ] triggered by Bot. Commit: a1b0b40

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24351 [ run ] completed with state SUCCESS. Commit: a1b0b40
/LLM/main/L0_MergeRequest_PR pipeline #18379 completed with status: 'FAILURE'

@nv-guomingz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24409 [ run ] triggered by Bot. Commit: a1b0b40

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24409 [ run ] completed with state SUCCESS. Commit: a1b0b40
/LLM/main/L0_MergeRequest_PR pipeline #18417 completed with status: 'FAILURE'

@nv-guomingz nv-guomingz force-pushed the user/guomingz/fix_qwen3_next_kv_cache branch from a1b0b40 to 51a9ccb Compare November 16, 2025 13:28
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
@nv-guomingz nv-guomingz force-pushed the user/guomingz/fix_qwen3_next_kv_cache branch from 51a9ccb to 10518b8 Compare November 16, 2025 16:25
@nv-guomingz
Copy link
Collaborator Author

/bot run

@nv-guomingz nv-guomingz enabled auto-merge (squash) November 16, 2025 16:25
@tensorrt-cicd
Copy link
Collaborator

PR_Github #24685 [ run ] triggered by Bot. Commit: 10518b8

@tensorrt-cicd
Copy link
Collaborator

PR_Github #24685 [ run ] completed with state SUCCESS. Commit: 10518b8
/LLM/main/L0_MergeRequest_PR pipeline #18640 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@nv-guomingz nv-guomingz merged commit e0f6965 into NVIDIA:main Nov 16, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants