Skip to content

Commit c077b8d

Browse files
committed
div_groups
1 parent eac691b commit c077b8d

File tree

3 files changed

+92
-25
lines changed

3 files changed

+92
-25
lines changed

Nbs/04_YaResNet.ipynb

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 5,
14+
"execution_count": 1,
1515
"source": [
1616
"#hide\n",
1717
"# from nbdev.showdoc import *\n",
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 1,
39+
"execution_count": 2,
4040
"source": [
4141
"#hide\n",
4242
"from model_constructor.yaresnet import YaResBlock"
@@ -46,7 +46,7 @@
4646
},
4747
{
4848
"cell_type": "code",
49-
"execution_count": 2,
49+
"execution_count": 3,
5050
"source": [
5151
"#collapse_output\n",
5252
"bl = YaResBlock(1,64,64,sa=True)\n",
@@ -77,14 +77,14 @@
7777
]
7878
},
7979
"metadata": {},
80-
"execution_count": 2
80+
"execution_count": 3
8181
}
8282
],
8383
"metadata": {}
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 6,
87+
"execution_count": 4,
8888
"source": [
8989
"#hide\n",
9090
"bs_test = 16\n",
@@ -106,7 +106,7 @@
106106
},
107107
{
108108
"cell_type": "code",
109-
"execution_count": 7,
109+
"execution_count": 5,
110110
"source": [
111111
"#collapse_output\n",
112112
"bl = YaResBlock(1,64,64,se=True)\n",
@@ -131,9 +131,9 @@
131131
" (se): SEBlock(\n",
132132
" (squeeze): AdaptiveAvgPool2d(output_size=1)\n",
133133
" (excitation): Sequential(\n",
134-
" (fc_reduce): Linear(in_features=64, out_features=4, bias=False)\n",
134+
" (fc_reduce): Linear(in_features=64, out_features=4, bias=True)\n",
135135
" (se_act): ReLU(inplace=True)\n",
136-
" (fc_expand): Linear(in_features=4, out_features=64, bias=False)\n",
136+
" (fc_expand): Linear(in_features=4, out_features=64, bias=True)\n",
137137
" (sigmoid): Sigmoid()\n",
138138
" )\n",
139139
" )\n",
@@ -143,14 +143,14 @@
143143
]
144144
},
145145
"metadata": {},
146-
"execution_count": 7
146+
"execution_count": 5
147147
}
148148
],
149149
"metadata": {}
150150
},
151151
{
152152
"cell_type": "code",
153-
"execution_count": 8,
153+
"execution_count": 6,
154154
"source": [
155155
"#hide\n",
156156
"bs_test = 16\n",
@@ -172,7 +172,7 @@
172172
},
173173
{
174174
"cell_type": "code",
175-
"execution_count": 9,
175+
"execution_count": 7,
176176
"source": [
177177
"#collapse_output\n",
178178
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n",
@@ -210,6 +210,73 @@
210210
]
211211
},
212212
"metadata": {},
213+
"execution_count": 7
214+
}
215+
],
216+
"metadata": {}
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": 8,
221+
"source": [
222+
"#hide\n",
223+
"bs_test = 16\n",
224+
"xb = torch.randn(bs_test, 256, 32, 32)\n",
225+
"y = bl(xb)\n",
226+
"print(y.shape)\n",
227+
"assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\""
228+
],
229+
"outputs": [
230+
{
231+
"output_type": "stream",
232+
"name": "stdout",
233+
"text": [
234+
"torch.Size([16, 512, 16, 16])\n"
235+
]
236+
}
237+
],
238+
"metadata": {}
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": 9,
243+
"source": [
244+
"#collapse_output\n",
245+
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n",
246+
"bl"
247+
],
248+
"outputs": [
249+
{
250+
"output_type": "execute_result",
251+
"data": {
252+
"text/plain": [
253+
"YaResBlock(\n",
254+
" (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
255+
" (convs): Sequential(\n",
256+
" (conv_0): ConvLayer(\n",
257+
" (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
258+
" (act_fn): LeakyReLU(negative_slope=0.01)\n",
259+
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
260+
" )\n",
261+
" (conv_1): ConvLayer(\n",
262+
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
263+
" (act_fn): LeakyReLU(negative_slope=0.01)\n",
264+
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
265+
" )\n",
266+
" (conv_2): ConvLayer(\n",
267+
" (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
268+
" (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
269+
" )\n",
270+
" )\n",
271+
" (idconv): ConvLayer(\n",
272+
" (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
273+
" (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
274+
" )\n",
275+
" (merge): LeakyReLU(negative_slope=0.01)\n",
276+
")"
277+
]
278+
},
279+
"metadata": {},
213280
"execution_count": 9
214281
}
215282
],
@@ -242,7 +309,7 @@
242309
"execution_count": 11,
243310
"source": [
244311
"#collapse_output\n",
245-
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n",
312+
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, div_groups=4)\n",
246313
"bl"
247314
],
248315
"outputs": [
@@ -259,7 +326,7 @@
259326
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
260327
" )\n",
261328
" (conv_1): ConvLayer(\n",
262-
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
329+
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n",
263330
" (act_fn): LeakyReLU(negative_slope=0.01)\n",
264331
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
265332
" )\n",

model_constructor/net.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ class ResBlock(nn.Module):
2929
def __init__(self, expansion, ni, nh, stride=1,
3030
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
3131
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False, se_reduction=16,
32-
groups=1, dw=False):
32+
groups=1, dw=False, div_groups=None):
3333
super().__init__()
3434
nf, ni = nh * expansion, ni * expansion
35-
# if groups != 1:
36-
# groups = int(nh / groups)
35+
if div_groups is not None: # check if grops != 1 and div_groups
36+
groups = int(nh / div_groups)
3737
if expansion == 1:
3838
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
3939
groups=nh if dw else groups)),
@@ -68,11 +68,11 @@ class NewResBlock(nn.Module):
6868
def __init__(self, expansion, ni, nh, stride=1,
6969
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
7070
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False, se_reduction=16,
71-
groups=1, dw=False):
71+
groups=1, dw=False, div_groups=None):
7272
super().__init__()
7373
nf, ni = nh * expansion, ni * expansion
74-
# if groups != 1:
75-
# groups = int(nh / groups)
74+
if div_groups is not None: # check if grops != 1 and div_groups
75+
groups = int(nh / div_groups)
7676
self.reduce = noop if stride == 1 else pool
7777
if expansion == 1:
7878
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
@@ -114,7 +114,7 @@ def _make_layer(self, expansion, ni, nf, blocks, stride, sa):
114114
layers = [(f"bl_{i}", self.block(expansion, ni if i == 0 else nf, nf,
115115
stride if i == 0 else 1, sa=sa if i == blocks - 1 else False,
116116
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
117-
zero_bn=self.zero_bn, bn_1st=self.bn_1st, groups=self.groups,
117+
zero_bn=self.zero_bn, bn_1st=self.bn_1st, groups=self.groups, div_groups=self.div_groups,
118118
dw=self.dw, se=self.se))
119119
for i in range(blocks)]
120120
return nn.Sequential(OrderedDict(layers))
@@ -144,7 +144,7 @@ def __init__(self, name='Net', c_in=3, c_out=1000,
144144
norm=nn.BatchNorm2d,
145145
act_fn=nn.ReLU(inplace=True),
146146
pool=nn.AvgPool2d(2, ceil_mode=True),
147-
expansion=1, groups=1, dw=False,
147+
expansion=1, groups=1, dw=False, div_groups=None,
148148
sa=False, se=False, se_reduction=16,
149149
bn_1st=True,
150150
zero_bn=True,
@@ -195,7 +195,7 @@ def __call__(self):
195195
def __repr__(self):
196196
return (f"{self.name} constructor\n"
197197
f" c_in: {self.c_in}, c_out: {self.c_out}\n"
198-
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}\n"
198+
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}\n, div_groups: {self.div_groups}"
199199
f" sa: {self.sa}, se: {self.se}\n"
200200
f" stem sizes: {self.stem_sizes}, stide on {self.stem_stride_on}\n"
201201
f" body sizes {self._block_sizes}\n"

model_constructor/yaresnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ class YaResBlock(nn.Module):
1919
def __init__(self, expansion, ni, nh, stride=1,
2020
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
2121
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False,
22-
groups=1, dw=False):
22+
groups=1, dw=False, div_groups=None):
2323
super().__init__()
2424
nf, ni = nh * expansion, ni * expansion
25-
# if groups != 1:
26-
# groups = int(nh / groups)
25+
if div_groups is not None: # check if grops != 1 and div_groups
26+
groups = int(nh / div_groups)
2727
self.reduce = noop if stride == 1 else pool
2828
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,
2929
groups=nh if dw else groups)),

0 commit comments

Comments
 (0)