11# ---------------------------------------------------------------
22# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
33#
4- # This work is licensed under the NVIDIA Source Code License
4+ # Licensed under the NVIDIA Source Code License. For full license
5+ # terms, please refer to the LICENSE file provided with this code
6+ # or visit NVIDIA's official repository at
7+ # https://github.com/NVlabs/SegFormer/tree/master.
8+ #
9+ # This code has been modified.
510# ---------------------------------------------------------------
611import math
712import torch
1116from timm .layers import DropPath , to_2tuple , trunc_normal_
1217
1318
19+ class LayerNorm (nn .LayerNorm ):
20+ def forward (self , x ):
21+ if x .ndim == 4 :
22+ B , C , H , W = x .shape
23+ x = x .view (B , C , - 1 ).transpose (1 , 2 )
24+ x = super ().forward (x )
25+ x = x .transpose (1 , 2 ).view (B , C , H , W )
26+ else :
27+ x = super ().forward (x )
28+ return x
29+
30+
1431class Mlp (nn .Module ):
1532 def __init__ (
1633 self ,
@@ -36,9 +53,6 @@ def _init_weights(self, m):
3653 trunc_normal_ (m .weight , std = 0.02 )
3754 if isinstance (m , nn .Linear ) and m .bias is not None :
3855 nn .init .constant_ (m .bias , 0 )
39- elif isinstance (m , nn .LayerNorm ):
40- nn .init .constant_ (m .bias , 0 )
41- nn .init .constant_ (m .weight , 1.0 )
4256 elif isinstance (m , nn .Conv2d ):
4357 fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
4458 fan_out //= m .groups
@@ -86,7 +100,7 @@ def __init__(
86100 self .sr_ratio = sr_ratio
87101 if sr_ratio > 1 :
88102 self .sr = nn .Conv2d (dim , dim , kernel_size = sr_ratio , stride = sr_ratio )
89- self .norm = nn . LayerNorm (dim )
103+ self .norm = LayerNorm (dim )
90104
91105 self .apply (self ._init_weights )
92106
@@ -95,7 +109,7 @@ def _init_weights(self, m):
95109 trunc_normal_ (m .weight , std = 0.02 )
96110 if isinstance (m , nn .Linear ) and m .bias is not None :
97111 nn .init .constant_ (m .bias , 0 )
98- elif isinstance (m , nn . LayerNorm ):
112+ elif isinstance (m , LayerNorm ):
99113 nn .init .constant_ (m .bias , 0 )
100114 nn .init .constant_ (m .weight , 1.0 )
101115 elif isinstance (m , nn .Conv2d ):
@@ -153,7 +167,7 @@ def __init__(
153167 attn_drop = 0.0 ,
154168 drop_path = 0.0 ,
155169 act_layer = nn .GELU ,
156- norm_layer = nn . LayerNorm ,
170+ norm_layer = LayerNorm ,
157171 sr_ratio = 1 ,
158172 ):
159173 super ().__init__ ()
@@ -185,7 +199,7 @@ def _init_weights(self, m):
185199 trunc_normal_ (m .weight , std = 0.02 )
186200 if isinstance (m , nn .Linear ) and m .bias is not None :
187201 nn .init .constant_ (m .bias , 0 )
188- elif isinstance (m , nn . LayerNorm ):
202+ elif isinstance (m , LayerNorm ):
189203 nn .init .constant_ (m .bias , 0 )
190204 nn .init .constant_ (m .weight , 1.0 )
191205 elif isinstance (m , nn .Conv2d ):
@@ -195,10 +209,12 @@ def _init_weights(self, m):
195209 if m .bias is not None :
196210 m .bias .data .zero_ ()
197211
198- def forward (self , x , H , W ):
212+ def forward (self , x ):
213+ B , _ , H , W = x .shape
214+ x = x .flatten (2 ).transpose (1 , 2 )
199215 x = x + self .drop_path (self .attn (self .norm1 (x ), H , W ))
200216 x = x + self .drop_path (self .mlp (self .norm2 (x ), H , W ))
201-
217+ x = x . transpose ( 1 , 2 ). view ( B , - 1 , H , W )
202218 return x
203219
204220
@@ -221,7 +237,7 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7
221237 stride = stride ,
222238 padding = (patch_size [0 ] // 2 , patch_size [1 ] // 2 ),
223239 )
224- self .norm = nn . LayerNorm (embed_dim )
240+ self .norm = LayerNorm (embed_dim )
225241
226242 self .apply (self ._init_weights )
227243
@@ -230,7 +246,7 @@ def _init_weights(self, m):
230246 trunc_normal_ (m .weight , std = 0.02 )
231247 if isinstance (m , nn .Linear ) and m .bias is not None :
232248 nn .init .constant_ (m .bias , 0 )
233- elif isinstance (m , nn . LayerNorm ):
249+ elif isinstance (m , LayerNorm ):
234250 nn .init .constant_ (m .bias , 0 )
235251 nn .init .constant_ (m .weight , 1.0 )
236252 elif isinstance (m , nn .Conv2d ):
@@ -242,11 +258,8 @@ def _init_weights(self, m):
242258
243259 def forward (self , x ):
244260 x = self .proj (x )
245- _ , _ , H , W = x .shape
246- x = x .flatten (2 ).transpose (1 , 2 )
247261 x = self .norm (x )
248-
249- return x , H , W
262+ return x
250263
251264
252265class MixVisionTransformer (nn .Module ):
@@ -264,7 +277,7 @@ def __init__(
264277 drop_rate = 0.0 ,
265278 attn_drop_rate = 0.0 ,
266279 drop_path_rate = 0.0 ,
267- norm_layer = nn . LayerNorm ,
280+ norm_layer = LayerNorm ,
268281 depths = [3 , 4 , 6 , 3 ],
269282 sr_ratios = [8 , 4 , 2 , 1 ],
270283 ):
@@ -307,8 +320,8 @@ def __init__(
307320 x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))
308321 ] # stochastic depth decay rule
309322 cur = 0
310- self .block1 = nn .ModuleList (
311- [
323+ self .block1 = nn .Sequential (
324+ * [
312325 Block (
313326 dim = embed_dims [0 ],
314327 num_heads = num_heads [0 ],
@@ -327,8 +340,8 @@ def __init__(
327340 self .norm1 = norm_layer (embed_dims [0 ])
328341
329342 cur += depths [0 ]
330- self .block2 = nn .ModuleList (
331- [
343+ self .block2 = nn .Sequential (
344+ * [
332345 Block (
333346 dim = embed_dims [1 ],
334347 num_heads = num_heads [1 ],
@@ -347,8 +360,8 @@ def __init__(
347360 self .norm2 = norm_layer (embed_dims [1 ])
348361
349362 cur += depths [1 ]
350- self .block3 = nn .ModuleList (
351- [
363+ self .block3 = nn .Sequential (
364+ * [
352365 Block (
353366 dim = embed_dims [2 ],
354367 num_heads = num_heads [2 ],
@@ -367,8 +380,8 @@ def __init__(
367380 self .norm3 = norm_layer (embed_dims [2 ])
368381
369382 cur += depths [2 ]
370- self .block4 = nn .ModuleList (
371- [
383+ self .block4 = nn .Sequential (
384+ * [
372385 Block (
373386 dim = embed_dims [3 ],
374387 num_heads = num_heads [3 ],
@@ -396,7 +409,7 @@ def _init_weights(self, m):
396409 trunc_normal_ (m .weight , std = 0.02 )
397410 if isinstance (m , nn .Linear ) and m .bias is not None :
398411 nn .init .constant_ (m .bias , 0 )
399- elif isinstance (m , nn . LayerNorm ):
412+ elif isinstance (m , LayerNorm ):
400413 nn .init .constant_ (m .bias , 0 )
401414 nn .init .constant_ (m .weight , 1.0 )
402415 elif isinstance (m , nn .Conv2d ):
@@ -450,39 +463,30 @@ def reset_classifier(self, num_classes, global_pool=""):
450463 )
451464
452465 def forward_features (self , x ):
453- B = x .shape [0 ]
454466 outs = []
455467
456468 # stage 1
457- x , H , W = self .patch_embed1 (x )
458- for i , blk in enumerate (self .block1 ):
459- x = blk (x , H , W )
460- x = self .norm1 (x )
461- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
469+ x = self .patch_embed1 (x )
470+ x = self .block1 (x )
471+ x = self .norm1 (x ).contiguous ()
462472 outs .append (x )
463473
464474 # stage 2
465- x , H , W = self .patch_embed2 (x )
466- for i , blk in enumerate (self .block2 ):
467- x = blk (x , H , W )
468- x = self .norm2 (x )
469- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
475+ x = self .patch_embed2 (x )
476+ x = self .block2 (x )
477+ x = self .norm2 (x ).contiguous ()
470478 outs .append (x )
471479
472480 # stage 3
473- x , H , W = self .patch_embed3 (x )
474- for i , blk in enumerate (self .block3 ):
475- x = blk (x , H , W )
476- x = self .norm3 (x )
477- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
481+ x = self .patch_embed3 (x )
482+ x = self .block3 (x )
483+ x = self .norm3 (x ).contiguous ()
478484 outs .append (x )
479485
480486 # stage 4
481- x , H , W = self .patch_embed4 (x )
482- for i , blk in enumerate (self .block4 ):
483- x = blk (x , H , W )
484- x = self .norm4 (x )
485- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
487+ x = self .patch_embed4 (x )
488+ x = self .block4 (x )
489+ x = self .norm4 (x ).contiguous ()
486490 outs .append (x )
487491
488492 return outs
@@ -500,7 +504,7 @@ def __init__(self, dim=768):
500504 self .dwconv = nn .Conv2d (dim , dim , 3 , 1 , 1 , bias = True , groups = dim )
501505
502506 def forward (self , x , H , W ):
503- B , N , C = x .shape
507+ B , _ , C = x .shape
504508 x = x .transpose (1 , 2 ).view (B , C , H , W )
505509 x = self .dwconv (x )
506510 x = x .flatten (2 ).transpose (1 , 2 )
@@ -522,21 +526,31 @@ def __init__(self, out_channels, depth=5, **kwargs):
522526 self ._depth = depth
523527 self ._in_channels = 3
524528
525- def make_dilated (self , * args , ** kwargs ):
526- raise ValueError ("MixVisionTransformer encoder does not support dilated mode" )
527-
528- def set_in_channels (self , in_channels , * args , ** kwargs ):
529- if in_channels != 3 :
530- raise ValueError (
531- "MixVisionTransformer encoder does not support in_channels setting other than 3"
532- )
529+ def get_stages (self ):
530+ return [
531+ nn .Identity (),
532+ nn .Identity (),
533+ nn .Sequential (self .patch_embed1 , self .block1 , self .norm1 ),
534+ nn .Sequential (self .patch_embed2 , self .block2 , self .norm2 ),
535+ nn .Sequential (self .patch_embed3 , self .block3 , self .norm3 ),
536+ nn .Sequential (self .patch_embed4 , self .block4 , self .norm4 ),
537+ ]
533538
534539 def forward (self , x ):
540+ stages = self .get_stages ()
541+
535542 # create dummy output for the first block
536- B , C , H , W = x .shape
543+ B , _ , H , W = x .shape
537544 dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
538545
539- return [x , dummy ] + self .forward_features (x )[: self ._depth - 1 ]
546+ features = []
547+ for i in range (self ._depth + 1 ):
548+ if i == 1 :
549+ features .append (dummy )
550+ else :
551+ x = stages [i ](x ).contiguous ()
552+ features .append (x )
553+ return features
540554
541555 def load_state_dict (self , state_dict ):
542556 state_dict .pop ("head.weight" , None )
@@ -568,7 +582,7 @@ def get_pretrained_cfg(name):
568582 num_heads = [1 , 2 , 5 , 8 ],
569583 mlp_ratios = [4 , 4 , 4 , 4 ],
570584 qkv_bias = True ,
571- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
585+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
572586 depths = [2 , 2 , 2 , 2 ],
573587 sr_ratios = [8 , 4 , 2 , 1 ],
574588 drop_rate = 0.0 ,
@@ -585,7 +599,7 @@ def get_pretrained_cfg(name):
585599 num_heads = [1 , 2 , 5 , 8 ],
586600 mlp_ratios = [4 , 4 , 4 , 4 ],
587601 qkv_bias = True ,
588- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
602+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
589603 depths = [2 , 2 , 2 , 2 ],
590604 sr_ratios = [8 , 4 , 2 , 1 ],
591605 drop_rate = 0.0 ,
@@ -602,7 +616,7 @@ def get_pretrained_cfg(name):
602616 num_heads = [1 , 2 , 5 , 8 ],
603617 mlp_ratios = [4 , 4 , 4 , 4 ],
604618 qkv_bias = True ,
605- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
619+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
606620 depths = [3 , 4 , 6 , 3 ],
607621 sr_ratios = [8 , 4 , 2 , 1 ],
608622 drop_rate = 0.0 ,
@@ -619,7 +633,7 @@ def get_pretrained_cfg(name):
619633 num_heads = [1 , 2 , 5 , 8 ],
620634 mlp_ratios = [4 , 4 , 4 , 4 ],
621635 qkv_bias = True ,
622- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
636+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
623637 depths = [3 , 4 , 18 , 3 ],
624638 sr_ratios = [8 , 4 , 2 , 1 ],
625639 drop_rate = 0.0 ,
@@ -636,7 +650,7 @@ def get_pretrained_cfg(name):
636650 num_heads = [1 , 2 , 5 , 8 ],
637651 mlp_ratios = [4 , 4 , 4 , 4 ],
638652 qkv_bias = True ,
639- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
653+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
640654 depths = [3 , 8 , 27 , 3 ],
641655 sr_ratios = [8 , 4 , 2 , 1 ],
642656 drop_rate = 0.0 ,
@@ -653,7 +667,7 @@ def get_pretrained_cfg(name):
653667 num_heads = [1 , 2 , 5 , 8 ],
654668 mlp_ratios = [4 , 4 , 4 , 4 ],
655669 qkv_bias = True ,
656- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
670+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
657671 depths = [3 , 6 , 40 , 3 ],
658672 sr_ratios = [8 , 4 , 2 , 1 ],
659673 drop_rate = 0.0 ,
0 commit comments