@@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
3434 def __init__ (self , dim : int , k : int = 3 , act : bool = False ):
3535 super (ConvPosEnc , self ).__init__ ()
3636
37- self .proj = nn .Conv2d (dim , dim , k , 1 , k // 2 , groups = dim )
37+ self .proj = nn .Conv2d (
38+ dim ,
39+ dim ,
40+ kernel_size = k ,
41+ stride = 1 ,
42+ padding = k // 2 ,
43+ groups = dim ,
44+ )
3845 self .act = nn .GELU () if act else nn .Identity ()
3946
4047 def forward (self , x : Tensor ):
@@ -84,30 +91,65 @@ def __init__(
8491 self ,
8592 in_chs ,
8693 out_chs ,
94+ kernel_size = 3 ,
8795 norm_layer = LayerNorm2d ,
8896 ):
8997 super ().__init__ ()
9098 self .in_chs = in_chs
9199 self .out_chs = out_chs
92100
93101 self .norm = norm_layer (in_chs )
102+ self .even_k = kernel_size % 2 == 0
94103 self .conv = nn .Conv2d (
95104 in_chs ,
96105 out_chs ,
97- kernel_size = 2 ,
106+ kernel_size = kernel_size ,
98107 stride = 2 ,
99- padding = 0 ,
108+ padding = 0 if self . even_k else kernel_size // 2 ,
100109 )
101110
102111 def forward (self , x : Tensor ):
103112 B , C , H , W = x .shape
104113 x = self .norm (x )
105- x = F .pad (x , (0 , (2 - W % 2 ) % 2 ))
106- x = F .pad (x , (0 , 0 , 0 , (2 - H % 2 ) % 2 ))
114+ if self .even_k :
115+ k_h , k_w = self .conv .kernel_size
116+ x = F .pad (x , (0 , (k_w - W % k_w ) % k_w ))
117+ x = F .pad (x , (0 , 0 , 0 , (k_h - H % k_h ) % k_h ))
107118 x = self .conv (x )
108119 return x
109120
110121
122+ class ChannelAttentionV2 (nn .Module ):
123+
124+ def __init__ (self , dim , num_heads = 8 , qkv_bias = True , dynamic_scale = True ):
125+ super ().__init__ ()
126+ self .groups = num_heads
127+ self .head_dim = dim // num_heads
128+ self .dynamic_scale = dynamic_scale
129+
130+ self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
131+ self .proj = nn .Linear (dim , dim )
132+
133+ def forward (self , x ):
134+ B , N , C = x .shape
135+
136+ qkv = self .qkv (x ).reshape (B , N , 3 , self .groups , C // self .groups ).permute (2 , 0 , 3 , 1 , 4 )
137+ q , k , v = qkv .unbind (0 )
138+
139+ if self .dynamic_scale :
140+ q = q * float (N ) ** - 0.5
141+ else :
142+ q = q * self .head_dim ** - 0.5
143+ attn = q .transpose (- 1 , - 2 ) @ k
144+ attn = attn .softmax (dim = - 1 )
145+ x = (attn @ v .transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
146+
147+ x = x .transpose (1 , 2 ).reshape (B , N , C )
148+ x = self .proj (x )
149+ return x
150+
151+
152+
111153class ChannelAttention (nn .Module ):
112154
113155 def __init__ (self , dim , num_heads = 8 , qkv_bias = False ):
@@ -147,13 +189,19 @@ def __init__(
147189 norm_layer = nn .LayerNorm ,
148190 ffn = True ,
149191 cpe_act = False ,
192+ v2 = False ,
150193 ):
151194 super ().__init__ ()
152195
153196 self .cpe1 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
154197 self .ffn = ffn
155198 self .norm1 = norm_layer (dim )
156- self .attn = ChannelAttention (dim , num_heads = num_heads , qkv_bias = qkv_bias )
199+ attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
200+ self .attn = attn_layer (
201+ dim ,
202+ num_heads = num_heads ,
203+ qkv_bias = qkv_bias ,
204+ )
157205 self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
158206 self .cpe2 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
159207
@@ -372,21 +420,24 @@ def __init__(
372420 attn_types = ('spatial' , 'channel' ),
373421 num_heads = 3 ,
374422 window_size = 7 ,
375- mlp_ratio = 4 ,
423+ mlp_ratio = 4. ,
376424 qkv_bias = True ,
377425 drop_path_rates = (0 , 0 ),
378426 norm_layer = LayerNorm2d ,
379427 norm_layer_cl = nn .LayerNorm ,
380428 ffn = True ,
381- cpe_act = False
429+ cpe_act = False ,
430+ down_kernel_size = 2 ,
431+ named_blocks = False ,
432+ channel_attn_v2 = False ,
382433 ):
383434 super ().__init__ ()
384435
385436 self .grad_checkpointing = False
386437
387438 # downsample embedding layer at the beginning of each stage
388439 if downsample :
389- self .downsample = Downsample (in_chs , out_chs , norm_layer = norm_layer )
440+ self .downsample = Downsample (in_chs , out_chs , kernel_size = down_kernel_size , norm_layer = norm_layer )
390441 else :
391442 self .downsample = nn .Identity ()
392443
@@ -399,10 +450,11 @@ def __init__(
399450 '''
400451 stage_blocks = []
401452 for block_idx in range (depth ):
453+ from collections import OrderedDict
402454 dual_attention_block = []
403455 for attn_idx , attn_type in enumerate (attn_types ):
404456 if attn_type == 'spatial' :
405- dual_attention_block .append (SpatialBlock (
457+ dual_attention_block .append (( 'spatial_block' , SpatialBlock (
406458 dim = out_chs ,
407459 num_heads = num_heads ,
408460 mlp_ratio = mlp_ratio ,
@@ -412,19 +464,23 @@ def __init__(
412464 ffn = ffn ,
413465 cpe_act = cpe_act ,
414466 window_size = window_size ,
415- ))
467+ )))
416468 elif attn_type == 'channel' :
417- dual_attention_block .append (ChannelBlock (
469+ dual_attention_block .append (( 'channel_block' , ChannelBlock (
418470 dim = out_chs ,
419471 num_heads = num_heads ,
420472 mlp_ratio = mlp_ratio ,
421473 qkv_bias = qkv_bias ,
422474 drop_path = drop_path_rates [block_idx ],
423475 norm_layer = norm_layer_cl ,
424476 ffn = ffn ,
425- cpe_act = cpe_act
426- ))
427- stage_blocks .append (nn .Sequential (* dual_attention_block ))
477+ cpe_act = cpe_act ,
478+ v2 = channel_attn_v2 ,
479+ )))
480+ if named_blocks :
481+ stage_blocks .append (nn .Sequential (OrderedDict (dual_attention_block )))
482+ else :
483+ stage_blocks .append (nn .Sequential (* [b [1 ] for b in dual_attention_block ]))
428484 self .blocks = nn .Sequential (* stage_blocks )
429485
430486 @torch .jit .ignore
@@ -473,6 +529,9 @@ def __init__(
473529 attn_types = ('spatial' , 'channel' ),
474530 ffn = True ,
475531 cpe_act = False ,
532+ down_kernel_size = 2 ,
533+ channel_attn_v2 = False ,
534+ named_blocks = False ,
476535 drop_rate = 0. ,
477536 drop_path_rate = 0. ,
478537 num_classes = 1000 ,
@@ -512,6 +571,9 @@ def __init__(
512571 norm_layer_cl = norm_layer_cl ,
513572 ffn = ffn ,
514573 cpe_act = cpe_act ,
574+ down_kernel_size = down_kernel_size ,
575+ channel_attn_v2 = channel_attn_v2 ,
576+ named_blocks = named_blocks ,
515577 )
516578 in_chs = out_chs
517579 stages .append (stage )
@@ -589,6 +651,34 @@ def forward(self, x):
589651 return x
590652
591653
654+ def _convert_florence2 (state_dict , model , prefix = 'vision_tower.' ):
655+ import re
656+ out_dict = {}
657+
658+ for k , v in state_dict .items ():
659+ if k .startswith (prefix ):
660+ k = k .replace (prefix , '' )
661+ else :
662+ continue
663+ k = re .sub (r'convs.([0-9]+)' , r'stages.\1.downsample' , k )
664+ k = re .sub (r'blocks.([0-9]+)' , r'stages.\1.blocks' , k )
665+ k = k .replace ('downsample.proj' , 'downsample.conv' )
666+ k = k .replace ('stages.0.downsample' , 'stem' )
667+ #k = k.replace('head.', 'head.fc.')
668+ #k = k.replace('norms.', 'head.norm.')
669+ k = k .replace ('window_attn.norm.' , 'norm1.' )
670+ k = k .replace ('window_attn.fn.' , 'attn.' )
671+ k = k .replace ('channel_attn.norm.' , 'norm1.' )
672+ k = k .replace ('channel_attn.fn.' , 'attn.' )
673+ k = k .replace ('ffn.norm.' , 'norm2.' )
674+ k = k .replace ('ffn.fn.net.' , 'mlp.' )
675+ k = k .replace ('conv1.fn.dw' , 'cpe1.proj' )
676+ k = k .replace ('conv2.fn.dw' , 'cpe2.proj' )
677+ out_dict [k ] = v
678+
679+ return out_dict
680+
681+
592682def checkpoint_filter_fn (state_dict , model ):
593683 """ Remap MSFT checkpoints -> timm """
594684 if 'head.fc.weight' in state_dict :
@@ -597,6 +687,9 @@ def checkpoint_filter_fn(state_dict, model):
597687 if 'state_dict' in state_dict :
598688 state_dict = state_dict ['state_dict' ]
599689
690+ if 'vision_tower.convs.0.proj.weight' in state_dict :
691+ return _convert_florence2 (state_dict , model )
692+
600693 import re
601694 out_dict = {}
602695 for k , v in state_dict .items ():
@@ -615,13 +708,17 @@ def checkpoint_filter_fn(state_dict, model):
615708def _create_davit (variant , pretrained = False , ** kwargs ):
616709 default_out_indices = tuple (i for i , _ in enumerate (kwargs .get ('depths' , (1 , 1 , 3 , 1 ))))
617710 out_indices = kwargs .pop ('out_indices' , default_out_indices )
618-
711+ strict = True
712+ if variant .endswith ('_fl' ):
713+ # FIXME cleaner approach to missing head norm?
714+ strict = False
619715 model = build_model_with_cfg (
620716 DaVit ,
621717 variant ,
622718 pretrained ,
623719 pretrained_filter_fn = checkpoint_filter_fn ,
624720 feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
721+ pretrained_strict = strict ,
625722 ** kwargs )
626723
627724 return model
@@ -650,6 +747,12 @@ def _cfg(url='', **kwargs):
650747 'davit_large' : _cfg (),
651748 'davit_huge' : _cfg (),
652749 'davit_giant' : _cfg (),
750+ 'davit_base_fl.msft_florence2' : _cfg (
751+ hf_hub_id = 'microsoft/Florence-2-base' ,
752+ num_classes = 0 , input_size = (3 , 768 , 768 )),
753+ 'davit_huge_fl.msft_florence2' : _cfg (
754+ hf_hub_id = 'microsoft/Florence-2-large' ,
755+ num_classes = 0 , input_size = (3 , 768 , 768 )),
653756})
654757
655758
@@ -687,3 +790,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
687790def davit_giant (pretrained = False , ** kwargs ) -> DaVit :
688791 model_args = dict (depths = (1 , 1 , 12 , 3 ), embed_dims = (384 , 768 , 1536 , 3072 ), num_heads = (12 , 24 , 48 , 96 ))
689792 return _create_davit ('davit_giant' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
793+
794+
795+
796+ @register_model
797+ def davit_base_fl (pretrained = False , ** kwargs ) -> DaVit :
798+ model_args = dict (
799+ depths = (1 , 1 , 9 , 1 ), embed_dims = (128 , 256 , 512 , 1024 ), num_heads = (4 , 8 , 16 , 32 ),
800+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
801+ )
802+ return _create_davit ('davit_base_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
803+
804+
805+ @register_model
806+ def davit_huge_fl (pretrained = False , ** kwargs ) -> DaVit :
807+ # NOTE: huge image tower used in 'large' Florence2 model
808+ model_args = dict (
809+ depths = (1 , 1 , 9 , 1 ), embed_dims = (256 , 512 , 1024 , 2048 ), num_heads = (8 , 16 , 32 , 64 ),
810+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
811+ )
812+ return _create_davit ('davit_huge_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments