Skip to content

Commit 4efee64

Browse files
committed
move blocks to module
1 parent d620b81 commit 4efee64

File tree

6 files changed

+200
-191
lines changed

6 files changed

+200
-191
lines changed

src/model_constructor/blocks.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from typing import Callable, Union
2+
3+
import torch
4+
from torch import nn
5+
6+
from .helpers import ListStrMod, nn_seq
7+
from .layers import ConvBnAct, get_act
8+
9+
10+
class BasicBlock(nn.Module):
11+
"""Basic Resnet block.
12+
Configurable - can use pool to reduce at identity path, change act etc."""
13+
14+
def __init__(
15+
self,
16+
in_channels: int,
17+
out_channels: int,
18+
stride: int = 1,
19+
conv_layer: type[ConvBnAct] = ConvBnAct,
20+
act_fn: type[nn.Module] = nn.ReLU,
21+
zero_bn: bool = True,
22+
bn_1st: bool = True,
23+
groups: int = 1,
24+
dw: bool = False,
25+
div_groups: Union[None, int] = None,
26+
pool: Union[Callable[[], nn.Module], None] = None,
27+
se: Union[nn.Module, None] = None,
28+
sa: Union[nn.Module, None] = None,
29+
):
30+
super().__init__()
31+
# pool defined at ModelConstructor.
32+
if div_groups is not None: # check if groups != 1 and div_groups
33+
groups = int(out_channels / div_groups)
34+
layers: ListStrMod = [
35+
(
36+
"conv_0",
37+
conv_layer(
38+
in_channels,
39+
out_channels,
40+
3,
41+
stride=stride,
42+
act_fn=act_fn,
43+
bn_1st=bn_1st,
44+
groups=in_channels if dw else groups,
45+
),
46+
),
47+
(
48+
"conv_1",
49+
conv_layer(
50+
out_channels,
51+
out_channels,
52+
3,
53+
zero_bn=zero_bn,
54+
act_fn=False,
55+
bn_1st=bn_1st,
56+
groups=out_channels if dw else groups,
57+
),
58+
),
59+
]
60+
if se:
61+
layers.append(("se", se(out_channels)))
62+
if sa:
63+
layers.append(("sa", sa(out_channels)))
64+
self.convs = nn_seq(layers)
65+
if stride != 1 or in_channels != out_channels:
66+
id_layers: ListStrMod = []
67+
if (
68+
stride != 1 and pool is not None
69+
): # if pool - reduce by pool else stride 2 art id_conv
70+
id_layers.append(("pool", pool()))
71+
if in_channels != out_channels or (stride != 1 and pool is None):
72+
id_layers.append(
73+
(
74+
"id_conv",
75+
conv_layer(
76+
in_channels,
77+
out_channels,
78+
1,
79+
stride=1 if pool else stride,
80+
act_fn=False,
81+
),
82+
)
83+
)
84+
self.id_conv = nn_seq(id_layers)
85+
else:
86+
self.id_conv = None
87+
self.act_fn = get_act(act_fn)
88+
89+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
90+
identity = self.id_conv(x) if self.id_conv is not None else x
91+
return self.act_fn(self.convs(x) + identity)
92+
93+
94+
class BottleneckBlock(nn.Module):
95+
"""Bottleneck Resnet block.
96+
Configurable - can use pool to reduce at identity path, change act etc."""
97+
98+
def __init__(
99+
self,
100+
in_channels: int,
101+
out_channels: int,
102+
stride: int = 1,
103+
expansion: int = 4,
104+
conv_layer: type[ConvBnAct] = ConvBnAct,
105+
act_fn: type[nn.Module] = nn.ReLU,
106+
zero_bn: bool = True,
107+
bn_1st: bool = True,
108+
groups: int = 1,
109+
dw: bool = False,
110+
div_groups: Union[None, int] = None,
111+
pool: Union[Callable[[], nn.Module], None] = None,
112+
se: Union[nn.Module, None] = None,
113+
sa: Union[nn.Module, None] = None,
114+
):
115+
super().__init__()
116+
# pool defined at ModelConstructor.
117+
mid_channels = out_channels // expansion
118+
if div_groups is not None: # check if groups != 1 and div_groups
119+
groups = int(mid_channels / div_groups)
120+
layers: ListStrMod = [
121+
(
122+
"conv_0",
123+
conv_layer(
124+
in_channels,
125+
mid_channels,
126+
1,
127+
act_fn=act_fn,
128+
bn_1st=bn_1st,
129+
),
130+
),
131+
(
132+
"conv_1",
133+
conv_layer(
134+
mid_channels,
135+
mid_channels,
136+
3,
137+
stride=stride,
138+
act_fn=act_fn,
139+
bn_1st=bn_1st,
140+
groups=mid_channels if dw else groups,
141+
),
142+
),
143+
(
144+
"conv_2",
145+
conv_layer(
146+
mid_channels,
147+
out_channels,
148+
1,
149+
zero_bn=zero_bn,
150+
act_fn=False,
151+
bn_1st=bn_1st,
152+
),
153+
),
154+
]
155+
if se:
156+
layers.append(("se", se(out_channels)))
157+
if sa:
158+
layers.append(("sa", sa(out_channels)))
159+
self.convs = nn_seq(layers)
160+
if stride != 1 or in_channels != out_channels:
161+
id_layers: ListStrMod = []
162+
if (
163+
stride != 1 and pool is not None
164+
): # if pool - reduce by pool else stride 2 art id_conv
165+
id_layers.append(("pool", pool()))
166+
if in_channels != out_channels or (stride != 1 and pool is None):
167+
id_layers.append(
168+
(
169+
"id_conv",
170+
conv_layer(
171+
in_channels,
172+
out_channels,
173+
1,
174+
stride=1 if pool else stride,
175+
act_fn=False,
176+
),
177+
)
178+
)
179+
self.id_conv = nn_seq(id_layers)
180+
else:
181+
self.id_conv = None
182+
self.act_fn = get_act(act_fn)
183+
184+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
185+
identity = self.id_conv(x) if self.id_conv is not None else x
186+
return self.act_fn(self.convs(x) + identity)

src/model_constructor/helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from torch import nn
77

88

9+
ListStrMod = list[tuple[str, nn.Module]]
10+
11+
912
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
1013
"""return nn.Sequential from OrderedDict from list of tuples"""
1114
return nn.Sequential(OrderedDict(list_of_tuples)) #

0 commit comments

Comments
 (0)