File tree Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Expand file tree Collapse file tree 2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -128,8 +128,12 @@ def apply_temperature(
128128 self ,
129129 logits : torch .Tensor ,
130130 temp : torch .Tensor ,
131+ all_random : bool ,
131132 ) -> torch .Tensor :
132133 # Use in-place division to avoid creating a new tensor.
134+ # Avoid division by zero if there are greedy requests.
135+ if not all_random :
136+ temp = torch .where (temp < _SAMPLING_EPS , 1.0 , temp )
133137 return logits .div_ (temp .unsqueeze (dim = 1 ))
134138
135139 def greedy_sample (self , logits : torch .Tensor ) -> torch .Tensor :
@@ -164,7 +168,8 @@ def sample(
164168 assert sampling_metadata .temperature is not None
165169
166170 # Apply temperature.
167- logits = self .apply_temperature (logits , sampling_metadata .temperature )
171+ logits = self .apply_temperature (logits , sampling_metadata .temperature ,
172+ sampling_metadata .all_random )
168173
169174 # Apply logits processors that only apply to random sampling
170175 # (argmax invariant)
Original file line number Diff line number Diff line change @@ -354,8 +354,8 @@ def add_request(
354354 and is_spec_decode_unsupported (sampling_params )):
355355 self .spec_decode_unsupported_reqs .add (req_id )
356356 if sampling_params .sampling_type == SamplingType .GREEDY :
357- # Avoid later division by zero.
358- self .temperature_cpu [req_index ] = - 1 .0
357+ # Should avoid division by zero later when apply_temperature .
358+ self .temperature_cpu [req_index ] = 0 .0
359359 self .greedy_reqs .add (req_id )
360360 else :
361361 self .temperature_cpu [req_index ] = sampling_params .temperature
You can’t perform that action at this time.
0 commit comments