@@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
3434 def __init__ (self , dim : int , k : int = 3 , act : bool = False ):
3535 super (ConvPosEnc , self ).__init__ ()
3636
37- self .proj = nn .Conv2d (dim , dim , k , 1 , k // 2 , groups = dim )
37+ self .proj = nn .Conv2d (
38+ dim ,
39+ dim ,
40+ kernel_size = k ,
41+ stride = 1 ,
42+ padding = k // 2 ,
43+ groups = dim ,
44+ )
3845 self .act = nn .GELU () if act else nn .Identity ()
3946
4047 def forward (self , x : Tensor ):
@@ -72,8 +79,9 @@ def __init__(
7279
7380 def forward (self , x : Tensor ):
7481 B , C , H , W = x .shape
75- x = F .pad (x , (0 , (self .stride [1 ] - W % self .stride [1 ]) % self .stride [1 ]))
76- x = F .pad (x , (0 , 0 , 0 , (self .stride [0 ] - H % self .stride [0 ]) % self .stride [0 ]))
82+ pad_r = (self .stride [1 ] - W % self .stride [1 ]) % self .stride [1 ]
83+ pad_b = (self .stride [0 ] - H % self .stride [0 ]) % self .stride [0 ]
84+ x = F .pad (x , (0 , pad_r , 0 , pad_b ))
7785 x = self .conv (x )
7886 x = self .norm (x )
7987 return x
@@ -84,30 +92,66 @@ def __init__(
8492 self ,
8593 in_chs ,
8694 out_chs ,
95+ kernel_size = 3 ,
8796 norm_layer = LayerNorm2d ,
8897 ):
8998 super ().__init__ ()
9099 self .in_chs = in_chs
91100 self .out_chs = out_chs
92101
93102 self .norm = norm_layer (in_chs )
103+ self .even_k = kernel_size % 2 == 0
94104 self .conv = nn .Conv2d (
95105 in_chs ,
96106 out_chs ,
97- kernel_size = 2 ,
107+ kernel_size = kernel_size ,
98108 stride = 2 ,
99- padding = 0 ,
109+ padding = 0 if self . even_k else kernel_size // 2 ,
100110 )
101111
102112 def forward (self , x : Tensor ):
103113 B , C , H , W = x .shape
104114 x = self .norm (x )
105- x = F .pad (x , (0 , (2 - W % 2 ) % 2 ))
106- x = F .pad (x , (0 , 0 , 0 , (2 - H % 2 ) % 2 ))
115+ if self .even_k :
116+ k_h , k_w = self .conv .kernel_size
117+ pad_r = (k_w - W % k_w ) % k_w
118+ pad_b = (k_h - H % k_h ) % k_h
119+ x = F .pad (x , (0 , pad_r , 0 , pad_b ))
107120 x = self .conv (x )
108121 return x
109122
110123
124+ class ChannelAttentionV2 (nn .Module ):
125+
126+ def __init__ (self , dim , num_heads = 8 , qkv_bias = True , dynamic_scale = True ):
127+ super ().__init__ ()
128+ self .groups = num_heads
129+ self .head_dim = dim // num_heads
130+ self .dynamic_scale = dynamic_scale
131+
132+ self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
133+ self .proj = nn .Linear (dim , dim )
134+
135+ def forward (self , x ):
136+ B , N , C = x .shape
137+
138+ qkv = self .qkv (x ).reshape (B , N , 3 , self .groups , C // self .groups ).permute (2 , 0 , 3 , 1 , 4 )
139+ q , k , v = qkv .unbind (0 )
140+
141+ if self .dynamic_scale :
142+ q = q * N ** - 0.5
143+ else :
144+ q = q * self .head_dim ** - 0.5
145+ attn = q .transpose (- 1 , - 2 ) @ k
146+ attn = attn .softmax (dim = - 1 )
147+ x = (attn @ v .transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
148+
149+ x = x .transpose (1 , 2 ).reshape (B , N , C )
150+ x = self .proj (x )
151+ return x
152+
153+
154+
111155class ChannelAttention (nn .Module ):
112156
113157 def __init__ (self , dim , num_heads = 8 , qkv_bias = False ):
@@ -147,13 +191,19 @@ def __init__(
147191 norm_layer = nn .LayerNorm ,
148192 ffn = True ,
149193 cpe_act = False ,
194+ v2 = False ,
150195 ):
151196 super ().__init__ ()
152197
153198 self .cpe1 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
154199 self .ffn = ffn
155200 self .norm1 = norm_layer (dim )
156- self .attn = ChannelAttention (dim , num_heads = num_heads , qkv_bias = qkv_bias )
201+ attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
202+ self .attn = attn_layer (
203+ dim ,
204+ num_heads = num_heads ,
205+ qkv_bias = qkv_bias ,
206+ )
157207 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
158208 self .cpe2 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
159209
@@ -372,21 +422,24 @@ def __init__(
372422 attn_types = ('spatial' , 'channel' ),
373423 num_heads = 3 ,
374424 window_size = 7 ,
375- mlp_ratio = 4 ,
425+ mlp_ratio = 4. ,
376426 qkv_bias = True ,
377427 drop_path_rates = (0 , 0 ),
378428 norm_layer = LayerNorm2d ,
379429 norm_layer_cl = nn .LayerNorm ,
380430 ffn = True ,
381- cpe_act = False
431+ cpe_act = False ,
432+ down_kernel_size = 2 ,
433+ named_blocks = False ,
434+ channel_attn_v2 = False ,
382435 ):
383436 super ().__init__ ()
384437
385438 self .grad_checkpointing = False
386439
387440 # downsample embedding layer at the beginning of each stage
388441 if downsample :
389- self .downsample = Downsample (in_chs , out_chs , norm_layer = norm_layer )
442+ self .downsample = Downsample (in_chs , out_chs , kernel_size = down_kernel_size , norm_layer = norm_layer )
390443 else :
391444 self .downsample = nn .Identity ()
392445
@@ -399,10 +452,11 @@ def __init__(
399452 '''
400453 stage_blocks = []
401454 for block_idx in range (depth ):
455+ from collections import OrderedDict
402456 dual_attention_block = []
403457 for attn_idx , attn_type in enumerate (attn_types ):
404458 if attn_type == 'spatial' :
405- dual_attention_block .append (SpatialBlock (
459+ dual_attention_block .append (( 'spatial_block' , SpatialBlock (
406460 dim = out_chs ,
407461 num_heads = num_heads ,
408462 mlp_ratio = mlp_ratio ,
@@ -412,19 +466,23 @@ def __init__(
412466 ffn = ffn ,
413467 cpe_act = cpe_act ,
414468 window_size = window_size ,
415- ))
469+ )))
416470 elif attn_type == 'channel' :
417- dual_attention_block .append (ChannelBlock (
471+ dual_attention_block .append (( 'channel_block' , ChannelBlock (
418472 dim = out_chs ,
419473 num_heads = num_heads ,
420474 mlp_ratio = mlp_ratio ,
421475 qkv_bias = qkv_bias ,
422476 drop_path = drop_path_rates [block_idx ],
423477 norm_layer = norm_layer_cl ,
424478 ffn = ffn ,
425- cpe_act = cpe_act
426- ))
427- stage_blocks .append (nn .Sequential (* dual_attention_block ))
479+ cpe_act = cpe_act ,
480+ v2 = channel_attn_v2 ,
481+ )))
482+ if named_blocks :
483+ stage_blocks .append (nn .Sequential (OrderedDict (dual_attention_block )))
484+ else :
485+ stage_blocks .append (nn .Sequential (* [b [1 ] for b in dual_attention_block ]))
428486 self .blocks = nn .Sequential (* stage_blocks )
429487
430488 @torch .jit .ignore
@@ -473,6 +531,9 @@ def __init__(
473531 attn_types = ('spatial' , 'channel' ),
474532 ffn = True ,
475533 cpe_act = False ,
534+ down_kernel_size = 2 ,
535+ channel_attn_v2 = False ,
536+ named_blocks = False ,
476537 drop_rate = 0. ,
477538 drop_path_rate = 0. ,
478539 num_classes = 1000 ,
@@ -512,6 +573,9 @@ def __init__(
512573 norm_layer_cl = norm_layer_cl ,
513574 ffn = ffn ,
514575 cpe_act = cpe_act ,
576+ down_kernel_size = down_kernel_size ,
577+ channel_attn_v2 = channel_attn_v2 ,
578+ named_blocks = named_blocks ,
515579 )
516580 in_chs = out_chs
517581 stages .append (stage )
@@ -589,6 +653,34 @@ def forward(self, x):
589653 return x
590654
591655
656+ def _convert_florence2 (state_dict , model , prefix = 'vision_tower.' ):
657+ import re
658+ out_dict = {}
659+
660+ for k , v in state_dict .items ():
661+ if k .startswith (prefix ):
662+ k = k .replace (prefix , '' )
663+ else :
664+ continue
665+ k = re .sub (r'convs.([0-9]+)' , r'stages.\1.downsample' , k )
666+ k = re .sub (r'blocks.([0-9]+)' , r'stages.\1.blocks' , k )
667+ k = k .replace ('downsample.proj' , 'downsample.conv' )
668+ k = k .replace ('stages.0.downsample' , 'stem' )
669+ #k = k.replace('head.', 'head.fc.')
670+ #k = k.replace('norms.', 'head.norm.')
671+ k = k .replace ('window_attn.norm.' , 'norm1.' )
672+ k = k .replace ('window_attn.fn.' , 'attn.' )
673+ k = k .replace ('channel_attn.norm.' , 'norm1.' )
674+ k = k .replace ('channel_attn.fn.' , 'attn.' )
675+ k = k .replace ('ffn.norm.' , 'norm2.' )
676+ k = k .replace ('ffn.fn.net.' , 'mlp.' )
677+ k = k .replace ('conv1.fn.dw' , 'cpe1.proj' )
678+ k = k .replace ('conv2.fn.dw' , 'cpe2.proj' )
679+ out_dict [k ] = v
680+
681+ return out_dict
682+
683+
592684def checkpoint_filter_fn (state_dict , model ):
593685 """ Remap MSFT checkpoints -> timm """
594686 if 'head.fc.weight' in state_dict :
@@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
597689 if 'state_dict' in state_dict :
598690 state_dict = state_dict ['state_dict' ]
599691
692+ if 'vision_tower.convs.0.proj.weight' in state_dict :
693+ return _convert_florence2 (state_dict , model )
694+
600695 import re
601696 out_dict = {}
602697 for k , v in state_dict .items ():
@@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
615710def _create_davit (variant , pretrained = False , ** kwargs ):
616711 default_out_indices = tuple (i for i , _ in enumerate (kwargs .get ('depths' , (1 , 1 , 3 , 1 ))))
617712 out_indices = kwargs .pop ('out_indices' , default_out_indices )
618-
713+ strict = True
714+ if variant .endswith ('_fl' ):
715+ # FIXME cleaner approach to missing head norm?
716+ strict = False
619717 model = build_model_with_cfg (
620718 DaVit ,
621719 variant ,
622720 pretrained ,
623721 pretrained_filter_fn = checkpoint_filter_fn ,
624722 feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
723+ pretrained_strict = strict ,
625724 ** kwargs )
626725
627726 return model
@@ -650,6 +749,12 @@ def _cfg(url='', **kwargs):
650749 'davit_large' : _cfg (),
651750 'davit_huge' : _cfg (),
652751 'davit_giant' : _cfg (),
752+ 'davit_base_fl.msft_florence2' : _cfg (
753+ hf_hub_id = 'microsoft/Florence-2-base' ,
754+ num_classes = 0 , input_size = (3 , 768 , 768 )),
755+ 'davit_huge_fl.msft_florence2' : _cfg (
756+ hf_hub_id = 'microsoft/Florence-2-large' ,
757+ num_classes = 0 , input_size = (3 , 768 , 768 )),
653758})
654759
655760
@@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
687792def davit_giant (pretrained = False , ** kwargs ) -> DaVit :
688793 model_args = dict (depths = (1 , 1 , 12 , 3 ), embed_dims = (384 , 768 , 1536 , 3072 ), num_heads = (12 , 24 , 48 , 96 ))
689794 return _create_davit ('davit_giant' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
795+
796+
797+
798+ @register_model
799+ def davit_base_fl (pretrained = False , ** kwargs ) -> DaVit :
800+ model_args = dict (
801+ depths = (1 , 1 , 9 , 1 ), embed_dims = (128 , 256 , 512 , 1024 ), num_heads = (4 , 8 , 16 , 32 ),
802+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
803+ )
804+ return _create_davit ('davit_base_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
805+
806+
807+ @register_model
808+ def davit_huge_fl (pretrained = False , ** kwargs ) -> DaVit :
809+ # NOTE: huge image tower used in 'large' Florence2 model
810+ model_args = dict (
811+ depths = (1 , 1 , 9 , 1 ), embed_dims = (256 , 512 , 1024 , 2048 ), num_heads = (8 , 16 , 32 , 64 ),
812+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
813+ )
814+ return _create_davit ('davit_huge_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments