Skip to content

Commit 0638708

Browse files
committed
rename naflexvit mask to attn_mask
1 parent c9f9c30 commit 0638708

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

timm/models/naflexvit.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,27 @@
2020
import math
2121
from dataclasses import dataclass, fields, replace
2222
from functools import partial
23-
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Final, Any, Literal
23+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any
2424

2525
import torch
2626
import torch.nn as nn
2727
import torch.nn.functional as F
2828

29-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
29+
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
3030
from timm.layers import (
3131
AttentionPoolLatent,
3232
Mlp,
3333
to_2tuple,
3434
get_act_layer,
3535
get_norm_layer,
3636
LayerNorm,
37-
LayerType,
3837
_assert,
3938
)
4039
from timm.models._builder import build_model_with_cfg
4140
from timm.models._features import feature_take_indices
4241
from timm.models._features_fx import register_notrace_function, register_notrace_module
4342
from timm.models._registry import register_model, generate_default_cfgs
44-
from timm.models._manipulate import checkpoint_seq, named_apply
43+
from timm.models._manipulate import checkpoint, checkpoint_seq, named_apply
4544

4645
from .vision_transformer import Block, global_pool_nlc
4746

@@ -1054,7 +1053,7 @@ def forward_intermediates(
10541053
output_dict: bool = False,
10551054
patch_coord: Optional[torch.Tensor] = None,
10561055
patch_valid: Optional[torch.Tensor] = None,
1057-
mask: Optional[torch.Tensor] = None,
1056+
attn_mask: Optional[torch.Tensor] = None,
10581057
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
10591058
""" Forward features that returns intermediates.
10601059
@@ -1069,7 +1068,7 @@ def forward_intermediates(
10691068
output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
10701069
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
10711070
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
1072-
mask: Optional attention mask
1071+
attn_mask: Optional attention mask for masked attention
10731072
Returns:
10741073
A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
10751074
'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
@@ -1093,8 +1092,8 @@ def forward_intermediates(
10931092
H, W = self.embeds.dynamic_feat_size((height, width))
10941093

10951094
# Create attention mask if patch_type is provided and mask is not
1096-
if mask is None and patch_valid is not None:
1097-
mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype)
1095+
if attn_mask is None and patch_valid is not None:
1096+
attn_mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype)
10981097

10991098
# Forward pass through embedding
11001099
x = self.embeds(patches, patch_coord=patch_coord)
@@ -1107,7 +1106,12 @@ def forward_intermediates(
11071106
blocks = self.blocks[:max_index + 1]
11081107

11091108
for i, blk in enumerate(blocks):
1110-
x = blk(x, attn_mask=mask)
1109+
if attn_mask is not None:
1110+
x = blk(x, attn_mask=attn_mask)
1111+
elif self.grad_checkpointing and not torch.jit.is_scripting():
1112+
x = checkpoint(blk. x)
1113+
else:
1114+
x = blk(x)
11111115
if i in take_indices:
11121116
# normalize intermediates with final norm layer if enabled
11131117
intermediates.append(self.norm(x) if norm else x)

0 commit comments

Comments
 (0)