|
428 | 428 | "# export\n", |
429 | 429 | "# v8\n", |
430 | 430 | "class Net():\n", |
431 | | - " def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net'):\n", |
| 431 | + " def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net',\n", |
| 432 | + " act_fn=nn.ReLU(inplace=True), pool = nn.AvgPool2d(2, ceil_mode=True), sa=0):\n", |
432 | 433 | " super().__init__()\n", |
433 | 434 | " self.name = name\n", |
434 | 435 | " self.c_in, self.c_out,self.expansion,self.layers = c_in,c_out,expansion,layers # todo setter for expansion\n", |
| 436 | + " self.act_fn, self.pool, self.sa = act_fn, pool, sa\n", |
| 437 | + " \n", |
| 438 | + " \n", |
435 | 439 | " self.stem_sizes = [c_in,32,32,64]\n", |
436 | 440 | " self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", |
437 | 441 | " self.stem_bn_end = False\n", |
438 | 442 | " self.block = ResBlock\n", |
439 | 443 | " self.norm = nn.BatchNorm2d\n", |
440 | | - " self.act_fn=nn.ReLU(inplace=True)\n", |
441 | | - " self.pool = nn.AvgPool2d(2, ceil_mode=True)\n", |
442 | | - " self.sa=False\n", |
443 | 444 | " self.bn_1st = True\n", |
444 | 445 | " self.zero_bn=True\n", |
445 | 446 | " self.conv_layer = ConvLayer\n", |
|
467 | 468 | " def body(self):\n", |
468 | 469 | " return self._make_body(self)\n", |
469 | 470 | " \n", |
470 | | - "# def _make_stem(self):\n", |
471 | | - "# stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n", |
472 | | - "# stride=2 if i==0 else 1, \n", |
473 | | - "# bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n", |
474 | | - "# act_fn=self.act_fn, bn_1st=self.bn_1st))\n", |
475 | | - "# for i in range(len(self.stem_sizes)-1)]\n", |
476 | | - "# stem.append(('stem_pool', self.stem_pool))\n", |
477 | | - "# if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n", |
478 | | - "# return nn.Sequential(OrderedDict(stem))\n", |
479 | | - " \n", |
480 | | - "# def _make_head(self):\n", |
481 | | - "# head = [('pool', nn.AdaptiveAvgPool2d(1)),\n", |
482 | | - "# ('flat', Flatten()),\n", |
483 | | - "# ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n", |
484 | | - "# return nn.Sequential(OrderedDict(head))\n", |
485 | | - " \n", |
486 | | - "# def _make_body(self):\n", |
487 | | - "# blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n", |
488 | | - "# self.block_szs[i], self.block_szs[i+1], l, \n", |
489 | | - "# 1 if i==0 else 2, self.sa if i==0 else False))\n", |
490 | | - "# for i,l in enumerate(self.layers)]\n", |
491 | | - "# return nn.Sequential(OrderedDict(blocks))\n", |
492 | | - " \n", |
493 | | - "# def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n", |
494 | | - "# return nn.Sequential(OrderedDict(\n", |
495 | | - "# [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n", |
496 | | - "# stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n", |
497 | | - "# conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n", |
498 | | - "# zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n", |
499 | | - "# for i in range(blocks)]))\n", |
500 | | - " \n", |
501 | 471 | " def __call__(self):\n", |
502 | 472 | " model = nn.Sequential(OrderedDict([\n", |
503 | 473 | " ('stem', self.stem),\n", |
|
1266 | 1236 | "outputs": [], |
1267 | 1237 | "source": [ |
1268 | 1238 | "# export\n", |
1269 | | - "me = sys.modules[__name__]\n", |
1270 | | - "for n,e,l in [[ 18 , 1, [2,2,2 ,2] ],\n", |
1271 | | - " [ 34 , 1, [3,4,6 ,3] ],\n", |
1272 | | - " [ 50 , 4, [3,4,6 ,3] ],\n", |
1273 | | - " [ 101, 4, [3,4,23,3] ],\n", |
1274 | | - " [ 152, 4, [3,8,36,3] ],]:\n", |
1275 | | - " name = f'net{n}'\n", |
1276 | | - " setattr(me, name, partial(Net, expansion=e, layers=l, name=name))\n", |
1277 | | - "xresnet34 = partial(Net, expansion=1, layers=[3, 4, 6, 3], name='xresnet34')\n", |
1278 | | - "xresnet50 = partial(Net, expansion=4, layers=[3, 4, 6, 3], name='xresnet50')" |
| 1239 | + "# me = sys.modules[__name__]\n", |
| 1240 | + "# for n,e,l in [[ 18 , 1, [2,2,2 ,2] ],\n", |
| 1241 | + "# [ 34 , 1, [3,4,6 ,3] ],\n", |
| 1242 | + "# [ 50 , 4, [3,4,6 ,3] ],\n", |
| 1243 | + "# [ 101, 4, [3,4,23,3] ],\n", |
| 1244 | + "# [ 152, 4, [3,8,36,3] ],]:\n", |
| 1245 | + "# name = f'net{n}'\n", |
| 1246 | + "# setattr(me, name, partial(Net, expansion=e, layers=l, name=name))\n", |
| 1247 | + "net34 = partial(Net, expansion=1, layers=[3, 4, 6, 3], name='xresnet34')\n", |
| 1248 | + "net50 = partial(Net, expansion=4, layers=[3, 4, 6, 3], name='xresnet50')" |
1279 | 1249 | ] |
1280 | 1250 | }, |
1281 | 1251 | { |
|
1284 | 1254 | "metadata": {}, |
1285 | 1255 | "outputs": [], |
1286 | 1256 | "source": [ |
1287 | | - "m = xresnet50(c_out=10)" |
| 1257 | + "m = net50(c_out=10)" |
1288 | 1258 | ] |
1289 | 1259 | }, |
1290 | 1260 | { |
|
0 commit comments