Skip to content

Commit 84d7673

Browse files
author
ayasyrev
committed
add Twist
1 parent 315e7bd commit 84d7673

File tree

5 files changed

+3007
-3
lines changed

5 files changed

+3007
-3
lines changed

model_constructor/_nbdev.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
"NewResBlock": "81_Net.ipynb",
3434
"net34": "04_Net.ipynb",
3535
"net50": "04_Net.ipynb",
36+
"nn": "05_Twist.ipynb",
37+
"F": "05_Twist.ipynb",
38+
"ConvTwist": "05_Twist.ipynb",
39+
"ConvLayerTwist": "05_Twist.ipynb",
40+
"NewResBlockTwist": "05_Twist.ipynb",
41+
"ResBlockTwist": "05_Twist.ipynb",
3642
"NewConvLayer": "81_Net.ipynb",
3743
"me": "81_Net.ipynb"}
3844

@@ -41,6 +47,7 @@
4147
"resnet.py",
4248
"xresnet.py",
4349
"net.py",
50+
"twist.py",
4451
"tst_net_2.py"]
4552

4653
doc_url = "https://ayasyrev.github.io/model_constructor/"

model_constructor/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
class ConvLayer(nn.Sequential):
1717
"""Basic conv layers block"""
18+
Conv2d = nn.Conv2d
1819
def __init__(self, ni, nf, ks=3, stride=1,
1920
act=True, act_fn=act_fn,
2021
bn_layer=True, bn_1st=True, zero_bn=False,
2122
padding=None, bias=False, groups=1, **kwargs):
2223

2324
# self.act = act
2425
if padding==None: padding = ks//2
25-
layers = [('conv', nn.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]
26+
layers = [('conv', self.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]
2627
act_bn = [('act_fn', act_fn)] if act else []
2728
if bn_layer:
2829
bn = nn.BatchNorm2d(nf)

model_constructor/twist.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_Twist.ipynb (unless otherwise specified).
2+
3+
__all__ = ['nn', 'F', 'ConvTwist', 'ConvLayerTwist', 'NewResBlockTwist', 'ResBlockTwist']
4+
5+
# Cell
6+
from functools import partial
7+
from collections import OrderedDict
8+
from .layers import *
9+
from .net import *
10+
11+
# Cell
12+
import sys, torch
13+
nn = torch.nn
14+
F = torch.nn.functional
15+
16+
# Cell
17+
class ConvTwist(nn.Module):
18+
'''Replacement for Conv2d (kernelsize 3x3)'''
19+
def __init__(self, ni, nf,
20+
ks=3, stride=1, padding=1, bias=False,
21+
groups=1, iters=1, init_max=0.7, twist = False, permute=True):
22+
# super(ConvTwist, self).__init__()
23+
super().__init__()
24+
self.twist = twist
25+
self.permute = permute
26+
self.same = ni==nf and stride==1
27+
if not (ni%groups==0 and nf%groups==0): groups = 1
28+
# elif ni%64==0: groups = ni//8
29+
self.conv = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=padding, bias=bias, groups=groups)
30+
if self.twist:
31+
# self.conv_x = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
32+
# self.conv_y = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
33+
std = self.conv.weight.std().item()
34+
self.coeff_Ax = nn.Parameter(torch.empty((nf,ni//groups)).normal_(0, std), requires_grad=True)
35+
self.coeff_Ay = nn.Parameter(torch.empty((nf,ni//groups)).normal_(0, std), requires_grad=True)
36+
# self.coeff_Bx = nn.Parameter(torch.zeros((nf,ni)).normal_(0, std), requires_grad=True)
37+
# self.coeff_By = nn.Parameter(torch.zeros((nf,ni)).normal_(0, std), requires_grad=True)
38+
# self.center_x = nn.Parameter(torch.Tensor(nf), requires_grad=True)
39+
# self.center_y = nn.Parameter(torch.Tensor(nf), requires_grad=True)
40+
# self.center_x.data.uniform_(-init_max, init_max)
41+
# self.center_y.data.uniform_(-init_max, init_max)
42+
self.iters = iters
43+
self.stride = stride
44+
self.groups = groups
45+
self.DD = self.derivatives()
46+
47+
def derivatives(self):
48+
I = torch.Tensor([[0,0,0],[0,1,0],[0,0,0]]).view(1,1,3,3)
49+
D_x = torch.Tensor([[-1,0,1],[-2,0,2],[-1,0,1]]).view(1,1,3,3) / 10
50+
D_y = torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).view(1,1,3,3) / 10
51+
def convolution(K1, K2):
52+
return F.conv2d(K1, K2.flip(2).flip(3), padding=2)
53+
D_xx = convolution(I+D_x, I+D_x).view(5,5)
54+
D_yy = convolution(I+D_y, I+D_y).view(5,5)
55+
D_xy = convolution(I+D_x, I+D_y).view(5,5)
56+
return {'x': D_x, 'y': D_y, 'xx': D_xx, 'yy': D_yy, 'xy': D_xy}
57+
58+
# def init_coeff(self):
59+
# self.coeff_Bx.data = self.coeff_Ay
60+
# self.coeff_By.data = -self.coeff_Ax
61+
62+
def kernel(self, coeff_x, coeff_y):
63+
D_x = torch.Tensor([[-1,0,1],[-2,0,2],[-1,0,1]]).to(coeff_x.device)
64+
D_y = torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).to(coeff_x.device)
65+
return coeff_x[:,:,None,None] * D_x + coeff_y[:,:,None,None] * D_y
66+
67+
def full_kernel(self, kernel): # permuting the groups
68+
if self.groups==1: return kernel
69+
n = self.groups
70+
a,b,_,_ = kernel.size()
71+
a = a//n
72+
KK = torch.zeros((a*n,b*n,3,3)).to(kernel.device)
73+
# KK[:a,-b:] = kernel[:a]
74+
for i in range(n):
75+
if i%4==0:
76+
KK[a*i:a*(i+1),b*(i+3):b*(i+4)] = kernel[a*i:a*(i+1)]
77+
else:
78+
KK[a*i:a*(i+1),b*(i-1):b*i] = kernel[a*i:a*(i+1)]
79+
return KK
80+
81+
def _conv(self, inpt, kernel=None):
82+
# permute = True
83+
if kernel is None:
84+
kernel = self.conv.weight
85+
if not self.permute:
86+
return F.conv2d(inpt, kernel, padding=1, stride=self.stride, groups=self.groups)
87+
else:
88+
return F.conv2d(inpt, self.full_kernel(kernel), padding=1, stride=self.stride, groups=1)
89+
90+
def symmetrize(self, conv_wt):
91+
# conv_wt.data = (conv_wt - conv_wt.flip(2).flip(3)) / 2
92+
if self.same:
93+
n = conv_wt.size()[1]
94+
for i in range(self.groups):
95+
conv_wt.data[n*i:n*(i+1)] = (conv_wt[n*i:n*(i+1)] + torch.transpose(conv_wt[n*i:n*(i+1)],0,1)) / 2
96+
97+
def forward(self, inpt):
98+
# self.symmetrize(self.conv.weight)
99+
out = self.conv(inpt)
100+
if self.twist is False:
101+
return out
102+
_,_,h,w = out.size()
103+
XX = torch.from_numpy(np.indices((1,1,h,w))[3]*2/w-1).type(out.dtype).to(out.device)
104+
YY = torch.from_numpy(np.indices((1,1,h,w))[2]*2/h-1).type(out.dtype).to(out.device)
105+
# self.symmetrize(self.conv_x.weight)
106+
# self.symmetrize(self.conv_y.weight)
107+
# kernel_x = self.conv_x.weight
108+
# kernel_y = self.conv_y.weight
109+
# self.symmetrize(self.coeff_Ax)
110+
# self.symmetrize(self.coeff_Ay)
111+
kernel_x = self.kernel(self.coeff_Ax, self.coeff_Ay)
112+
self.symmetrize(kernel_x)
113+
# self.symmetrize(kernel_y)
114+
kernel_y = kernel_x.transpose(2,3).flip(3) # make conv_y a 90 degree rotation of conv_x
115+
# kernel_y = self.kernel(self.coeff_Bx, self.coeff_By)
116+
out = out + XX * self._conv(inpt, kernel_x) + YY * self._conv(inpt, kernel_y)
117+
# out = out + (XX-self.center_x.view(-1,1,1)) * self.conv_x(inpt) + (YY-self.center_y.view(-1,1,1)) * self.conv_y(inpt)
118+
if self.same and self.iters>1:
119+
out = inpt + out / self.iters
120+
for _ in range(self.iters-1):
121+
out = out + (self._conv(out) + XX * self._conv(out, kernel_x) + YY * self._conv(out, kernel_y)) / self.iters
122+
out = out - inpt
123+
return out
124+
125+
def extra_repr(self):
126+
return f"twist: {self.twist}, permute: {self.permute}, same: {self.same}"
127+
128+
# Cell
129+
class ConvLayerTwist(ConvLayer): # replace Conv2d by Twist
130+
Conv2d = ConvTwist
131+
132+
# Cell
133+
class NewResBlockTwist(nn.Module):
134+
def __init__(self, expansion, ni, nh, stride=1,
135+
conv_layer=ConvLayer, act_fn=act_fn, bn_1st=True,
136+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True):
137+
super().__init__()
138+
nf,ni = nh*expansion,ni*expansion
139+
# conv_layer = ConvLayerTwist
140+
self.reduce = noop if stride==1 else pool
141+
layers = [(f"conv_0", conv_layer(ni, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
142+
(f"conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
143+
] if expansion == 1 else [
144+
(f"conv_0", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
145+
# (f"conv_1", conv_layer(nh, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
146+
(f"conv_1_twist", ConvLayerTwist(nh, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
147+
(f"conv_2", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
148+
]
149+
if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))
150+
self.convs = nn.Sequential(OrderedDict(layers))
151+
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False, bn_1st=bn_1st)
152+
self.merge = act_fn
153+
154+
def forward(self, x):
155+
o = self.reduce(x)
156+
return self.merge(self.convs(o) + self.idconv(o))
157+
158+
# Cell
159+
class ResBlockTwist(nn.Module):
160+
def __init__(self, expansion, ni, nh, stride=1,
161+
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
162+
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False,sym=False):
163+
super().__init__()
164+
nf,ni = nh*expansion,ni*expansion
165+
# conv_layer = ConvLayerTwist
166+
layers = [(f"conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
167+
(f"conv_1", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
168+
] if expansion == 1 else [
169+
(f"conv_0",conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),
170+
# (f"conv_1",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
171+
(f"conv_1_twist",ConvLayerTwist(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
172+
(f"conv_2",conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))
173+
]
174+
if sa: layers.append(('sa', SimpleSelfAttention(nf,ks=1,sym=sym)))
175+
self.convs = nn.Sequential(OrderedDict(layers))
176+
self.pool = noop if stride==1 else pool
177+
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
178+
self.act_fn =act_fn
179+
180+
def forward(self, x): return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))

nbs/01_layers.ipynb

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,15 @@
6060
"\n",
6161
"class ConvLayer(nn.Sequential):\n",
6262
" \"\"\"Basic conv layers block\"\"\"\n",
63+
" Conv2d = nn.Conv2d\n",
6364
" def __init__(self, ni, nf, ks=3, stride=1, \n",
6465
" act=True, act_fn=act_fn, \n",
6566
" bn_layer=True, bn_1st=True, zero_bn=False, \n",
6667
" padding=None, bias=False, groups=1, **kwargs):\n",
6768
"\n",
6869
"# self.act = act\n",
6970
" if padding==None: padding = ks//2 \n",
70-
" layers = [('conv', nn.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]\n",
71+
" layers = [('conv', self.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]\n",
7172
" act_bn = [('act_fn', act_fn)] if act else []\n",
7273
" if bn_layer:\n",
7374
" bn = nn.BatchNorm2d(nf)\n",
@@ -1904,7 +1905,10 @@
19041905
"Converted 01_layers.ipynb.\n",
19051906
"Converted 02_resnet.ipynb.\n",
19061907
"Converted 03_xresnet.ipynb.\n",
1907-
"Converted 80_test_layers.ipynb.\n",
1908+
"Converted 04_Net.ipynb.\n",
1909+
"Converted 05_Twist.ipynb.\n",
1910+
"Converted 80_test_net.ipynb.\n",
1911+
"Converted 81_Net.ipynb.\n",
19081912
"Converted 81_test_xresnet.ipynb.\n",
19091913
"Converted index.ipynb.\n"
19101914
]

0 commit comments

Comments
 (0)