-
Notifications
You must be signed in to change notification settings - Fork 31.4k
Prefill-related logic in input preparation for generation #42088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Another worm of cans, assisted decoding has no prefill separated out and is causing issues now 😢 |
worm of cans?? 🤣 haha love it Sooo this already arose on my PR. The main gist is that assisted generate does not prefill with the prompt tokens, but waits for the first batch of candidates and then prefills. Thus, we could not apply the standard prefill. But surely assisted_gen can pass the prefill flag on the first call, or we can also maybe call _prefill with the first batch of candidates. |
yeah, this seemed to be the easiest option. The only issue with VLMs is that we should not be passing certain inputs (pixels/etc) after a prefill phase. But with assistant model calling |
|
Support for I don't want us to multiplicate number of input args for |
|
@bot /style |
|
Style fix runs successfully without any file modified. |
| if is_first_iteration is None: | ||
| generation_args = self.assistant_model._get_initial_cache_position( | ||
| input_ids.shape[1], input_ids.device, self.assistant_kwargs | ||
| ) | ||
| generation_args = self.assistant_model.prepare_inputs_for_generation( | ||
| input_ids, is_first_iteration=True, **generation_args | ||
| ) | ||
| generation_args[self.input_ids_key] = input_ids | ||
| for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: | ||
| generation_args.pop(model_input_name, None) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is needed for specific models which prepare inputs differently depending on first vs subsequent iterations. For ex in multimodal models, we pass over multimodal data only in first iteration and then rely on cached inputs
Assisted generation however calls internally generate() many times and technically will trigger many times first_iiteration. This way we can call prefill only once per assistant model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add this to the comments. This is a nice to know. Possibly into the docstring directly, I think the scope is worth enough to cover properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing this is explained in utils directly, maybe just my order of reviewing was just bad then... Can keep it this way
| # Assisted generation completes the prefill stage in candidate generator so that | ||
| # we don't have several `prefill` calls in one generation loop. Skip `_prefill` for assistants | ||
| if not generation_config.is_assistant: | ||
| model_outputs = self._prefill(input_ids, generation_config, model_kwargs) | ||
| prefill_consumed = False | ||
| else: | ||
| model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs) | ||
| prefill_consumed = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above - since we already called prefill on assistant, we should not call it a second time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, the logic is good
I mainly left some smaller comments and things that might've been missed. We should be a bit careful here and run-slow on a few models, e.g. gemma3, mamba2, etc
| if is_first_iteration is None: | ||
| generation_args = self.assistant_model._get_initial_cache_position( | ||
| input_ids.shape[1], input_ids.device, self.assistant_kwargs | ||
| ) | ||
| generation_args = self.assistant_model.prepare_inputs_for_generation( | ||
| input_ids, is_first_iteration=True, **generation_args | ||
| ) | ||
| generation_args[self.input_ids_key] = input_ids | ||
| for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: | ||
| generation_args.pop(model_input_name, None) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add this to the comments. This is a nice to know. Possibly into the docstring directly, I think the scope is worth enough to cover properly
| if is_first_iteration is None: | ||
| generation_args = self.assistant_model._get_initial_cache_position( | ||
| input_ids.shape[1], input_ids.device, self.assistant_kwargs | ||
| ) | ||
| generation_args = self.assistant_model.prepare_inputs_for_generation( | ||
| input_ids, is_first_iteration=True, **generation_args | ||
| ) | ||
| generation_args[self.input_ids_key] = input_ids | ||
| for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: | ||
| generation_args.pop(model_input_name, None) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing this is explained in utils directly, maybe just my order of reviewing was just bad then... Can keep it this way
| # It is safe to assume that `length!=1` means we're in pre-fill because compiled | ||
| # models currently cannot do assisted decoding | ||
| if cache_position[0] == 0 or self.model.rope_deltas is None: | ||
| if (cache_position[0] == 0 or not use_cache) or self.model.rope_deltas is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not simplify here? Same for the other related models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wanted to use is_first_iteration at first for all models, and realized the concept of prefill and first iteration can be different
Specifically in mRoPE, the deltas are computed once with the first prompt and should not be computed again if user wants to re-use cache and continue generation from where it's left
This caused me so much headache tbh 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, mRoPE strikes again 😢 cant wait when we have a standardized way here outside of modeling.py
| expected_slice = torch.tensor([-0.8805, -0.8803, -0.8799], device=torch_device) | ||
| expected_slice = torch.tensor([-0.8433, -0.8432, -0.8429], device=torch_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this failing before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, my small refactor changed the logits slightly due to difference in how caching is done. I will trigger slow tests and check how big of a diff I caused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oke, i see why there is a difference. Prev we would compute attention for the image embeddings (around 900 tokens) always, even though it is cached. After this PR, caching works as expected and we have one single token at each decoding step
IMO the current version is more correct and it's expected that caching results in tiny numerical differences that can add up
|
run-slow: git |
|
This comment contains models: ["models/git"] |
CI ResultsModel CI Report❌ Failed tests
|
|
@vasqu requesting one last review. One tiny thing left to do is to make sure slow GIT tests pass, some expected values are hardware-dependent Otherwise should be ready, addressed a few comments and answered q above |
|
run-slow: git |
|
This comment contains models: ["models/git"] |
CI ResultsModel CI Report❌ Failed tests
|
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aria, aya_vision, bamba, bloom, chameleon, clvp, cohere2_vision, csm, ctrl, deepseek_vl, deepseek_vl_hybrid, emu3, falcon_h1, falcon_mamba, fast_vlm, florence2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, left smaller comments/nits but I think this looks pretty ready
Trusting you on fixing up git and other CI stuff 👁️ lots of potential for followups PRs to clean more but this is already big enough as is and solves the biggest issue(s)
| # Generate candidates. Run prefill-specific logic in first generation and prepare model kwargs. | ||
| # Some models prepare inputs differently depending on first vs subsequent iterations.(e.g. VLMs) | ||
| # Assisted generation however calls internally `self.generate()` many times and technically will | ||
| # lead to many `ifirst_iteration's`. This way we can call prefill only once per assistant model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # lead to many `ifirst_iteration's`. This way we can call prefill only once per assistant model | |
| # lead to many `first_iteration's`. This way we can call prefill only once per assistant model |
typo
| # lead to many `ifirst_iteration's`. This way we can call prefill only once per assistant model | ||
| if is_first_iteration: | ||
| generation_args = self.assistant_model._get_initial_cache_position( | ||
| input_ids.shape[1], input_ids.device, self.assistant_kwargs.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason we copy the kwargs here? Any risk this could be None?
I suspect some inplace ops but just checking
| def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, torch.FloatTensor | None]: | ||
| """Generate candidate sequences using the assistant model.""" | ||
| assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) | ||
| assistant_output = self.assistant_model.generate(**generation_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we now directly write into the generation args instead (from the prep)
| attention_mask: torch.LongTensor | None = None, | ||
| inputs_embeds: torch.FloatTensor | None = None, | ||
| cache_position: torch.LongTensor | None = None, | ||
| is_first_iteration: Optional[bool] = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| is_first_iteration: Optional[bool] = False, | |
| is_first_iteration: bool | None = False, |
also often forgetting this, but we really should move to the new typing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important model, so just double check with run-slow or similar
| pixel_values=kwargs.get("pixel_values"), | ||
| is_first_iteration, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we not keep pixel values here? I just fear that the args order won't match anymore? So is_first_iteration would be used as pixel values
| # It is safe to assume that `length!=1` means we're in pre-fill because compiled | ||
| # models currently cannot do assisted decoding | ||
| if cache_position[0] == 0 or self.model.rope_deltas is None: | ||
| if (cache_position[0] == 0 or not use_cache) or self.model.rope_deltas is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, mRoPE strikes again 😢 cant wait when we have a standardized way here outside of modeling.py
| model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) | ||
|
|
||
| if cache_position is not None and cache_position[0] == 0: | ||
| if is_first_iteration or not kwargs.get("use_cache", True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this or intentional? Not sure if this happened elsewhere
| torch.testing.assert_close(transition_scores_sum, outputs.sequences_scores, rtol=1e-3, atol=1e-3) | ||
|
|
||
| @slow | ||
| def test_generate_inputs_embeds_one_token(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 🙏
What does this PR do?
Fixes #41863 and fixes #40910
We always have had an imperfect way to infer if we're in prefill or decoding stage, which caused us many bugs in the past. The most reliable way is to check cache position values but it is not compile-compatible and also has an edge case
Recently Manuel merged a PR to split prefill into its own function so now we can benefit from it and know with 100% certainty which stage we're in. This PR adds
is_first_iterationflag to generation input preparation and replaces existing logic with the flagNote in some models, we have to keep checking
if cache_position[0] == 0because first iteration does not mean first in total. We might get cached system prompt and we don't want to call some methods a second time (e.g. Qwen mRoPE)Also it adds a test case for the above linked issue