Skip to content

Commit d691872

Browse files
committed
typing
1 parent e444c68 commit d691872

File tree

2 files changed

+57
-36
lines changed

2 files changed

+57
-36
lines changed

model_constructor/layers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import List, Optional
2-
import torch.nn as nn
3-
import torch
4-
from torch.nn.utils.spectral_norm import spectral_norm
51
from collections import OrderedDict
2+
from typing import List, Optional, Union
63

4+
import torch
5+
import torch.nn as nn
6+
from torch.nn.utils.spectral_norm import spectral_norm
77

88
__all__ = ['Flatten', 'noop', 'Noop', 'ConvLayer', 'act_fn',
99
'conv1d', 'SimpleSelfAttention', 'SEBlock', 'SEBlockConv']
@@ -49,7 +49,7 @@ def __init__(
4949
padding: Optional[int] = None,
5050
bias: bool = False,
5151
groups: int = 1,
52-
act_fn: Optional[nn.Module] = act_fn,
52+
act_fn: Union[nn.Module, bool] = act_fn,
5353
pre_act: bool = False,
5454
bn_layer: bool = True,
5555
bn_1st: bool = True,

model_constructor/model_constructor.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import OrderedDict
22
from functools import partial
3-
from typing import Callable, Union
3+
from typing import Callable, List, Sequence, Union
44

55
import torch.nn as nn
66

@@ -24,14 +24,25 @@ def init_cnn(module: nn.Module):
2424

2525

2626
class ResBlock(nn.Module):
27-
'''Resnet block'''
28-
29-
def __init__(self, expansion, in_channels, mid_channels, stride=1,
30-
conv_layer=ConvBnAct, act_fn=act_fn, zero_bn=True, bn_1st=True,
31-
groups=1, dw=False, div_groups=None,
32-
pool=None,
33-
se=None, sa=None
34-
):
27+
'''Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck.'''
28+
29+
def __init__(
30+
self,
31+
expansion: int,
32+
in_channels: int,
33+
mid_channels: int,
34+
stride: int = 1,
35+
conv_layer: Union[nn.Module, nn.Sequential] = ConvBnAct,
36+
act_fn: nn.Module = act_fn,
37+
zero_bn: bool = True,
38+
bn_1st: bool = True,
39+
groups: int = 1,
40+
dw: bool = False,
41+
div_groups: Union[None, int] = None,
42+
pool: Union[nn.Module, None] = None,
43+
se: Union[nn.Module, None] = None,
44+
sa: Union[nn.Module, None] = None,
45+
):
3546
super().__init__()
3647
# pool defined at ModelConstructor.
3748
out_channels, in_channels = mid_channels * expansion, in_channels * expansion
@@ -124,28 +135,38 @@ def _make_head(self):
124135

125136
class ModelConstructor():
126137
"""Model constructor. As default - xresnet18"""
127-
def __init__(self, name='MC', in_chans=3, num_classes=1000,
128-
block=ResBlock, conv_layer=ConvBnAct,
129-
block_sizes=[64, 128, 256, 512], layers=[2, 2, 2, 2],
130-
norm=nn.BatchNorm2d,
131-
act_fn=nn.ReLU(inplace=True),
132-
pool=nn.AvgPool2d(2, ceil_mode=True),
133-
expansion=1, groups=1, dw=False, div_groups=None,
134-
sa: Union[bool, int, Callable] = False,
135-
se: Union[bool, int, Callable] = False,
136-
se_module=None, se_reduction=None,
137-
bn_1st=True,
138-
zero_bn=True,
139-
stem_stride_on=0,
140-
stem_sizes=[32, 32, 64],
141-
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
142-
stem_bn_end=False,
143-
_init_cnn=init_cnn,
144-
_make_stem=_make_stem,
145-
_make_layer=_make_layer,
146-
_make_body=_make_body,
147-
_make_head=_make_head,
148-
):
138+
def __init__(
139+
self,
140+
name: str = 'MC',
141+
in_chans: int = 3,
142+
num_classes: int = 1000,
143+
block=ResBlock,
144+
conv_layer=ConvBnAct,
145+
block_sizes: List[int] = [64, 128, 256, 512],
146+
layers: List[int] = [2, 2, 2, 2],
147+
norm: nn.Module = nn.BatchNorm2d,
148+
act_fn: nn.Module = nn.ReLU(inplace=True),
149+
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True),
150+
expansion: int = 1,
151+
groups: int = 1,
152+
dw: bool = False,
153+
div_groups=None,
154+
sa: Union[bool, int, Callable] = False,
155+
se: Union[bool, int, Callable] = False,
156+
se_module=None,
157+
se_reduction=None,
158+
bn_1st=True,
159+
zero_bn=True,
160+
stem_stride_on=0,
161+
stem_sizes=[32, 32, 64],
162+
stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
163+
stem_bn_end=False,
164+
_init_cnn=init_cnn,
165+
_make_stem=_make_stem,
166+
_make_layer=_make_layer,
167+
_make_body=_make_body,
168+
_make_head=_make_head,
169+
):
149170
super().__init__()
150171
# se can be bool, int (0, 1) or nn.Module
151172
# se_module - deprecated. Leaved for warning and checks.

0 commit comments

Comments
 (0)