@@ -215,7 +215,7 @@ <h2 id="ResBlock" class="doc_header"><code>class</code> <code>ResBlock</code><a
215215 (act_fn): ReLU(inplace=True)
216216 )
217217 (conv_1): ConvLayer(
218- (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4 , bias=False)
218+ (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16 , bias=False)
219219 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
220220 (act_fn): ReLU(inplace=True)
221221 )
@@ -269,7 +269,7 @@ <h2 id="ResBlock" class="doc_header"><code>class</code> <code>ResBlock</code><a
269269 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
270270 )
271271 (conv_1): ConvLayer(
272- (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
272+ (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
273273 (act_fn): LeakyReLU(negative_slope=0.01)
274274 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
275275 )
@@ -314,7 +314,7 @@ <h1 id="NewResBlock">NewResBlock<a class="anchor-link" href="#NewResBlock"> </a>
314314
315315
316316< div class ="output_markdown rendered_html output_subarea ">
317- < h2 id ="NewResBlock " class ="doc_header "> < code > class</ code > < code > NewResBlock</ code > < a href ="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L46 " class ="source_link " style ="float:right "> [source]</ a > </ h2 > < blockquote > < p > < code > NewResBlock</ code > (< strong > < code > expansion</ code > </ strong > , < strong > < code > ni</ code > </ strong > , < strong > < code > nh</ code > </ strong > , < strong > < code > stride</ code > </ strong > =< em > < code > 1</ code > </ em > , < strong > < code > conv_layer</ code > </ strong > =< em > < code > 'ConvLayer'</ code > </ em > , < strong > < code > act_fn</ code > </ strong > =< em > < code > ReLU(inplace=True)</ code > </ em > , < strong > < code > zero_bn</ code > </ strong > =< em > < code > True</ code > </ em > , < strong > < code > bn_1st</ code > </ strong > =< em > < code > True</ code > </ em > , < strong > < code > pool</ code > </ strong > =< em > < code > AvgPool2d(kernel_size=2, stride=2, padding=0)</ code > </ em > , < strong > < code > sa</ code > </ strong > =< em > < code > False</ code > </ em > , < strong > < code > sym</ code > </ strong > =< em > < code > False</ code > </ em > , < strong > < code > groups</ code > </ strong > =< em > < code > 1</ code > </ em > ) :: < code > Module</ code > </ p >
317+ < h2 id ="NewResBlock " class ="doc_header "> < code > class</ code > < code > NewResBlock</ code > < a href ="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L47 " class ="source_link " style ="float:right "> [source]</ a > </ h2 > < blockquote > < p > < code > NewResBlock</ code > (< strong > < code > expansion</ code > </ strong > , < strong > < code > ni</ code > </ strong > , < strong > < code > nh</ code > </ strong > , < strong > < code > stride</ code > </ strong > =< em > < code > 1</ code > </ em > , < strong > < code > conv_layer</ code > </ strong > =< em > < code > 'ConvLayer'</ code > </ em > , < strong > < code > act_fn</ code > </ strong > =< em > < code > ReLU(inplace=True)</ code > </ em > , < strong > < code > zero_bn</ code > </ strong > =< em > < code > True</ code > </ em > , < strong > < code > bn_1st</ code > </ strong > =< em > < code > True</ code > </ em > , < strong > < code > pool</ code > </ strong > =< em > < code > AvgPool2d(kernel_size=2, stride=2, padding=0)</ code > </ em > , < strong > < code > sa</ code > </ strong > =< em > < code > False</ code > </ em > , < strong > < code > sym</ code > </ strong > =< em > < code > False</ code > </ em > , < strong > < code > groups</ code > </ strong > =< em > < code > 1</ code > </ em > ) :: < code > Module</ code > </ p >
318318</ blockquote >
319319< p > Base class for all neural network modules.</ p >
320320< p > Your models should also subclass this class.</ p >
@@ -406,7 +406,7 @@ <h1 id="Net-class.">Net class.<a class="anchor-link" href="#Net-class."> </a></h
406406
407407
408408< div class ="output_markdown rendered_html output_subarea ">
409- < h2 id ="Net " class ="doc_header "> < code > class</ code > < code > Net</ code > < a href ="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L106 " class ="source_link " style ="float:right "> [source]</ a > </ h2 > < blockquote > < p > < code > Net</ code > (< strong > < code > expansion</ code > </ strong > =< em > < code > 1</ code > </ em > , < strong > < code > layers</ code > </ strong > =< em > < code > [2, 2, 2, 2]</ code > </ em > , < strong > < code > c_in</ code > </ strong > =< em > < code > 3</ code > </ em > , < strong > < code > c_out</ code > </ strong > =< em > < code > 1000</ code > </ em > , < strong > < code > name</ code > </ strong > =< em > < code > 'Net'</ code > </ em > , < strong > < code > act_fn</ code > </ strong > =< em > < code > ReLU(inplace=True)</ code > </ em > , < strong > < code > pool</ code > </ strong > =< em > < code > AvgPool2d(kernel_size=2, stride=2, padding=0)</ code > </ em > , < strong > < code > sa</ code > </ strong > =< em > < code > 0</ code > </ em > )</ p >
409+ < h2 id ="Net " class ="doc_header "> < code > class</ code > < code > Net</ code > < a href ="https://github.com/ayasyrev/model_constructor/tree/master/model_constructor/net.py#L107 " class ="source_link " style ="float:right "> [source]</ a > </ h2 > < blockquote > < p > < code > Net</ code > (< strong > < code > expansion</ code > </ strong > =< em > < code > 1</ code > </ em > , < strong > < code > layers</ code > </ strong > =< em > < code > [2, 2, 2, 2]</ code > </ em > , < strong > < code > c_in</ code > </ strong > =< em > < code > 3</ code > </ em > , < strong > < code > c_out</ code > </ strong > =< em > < code > 1000</ code > </ em > , < strong > < code > name</ code > </ strong > =< em > < code > 'Net'</ code > </ em > , < strong > < code > act_fn</ code > </ strong > =< em > < code > ReLU(inplace=True)</ code > </ em > , < strong > < code > pool</ code > </ strong > =< em > < code > AvgPool2d(kernel_size=2, stride=2, padding=0)</ code > </ em > , < strong > < code > sa</ code > </ strong > =< em > < code > 0</ code > </ em > )</ p >
410410</ blockquote >
411411
412412</ div >
@@ -533,18 +533,18 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
533533< pre > Sequential(
534534 (conv_0): ConvLayer(
535535 (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
536- (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
537536 (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
537+ (act_fn): ReLU(inplace=True)
538538 )
539539 (conv_1): ConvLayer(
540540 (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
541- (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
542541 (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
542+ (act_fn): ReLU(inplace=True)
543543 )
544544 (conv_2): ConvLayer(
545545 (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
546- (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
547546 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
547+ (act_fn): ReLU(inplace=True)
548548 )
549549 (stem_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
550550)</ pre >
@@ -771,7 +771,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
771771 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
772772 )
773773 (conv_1): ConvLayer(
774- (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
774+ (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
775775 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
776776 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
777777 )
@@ -794,14 +794,17 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
794794 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
795795 )
796796 (conv_1): ConvLayer(
797- (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
797+ (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
798798 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
799799 (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
800800 )
801801 (conv_2): ConvLayer(
802802 (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
803803 (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
804804 )
805+ (sa): SimpleSelfAttention(
806+ (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
807+ )
805808 )
806809 (merge): LeakyReLU(negative_slope=0.01, inplace=True)
807810 )
@@ -816,7 +819,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
816819 (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
817820 )
818821 (conv_1): ConvLayer(
819- (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
822+ (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
820823 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
821824 (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
822825 )
@@ -839,7 +842,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
839842 (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
840843 )
841844 (conv_1): ConvLayer(
842- (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
845+ (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
843846 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
844847 (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
845848 )
@@ -861,7 +864,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
861864 (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
862865 )
863866 (conv_1): ConvLayer(
864- (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
867+ (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
865868 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
866869 (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
867870 )
@@ -884,7 +887,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
884887 (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
885888 )
886889 (conv_1): ConvLayer(
887- (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
890+ (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
888891 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
889892 (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
890893 )
@@ -906,7 +909,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
906909 (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
907910 )
908911 (conv_1): ConvLayer(
909- (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
912+ (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
910913 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
911914 (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
912915 )
@@ -929,7 +932,7 @@ <h2 id="Net" class="doc_header"><code>class</code> <code>Net</code><a href="http
929932 (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
930933 )
931934 (conv_1): ConvLayer(
932- (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
935+ (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512, bias=False)
933936 (act_fn): LeakyReLU(negative_slope=0.01, inplace=True)
934937 (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
935938 )
0 commit comments