Skip to content

Commit 018940f

Browse files
ashors1heheda12345
authored andcommitted
Add truncate arg to yarn to match openai implementation of gpt-oss (vllm-project#28244)
Signed-off-by: ashors1 <ashors@nvidia.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
1 parent a0b3422 commit 018940f

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

vllm/model_executor/layers/rotary_embedding/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def get_rope(
197197
"beta_fast",
198198
"beta_slow",
199199
"apply_yarn_scaling",
200+
"truncate",
200201
)
201202
}
202203
if "mrope_section" in rope_parameters:

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,13 @@ def yarn_find_correction_range(
117117
dim: int,
118118
base: float = 10000,
119119
max_position_embeddings: int = 2048,
120-
) -> tuple[int, int]:
121-
low = math.floor(
122-
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
123-
)
124-
high = math.ceil(
125-
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
126-
)
120+
truncate: bool = True,
121+
) -> tuple[float | int, float | int]:
122+
low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
123+
high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
124+
if truncate:
125+
low = math.floor(low)
126+
high = math.ceil(high)
127127
return max(low, 0), min(high, dim - 1) # Clamp values just in case
128128

129129

vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def __init__(
2828
beta_fast: int = 32,
2929
beta_slow: int = 1,
3030
apply_yarn_scaling: bool = True,
31+
truncate: bool = True,
3132
) -> None:
3233
self.scaling_factor = scaling_factor
3334
self.extrapolation_factor = extrapolation_factor
3435
self.attn_factor = attn_factor
3536
self.beta_fast = beta_fast
3637
self.beta_slow = beta_slow
38+
self.truncate = truncate
3739
# Get n-d magnitude scaling corrected for interpolation
3840
self.mscale = (
3941
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
@@ -57,6 +59,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
5759
self.rotary_dim,
5860
self.base,
5961
self.max_position_embeddings,
62+
self.truncate,
6063
)
6164
# Get n-d rotational scaling corrected for extrapolation
6265
inv_freq_mask = (

vllm/model_executor/models/gpt_oss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
],
7979
"beta_fast": config.rope_parameters["beta_fast"],
8080
"beta_slow": config.rope_parameters["beta_slow"],
81+
"truncate": config.rope_parameters.get("truncate", True),
8182
},
8283
is_neox_style=True,
8384
)

0 commit comments

Comments
 (0)