Skip to content

Commit 2a42354

Browse files
authored
Merge branch 'main' into main
2 parents b38c4ba + 37d48bb commit 2a42354

File tree

15 files changed

+426
-553
lines changed

15 files changed

+426
-553
lines changed

src/transformers/image_transforms.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -821,14 +821,26 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
821821
return image
822822

823823

824-
def _cast_tensor_to_float(x):
825-
if x.is_floating_point():
826-
return x
827-
return x.float()
828-
829-
830824
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
831-
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
825+
"""
826+
Helper function to flatten a single level of nested image and batch structures and group by shape.
827+
Args:
828+
nested_images (list):
829+
A list of images or a single tensor
830+
paired_inputs (Any, *optional*):
831+
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
832+
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
833+
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
834+
they do not need to be tensors.
835+
is_nested (bool, *optional*, defaults to False):
836+
Whether the images are nested.
837+
Returns:
838+
tuple[dict, ...]:
839+
- A dictionary with shape as key and list of images with that shape as value
840+
- A dictionary with shape as key and list of paired values with that shape as value
841+
- A dictionary mapping original indices to (shape, index) tuples
842+
- A dictionary mapping original indices to (shape, index) tuples for each paired input
843+
"""
832844
grouped_images = defaultdict(list)
833845
grouped_images_index = {}
834846
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
@@ -880,27 +892,20 @@ def _reconstruct_nested_structure(indices, processed_images):
880892
return result
881893

882894

883-
def _disable_grouping_output_nested(images, *paired_inputs):
884-
"""Build the disable_grouping output tuple for a single-level nested structure."""
885-
outer_range = range(len(images))
886-
inner_ranges = [range(len(images[i])) for i in outer_range]
887-
888-
# Precompute all (i, j) pairs
889-
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
890-
891-
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
892-
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
893-
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
894-
return images_dict, *paired_dicts, index_map
895-
895+
def _iterate_items(items, is_nested: bool):
896+
"""
897+
Helper function to iterate over items yielding (key, item) pairs.
896898
897-
def _disable_grouping_output_flat(images, *paired_inputs):
898-
"""Build the disable_grouping output tuple for a flat list structure."""
899-
idx_range = range(len(images))
900-
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
901-
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
902-
index_map = {i: (i, 0) for i in idx_range}
903-
return images_dict, *paired_dicts, index_map
899+
For nested structures, yields ((row_index, col_index), item).
900+
For flat structures, yields (index, item).
901+
"""
902+
if is_nested:
903+
for i, row in enumerate(items):
904+
for j, item in enumerate(row):
905+
yield (i, j), item
906+
else:
907+
for i, item in enumerate(items):
908+
yield i, item
904909

905910

906911
def group_images_by_shape(
@@ -920,7 +925,7 @@ def group_images_by_shape(
920925
Args:
921926
images (Union[list["torch.Tensor"], "torch.Tensor"]):
922927
A list of images or a single tensor
923-
*paired_inputs (Any):
928+
paired_inputs (Any, *optional*):
924929
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
925930
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
926931
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
@@ -944,10 +949,14 @@ def group_images_by_shape(
944949
disable_grouping = device == "cpu"
945950

946951
if disable_grouping:
947-
if is_nested:
948-
return _disable_grouping_output_nested(images, *paired_inputs)
949-
else:
950-
return _disable_grouping_output_flat(images, *paired_inputs)
952+
return (
953+
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
954+
*[
955+
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
956+
for paired_list in paired_inputs
957+
],
958+
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
959+
)
951960

952961
# Handle single level nested structure
953962
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
@@ -990,14 +999,3 @@ def reorder_images(
990999
]
9911000

9921001
return _reconstruct_nested_structure(grouped_images_index, processed_images)
993-
994-
995-
class NumpyToTensor:
996-
"""
997-
Convert a numpy array to a PyTorch tensor.
998-
"""
999-
1000-
def __call__(self, image: np.ndarray):
1001-
# Same as in PyTorch, we assume incoming numpy images are in HWC format
1002-
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
1003-
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()

src/transformers/integrations/executorch.py

Lines changed: 1 addition & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# specific language governing permissions and limitations under the License.
1212

1313
import logging
14-
from collections.abc import Callable
1514
from typing import Optional
1615

1716
import torch
@@ -24,13 +23,7 @@
2423
StaticCache,
2524
)
2625
from ..generation.configuration_utils import GenerationConfig
27-
from ..masking_utils import (
28-
ALL_MASK_ATTENTION_FUNCTIONS,
29-
_ignore_causal_mask_sdpa,
30-
_is_torch_greater_or_equal_than_2_5,
31-
prepare_padding_mask,
32-
)
33-
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
26+
from ..modeling_utils import PreTrainedModel
3427
from ..pytorch_utils import (
3528
is_torch_greater_or_equal,
3629
is_torch_greater_or_equal_than_2_3,
@@ -229,10 +222,6 @@ def __init__(
229222
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
230223
)
231224
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
232-
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
233-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
234-
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
235-
self.model.model.config._attn_implementation = "sdpa_without_vmap"
236225

237226
def forward(
238227
self,
@@ -768,11 +757,6 @@ def convert_and_export_with_cache(
768757

769758
import torch.export._trace
770759

771-
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
772-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
773-
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
774-
model.config._attn_implementation = "sdpa_without_vmap"
775-
776760
with torch.no_grad():
777761
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
778762
example_input_ids = (
@@ -1036,11 +1020,6 @@ def export_with_dynamic_cache(
10361020
if not is_torch_greater_or_equal_than_2_3:
10371021
raise ImportError("torch >= 2.3 is required.")
10381022

1039-
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
1040-
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
1041-
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
1042-
model.config._attn_implementation = "sdpa_without_vmap"
1043-
10441023
register_dynamic_cache_export_support()
10451024

10461025
with torch.no_grad():
@@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
11091088
value = value_list[idx] if idx < len(value_list) else None
11101089
cache.update(key, value, idx)
11111090
return cache
1112-
1113-
1114-
def sdpa_mask_without_vmap(
1115-
batch_size: int,
1116-
cache_position: torch.Tensor,
1117-
kv_length: int,
1118-
kv_offset: int = 0,
1119-
mask_function: Optional[Callable] = None,
1120-
attention_mask: Optional[torch.Tensor] = None,
1121-
local_size: Optional[int] = None,
1122-
allow_is_causal_skip: bool = True,
1123-
allow_torch_fix: bool = True,
1124-
**kwargs,
1125-
) -> Optional[torch.Tensor]:
1126-
"""
1127-
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
1128-
the element should take part in the attention computation, and False that it should not.
1129-
1130-
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
1131-
1132-
Args:
1133-
batch_size (`int`):
1134-
The batch size of the input sequence.
1135-
cache_position (`torch.Tensor`):
1136-
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
1137-
kv_length (`int`):
1138-
The size that the key and value states will have during the attention computation.
1139-
kv_offset (`int`, optional):
1140-
An optional offset to indicate at which first position the key and values states will refer to.
1141-
mask_function (`Callable`):
1142-
The mask factory function describing the mask pattern.
1143-
attention_mask (`torch.Tensor`, optional):
1144-
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
1145-
local_size (`int`, optional):
1146-
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
1147-
to try to skip mask creation if possible.
1148-
allow_is_causal_skip (`bool`, optional):
1149-
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
1150-
`torch.sdpa` instead. Default to `True`.
1151-
allow_torch_fix (`bool`, optional):
1152-
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
1153-
versions. We need an arg to skip it when using eager. By default `True`.
1154-
1155-
"""
1156-
1157-
q_length = cache_position.shape[0]
1158-
# Potentially pad the 2D mask, and slice it correctly
1159-
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
1160-
1161-
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
1162-
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
1163-
return None
1164-
1165-
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
1166-
# but without data-dependent slicing (i.e. torch.compile friendly)
1167-
kv_arange = torch.arange(kv_length, device=cache_position.device)
1168-
kv_arange += kv_offset
1169-
reshaped_cache_position = cache_position.view(-1, 1)
1170-
1171-
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
1172-
# the config through kwargs anyway, so it allows to rely on it
1173-
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
1174-
# but this is more efficient
1175-
sliding_window = getattr(kwargs["config"], "sliding_window", None)
1176-
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
1177-
1178-
if sliding_window is not None and chunk_size is not None:
1179-
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
1180-
1181-
# Simplest and most efficient way to obtain a causal mask
1182-
causal_mask = kv_arange <= reshaped_cache_position
1183-
# If using sliding window, add the sliding mask
1184-
if sliding_window is not None:
1185-
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
1186-
causal_mask *= sliding_mask_overlay
1187-
# If using chunk attention, add the chunked mask
1188-
elif chunk_size is not None:
1189-
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
1190-
causal_mask *= chunked_mask_overlay
1191-
1192-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
1193-
if padding_mask is not None:
1194-
causal_mask = causal_mask * padding_mask[:, None, None, :]
1195-
1196-
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
1197-
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
1198-
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
1199-
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
1200-
return causal_mask

0 commit comments

Comments
 (0)