1+ from typing import List , Optional
12import torch .nn as nn
23import torch
3- from torch .nn .utils import spectral_norm
4+ from torch .nn .utils . spectral_norm import spectral_norm
45from collections import OrderedDict
56
67
@@ -39,16 +40,28 @@ class ConvBnAct(nn.Sequential):
3940 convolution_module = nn .Conv2d # can be changed in models like twist.
4041 batchnorm_module = nn .BatchNorm2d
4142
42- def __init__ (self , in_channels , out_channels , kernel_size = 3 , stride = 1 ,
43- padding = None , bias = False , groups = 1 ,
44- act_fn = act_fn , pre_act = False ,
45- bn_layer = True , bn_1st = True , zero_bn = False ,
46- ):
43+ def __init__ (
44+ self ,
45+ in_channels : int ,
46+ out_channels : int ,
47+ kernel_size : int = 3 ,
48+ stride : int = 1 ,
49+ padding : Optional [int ] = None ,
50+ bias : bool = False ,
51+ groups : int = 1 ,
52+ act_fn : Optional [nn .Module ] = act_fn ,
53+ pre_act : bool = False ,
54+ bn_layer : bool = True ,
55+ bn_1st : bool = True ,
56+ zero_bn : bool = False ,
57+ ):
4758
4859 if padding is None :
4960 padding = kernel_size // 2
50- layers = [('conv' , self .convolution_module (in_channels , out_channels , kernel_size , stride = stride ,
51- padding = padding , bias = bias , groups = groups ))] # if no bn - bias True?
61+ layers : List [tuple [str , nn .Module ]] = [
62+ ('conv' , self .convolution_module (
63+ in_channels , out_channels , kernel_size , stride = stride , padding = padding , bias = bias , groups = groups ))
64+ ] # if no bn - bias True?
5265 if bn_layer :
5366 bn = self .batchnorm_module (out_channels )
5467 nn .init .constant_ (bn .weight , 0. if zero_bn else 1. )
@@ -133,7 +146,7 @@ def forward(self, x):
133146 return o .view (* size ).contiguous ()
134147
135148
136- class SEBlock (nn .Module ): # todo: deprecation worning .
149+ class SEBlock (nn .Module ): # todo: deprecation warning .
137150 "se block"
138151 se_layer = nn .Linear
139152 act_fn = nn .ReLU (inplace = True )
@@ -157,7 +170,7 @@ def forward(self, x):
157170 return x * y .expand_as (x )
158171
159172
160- class SEBlockConv (nn .Module ): # todo: deprecation worning .
173+ class SEBlockConv (nn .Module ): # todo: deprecation warning .
161174 "se block with conv on excitation"
162175 se_layer = nn .Conv2d
163176 act_fn = nn .ReLU (inplace = True )
0 commit comments