Skip to content

Commit a676e66

Browse files
authored
[Bugfix] fix apply_temperature to avoid nan in probs (vllm-project#24734)
Signed-off-by: courage17340 <courage17340@163.com>
1 parent c85be1f commit a676e66

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

vllm/v1/sample/sampler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ def apply_temperature(
128128
self,
129129
logits: torch.Tensor,
130130
temp: torch.Tensor,
131+
all_random: bool,
131132
) -> torch.Tensor:
132133
# Use in-place division to avoid creating a new tensor.
134+
# Avoid division by zero if there are greedy requests.
135+
if not all_random:
136+
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
133137
return logits.div_(temp.unsqueeze(dim=1))
134138

135139
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
@@ -164,7 +168,8 @@ def sample(
164168
assert sampling_metadata.temperature is not None
165169

166170
# Apply temperature.
167-
logits = self.apply_temperature(logits, sampling_metadata.temperature)
171+
logits = self.apply_temperature(logits, sampling_metadata.temperature,
172+
sampling_metadata.all_random)
168173

169174
# Apply logits processors that only apply to random sampling
170175
# (argmax invariant)

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def add_request(
354354
and is_spec_decode_unsupported(sampling_params)):
355355
self.spec_decode_unsupported_reqs.add(req_id)
356356
if sampling_params.sampling_type == SamplingType.GREEDY:
357-
# Avoid later division by zero.
358-
self.temperature_cpu[req_index] = -1.0
357+
# Should avoid division by zero later when apply_temperature.
358+
self.temperature_cpu[req_index] = 0.0
359359
self.greedy_reqs.add(req_id)
360360
else:
361361
self.temperature_cpu[req_index] = sampling_params.temperature

0 commit comments

Comments
 (0)