Skip to content

Conversation

@ashors1
Copy link
Contributor

@ashors1 ashors1 commented Nov 6, 2025

Purpose

Refer to the issue for context: #27722. VLLM's implementation of Yarn does not match OpenAI's for GPT-OSS. This PR provides a fix.

Test Plan

I tested this change on GPT-OSS and validated that the yarn correction range is as expected.

Test Result

Yarn correction range is not rounded to an int after this fix:

low=8.092779115512402, high=17.39802450158856

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

github-actions bot commented Nov 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a truncate argument to the YaRN scaling implementation to align with OpenAI's GPT-OSS, propagating the change from the model configuration down to the rotary embedding calculation. My review identifies a couple of areas for improvement: an incorrect type hint that could affect static analysis, and a potential KeyError that could impact model loading with older configurations. The proposed suggestions aim to improve correctness and robustness.

base: float = 10000,
max_position_embeddings: int = 2048,
truncate: bool = True,
) -> tuple[int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function's return type annotation tuple[int, int] is now incorrect. When truncate is False, the function returns a tuple of floats because yarn_find_correction_dim returns a float and no truncation is applied. This can lead to issues with static type checkers. To ensure type consistency for both truncate=True and truncate=False scenarios, the return type should be tuple[float, float]. In Python's type system, int values are compatible where float types are expected, making tuple[float, float] the correct annotation for both return paths.

Suggested change
) -> tuple[int, int]:
) -> tuple[float, float]:

],
"beta_fast": config.rope_scaling["beta_fast"],
"beta_slow": config.rope_scaling["beta_slow"],
"truncate": config.rope_scaling["truncate"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Directly accessing config.rope_scaling["truncate"] will raise a KeyError if the key is not present in the configuration, which could happen with older model configs. This would cause model loading to fail. To improve robustness and maintain backward compatibility, it's safer to use the .get() method with a default value. Since this change is a bug fix to align with the GPT-OSS implementation (which should not truncate), a default of False is appropriate. This ensures that older configurations without this key will adopt the correct behavior.

Suggested change
"truncate": config.rope_scaling["truncate"],
"truncate": config.rope_scaling.get("truncate", False),

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 77 to 82
],
"beta_fast": config.rope_scaling["beta_fast"],
"beta_slow": config.rope_scaling["beta_slow"],
"truncate": config.rope_scaling["truncate"],
},
is_neox_style=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard against missing rope_scaling.truncate

The new YaRN path now unconditionally reads config.rope_scaling["truncate"] when constructing the rotary embedding. Older GPT‑OSS configs (including those in prior releases) do not carry this key because truncation used to be implicit. In that case, model initialization will raise KeyError before any generation runs, whereas before the change the model still worked (albeit with rounded correction bounds). Consider using config.rope_scaling.get("truncate", True) or another default so that existing configs continue to load while newer configs can opt out of truncation.

Useful? React with 👍 / 👎.

@heheda12345
Copy link
Collaborator

Can you fix the pre-commit and run the accuracy benchmark with the tutorial here? https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html#accuracy-evaluation-panels
At least 20B model with reasoning effort low.

@ashors1
Copy link
Contributor Author

ashors1 commented Nov 6, 2025

@heheda12345 I fixed the pre-commit failure and am happy to fix and other issues that arise. Would it be possible to get help running the accuracy benchmark? I typically run vllm through another framework, so running the standalone accuracy benchmark would require some setup on my end.

@heheda12345
Copy link
Collaborator

Thanks! You can follow the instructions in the above link.

@ashors1
Copy link
Contributor Author

ashors1 commented Nov 7, 2025

@heheda12345 I am trying to get this working but I'm having trouble running the vllm server with our cluster. This might require some ramp up on my end. If you are able to help out with running these evals, it would be very greatly appreciated.

@heheda12345
Copy link
Collaborator

The GPQA eval looks good to me on H100 (20b, low reasoning effort 0.56, medium reasoning effort 0.66)

OPENAI_API_KEY=empty python -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 128 --reasoning-effort low,medium
Running with args Namespace(model='openai/gpt-oss-20b', reasoning_effort='low,medium', sampler='responses', base_url='http://localhost:8000/v1', eval='gpqa', temperature=1.0, n_threads=128, debug=False, examples=None)

Running the following evals: {'gpqa': <gpt_oss.evals.gpqa_eval.GPQAEval object at 0x7f9494660260>}
Running evals for the following models: {'openai/gpt-oss-20b-low': <gpt_oss.evals.responses_sampler.ResponsesSampler object at 0x7f949476f9b0>, 'openai/gpt-oss-20b-medium': <gpt_oss.evals.responses_sampler.ResponsesSampler object at 0x7f9494a27440>}
100%|████████████████████████| 1584/1584 [02:10<00:00, 12.12it/s]
Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251111_235935.html
{'chars': np.float64(74.37436868686869), 'chars:std': np.float64(277.1982035750917), 'score': np.float64(0.5618686868686869), 'score:std': np.float64(0.4961575007849265)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251111_235935.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251111_235935_allresults.json
 30%|██████▉                | 480/1584 [19:58<1:02:52,  3.42s/it]Bad Request Error Error code: 400 - {'error': {'message': "'python'", 'type': 'BadRequestError', 'param': None, 'code': 400}}
100%|██████████████████████| 1584/1584 [1:49:15<00:00,  4.14s/it]
Writing report to /tmp/gpqa_openai__gpt-oss-20b-medium_temp1.0_20251111_235935.html
{'chars': np.float64(64.36679292929293), 'chars:std': np.float64(239.31678570995177), 'score': np.float64(0.6641414141414141), 'score:std': np.float64(0.4722897375167671)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-medium_temp1.0_20251111_235935.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-medium_temp1.0_20251111_235935_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20251111_235935', 'metric': 0.5618686868686869}, {'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-medium_temp1.0_20251111_235935', 'metric': 0.6641414141414141}]

heheda12345
heheda12345 previously approved these changes Nov 12, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Nov 12, 2025
@heheda12345 heheda12345 dismissed their stale review November 12, 2025 18:54

wait for the signature change

@ashors1
Copy link
Contributor Author

ashors1 commented Nov 12, 2025

@heheda12345 Just FYI, I also found a small issue with BF16 + EP (fixed in my latest commit). ep_rank_start and ep_rank_end were being passed in the wrong order when loading the weights.

@ashors1
Copy link
Contributor Author

ashors1 commented Nov 14, 2025

@heheda12345 are there any action items for me required for merge?

@heheda12345
Copy link
Collaborator

Can you revert the EP bug fix and put it in another new PR?

@ashors1
Copy link
Contributor Author

ashors1 commented Nov 15, 2025

Done. Here's the new PR: #28765

@heheda12345
Copy link
Collaborator

can you fix the dco?

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
This reverts commit e470c8f.

Signed-off-by: ashors1 <ashors@nvidia.com>
@ashors1 ashors1 force-pushed the ashors/gpt-oss-rope branch from cf1cb1c to 15db281 Compare November 15, 2025 01:26
@ashors1
Copy link
Contributor Author

ashors1 commented Nov 15, 2025

@heheda12345 Done

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@heheda12345 heheda12345 enabled auto-merge (squash) November 16, 2025 07:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 16, 2025
@ashors1
Copy link
Contributor Author

ashors1 commented Nov 18, 2025

@heheda12345 CI is failing, but the failures seem unrelated to my changes. Is there anything I can do to fix?

@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ashors1.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
@heheda12345
Copy link
Collaborator

@ashors1 can you help to rebase the PR? And please ping me again if the CI finishes and there are some failures that you feel unrelated.

auto-merge was automatically disabled November 20, 2025 04:03

Head branch was pushed to by a user without write access

@mergify mergify bot removed the needs-rebase label Nov 20, 2025
@DarkLight1337 DarkLight1337 merged commit 6eb745d into vllm-project:main Nov 20, 2025
53 checks passed
LuminolT pushed a commit to LuminolT/vllm that referenced this pull request Nov 21, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: LuminolT <lumischen01@gmail.com>
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
…llm-project#28244)

Signed-off-by: ashors1 <ashors@nvidia.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants