1313 ParallelLMHead , VocabParallelEmbedding )
1414from vllm .model_executor .model_loader .weight_utils import default_weight_loader
1515
16+ from .utils import maybe_prefix
17+
1618SQRT2 = 2 ** 0.5
1719
1820
@@ -97,8 +99,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
9799 self .proj = nn .ModuleList ([proj_first ] + [proj_tied ] *
98100 (self .max_speculative_tokens - 1 ))
99101
100- head = ParallelLMHead (self .vocab_size , self .inner_dim , bias = False )
101- self .head = nn .ModuleList ([head ] * self .max_speculative_tokens )
102+ self .head = nn .ModuleList ([
103+ ParallelLMHead (self .vocab_size ,
104+ self .inner_dim ,
105+ bias = False ,
106+ prefix = maybe_prefix (prefix , f"head.{ i } " ))
107+ for i in range (self .max_speculative_tokens )
108+ ])
102109
103110 ln = MLPSpeculatorLayerNorm (self .inner_dim ,
104111 elementwise_scale_and_shift = True )
@@ -120,8 +127,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
120127 ])
121128
122129 self .head = nn .ModuleList ([
123- ParallelLMHead (self .vocab_size , self .inner_dim , bias = False )
124- for _ in range (self .max_speculative_tokens )
130+ ParallelLMHead (self .vocab_size ,
131+ self .inner_dim ,
132+ bias = False ,
133+ prefix = maybe_prefix (prefix , f"head.{ i } " ))
134+ for i in range (self .max_speculative_tokens )
125135 ])
126136 self .ln = nn .ModuleList ([
127137 MLPSpeculatorLayerNorm (self .inner_dim ,
0 commit comments