From 2a3ce8cefca5fb446192e86ec7feba90ffffa761 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 1 Dec 2025 10:41:52 +0800 Subject: [PATCH] eagle reconstruct Signed-off-by: zhenwenqi2024 --- vllm_ascend/patch/platform/patch_config.py | 228 --------------------- vllm_ascend/worker/model_runner_v1.py | 37 +--- vllm_ascend/worker/worker_v1.py | 1 + 3 files changed, 6 insertions(+), 260 deletions(-) delete mode 100644 vllm_ascend/patch/platform/patch_config.py diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py deleted file mode 100644 index 0e8642d1cea..00000000000 --- a/vllm_ascend/patch/platform/patch_config.py +++ /dev/null @@ -1,228 +0,0 @@ -import ast - -from vllm.config.speculative import SpeculativeConfig -from vllm.logger import logger - - -def __post_init__(self): - - # Note: "method" is a new parameter that helps to extend the - # configuration of non-model-based proposers, and the "model" parameter - # will be used to set the draft model, eagle head, or additional weight - # when needed. If users do not specify "method", the speculative method - # will be detected automatically if possible. If the speculative method - # can not be detected, it will be considered as the "draft_model" by - # default. - - if self.model is None and self.num_speculative_tokens is not None: - # TODO(Shangming): Refactor mtp configuration logic when supporting - if (self.target_model_config - and self.target_model_config.hf_text_config.model_type - in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe", - "qwen3_next")): - # use the draft model from the same model: - self.model = self.target_model_config.model - # Align the quantization of draft model for cases such as - # --quantization fp8 with a bf16 checkpoint. - if not self.quantization: - self.quantization = self.target_model_config.quantization - elif self.method in ("ngram", "[ngram]"): - self.model = "ngram" - else: - raise ValueError("num_speculative_tokens was provided but without " - "speculative model.") - - # Automatically configure the method for ngram when "model" is used - # instead of "method" - if self.method is None and (self.model is not None - and self.model in ("ngram", "[ngram]")): - self.method = "ngram" - - if self.method in ("ngram", "[ngram]"): - # Unified to "ngram" internally - self.method = "ngram" - # Set default values if not provided - if (self.prompt_lookup_min is None and self.prompt_lookup_max is None): - # TODO(woosuk): Tune these values. They are arbitrarily chosen. - self.prompt_lookup_min = 5 - self.prompt_lookup_max = 5 - elif self.prompt_lookup_min is None: - assert self.prompt_lookup_max is not None - self.prompt_lookup_min = self.prompt_lookup_max - elif self.prompt_lookup_max is None: - assert self.prompt_lookup_min is not None - self.prompt_lookup_max = self.prompt_lookup_min - - # Validate values - if self.prompt_lookup_min < 1: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") - if self.prompt_lookup_max < 1: - raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") - if self.prompt_lookup_min > self.prompt_lookup_max: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must " - f"be <= prompt_lookup_max={self.prompt_lookup_max}") - - # TODO: current we still need extract vocab_size from target model - # config, in future, we may try refactor it out, and set - # draft related config as None here. - self.draft_model_config = self.target_model_config - self.draft_parallel_config = self.target_parallel_config - else: - self.prompt_lookup_max = 0 - self.prompt_lookup_min = 0 - - if self.model is not None: - # TODO: Move this import to the top once `ModelConfig` - # lives in `vllm.config.model`. - from vllm.config import ModelConfig - self.draft_model_config = ModelConfig( - model=self.model, - runner="draft", - tokenizer=self.target_model_config.tokenizer, - tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config.trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - allowed_media_domains=self.target_model_config. - allowed_media_domains, - dtype=self.target_model_config.dtype, - seed=self.target_model_config.seed, - revision=self.revision, - code_revision=self.code_revision, - tokenizer_revision=self.target_model_config.tokenizer_revision, - spec_target_max_model_len=self.target_model_config. - max_model_len, - quantization=self.quantization, - enforce_eager=self.target_model_config.enforce_eager, - max_logprobs=self.target_model_config.max_logprobs, - hf_overrides=SpeculativeConfig.hf_config_override, - ) - - # Automatically detect the method - if self.method in ('eagle', 'eagle3'): - pass - # examples: - # yuhuili/EAGLE-LLaMA3-Instruct-8B - # yuhuili/EAGLE3-LLaMA3.1-Instruct-8B - # AngelSlim/Qwen3-8B_eagle3 - elif "eagle-" in self.draft_model_config.model.lower(): - self.method = "eagle" - elif "eagle3" in self.draft_model_config.model.lower(): - self.method = "eagle3" - elif self.draft_model_config.hf_config.model_type == "medusa": - self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): - self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type - in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): - self.method = "deepseek_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Deepseek MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"): - self.method = "ernie_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Ernie MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type == - "qwen3_next_mtp"): - self.method = "qwen3_next_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "All Qwen3Next MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - elif (self.draft_model_config.hf_config.model_type - in ("longcat_flash_mtp")): - self.method = "longcat_flash_mtp" - if self.num_speculative_tokens > 1: - logger.warning( - "LongCat MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) - else: - self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") - - # Replace hf_config for EAGLE draft_model - if self.method in ("eagle", "eagle3"): - from vllm.transformers_utils.configs import SpeculatorsConfig - from vllm.transformers_utils.configs.eagle import EAGLEConfig - - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): - pass - else: - eagle_config = EAGLEConfig( - self.draft_model_config.hf_config, - method=self.method, - model_type="eagle") - self.draft_model_config.hf_config = eagle_config - - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens - - n_predict = getattr(self.draft_model_config.hf_config, "n_predict", - None) - if n_predict is not None: - if self.num_speculative_tokens is None: - # Default to max value defined in draft model config. - self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: - # Ensure divisibility for MTP module reuse. - raise ValueError( - f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") - - if self.speculative_token_tree is None: - # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(self.num_speculative_tokens) - ]) - else: - # Sort the token tree breadth-first. - tree_choices = ast.literal_eval(self.speculative_token_tree) - self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) - - self.draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_tp( - self.target_parallel_config, - self.draft_tensor_parallel_size, - self.draft_model_config.hf_config - ) - - self.draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - self.max_model_len, - self.draft_model_config.max_model_len, - self.target_model_config.max_model_len, - )) - - self.draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) - - -SpeculativeConfig.__post_init__ = __post_init__ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ce5848b3495..4872c04dbd1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -58,7 +58,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.models.interfaces import (SupportsMultiModal, supports_mrope, supports_transcription) @@ -103,6 +103,7 @@ gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.v1.worker.gpu_model_runner import GPUModelRunner import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -269,7 +270,7 @@ class ExecuteModelState(NamedTuple): positions: torch.Tensor -class NPUModelRunner(LoRAModelRunnerMixin): +class NPUModelRunner(GPUModelRunner, LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config @@ -386,6 +387,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.device) self._set_up_drafter() + self.use_aux_hidden_state_outputs = False # kv role self.is_kv_producer = False @@ -3175,36 +3177,7 @@ def eplb_warmup(self): self.eplb_updator.set_adaptor(self.eplb_adaptor) self.eplb_updator.warm_up_eplb() - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - - with DeviceMemoryProfiler() as m: # noqa: SIM117 - self.model = get_model(vllm_config=self.vllm_config) - if self.dynamic_eplb: - model_register(self.model, self.model_config) - if get_ascend_device_type() == AscendDeviceType._310P: - from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) - for module in self.model.modules(): - if isinstance(module, - (MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear)): - module.weight.data = self._convert_torch_format( - module.weight.data) - if self.drafter: - logger.info("Loading drafter model...") - self.drafter.load_model(self.model) - if self.drafter.name == SpecDcodeType.EAGLE3: - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - - if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) - logger.info("Loading model weights took %.4f GB", - m.consumed_memory / float(2**30)) - + def aclgraph_wrapper(self): # wrap the model with full graph wrapper if needed. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.update_stream: torch.npu.Stream = torch.npu.Stream() diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ef3f2e49cb3..d7a9be59de6 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -337,6 +337,7 @@ def load_model(self) -> None: context = nullcontext() # type: ignore with context: self.model_runner.load_model() + self.model_runner.aclgraph_wrapper() def compile_or_warm_up_model(self) -> None: # Note: need to adapt for graph mode.