Skip to content

Commit a51f418

Browse files
ZJY0516ywang96
andauthored
[Bugfix] fix dots.llm1.inst (#29687)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent 7675ba3 commit a51f418

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm/model_executor/models/dots1.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181181
hidden_states = hidden_states.view(-1, hidden_dim)
182182

183183
router_logits, _ = self.gate(hidden_states)
184-
final_hidden_states = (
185-
self.experts(hidden_states=hidden_states, router_logits=router_logits)
186-
* self.routed_scaling_factor
187-
)
188184

185+
shared_out, routed_out = self.experts(
186+
hidden_states=hidden_states, router_logits=router_logits
187+
)
189188
if self.shared_experts is not None:
190-
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
189+
final_hidden_states = (routed_out + shared_out) * self.routed_scaling_factor
190+
else:
191+
final_hidden_states = routed_out * self.routed_scaling_factor
191192

192193
if self.tp_size > 1:
193194
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

0 commit comments

Comments
 (0)