1+ # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05_Twist.ipynb (unless otherwise specified).
2+
3+ __all__ = ['nn' , 'F' , 'ConvTwist' , 'ConvLayerTwist' , 'NewResBlockTwist' , 'ResBlockTwist' ]
4+
5+ # Cell
6+ from functools import partial
7+ from collections import OrderedDict
8+ from .layers import *
9+ from .net import *
10+
11+ # Cell
12+ import sys , torch
13+ nn = torch .nn
14+ F = torch .nn .functional
15+
16+ # Cell
17+ class ConvTwist (nn .Module ):
18+ '''Replacement for Conv2d (kernelsize 3x3)'''
19+ def __init__ (self , ni , nf ,
20+ ks = 3 , stride = 1 , padding = 1 , bias = False ,
21+ groups = 1 , iters = 1 , init_max = 0.7 , twist = False , permute = True ):
22+ # super(ConvTwist, self).__init__()
23+ super ().__init__ ()
24+ self .twist = twist
25+ self .permute = permute
26+ self .same = ni == nf and stride == 1
27+ if not (ni % groups == 0 and nf % groups == 0 ): groups = 1
28+ # elif ni%64==0: groups = ni//8
29+ self .conv = nn .Conv2d (ni , nf , kernel_size = 3 , stride = stride , padding = padding , bias = bias , groups = groups )
30+ if self .twist :
31+ # self.conv_x = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
32+ # self.conv_y = nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False, groups=groups)
33+ std = self .conv .weight .std ().item ()
34+ self .coeff_Ax = nn .Parameter (torch .empty ((nf ,ni // groups )).normal_ (0 , std ), requires_grad = True )
35+ self .coeff_Ay = nn .Parameter (torch .empty ((nf ,ni // groups )).normal_ (0 , std ), requires_grad = True )
36+ # self.coeff_Bx = nn.Parameter(torch.zeros((nf,ni)).normal_(0, std), requires_grad=True)
37+ # self.coeff_By = nn.Parameter(torch.zeros((nf,ni)).normal_(0, std), requires_grad=True)
38+ # self.center_x = nn.Parameter(torch.Tensor(nf), requires_grad=True)
39+ # self.center_y = nn.Parameter(torch.Tensor(nf), requires_grad=True)
40+ # self.center_x.data.uniform_(-init_max, init_max)
41+ # self.center_y.data.uniform_(-init_max, init_max)
42+ self .iters = iters
43+ self .stride = stride
44+ self .groups = groups
45+ self .DD = self .derivatives ()
46+
47+ def derivatives (self ):
48+ I = torch .Tensor ([[0 ,0 ,0 ],[0 ,1 ,0 ],[0 ,0 ,0 ]]).view (1 ,1 ,3 ,3 )
49+ D_x = torch .Tensor ([[- 1 ,0 ,1 ],[- 2 ,0 ,2 ],[- 1 ,0 ,1 ]]).view (1 ,1 ,3 ,3 ) / 10
50+ D_y = torch .Tensor ([[1 ,2 ,1 ],[0 ,0 ,0 ],[- 1 ,- 2 ,- 1 ]]).view (1 ,1 ,3 ,3 ) / 10
51+ def convolution (K1 , K2 ):
52+ return F .conv2d (K1 , K2 .flip (2 ).flip (3 ), padding = 2 )
53+ D_xx = convolution (I + D_x , I + D_x ).view (5 ,5 )
54+ D_yy = convolution (I + D_y , I + D_y ).view (5 ,5 )
55+ D_xy = convolution (I + D_x , I + D_y ).view (5 ,5 )
56+ return {'x' : D_x , 'y' : D_y , 'xx' : D_xx , 'yy' : D_yy , 'xy' : D_xy }
57+
58+ # def init_coeff(self):
59+ # self.coeff_Bx.data = self.coeff_Ay
60+ # self.coeff_By.data = -self.coeff_Ax
61+
62+ def kernel (self , coeff_x , coeff_y ):
63+ D_x = torch .Tensor ([[- 1 ,0 ,1 ],[- 2 ,0 ,2 ],[- 1 ,0 ,1 ]]).to (coeff_x .device )
64+ D_y = torch .Tensor ([[1 ,2 ,1 ],[0 ,0 ,0 ],[- 1 ,- 2 ,- 1 ]]).to (coeff_x .device )
65+ return coeff_x [:,:,None ,None ] * D_x + coeff_y [:,:,None ,None ] * D_y
66+
67+ def full_kernel (self , kernel ): # permuting the groups
68+ if self .groups == 1 : return kernel
69+ n = self .groups
70+ a ,b ,_ ,_ = kernel .size ()
71+ a = a // n
72+ KK = torch .zeros ((a * n ,b * n ,3 ,3 )).to (kernel .device )
73+ # KK[:a,-b:] = kernel[:a]
74+ for i in range (n ):
75+ if i % 4 == 0 :
76+ KK [a * i :a * (i + 1 ),b * (i + 3 ):b * (i + 4 )] = kernel [a * i :a * (i + 1 )]
77+ else :
78+ KK [a * i :a * (i + 1 ),b * (i - 1 ):b * i ] = kernel [a * i :a * (i + 1 )]
79+ return KK
80+
81+ def _conv (self , inpt , kernel = None ):
82+ # permute = True
83+ if kernel is None :
84+ kernel = self .conv .weight
85+ if not self .permute :
86+ return F .conv2d (inpt , kernel , padding = 1 , stride = self .stride , groups = self .groups )
87+ else :
88+ return F .conv2d (inpt , self .full_kernel (kernel ), padding = 1 , stride = self .stride , groups = 1 )
89+
90+ def symmetrize (self , conv_wt ):
91+ # conv_wt.data = (conv_wt - conv_wt.flip(2).flip(3)) / 2
92+ if self .same :
93+ n = conv_wt .size ()[1 ]
94+ for i in range (self .groups ):
95+ conv_wt .data [n * i :n * (i + 1 )] = (conv_wt [n * i :n * (i + 1 )] + torch .transpose (conv_wt [n * i :n * (i + 1 )],0 ,1 )) / 2
96+
97+ def forward (self , inpt ):
98+ # self.symmetrize(self.conv.weight)
99+ out = self .conv (inpt )
100+ if self .twist is False :
101+ return out
102+ _ ,_ ,h ,w = out .size ()
103+ XX = torch .from_numpy (np .indices ((1 ,1 ,h ,w ))[3 ]* 2 / w - 1 ).type (out .dtype ).to (out .device )
104+ YY = torch .from_numpy (np .indices ((1 ,1 ,h ,w ))[2 ]* 2 / h - 1 ).type (out .dtype ).to (out .device )
105+ # self.symmetrize(self.conv_x.weight)
106+ # self.symmetrize(self.conv_y.weight)
107+ # kernel_x = self.conv_x.weight
108+ # kernel_y = self.conv_y.weight
109+ # self.symmetrize(self.coeff_Ax)
110+ # self.symmetrize(self.coeff_Ay)
111+ kernel_x = self .kernel (self .coeff_Ax , self .coeff_Ay )
112+ self .symmetrize (kernel_x )
113+ # self.symmetrize(kernel_y)
114+ kernel_y = kernel_x .transpose (2 ,3 ).flip (3 ) # make conv_y a 90 degree rotation of conv_x
115+ # kernel_y = self.kernel(self.coeff_Bx, self.coeff_By)
116+ out = out + XX * self ._conv (inpt , kernel_x ) + YY * self ._conv (inpt , kernel_y )
117+ # out = out + (XX-self.center_x.view(-1,1,1)) * self.conv_x(inpt) + (YY-self.center_y.view(-1,1,1)) * self.conv_y(inpt)
118+ if self .same and self .iters > 1 :
119+ out = inpt + out / self .iters
120+ for _ in range (self .iters - 1 ):
121+ out = out + (self ._conv (out ) + XX * self ._conv (out , kernel_x ) + YY * self ._conv (out , kernel_y )) / self .iters
122+ out = out - inpt
123+ return out
124+
125+ def extra_repr (self ):
126+ return f"twist: { self .twist } , permute: { self .permute } , same: { self .same } "
127+
128+ # Cell
129+ class ConvLayerTwist (ConvLayer ): # replace Conv2d by Twist
130+ Conv2d = ConvTwist
131+
132+ # Cell
133+ class NewResBlockTwist (nn .Module ):
134+ def __init__ (self , expansion , ni , nh , stride = 1 ,
135+ conv_layer = ConvLayer , act_fn = act_fn , bn_1st = True ,
136+ pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False , sym = False , zero_bn = True ):
137+ super ().__init__ ()
138+ nf ,ni = nh * expansion ,ni * expansion
139+ # conv_layer = ConvLayerTwist
140+ self .reduce = noop if stride == 1 else pool
141+ layers = [(f"conv_0" , conv_layer (ni , nh , 3 , act_fn = act_fn , bn_1st = bn_1st )),
142+ (f"conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
143+ ] if expansion == 1 else [
144+ (f"conv_0" , conv_layer (ni , nh , 1 , act_fn = act_fn , bn_1st = bn_1st )),
145+ # (f"conv_1", conv_layer(nh, nh, 3, act_fn=act_fn, bn_1st=bn_1st)),
146+ (f"conv_1_twist" , ConvLayerTwist (nh , nh , 3 , act_fn = act_fn , bn_1st = bn_1st )),
147+ (f"conv_2" , conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
148+ ]
149+ if sa : layers .append (('sa' , SimpleSelfAttention (nf ,ks = 1 ,sym = sym )))
150+ self .convs = nn .Sequential (OrderedDict (layers ))
151+ self .idconv = noop if ni == nf else conv_layer (ni , nf , 1 , act = False , bn_1st = bn_1st )
152+ self .merge = act_fn
153+
154+ def forward (self , x ):
155+ o = self .reduce (x )
156+ return self .merge (self .convs (o ) + self .idconv (o ))
157+
158+ # Cell
159+ class ResBlockTwist (nn .Module ):
160+ def __init__ (self , expansion , ni , nh , stride = 1 ,
161+ conv_layer = ConvLayer , act_fn = act_fn , zero_bn = True , bn_1st = True ,
162+ pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False ,sym = False ):
163+ super ().__init__ ()
164+ nf ,ni = nh * expansion ,ni * expansion
165+ # conv_layer = ConvLayerTwist
166+ layers = [(f"conv_0" , conv_layer (ni , nh , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st )),
167+ (f"conv_1" , conv_layer (nh , nf , 3 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
168+ ] if expansion == 1 else [
169+ (f"conv_0" ,conv_layer (ni , nh , 1 , act_fn = act_fn , bn_1st = bn_1st )),
170+ # (f"conv_1",conv_layer(nh, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st)),
171+ (f"conv_1_twist" ,ConvLayerTwist (nh , nh , 3 , stride = stride , act_fn = act_fn , bn_1st = bn_1st )),
172+ (f"conv_2" ,conv_layer (nh , nf , 1 , zero_bn = zero_bn , act = False , bn_1st = bn_1st ))
173+ ]
174+ if sa : layers .append (('sa' , SimpleSelfAttention (nf ,ks = 1 ,sym = sym )))
175+ self .convs = nn .Sequential (OrderedDict (layers ))
176+ self .pool = noop if stride == 1 else pool
177+ self .idconv = noop if ni == nf else conv_layer (ni , nf , 1 , act = False )
178+ self .act_fn = act_fn
179+
180+ def forward (self , x ): return self .act_fn (self .convs (x ) + self .idconv (self .pool (x )))
0 commit comments