Skip to content

Commit e444c68

Browse files
committed
layers typing
1 parent 7a4f0aa commit e444c68

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

model_constructor/layers.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import List, Optional
12
import torch.nn as nn
23
import torch
3-
from torch.nn.utils import spectral_norm
4+
from torch.nn.utils.spectral_norm import spectral_norm
45
from collections import OrderedDict
56

67

@@ -39,16 +40,28 @@ class ConvBnAct(nn.Sequential):
3940
convolution_module = nn.Conv2d # can be changed in models like twist.
4041
batchnorm_module = nn.BatchNorm2d
4142

42-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
43-
padding=None, bias=False, groups=1,
44-
act_fn=act_fn, pre_act=False,
45-
bn_layer=True, bn_1st=True, zero_bn=False,
46-
):
43+
def __init__(
44+
self,
45+
in_channels: int,
46+
out_channels: int,
47+
kernel_size: int = 3,
48+
stride: int = 1,
49+
padding: Optional[int] = None,
50+
bias: bool = False,
51+
groups: int = 1,
52+
act_fn: Optional[nn.Module] = act_fn,
53+
pre_act: bool = False,
54+
bn_layer: bool = True,
55+
bn_1st: bool = True,
56+
zero_bn: bool = False,
57+
):
4758

4859
if padding is None:
4960
padding = kernel_size // 2
50-
layers = [('conv', self.convolution_module(in_channels, out_channels, kernel_size, stride=stride,
51-
padding=padding, bias=bias, groups=groups))] # if no bn - bias True?
61+
layers: List[tuple[str, nn.Module]] = [
62+
('conv', self.convolution_module(
63+
in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias, groups=groups))
64+
] # if no bn - bias True?
5265
if bn_layer:
5366
bn = self.batchnorm_module(out_channels)
5467
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
@@ -133,7 +146,7 @@ def forward(self, x):
133146
return o.view(*size).contiguous()
134147

135148

136-
class SEBlock(nn.Module): # todo: deprecation worning.
149+
class SEBlock(nn.Module): # todo: deprecation warning.
137150
"se block"
138151
se_layer = nn.Linear
139152
act_fn = nn.ReLU(inplace=True)
@@ -157,7 +170,7 @@ def forward(self, x):
157170
return x * y.expand_as(x)
158171

159172

160-
class SEBlockConv(nn.Module): # todo: deprecation worning.
173+
class SEBlockConv(nn.Module): # todo: deprecation warning.
161174
"se block with conv on excitation"
162175
se_layer = nn.Conv2d
163176
act_fn = nn.ReLU(inplace=True)

0 commit comments

Comments
 (0)