From 568426f5efb3f88b889faee6000de740c2437966 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 20 Aug 2025 12:44:48 +0000 Subject: [PATCH] [Eagle3] Add Eagle3 verifier support to Qwen2ForCausalLM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement SupportsEagle3 interface for Qwen2ForCausalLM - Add set_aux_hidden_state_layers() and get_eagle3_aux_hidden_state_layers() methods - Qwen2 models now support Eagle3 speculative decoding Changes: - Import SupportsEagle3 interface - Update class declaration to inherit from SupportsEagle3 - Add Eagle3 auxiliary hidden state layer management methods - Use standard layer selection pattern: (2, num_layers // 2, num_layers - 3) Tested with: ./local/validate_eagle3_support.sh qwen2 Qwen2ForCausalLM qwen All validation checks passed ✅ 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude Signed-off-by: Rahul Tuli --- vllm/model_executor/models/qwen2.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7304fbf120cc..05cd564cbf37 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -439,7 +439,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -485,6 +485,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + """Set auxiliary hidden state layers for Eagle3 speculation.""" + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + """Get the layer indices for Eagle3 auxiliary hidden states.""" + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor,