Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/layers/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_rope(
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
"truncate",
)
}
if "mrope_section" in rope_scaling:
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ def yarn_find_correction_range(
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> tuple[int, int]:
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
truncate: bool = True,
) -> tuple[float | int, float | int]:
low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1) # Clamp values just in case


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def __init__(
beta_fast: int = 32,
beta_slow: int = 1,
apply_yarn_scaling: bool = True,
truncate: bool = True,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.truncate = truncate
# Get n-d magnitude scaling corrected for interpolation
self.mscale = (
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
Expand All @@ -57,6 +59,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
self.rotary_dim,
self.base,
self.max_position_embeddings,
self.truncate,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
],
"beta_fast": config.rope_scaling["beta_fast"],
"beta_slow": config.rope_scaling["beta_slow"],
"truncate": config.rope_scaling.get("truncate", True),
},
is_neox_style=True,
)
Expand Down