|
11 | 11 | }, |
12 | 12 | { |
13 | 13 | "cell_type": "code", |
14 | | - "execution_count": 5, |
| 14 | + "execution_count": 1, |
15 | 15 | "source": [ |
16 | 16 | "#hide\n", |
17 | 17 | "# from nbdev.showdoc import *\n", |
|
36 | 36 | }, |
37 | 37 | { |
38 | 38 | "cell_type": "code", |
39 | | - "execution_count": 1, |
| 39 | + "execution_count": 2, |
40 | 40 | "source": [ |
41 | 41 | "#hide\n", |
42 | 42 | "from model_constructor.yaresnet import YaResBlock" |
|
46 | 46 | }, |
47 | 47 | { |
48 | 48 | "cell_type": "code", |
49 | | - "execution_count": 2, |
| 49 | + "execution_count": 3, |
50 | 50 | "source": [ |
51 | 51 | "#collapse_output\n", |
52 | 52 | "bl = YaResBlock(1,64,64,sa=True)\n", |
|
77 | 77 | ] |
78 | 78 | }, |
79 | 79 | "metadata": {}, |
80 | | - "execution_count": 2 |
| 80 | + "execution_count": 3 |
81 | 81 | } |
82 | 82 | ], |
83 | 83 | "metadata": {} |
84 | 84 | }, |
85 | 85 | { |
86 | 86 | "cell_type": "code", |
87 | | - "execution_count": 6, |
| 87 | + "execution_count": 4, |
88 | 88 | "source": [ |
89 | 89 | "#hide\n", |
90 | 90 | "bs_test = 16\n", |
|
106 | 106 | }, |
107 | 107 | { |
108 | 108 | "cell_type": "code", |
109 | | - "execution_count": 7, |
| 109 | + "execution_count": 5, |
110 | 110 | "source": [ |
111 | 111 | "#collapse_output\n", |
112 | 112 | "bl = YaResBlock(1,64,64,se=True)\n", |
|
131 | 131 | " (se): SEBlock(\n", |
132 | 132 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", |
133 | 133 | " (excitation): Sequential(\n", |
134 | | - " (fc_reduce): Linear(in_features=64, out_features=4, bias=False)\n", |
| 134 | + " (fc_reduce): Linear(in_features=64, out_features=4, bias=True)\n", |
135 | 135 | " (se_act): ReLU(inplace=True)\n", |
136 | | - " (fc_expand): Linear(in_features=4, out_features=64, bias=False)\n", |
| 136 | + " (fc_expand): Linear(in_features=4, out_features=64, bias=True)\n", |
137 | 137 | " (sigmoid): Sigmoid()\n", |
138 | 138 | " )\n", |
139 | 139 | " )\n", |
|
143 | 143 | ] |
144 | 144 | }, |
145 | 145 | "metadata": {}, |
146 | | - "execution_count": 7 |
| 146 | + "execution_count": 5 |
147 | 147 | } |
148 | 148 | ], |
149 | 149 | "metadata": {} |
150 | 150 | }, |
151 | 151 | { |
152 | 152 | "cell_type": "code", |
153 | | - "execution_count": 8, |
| 153 | + "execution_count": 6, |
154 | 154 | "source": [ |
155 | 155 | "#hide\n", |
156 | 156 | "bs_test = 16\n", |
|
172 | 172 | }, |
173 | 173 | { |
174 | 174 | "cell_type": "code", |
175 | | - "execution_count": 9, |
| 175 | + "execution_count": 7, |
176 | 176 | "source": [ |
177 | 177 | "#collapse_output\n", |
178 | 178 | "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n", |
|
210 | 210 | ] |
211 | 211 | }, |
212 | 212 | "metadata": {}, |
| 213 | + "execution_count": 7 |
| 214 | + } |
| 215 | + ], |
| 216 | + "metadata": {} |
| 217 | + }, |
| 218 | + { |
| 219 | + "cell_type": "code", |
| 220 | + "execution_count": 8, |
| 221 | + "source": [ |
| 222 | + "#hide\n", |
| 223 | + "bs_test = 16\n", |
| 224 | + "xb = torch.randn(bs_test, 256, 32, 32)\n", |
| 225 | + "y = bl(xb)\n", |
| 226 | + "print(y.shape)\n", |
| 227 | + "assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\"" |
| 228 | + ], |
| 229 | + "outputs": [ |
| 230 | + { |
| 231 | + "output_type": "stream", |
| 232 | + "name": "stdout", |
| 233 | + "text": [ |
| 234 | + "torch.Size([16, 512, 16, 16])\n" |
| 235 | + ] |
| 236 | + } |
| 237 | + ], |
| 238 | + "metadata": {} |
| 239 | + }, |
| 240 | + { |
| 241 | + "cell_type": "code", |
| 242 | + "execution_count": 9, |
| 243 | + "source": [ |
| 244 | + "#collapse_output\n", |
| 245 | + "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n", |
| 246 | + "bl" |
| 247 | + ], |
| 248 | + "outputs": [ |
| 249 | + { |
| 250 | + "output_type": "execute_result", |
| 251 | + "data": { |
| 252 | + "text/plain": [ |
| 253 | + "YaResBlock(\n", |
| 254 | + " (reduce): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", |
| 255 | + " (convs): Sequential(\n", |
| 256 | + " (conv_0): ConvLayer(\n", |
| 257 | + " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
| 258 | + " (act_fn): LeakyReLU(negative_slope=0.01)\n", |
| 259 | + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
| 260 | + " )\n", |
| 261 | + " (conv_1): ConvLayer(\n", |
| 262 | + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n", |
| 263 | + " (act_fn): LeakyReLU(negative_slope=0.01)\n", |
| 264 | + " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
| 265 | + " )\n", |
| 266 | + " (conv_2): ConvLayer(\n", |
| 267 | + " (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
| 268 | + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
| 269 | + " )\n", |
| 270 | + " )\n", |
| 271 | + " (idconv): ConvLayer(\n", |
| 272 | + " (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", |
| 273 | + " (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
| 274 | + " )\n", |
| 275 | + " (merge): LeakyReLU(negative_slope=0.01)\n", |
| 276 | + ")" |
| 277 | + ] |
| 278 | + }, |
| 279 | + "metadata": {}, |
213 | 280 | "execution_count": 9 |
214 | 281 | } |
215 | 282 | ], |
|
242 | 309 | "execution_count": 11, |
243 | 310 | "source": [ |
244 | 311 | "#collapse_output\n", |
245 | | - "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n", |
| 312 | + "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, div_groups=4)\n", |
246 | 313 | "bl" |
247 | 314 | ], |
248 | 315 | "outputs": [ |
|
259 | 326 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
260 | 327 | " )\n", |
261 | 328 | " (conv_1): ConvLayer(\n", |
262 | | - " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n", |
| 329 | + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n", |
263 | 330 | " (act_fn): LeakyReLU(negative_slope=0.01)\n", |
264 | 331 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
265 | 332 | " )\n", |
|
0 commit comments