2020import math
2121from dataclasses import dataclass , fields , replace
2222from functools import partial
23- from typing import Callable , Dict , List , Optional , Set , Tuple , Type , Union , Final , Any , Literal
23+ from typing import Callable , Dict , List , Optional , Set , Tuple , Type , Union , Any
2424
2525import torch
2626import torch .nn as nn
2727import torch .nn .functional as F
2828
29- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
29+ from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
3030from timm .layers import (
3131 AttentionPoolLatent ,
3232 Mlp ,
3333 to_2tuple ,
3434 get_act_layer ,
3535 get_norm_layer ,
3636 LayerNorm ,
37- LayerType ,
3837 _assert ,
3938)
4039from timm .models ._builder import build_model_with_cfg
4140from timm .models ._features import feature_take_indices
4241from timm .models ._features_fx import register_notrace_function , register_notrace_module
4342from timm .models ._registry import register_model , generate_default_cfgs
44- from timm .models ._manipulate import checkpoint_seq , named_apply
43+ from timm .models ._manipulate import checkpoint , checkpoint_seq , named_apply
4544
4645from .vision_transformer import Block , global_pool_nlc
4746
@@ -1054,7 +1053,7 @@ def forward_intermediates(
10541053 output_dict : bool = False ,
10551054 patch_coord : Optional [torch .Tensor ] = None ,
10561055 patch_valid : Optional [torch .Tensor ] = None ,
1057- mask : Optional [torch .Tensor ] = None ,
1056+ attn_mask : Optional [torch .Tensor ] = None ,
10581057 ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]], Dict [str , Any ]]:
10591058 """ Forward features that returns intermediates.
10601059
@@ -1069,7 +1068,7 @@ def forward_intermediates(
10691068 output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
10701069 patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
10711070 patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
1072- mask : Optional attention mask
1071+ attn_mask : Optional attention mask for masked attention
10731072 Returns:
10741073 A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
10751074 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
@@ -1093,8 +1092,8 @@ def forward_intermediates(
10931092 H , W = self .embeds .dynamic_feat_size ((height , width ))
10941093
10951094 # Create attention mask if patch_type is provided and mask is not
1096- if mask is None and patch_valid is not None :
1097- mask = create_attention_mask (patch_valid , self .num_prefix_tokens , patches .dtype )
1095+ if attn_mask is None and patch_valid is not None :
1096+ attn_mask = create_attention_mask (patch_valid , self .num_prefix_tokens , patches .dtype )
10981097
10991098 # Forward pass through embedding
11001099 x = self .embeds (patches , patch_coord = patch_coord )
@@ -1107,7 +1106,12 @@ def forward_intermediates(
11071106 blocks = self .blocks [:max_index + 1 ]
11081107
11091108 for i , blk in enumerate (blocks ):
1110- x = blk (x , attn_mask = mask )
1109+ if attn_mask is not None :
1110+ x = blk (x , attn_mask = attn_mask )
1111+ elif self .grad_checkpointing and not torch .jit .is_scripting ():
1112+ x = checkpoint (blk . x )
1113+ else :
1114+ x = blk (x )
11111115 if i in take_indices :
11121116 # normalize intermediates with final norm layer if enabled
11131117 intermediates .append (self .norm (x ) if norm else x )
0 commit comments