-
Notifications
You must be signed in to change notification settings - Fork 29
Support cross attention kv cache #187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| self.cross_attention_cache = StaticCache( | ||
| config=self.config, | ||
| max_batch_size=batch_size, | ||
| max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull this outside into a var like the other arguments
| self.cross_attention_cache = StaticCache( | ||
| config=self.config, | ||
| max_batch_size=batch_size, | ||
| max_cache_len=getattr(self.config, "max_source_positions", max_static_cache_length), # This is fixed in whisper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what do you mean this is fixed in whisper? Will this work for t5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically they always have 1500 for max_source_positions and that translates to 30 seconds of audio. So we should use that for cache len. For T5 I don't know and that's why I name this class WhisperCrossAttention.
| ) | ||
|
|
||
| # Update the KV cache outside of torch.cond. | ||
| past_key_values.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put this inside the recompute_kv branch?
jackzhxng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh also run make style for formatting
b6e172d to
6ca7dd0
Compare
14a9a0e to
128b100
Compare

To avoid excessive computation we want to support kv cache for cross attention in Whisper.
Fundamentally we only run
k_projandv_projonce on the encoder output hidden state, at the first token generation, then we should keep thekey_statesandvalue_statesand reuse them in all the subsequent token generation.For whisper-large-v3-turbo, where we have 4 layers of decoder:
Without KV cache in
encoder_attn, we are doing 2 1280x1280 MM for each layer, so in total 8 1280x1280 MM for each token generated. This largely impacts token/sec perf number.This PR replaces
encoder_attnwith aWhisperCrossAttentionclass, where we replacesifcondition withtorch.cond. The logic becomes:Notice that we still have 1 extra read and 1 extra write, but it should be much faster than MM.