|
28 | 28 | "source": [ |
29 | 29 | "#hide\n", |
30 | 30 | "import torch\n", |
31 | | - "import torch.nn as nn\n", |
32 | | - "\n", |
33 | | - "from nbdev.showdoc import show_doc\n", |
34 | | - "from IPython.display import Markdown, display" |
35 | | - ] |
36 | | - }, |
37 | | - { |
38 | | - "cell_type": "code", |
39 | | - "execution_count": null, |
40 | | - "metadata": {}, |
41 | | - "outputs": [], |
42 | | - "source": [ |
43 | | - "# hide\n", |
44 | | - "def print_doc(func_name):\n", |
45 | | - " doc = show_doc(func_name, title_level=4, disp=False)\n", |
46 | | - " display(Markdown(doc))" |
| 31 | + "import torch.nn as nn" |
47 | 32 | ] |
48 | 33 | }, |
49 | 34 | { |
|
59 | 44 | "metadata": {}, |
60 | 45 | "outputs": [], |
61 | 46 | "source": [ |
62 | | - "#hide\n", |
63 | 47 | "from model_constructor.yaresnet import YaResBlock" |
64 | 48 | ] |
65 | 49 | }, |
66 | | - { |
67 | | - "cell_type": "code", |
68 | | - "execution_count": null, |
69 | | - "metadata": {}, |
70 | | - "outputs": [], |
71 | | - "source": [ |
72 | | - "#hide_input\n", |
73 | | - "# print_doc(YaResBlock)" |
74 | | - ] |
75 | | - }, |
76 | 50 | { |
77 | 51 | "cell_type": "code", |
78 | 52 | "execution_count": null, |
|
341 | 315 | " (se): SEModule(\n", |
342 | 316 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", |
343 | 317 | " (excitation): Sequential(\n", |
344 | | - " (fc_reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
| 318 | + " (reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
345 | 319 | " (se_act): ReLU(inplace=True)\n", |
346 | | - " (fc_expand): Linear(in_features=32, out_features=512, bias=True)\n", |
| 320 | + " (expand): Linear(in_features=32, out_features=512, bias=True)\n", |
347 | 321 | " (se_gate): Sigmoid()\n", |
348 | 322 | " )\n", |
349 | 323 | " )\n", |
|
443 | 417 | " (se): SEModule(\n", |
444 | 418 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", |
445 | 419 | " (excitation): Sequential(\n", |
446 | | - " (fc_reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
| 420 | + " (reduce): Linear(in_features=512, out_features=32, bias=True)\n", |
447 | 421 | " (se_act): ReLU(inplace=True)\n", |
448 | | - " (fc_expand): Linear(in_features=32, out_features=512, bias=True)\n", |
| 422 | + " (expand): Linear(in_features=32, out_features=512, bias=True)\n", |
449 | 423 | " (se_gate): Sigmoid()\n", |
450 | 424 | " )\n", |
451 | 425 | " )\n", |
|
468 | 442 | ], |
469 | 443 | "source": [ |
470 | 444 | "#collapse_output\n", |
471 | | - "bl = YaResBlock(4, 64, 128, stride=2, pool=pool, act_fn=nn.LeakyReLU(), dw=True,\n", |
472 | | - " se=SEModule, sa=SimpleSelfAttention)\n", |
| 445 | + "bl = YaResBlock(\n", |
| 446 | + " 4, 64, 128,\n", |
| 447 | + " stride=2,\n", |
| 448 | + " pool=pool,\n", |
| 449 | + " act_fn=nn.LeakyReLU(),\n", |
| 450 | + " dw=True,\n", |
| 451 | + " se=SEModule,\n", |
| 452 | + " sa=SimpleSelfAttention)\n", |
473 | 453 | "bl" |
474 | 454 | ] |
475 | 455 | }, |
|
528 | 508 | { |
529 | 509 | "data": { |
530 | 510 | "text/plain": [ |
531 | | - "([64, 64, 128, 256, 512], [2, 2, 2, 2])" |
| 511 | + "([64, 128, 256, 512], [2, 2, 2, 2])" |
532 | 512 | ] |
533 | 513 | }, |
534 | 514 | "execution_count": null, |
|
982 | 962 | " (se): SEModule(\n", |
983 | 963 | " (squeeze): AdaptiveAvgPool2d(output_size=1)\n", |
984 | 964 | " (excitation): Sequential(\n", |
985 | | - " (fc_reduce): Linear(in_features=64, out_features=4, bias=True)\n", |
| 965 | + " (reduce): Linear(in_features=64, out_features=4, bias=True)\n", |
986 | 966 | " (se_act): ReLU(inplace=True)\n", |
987 | | - " (fc_expand): Linear(in_features=4, out_features=64, bias=True)\n", |
| 967 | + " (expand): Linear(in_features=4, out_features=64, bias=True)\n", |
988 | 968 | " (se_gate): Sigmoid()\n", |
989 | 969 | " )\n", |
990 | 970 | " )\n", |
|
1003 | 983 | "yaresnet.body.l_0.bl_0" |
1004 | 984 | ] |
1005 | 985 | }, |
| 986 | + { |
| 987 | + "cell_type": "markdown", |
| 988 | + "metadata": {}, |
| 989 | + "source": [ |
| 990 | + "# YaResnet34, YaResnet50" |
| 991 | + ] |
| 992 | + }, |
| 993 | + { |
| 994 | + "cell_type": "code", |
| 995 | + "execution_count": null, |
| 996 | + "metadata": {}, |
| 997 | + "outputs": [], |
| 998 | + "source": [ |
| 999 | + "from model_constructor.yaresnet import YaResNet34, YaResNet50" |
| 1000 | + ] |
| 1001 | + }, |
| 1002 | + { |
| 1003 | + "cell_type": "code", |
| 1004 | + "execution_count": null, |
| 1005 | + "metadata": {}, |
| 1006 | + "outputs": [ |
| 1007 | + { |
| 1008 | + "data": { |
| 1009 | + "text/plain": [ |
| 1010 | + "YaResnet34 constructor\n", |
| 1011 | + " in_chans: 3, num_classes: 1000\n", |
| 1012 | + " expansion: 1, groups: 1, dw: False, div_groups: None\n", |
| 1013 | + " sa: False, se: False\n", |
| 1014 | + " stem sizes: [3, 32, 32, 64], stride on 0\n", |
| 1015 | + " body sizes [64, 128, 256, 512]\n", |
| 1016 | + " layers: [3, 4, 6, 3]" |
| 1017 | + ] |
| 1018 | + }, |
| 1019 | + "execution_count": null, |
| 1020 | + "metadata": {}, |
| 1021 | + "output_type": "execute_result" |
| 1022 | + } |
| 1023 | + ], |
| 1024 | + "source": [ |
| 1025 | + "yaresnet34 = YaResNet34()\n", |
| 1026 | + "yaresnet34" |
| 1027 | + ] |
| 1028 | + }, |
| 1029 | + { |
| 1030 | + "cell_type": "code", |
| 1031 | + "execution_count": null, |
| 1032 | + "metadata": {}, |
| 1033 | + "outputs": [ |
| 1034 | + { |
| 1035 | + "data": { |
| 1036 | + "text/plain": [ |
| 1037 | + "YaResnet50 constructor\n", |
| 1038 | + " in_chans: 3, num_classes: 1000\n", |
| 1039 | + " expansion: 4, groups: 1, dw: False, div_groups: None\n", |
| 1040 | + " sa: False, se: False\n", |
| 1041 | + " stem sizes: [3, 32, 32, 64], stride on 0\n", |
| 1042 | + " body sizes [64, 128, 256, 512]\n", |
| 1043 | + " layers: [3, 4, 6, 3]" |
| 1044 | + ] |
| 1045 | + }, |
| 1046 | + "execution_count": null, |
| 1047 | + "metadata": {}, |
| 1048 | + "output_type": "execute_result" |
| 1049 | + } |
| 1050 | + ], |
| 1051 | + "source": [ |
| 1052 | + "yaresnet50 = YaResNet50()\n", |
| 1053 | + "yaresnet50" |
| 1054 | + ] |
| 1055 | + }, |
1006 | 1056 | { |
1007 | 1057 | "cell_type": "markdown", |
1008 | 1058 | "metadata": {}, |
|
0 commit comments