From 3b9d55fced239feeb43ec50ddf18f4e051a6d8a4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 6 Nov 2025 11:49:17 -0800 Subject: [PATCH 1/6] add truncate arg to yarn to match openai implementation of gpt-oss Signed-off-by: ashors1 --- .../layers/rotary_embedding/__init__.py | 1 + .../model_executor/layers/rotary_embedding/common.py | 12 ++++++------ .../layers/rotary_embedding/yarn_scaling_rope.py | 3 +++ vllm/model_executor/models/gpt_oss.py | 1 + 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 56c165f9c041..47131d79ef18 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -197,6 +197,7 @@ def get_rope( "beta_fast", "beta_slow", "apply_yarn_scaling", + "truncate", ) } if "mrope_section" in rope_scaling: diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 196533b61795..55d45cd8d466 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -117,13 +117,13 @@ def yarn_find_correction_range( dim: int, base: float = 10000, max_position_embeddings: int = 2048, + truncate: bool = True, ) -> tuple[int, int]: - low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) + low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) return max(low, 0), min(high, dim - 1) # Clamp values just in case diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py index ff46ad74b302..f01ca1e23121 100644 --- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -28,12 +28,14 @@ def __init__( beta_fast: int = 32, beta_slow: int = 1, apply_yarn_scaling: bool = True, + truncate: bool = True, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.truncate = truncate # Get n-d magnitude scaling corrected for interpolation self.mscale = ( float(yarn_get_mscale(self.scaling_factor) * attn_factor) @@ -57,6 +59,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: self.rotary_dim, self.base, self.max_position_embeddings, + self.truncate, ) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = ( diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 04038ae74882..35ed94c0d268 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -77,6 +77,7 @@ def __init__( ], "beta_fast": config.rope_scaling["beta_fast"], "beta_slow": config.rope_scaling["beta_slow"], + "truncate": config.rope_scaling["truncate"], }, is_neox_style=True, ) From 9b6bbd9c046e5df89603509b13ce5d749024a233 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 6 Nov 2025 13:30:26 -0800 Subject: [PATCH 2/6] address comments Signed-off-by: ashors1 --- vllm/model_executor/layers/rotary_embedding/common.py | 2 +- vllm/model_executor/models/gpt_oss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 55d45cd8d466..ed939c6a63e8 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -118,7 +118,7 @@ def yarn_find_correction_range( base: float = 10000, max_position_embeddings: int = 2048, truncate: bool = True, -) -> tuple[int, int]: +) -> tuple[float, float]: low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 35ed94c0d268..db25f427fce0 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -77,7 +77,7 @@ def __init__( ], "beta_fast": config.rope_scaling["beta_fast"], "beta_slow": config.rope_scaling["beta_slow"], - "truncate": config.rope_scaling["truncate"], + "truncate": config.rope_scaling.get("truncate", True), }, is_neox_style=True, ) From 71e1dfce910d7733cb54219640b873e166c9b645 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 12 Nov 2025 13:19:35 -0800 Subject: [PATCH 3/6] fix gpt oss EP with bf16 Signed-off-by: ashors1 --- vllm/model_executor/models/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index db25f427fce0..8e986845922b 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -489,8 +489,8 @@ def _load_weights_mxfp4( def _load_weights_other( self, - ep_rank_start: int, ep_rank_end: int, + ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], From df82ee9af5b766600a2d0a87aa92f71df5742921 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 12 Nov 2025 13:22:58 -0800 Subject: [PATCH 4/6] fix signature Signed-off-by: ashors1 --- vllm/model_executor/layers/rotary_embedding/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index ed939c6a63e8..946c390f568a 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -118,7 +118,7 @@ def yarn_find_correction_range( base: float = 10000, max_position_embeddings: int = 2048, truncate: bool = True, -) -> tuple[float, float]: +) -> tuple[float|int, float|int]: low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: From a0ac7e8ad0b1ad251d57427dfe6f4f55937a3fd6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 12 Nov 2025 14:16:46 -0800 Subject: [PATCH 5/6] lint Signed-off-by: ashors1 --- vllm/model_executor/layers/rotary_embedding/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 946c390f568a..13f8d15cc0f7 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -118,7 +118,7 @@ def yarn_find_correction_range( base: float = 10000, max_position_embeddings: int = 2048, truncate: bool = True, -) -> tuple[float|int, float|int]: +) -> tuple[float | int, float | int]: low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) if truncate: From 15db28175cd1fc1d9f4250c9d7855e0439c69009 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 14 Nov 2025 16:49:59 -0800 Subject: [PATCH 6/6] Revert "fix gpt oss EP with bf16" This reverts commit e470c8f91d1491643dd5a0165a17135339f6b0d3. Signed-off-by: ashors1 --- vllm/model_executor/models/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 8e986845922b..db25f427fce0 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -489,8 +489,8 @@ def _load_weights_mxfp4( def _load_weights_other( self, - ep_rank_end: int, ep_rank_start: int, + ep_rank_end: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]],