Skip to content

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Oct 3, 2025

Purpose

Dynamic shapes guards are dropped unsoundly in vLLM, those can be added during both dynamo and inductor compilations. The proper way to compile a sounds graph is to use unbacked dynamic shapes or ""backed size oblivious
maybe"" ! (see the note in the diff content). This PR adds a config that allow choosing between backed, unbacked and backed_size_oblivious as explained in the comment in the content of the PR.

When using unbacked, users might want to to provide invariants about the shapes of the Model, a lambda argument
is added to support_torch_compile decorators where users can provide a lambda that assert on invariants
on the model inputs shapes. Those are needed to avoid DDE and to be able to trace the model with unbacked.
example:


def llama_model_invariants(input_ids,
                           positions,
                           intermediate_tensors=None,
                           inputs_embeds=None):
    """Shape invariants for Llama model compilation"""
    if input_ids is not None:
        torch._check(positions.size()[0] == input_ids.size()[0])
        
@support_torch_compile(shape_invariants=llama_model_invariants)
..

see the qwen example in the code also.

Test Plan

Added a unit test .
only works for torch. 2.10+

Perf testing

command

Qwen/Qwen2-1.5B-Instruct

will look into some of those perf issues in the future, but the long term vision with pre_compile is that this will be
a fallback mode. results anyway still better than eager.

CUDA_VISIBLE_DEVICES=1 vllm bench throughput --model Qwen/Qwen2-1.5B-Instruct --input-len 512 --output-len 128 --num-prompts 100 --gpu-memory-utilization 0.8

backed

Throughput: 88.94 requests/s, 56738.83 total tokens/s, 11383.87 output tokens/s

backed size oblivious

88.78

unbacked

Throughput: 82.98  51860.18 total tokens/s, 10405.04 output tokens/s

eager

Throughput: 63.29 requests/s, 40376.23 total tokens/s, 8100.94 output tokens/s

@mergify mergify bot added llama Related to Llama models qwen Related to Qwen models v1 tpu Related to Google TPUs labels Oct 3, 2025
@mergify
Copy link

mergify bot commented Oct 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @laithsakka.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 3, 2025
@laithsakka laithsakka changed the title Hu Add option to use unbacked dynamic shapes for more sounds compilation. Oct 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring of the torch.compile integration, particularly around handling dynamic shapes. The new TorchCompileGuardsStripWrapper and the introduction of shape invariants provide a much cleaner and more robust approach to compilation. The changes are well-structured and the new tests are a great addition.

I've found a few issues that need to be addressed: a critical bug in the new DynamicShapesConfig that would cause a runtime error, a minor bug in hash computation, and an opportunity to strengthen the new dynamic shapes test.

@laithsakka laithsakka changed the title Add option to use unbacked dynamic shapes for more sounds compilation. Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. Oct 3, 2025
Copy link
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conflicts are caused by our migration to ruff. Please see https://vllm-dev.slack.com/archives/C07R5Q1Q2BB/p1759663228844749 which contains detailed instructions to make updating your branch as painless as possible.

@ProExpertProg ProExpertProg removed the ready ONLY add when PR is ready to merge/full CI is needed label Nov 20, 2025
@ProExpertProg
Copy link
Collaborator

@laithsakka removed ready, ping when you want CI back on

))
```

These modes are stricter and reduce or eliminate guarding, which can help isolate issues:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some confusion between dynamic shape guards, and dynamo guards in general. There is also some conclusion, because vLLM drops all Dynamo guards -- why do these guards still matter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let me try phrase it differently, any specific phrasing you like there?

@laithsakka
Copy link
Contributor Author

laithsakka commented Nov 20, 2025

@ProExpertProg
why removed ready i addressed all comments this morning?
I kind of need the ready also to know if things break.

@ProExpertProg
Copy link
Collaborator

Sorry I thought you were still pushing, but also the CI ran already anyway

@laithsakka
Copy link
Contributor Author

just addressed richard comment on the debug section. @zou3519 looking good?

@laithsakka laithsakka requested a review from zou3519 November 20, 2025 23:47
@laithsakka laithsakka force-pushed the hu branch 4 times, most recently from 898d8b0 to 0e849f0 Compare November 21, 2025 00:07
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.

## Debugging constraint violations and dynamic shapes guards issues
Copy link
Collaborator

@zou3519 zou3519 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit of a repeat of the previous section. Let's resolve it in a future PR, figuring out what to write is a bit annoying

Copy link
Contributor Author

@laithsakka laithsakka Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see

## Debugging Dynamic Shape full graph capture

we probably should merge the two i agree

@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 21, 2025
@laithsakka
Copy link
Contributor Author

@ProExpertProg the failing job is a timeout in building docs is there a way to retry?

@ProExpertProg
Copy link
Collaborator

@hmellor

Signed-off-by: Laith Sakka <lsakka@meta.com>
@laithsakka
Copy link
Contributor Author

dummy update to force running all tests/

Comment on lines +223 to +227
- BACKED: Default PyTorch behavior with potential guards ignored.
- UNBACKED: No guards guaranteed (most sound) but may throw
data dependent errors.
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
backed/unbacked.
Copy link
Member

@hmellor hmellor Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit to get the help text to render nicely. Should bhe done in a follow up to save CI

Suggested change
- BACKED: Default PyTorch behavior with potential guards ignored.
- UNBACKED: No guards guaranteed (most sound) but may throw
data dependent errors.
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
backed/unbacked.
- BACKED: Default PyTorch behavior with potential guards ignored.\n
- UNBACKED: No guards guaranteed (most sound) but may throw
data dependent errors.\n
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
backed/unbacked.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor
Copy link
Member

hmellor commented Nov 24, 2025

Docs failure/timeout was likely related to the Python docs (which our docs references) being down. It appears to have resolved itself

@zou3519 zou3519 merged commit 7a228b5 into vllm-project:main Nov 24, 2025
55 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in torch.compile integration Nov 24, 2025
lpapavassiliou pushed a commit to lpapavassiliou/vllm that referenced this pull request Nov 24, 2025
…re sounds compilation. (vllm-project#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
…re sounds compilation. (vllm-project#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
MatthewBonanni pushed a commit to MatthewBonanni/vllm that referenced this pull request Nov 24, 2025
…re sounds compilation. (vllm-project#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
bringlein pushed a commit to bringlein/vllm that referenced this pull request Nov 26, 2025
…re sounds compilation. (vllm-project#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation llama Related to Llama models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed torch.compile v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants