Skip to content

Commit df26b9b

Browse files
author
ayasyrev
committed
twist - fix permute param
1 parent 1572ccf commit df26b9b

File tree

9 files changed

+14
-26
lines changed

9 files changed

+14
-26
lines changed

docs/Twist.html

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ <h1 id="ConvLayerTwist">ConvLayerTwist<a class="anchor-link" href="#ConvLayerTwi
234234

235235

236236
<div class="output_markdown rendered_html output_subarea ">
237-
<h2 id="ConvLayerTwist" class="doc_header"><code>class</code> <code>ConvLayerTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L105" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ConvLayerTwist</code>(<strong><code>ni</code></strong>, <strong><code>nf</code></strong>, <strong><code>ks</code></strong>=<em><code>3</code></em>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>act</code></strong>=<em><code>True</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>bn_layer</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>zero_bn</code></strong>=<em><code>False</code></em>, <strong><code>padding</code></strong>=<em><code>None</code></em>, <strong><code>bias</code></strong>=<em><code>False</code></em>, <strong><code>groups</code></strong>=<em><code>1</code></em>, <strong>**<code>kwargs</code></strong>) :: <a href="/model_constructor/layers#ConvLayer"><code>ConvLayer</code></a></p>
237+
<h2 id="ConvLayerTwist" class="doc_header"><code>class</code> <code>ConvLayerTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L104" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ConvLayerTwist</code>(<strong><code>ni</code></strong>, <strong><code>nf</code></strong>, <strong><code>ks</code></strong>=<em><code>3</code></em>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>act</code></strong>=<em><code>True</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>bn_layer</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>zero_bn</code></strong>=<em><code>False</code></em>, <strong><code>padding</code></strong>=<em><code>None</code></em>, <strong><code>bias</code></strong>=<em><code>False</code></em>, <strong><code>groups</code></strong>=<em><code>1</code></em>, <strong>**<code>kwargs</code></strong>) :: <a href="/model_constructor/layers#ConvLayer"><code>ConvLayer</code></a></p>
238238
</blockquote>
239239
<p>Basic conv layers block</p>
240240

@@ -728,7 +728,7 @@ <h1 id="NewResBlockTwist">NewResBlockTwist<a class="anchor-link" href="#NewResBl
728728

729729

730730
<div class="output_markdown rendered_html output_subarea ">
731-
<h2 id="NewResBlockTwist" class="doc_header"><code>class</code> <code>NewResBlockTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L109" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>NewResBlockTwist</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>) :: <code>Module</code></p>
731+
<h2 id="NewResBlockTwist" class="doc_header"><code>class</code> <code>NewResBlockTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L108" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>NewResBlockTwist</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>) :: <code>Module</code></p>
732732
</blockquote>
733733
<p>Base class for all neural network modules.</p>
734734
<p>Your models should also subclass this class.</p>
@@ -940,7 +940,7 @@ <h1 id="ResBlockTwist">ResBlockTwist<a class="anchor-link" href="#ResBlockTwist"
940940

941941

942942
<div class="output_markdown rendered_html output_subarea ">
943-
<h2 id="ResBlockTwist" class="doc_header"><code>class</code> <code>ResBlockTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L135" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ResBlockTwist</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>) :: <code>Module</code></p>
943+
<h2 id="ResBlockTwist" class="doc_header"><code>class</code> <code>ResBlockTwist</code><a href="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/twist.py#L134" class="source_link" style="float:right">[source]</a></h2><blockquote><p><code>ResBlockTwist</code>(<strong><code>expansion</code></strong>, <strong><code>ni</code></strong>, <strong><code>nh</code></strong>, <strong><code>stride</code></strong>=<em><code>1</code></em>, <strong><code>conv_layer</code></strong>=<em><code>'ConvLayer'</code></em>, <strong><code>act_fn</code></strong>=<em><code>ReLU(inplace=True)</code></em>, <strong><code>zero_bn</code></strong>=<em><code>True</code></em>, <strong><code>bn_1st</code></strong>=<em><code>True</code></em>, <strong><code>pool</code></strong>=<em><code>AvgPool2d(kernel_size=2, stride=2, padding=0)</code></em>, <strong><code>sa</code></strong>=<em><code>False</code></em>, <strong><code>sym</code></strong>=<em><code>False</code></em>) :: <code>Module</code></p>
944944
</blockquote>
945945
<p>Base class for all neural network modules.</p>
946946
<p>Your models should also subclass this class.</p>

docs/_data/sidebars/home_sidebar.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ entries:
2727
- output: web,pdf
2828
title: Twist.
2929
url: /Twist
30-
- output: web,pdf
31-
title: Title
32-
url: /test_xresnet
3330
output: web
3431
title: model_constructor
3532
output: web

docs/sidebar.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"resnet": "/resnet",
77
"xresnet": "/xresnet",
88
"Net.": "/Net",
9-
"Twist.": "/Twist",
10-
"Title": "/test_xresnet"
9+
"Twist.": "/Twist"
1110
}
1211
}

model_constructor/_nbdev.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
"Body": "00_constructor.ipynb",
88
"Head": "00_constructor.ipynb",
99
"init_model": "00_constructor.ipynb",
10-
"Net": "81_Net.ipynb",
10+
"Net": "04_Net.ipynb",
1111
"ConvLayer": "01_layers.ipynb",
12-
"act_fn": "81_Net.ipynb",
12+
"act_fn": "04_Net.ipynb",
1313
"Flatten": "01_layers.ipynb",
1414
"noop": "01_layers.ipynb",
1515
"Noop": "01_layers.ipynb",
@@ -22,33 +22,30 @@
2222
"SimpleSelfAttention": "01_layers.ipynb",
2323
"ConvBlockBasic": "01_layers.ipynb",
2424
"ConvBlockBottle": "01_layers.ipynb",
25-
"ResBlock": "81_Net.ipynb",
25+
"ResBlock": "04_Net.ipynb",
2626
"resnet18": "02_resnet.ipynb",
2727
"resnet34": "02_resnet.ipynb",
2828
"resnet50": "02_resnet.ipynb",
2929
"xresnet18": "03_xresnet.ipynb",
3030
"xresnet34": "03_xresnet.ipynb",
31-
"xresnet50": "81_Net.ipynb",
32-
"init_cnn": "81_Net.ipynb",
33-
"NewResBlock": "81_Net.ipynb",
31+
"xresnet50": "03_xresnet.ipynb",
32+
"init_cnn": "04_Net.ipynb",
33+
"NewResBlock": "04_Net.ipynb",
3434
"net34": "04_Net.ipynb",
3535
"net50": "04_Net.ipynb",
3636
"nn": "05_Twist.ipynb",
3737
"F": "05_Twist.ipynb",
3838
"ConvTwist": "05_Twist.ipynb",
3939
"ConvLayerTwist": "05_Twist.ipynb",
4040
"NewResBlockTwist": "05_Twist.ipynb",
41-
"ResBlockTwist": "05_Twist.ipynb",
42-
"NewConvLayer": "81_Net.ipynb",
43-
"me": "81_Net.ipynb"}
41+
"ResBlockTwist": "05_Twist.ipynb"}
4442

4543
modules = ["constructor.py",
4644
"layers.py",
4745
"resnet.py",
4846
"xresnet.py",
4947
"net.py",
50-
"twist.py",
51-
"tst_net_2.py"]
48+
"twist.py"]
5249

5350
doc_url = "https://ayasyrev.github.io/model_constructor/"
5451

model_constructor/twist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,9 @@ def full_kernel(self, kernel): # permuting the groups
6666
return KK
6767

6868
def _conv(self, inpt, kernel=None):
69-
permute = True
7069
if kernel is None:
7170
kernel = self.conv.weight
72-
if permute is False:
71+
if self.permute is False:
7372
return F.conv2d(inpt, kernel, padding=1, stride=self.stride, groups=self.groups)
7473
else:
7574
return F.conv2d(inpt, self.full_kernel(kernel), padding=1, stride=self.stride, groups=1)

nbs/05_Twist.ipynb

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,9 @@
120120
" return KK\n",
121121
" \n",
122122
" def _conv(self, inpt, kernel=None):\n",
123-
" permute = True\n",
124123
" if kernel is None:\n",
125124
" kernel = self.conv.weight\n",
126-
" if permute is False:\n",
125+
" if self.permute is False:\n",
127126
" return F.conv2d(inpt, kernel, padding=1, stride=self.stride, groups=self.groups)\n",
128127
" else:\n",
129128
" return F.conv2d(inpt, self.full_kernel(kernel), padding=1, stride=self.stride, groups=1)\n",
@@ -2737,9 +2736,6 @@
27372736
"Converted 03_xresnet.ipynb.\n",
27382737
"Converted 04_Net.ipynb.\n",
27392738
"Converted 05_Twist.ipynb.\n",
2740-
"Converted 80_test_net.ipynb.\n",
2741-
"Converted 81_Net.ipynb.\n",
2742-
"Converted 81_test_xresnet.ipynb.\n",
27432739
"Converted index.ipynb.\n"
27442740
]
27452741
}
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)