Skip to content

Commit cbf9221

Browse files
authored
[Model] Supplement to PR 24862: Pass param prefix to LLMHead (vllm-project#25805)
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 5f42fc5 commit cbf9221

File tree

8 files changed

+35
-12
lines changed

8 files changed

+35
-12
lines changed

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ class SharedHead(nn.Module):
2828
def __init__(
2929
self,
3030
config: PretrainedConfig,
31+
prefix: str,
3132
quant_config: Optional[QuantizationConfig] = None,
3233
) -> None:
3334
super().__init__()
3435
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
3536
self.head = ParallelLMHead(config.vocab_size,
3637
config.hidden_size,
37-
quant_config=quant_config)
38+
quant_config=quant_config,
39+
prefix=maybe_prefix(prefix, "head"))
3840

3941
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4042
return self.norm(hidden_states)
@@ -64,7 +66,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
6466
device="cuda")
6567
else:
6668
topk_indices_buffer = None
67-
self.shared_head = SharedHead(config=config, quant_config=quant_config)
69+
self.shared_head = SharedHead(config=config,
70+
prefix=prefix,
71+
quant_config=quant_config)
6872
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
6973
topk_indices_buffer)
7074

vllm/model_executor/models/glm4_moe_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ class SharedHead(nn.Module):
5050
def __init__(
5151
self,
5252
config: PretrainedConfig,
53+
prefix: str,
5354
quant_config: Optional[QuantizationConfig] = None,
5455
) -> None:
5556
super().__init__()
5657
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5758
self.head = ParallelLMHead(config.vocab_size,
5859
config.hidden_size,
59-
quant_config=quant_config)
60+
quant_config=quant_config,
61+
prefix=maybe_prefix(prefix, "head"))
6062

6163
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6264
return self.norm(hidden_states)
@@ -77,7 +79,9 @@ def __init__(
7779
self.eh_proj = nn.Linear(config.hidden_size * 2,
7880
config.hidden_size,
7981
bias=False)
80-
self.shared_head = SharedHead(config=config, quant_config=quant_config)
82+
self.shared_head = SharedHead(config=config,
83+
prefix=prefix,
84+
quant_config=quant_config)
8185
self.mtp_block = Glm4MoeDecoderLayer(config=config,
8286
cache_config=cache_config,
8387
quant_config=quant_config,

vllm/model_executor/models/gpt_neox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
296296
config.vocab_size,
297297
config.hidden_size,
298298
quant_config=quant_config,
299+
prefix=maybe_prefix(prefix, "embed_out"),
299300
)
300301
if self.config.tie_word_embeddings:
301302
self.embed_out.weight = self.gpt_neox.embed_in.weight

vllm/model_executor/models/longcat_flash_mtp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
140140
self.config.vocab_size,
141141
self.config.hidden_size,
142142
quant_config=self.quant_config,
143+
prefix=maybe_prefix(prefix, "lm_head"),
143144
)
144145
self.logits_processor = LogitsProcessor(self.config.vocab_size)
145146

vllm/model_executor/models/medusa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
8282
config.hidden_size,
8383
org_num_embeddings=self.truncated_vocab_size,
8484
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
85-
) for _ in range(self.config.num_heads)
85+
prefix=maybe_prefix(prefix, f"lm_heads.{i}"),
86+
) for i in range(self.config.num_heads)
8687
])
8788

8889
logit_scale = getattr(config, "logit_scale", 1.0)

vllm/model_executor/models/mlp_speculator.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
ParallelLMHead, VocabParallelEmbedding)
1414
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1515

16+
from .utils import maybe_prefix
17+
1618
SQRT2 = 2**0.5
1719

1820

@@ -97,8 +99,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
9799
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
98100
(self.max_speculative_tokens - 1))
99101

100-
head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
101-
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
102+
self.head = nn.ModuleList([
103+
ParallelLMHead(self.vocab_size,
104+
self.inner_dim,
105+
bias=False,
106+
prefix=maybe_prefix(prefix, f"head.{i}"))
107+
for i in range(self.max_speculative_tokens)
108+
])
102109

103110
ln = MLPSpeculatorLayerNorm(self.inner_dim,
104111
elementwise_scale_and_shift=True)
@@ -120,8 +127,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
120127
])
121128

122129
self.head = nn.ModuleList([
123-
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
124-
for _ in range(self.max_speculative_tokens)
130+
ParallelLMHead(self.vocab_size,
131+
self.inner_dim,
132+
bias=False,
133+
prefix=maybe_prefix(prefix, f"head.{i}"))
134+
for i in range(self.max_speculative_tokens)
125135
])
126136
self.ln = nn.ModuleList([
127137
MLPSpeculatorLayerNorm(self.inner_dim,

vllm/model_executor/models/qwen3_vl_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
296296
prefix=maybe_prefix(prefix, "model"))
297297
self.lm_head = ParallelLMHead(self.config.vocab_size,
298298
self.config.hidden_size,
299-
quant_config=self.quant_config)
299+
quant_config=self.quant_config,
300+
prefix=maybe_prefix(prefix, "lm_head"))
300301
if self.config.tie_word_embeddings:
301302
self.lm_head.weight = self.model.embed_tokens.weight
302303
self.logits_processor = LogitsProcessor(self.config.vocab_size)

vllm/model_executor/models/whisper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
4646
SupportsTranscription)
4747
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
48-
make_layers)
48+
make_layers, maybe_prefix)
4949

5050
logger = init_logger(__name__)
5151

@@ -885,7 +885,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
885885
self.unpadded_vocab_size = config.vocab_size
886886
self.proj_out = ParallelLMHead(config.vocab_size,
887887
config.d_model,
888-
quant_config=quant_config)
888+
quant_config=quant_config,
889+
prefix=maybe_prefix(prefix, "proj_out"))
889890
self.proj_out = self.proj_out.tie_weights(
890891
self.model.decoder.embed_tokens)
891892
logit_scale = getattr(config, "logit_scale", 1.0)

0 commit comments

Comments
 (0)