Skip to content

Commit 8a8eb35

Browse files
committed
mc init
1 parent 8d97595 commit 8a8eb35

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

model_constructor/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
bn = self.batchnorm_module(out_channels)
6767
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
6868
layers.append(('bn', bn))
69-
if act_fn:
69+
if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False
7070
if pre_act:
7171
act_position = 0
7272
elif not bn_1st:
@@ -111,7 +111,7 @@ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bia
111111
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
112112
nn.init.kaiming_normal_(conv.weight)
113113
if bias:
114-
conv.bias.data.zero_()
114+
conv.bias.data.zero_() # type: ignore
115115
return spectral_norm(conv)
116116

117117

@@ -125,7 +125,7 @@ class SimpleSelfAttention(nn.Module):
125125
def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
126126
super().__init__()
127127
self.conv = conv1d(n_in, n_in, ks, padding=ks // 2, bias=use_bias)
128-
self.gamma = nn.Parameter(torch.tensor([0.]))
128+
self.gamma = torch.nn.Parameter(torch.tensor([0.])) # type: ignore
129129
self.sym = sym
130130
self.n_in = n_in
131131

model_constructor/model_constructor.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Callable, List, Sequence, Union
3+
from typing import Callable, List, Type, Union
44

55
import torch.nn as nn
66

@@ -16,7 +16,7 @@
1616
def init_cnn(module: nn.Module):
1717
"Init module - kaiming_normal for Conv2d and 0 for biases."
1818
if getattr(module, 'bias', None) is not None:
19-
nn.init.constant_(module.bias, 0)
19+
nn.init.constant_(module.bias, 0) # type: ignore
2020
if isinstance(module, (nn.Conv2d, nn.Linear)):
2121
nn.init.kaiming_normal_(module.weight)
2222
for layer in module.children():
@@ -32,7 +32,7 @@ def __init__(
3232
in_channels: int,
3333
mid_channels: int,
3434
stride: int = 1,
35-
conv_layer: Union[nn.Module, nn.Sequential] = ConvBnAct,
35+
conv_layer=ConvBnAct,
3636
act_fn: nn.Module = act_fn,
3737
zero_bn: bool = True,
3838
bn_1st: bool = True,
@@ -49,7 +49,7 @@ def __init__(
4949
if div_groups is not None: # check if groups != 1 and div_groups
5050
groups = int(mid_channels / div_groups)
5151
if expansion == 1:
52-
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride,
52+
layers = [("conv_0", conv_layer(in_channels, mid_channels, 3, stride=stride, # type: ignore
5353
act_fn=act_fn, bn_1st=bn_1st, groups=in_channels if dw else groups)),
5454
("conv_1", conv_layer(mid_channels, out_channels, 3, zero_bn=zero_bn,
5555
act_fn=False, bn_1st=bn_1st, groups=mid_channels if dw else groups))
@@ -99,7 +99,8 @@ def _make_stem(self):
9999

100100
def _make_layer(self, layer_num: int) -> nn.Module:
101101
# expansion, in_channels, out_channels, blocks, stride, sa):
102-
stride = 1 if self.stem_pool and layer_num == 0 else 2 # if no pool on stem - stride = 2 for first layer block in body
102+
# if no pool on stem - stride = 2 for first layer block in body
103+
stride = 1 if self.stem_pool and layer_num == 0 else 2
103104
num_blocks = self.layers[layer_num]
104105
return nn.Sequential(OrderedDict([
105106
(f"bl_{block_num}", self.block(
@@ -144,22 +145,22 @@ def __init__(
144145
conv_layer=ConvBnAct,
145146
block_sizes: List[int] = [64, 128, 256, 512],
146147
layers: List[int] = [2, 2, 2, 2],
147-
norm: nn.Module = nn.BatchNorm2d,
148+
norm: Type[nn.Module] = nn.BatchNorm2d,
148149
act_fn: nn.Module = nn.ReLU(inplace=True),
149150
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
150151
expansion: int = 1,
151152
groups: int = 1,
152153
dw: bool = False,
153-
div_groups: Union[int, None]=None,
154-
sa: Union[bool, int, Callable] = False,
155-
se: Union[bool, int, Callable] = False,
154+
div_groups: Union[int, None] = None,
155+
sa: Union[bool, int, Type[nn.Module]] = False,
156+
se: Union[bool, int, Type[nn.Module]] = False,
156157
se_module=None,
157158
se_reduction=None,
158159
bn_1st: bool = True,
159160
zero_bn: bool = True,
160161
stem_stride_on: int = 0,
161162
stem_sizes: List[int] = [32, 32, 64],
162-
stem_pool: Union[nn.Module, None] =nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
163+
stem_pool: Union[Type[nn.Module], None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # type: ignore
163164
stem_bn_end: bool = False,
164165
_init_cnn: Callable = init_cnn,
165166
_make_stem: Callable = _make_stem,
@@ -172,24 +173,54 @@ def __init__(
172173
# se_module - deprecated. Leaved for warning and checks.
173174
# if stem_pool is False - no pool at stem
174175

175-
params = locals()
176-
del params['self']
177-
self.__dict__ = params
178-
179-
self._block_sizes = params['block_sizes']
176+
self.name = name
177+
self.in_chans = in_chans
178+
self.num_classes = num_classes
179+
self.block = block
180+
self.conv_layer = conv_layer
181+
self._block_sizes = block_sizes
182+
self.layers = layers
183+
self.norm = norm
184+
self.act_fn = act_fn
185+
self.pool = pool
186+
self.expansion = expansion
187+
self.groups = groups
188+
self.dw = dw
189+
self.div_groups = div_groups
190+
# se_module
191+
# se_reduction
192+
self.bn_1st = bn_1st
193+
self.zero_bn = zero_bn
194+
self.stem_stride_on = stem_stride_on
195+
self.stem_pool = stem_pool
196+
self.stem_bn_end = stem_bn_end
197+
self._init_cnn = _init_cnn
198+
self._make_stem = _make_stem
199+
self._make_layer = _make_layer
200+
self._make_body = _make_body
201+
self._make_head = _make_head
202+
203+
# params = locals()
204+
# del params['self']
205+
# self.__dict__ = params
206+
207+
# self._block_sizes = params['block_sizes']
208+
self.stem_sizes = stem_sizes
180209
if self.stem_sizes[0] != self.in_chans:
181210
self.stem_sizes = [self.in_chans] + self.stem_sizes
211+
self.se = se
182212
if self.se:
183213
if type(self.se) in (bool, int): # if se=1 or se=True
184214
self.se = SEModule
185215
else:
186216
self.se = se # TODO add check issubclass or isinstance of nn.Module
217+
self.sa = sa
187218
if self.sa: # if sa=1 or sa=True
188219
if type(self.sa) in (bool, int):
189220
self.sa = SimpleSelfAttention # default: ks=1, sym=sym
190221
else:
191222
self.sa = sa
192-
if self.se_module or se_reduction: # pragma: no cover
223+
if se_module or se_reduction: # pragma: no cover
193224
print("Deprecated. Pass se_module as se argument, se_reduction as arg to se.") # add deprecation warning.
194225

195226
@property

0 commit comments

Comments
 (0)