|
11 | 11 | # specific language governing permissions and limitations under the License. |
12 | 12 |
|
13 | 13 | import logging |
14 | | -from collections.abc import Callable |
15 | 14 | from typing import Optional |
16 | 15 |
|
17 | 16 | import torch |
|
24 | 23 | StaticCache, |
25 | 24 | ) |
26 | 25 | from ..generation.configuration_utils import GenerationConfig |
27 | | -from ..masking_utils import ( |
28 | | - ALL_MASK_ATTENTION_FUNCTIONS, |
29 | | - _ignore_causal_mask_sdpa, |
30 | | - _is_torch_greater_or_equal_than_2_5, |
31 | | - prepare_padding_mask, |
32 | | -) |
33 | | -from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| 26 | +from ..modeling_utils import PreTrainedModel |
34 | 27 | from ..pytorch_utils import ( |
35 | 28 | is_torch_greater_or_equal, |
36 | 29 | is_torch_greater_or_equal_than_2_3, |
@@ -229,10 +222,6 @@ def __init__( |
229 | 222 | "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." |
230 | 223 | ) |
231 | 224 | self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) |
232 | | - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable |
233 | | - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) |
234 | | - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) |
235 | | - self.model.model.config._attn_implementation = "sdpa_without_vmap" |
236 | 225 |
|
237 | 226 | def forward( |
238 | 227 | self, |
@@ -768,11 +757,6 @@ def convert_and_export_with_cache( |
768 | 757 |
|
769 | 758 | import torch.export._trace |
770 | 759 |
|
771 | | - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable |
772 | | - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) |
773 | | - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) |
774 | | - model.config._attn_implementation = "sdpa_without_vmap" |
775 | | - |
776 | 760 | with torch.no_grad(): |
777 | 761 | # TODO: The default inputs only work for text models. We need to add support for vision/audio models. |
778 | 762 | example_input_ids = ( |
@@ -1036,11 +1020,6 @@ def export_with_dynamic_cache( |
1036 | 1020 | if not is_torch_greater_or_equal_than_2_3: |
1037 | 1021 | raise ImportError("torch >= 2.3 is required.") |
1038 | 1022 |
|
1039 | | - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable |
1040 | | - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) |
1041 | | - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) |
1042 | | - model.config._attn_implementation = "sdpa_without_vmap" |
1043 | | - |
1044 | 1023 | register_dynamic_cache_export_support() |
1045 | 1024 |
|
1046 | 1025 | with torch.no_grad(): |
@@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): |
1109 | 1088 | value = value_list[idx] if idx < len(value_list) else None |
1110 | 1089 | cache.update(key, value, idx) |
1111 | 1090 | return cache |
1112 | | - |
1113 | | - |
1114 | | -def sdpa_mask_without_vmap( |
1115 | | - batch_size: int, |
1116 | | - cache_position: torch.Tensor, |
1117 | | - kv_length: int, |
1118 | | - kv_offset: int = 0, |
1119 | | - mask_function: Optional[Callable] = None, |
1120 | | - attention_mask: Optional[torch.Tensor] = None, |
1121 | | - local_size: Optional[int] = None, |
1122 | | - allow_is_causal_skip: bool = True, |
1123 | | - allow_torch_fix: bool = True, |
1124 | | - **kwargs, |
1125 | | -) -> Optional[torch.Tensor]: |
1126 | | - """ |
1127 | | - Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that |
1128 | | - the element should take part in the attention computation, and False that it should not. |
1129 | | -
|
1130 | | - This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export. |
1131 | | -
|
1132 | | - Args: |
1133 | | - batch_size (`int`): |
1134 | | - The batch size of the input sequence. |
1135 | | - cache_position (`torch.Tensor`): |
1136 | | - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. |
1137 | | - kv_length (`int`): |
1138 | | - The size that the key and value states will have during the attention computation. |
1139 | | - kv_offset (`int`, optional): |
1140 | | - An optional offset to indicate at which first position the key and values states will refer to. |
1141 | | - mask_function (`Callable`): |
1142 | | - The mask factory function describing the mask pattern. |
1143 | | - attention_mask (`torch.Tensor`, optional): |
1144 | | - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) |
1145 | | - local_size (`int`, optional): |
1146 | | - The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` |
1147 | | - to try to skip mask creation if possible. |
1148 | | - allow_is_causal_skip (`bool`, optional): |
1149 | | - Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in |
1150 | | - `torch.sdpa` instead. Default to `True`. |
1151 | | - allow_torch_fix (`bool`, optional): |
1152 | | - Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older |
1153 | | - versions. We need an arg to skip it when using eager. By default `True`. |
1154 | | -
|
1155 | | - """ |
1156 | | - |
1157 | | - q_length = cache_position.shape[0] |
1158 | | - # Potentially pad the 2D mask, and slice it correctly |
1159 | | - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) |
1160 | | - |
1161 | | - # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument |
1162 | | - if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size): |
1163 | | - return None |
1164 | | - |
1165 | | - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` |
1166 | | - # but without data-dependent slicing (i.e. torch.compile friendly) |
1167 | | - kv_arange = torch.arange(kv_length, device=cache_position.device) |
1168 | | - kv_arange += kv_offset |
1169 | | - reshaped_cache_position = cache_position.view(-1, 1) |
1170 | | - |
1171 | | - # This is a bit hacky to know what pattern we are using, but all mask creation function actually forward |
1172 | | - # the config through kwargs anyway, so it allows to rely on it |
1173 | | - # Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it, |
1174 | | - # but this is more efficient |
1175 | | - sliding_window = getattr(kwargs["config"], "sliding_window", None) |
1176 | | - chunk_size = getattr(kwargs["config"], "attention_chunk_size", None) |
1177 | | - |
1178 | | - if sliding_window is not None and chunk_size is not None: |
1179 | | - raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`") |
1180 | | - |
1181 | | - # Simplest and most efficient way to obtain a causal mask |
1182 | | - causal_mask = kv_arange <= reshaped_cache_position |
1183 | | - # If using sliding window, add the sliding mask |
1184 | | - if sliding_window is not None: |
1185 | | - sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window |
1186 | | - causal_mask *= sliding_mask_overlay |
1187 | | - # If using chunk attention, add the chunked mask |
1188 | | - elif chunk_size is not None: |
1189 | | - chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size |
1190 | | - causal_mask *= chunked_mask_overlay |
1191 | | - |
1192 | | - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) |
1193 | | - if padding_mask is not None: |
1194 | | - causal_mask = causal_mask * padding_mask[:, None, None, :] |
1195 | | - |
1196 | | - # Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any |
1197 | | - # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 |
1198 | | - if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: |
1199 | | - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) |
1200 | | - return causal_mask |
0 commit comments