@@ -51,93 +51,80 @@ function (l::SparseKernel)(X::AbstractArray)
5151end
5252
5353
54- # struct MWT_CZ1d
55-
56- # end
57-
58- # function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform)
59-
60- # end
61-
62- # class MWT_CZ1d(nn.Module):
63- # def __init__(self,
64- # k = 3, alpha = 5,
65- # L = 0, c = 1,
66- # base = 'legendre',
67- # initializer = None,
68- # **kwargs):
69- # super(MWT_CZ1d, self).__init__()
70-
71- # self.k = k
72- # self.L = L
73- # H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
74- # H0r = H0@PHI0
75- # G0r = G0@PHI0
76- # H1r = H1@PHI1
77- # G1r = G1@PHI1
78-
79- # H0r[np.abs(H0r)<1e-8]=0
80- # H1r[np.abs(H1r)<1e-8]=0
81- # G0r[np.abs(G0r)<1e-8]=0
82- # G1r[np.abs(G1r)<1e-8]=0
83-
84- # self.A = sparseKernelFT1d(k, alpha, c)
85- # self.B = sparseKernelFT1d(k, alpha, c)
86- # self.C = sparseKernelFT1d(k, alpha, c)
87-
88- # self.T0 = nn.Linear(k, k)
89-
90- # self.register_buffer('ec_s', torch.Tensor(
91- # np.concatenate((H0.T, H1.T), axis=0)))
92- # self.register_buffer('ec_d', torch.Tensor(
93- # np.concatenate((G0.T, G1.T), axis=0)))
94-
95- # self.register_buffer('rc_e', torch.Tensor(
96- # np.concatenate((H0r, G0r), axis=0)))
97- # self.register_buffer('rc_o', torch.Tensor(
98- # np.concatenate((H1r, G1r), axis=0)))
99-
100-
101- # def forward(self, x):
102-
103- # B, N, c, ich = x.shape # (B, N, k)
104- # ns = math.floor(np.log2(N))
105-
106- # Ud = torch.jit.annotate(List[Tensor], [])
107- # Us = torch.jit.annotate(List[Tensor], [])
108- # # decompose
109- # for i in range(ns-self.L):
110- # d, x = self.wavelet_transform(x)
111- # Ud += [self.A(d) + self.B(x)]
112- # Us += [self.C(d)]
113- # x = self.T0(x) # coarsest scale transform
114-
115- # # reconstruct
116- # for i in range(ns-1-self.L,-1,-1):
117- # x = x + Us[i]
118- # x = torch.cat((x, Ud[i]), -1)
119- # x = self.evenOdd(x)
120- # return x
121-
122-
123- # def wavelet_transform(self, x):
124- # xa = torch.cat([x[:, ::2, :, :],
125- # x[:, 1::2, :, :],
126- # ], -1)
127- # d = torch.matmul(xa, self.ec_d)
128- # s = torch.matmul(xa, self.ec_s)
129- # return d, s
130-
131-
132- # def evenOdd(self, x):
133-
134- # B, N, c, ich = x.shape # (B, N, c, k)
135- # assert ich == 2*self.k
136- # x_e = torch.matmul(x, self.rc_e)
137- # x_o = torch.matmul(x, self.rc_o)
138-
54+ struct MWT_CZ1d{T,S,R,Q,P}
55+ k:: Int
56+ L:: Int
57+ A:: T
58+ B:: S
59+ C:: R
60+ T0:: Q
61+ ec_s:: P
62+ ec_d:: P
63+ rc_e:: P
64+ rc_o:: P
65+ end
66+
67+ function MWT_CZ1d (k:: Int = 3 , α:: Int = 5 , L:: Int = 0 , c:: Int = 1 ; base:: Symbol = :legendre , init= Flux. glorot_uniform)
68+ H0, H1, G0, G1, ϕ0, ϕ1 = get_filter (base, k)
69+ H0r = zero_out! (H0 * ϕ0)
70+ G0r = zero_out! (G0 * ϕ0)
71+ H1r = zero_out! (H1 * ϕ1)
72+ G1r = zero_out! (G1 * ϕ1)
73+
74+ dim = c* k
75+ A = SpectralConv (dim=> dim, (α,); init= init)
76+ B = SpectralConv (dim=> dim, (α,); init= init)
77+ C = SpectralConv (dim=> dim, (α,); init= init)
78+ T0 = Dense (k, k)
79+
80+ ec_s = vcat (H0' , H1' )
81+ ec_d = vcat (G0' , G1' )
82+ rc_e = vcat (H0r, G0r)
83+ rc_o = vcat (H1r, G1r)
84+ return MWT_CZ1d (k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
85+ end
86+
87+ function wavelet_transform (l:: MWT_CZ1d , X:: AbstractArray{T,4} ) where {T}
88+ N = size (X, 3 )
89+ Xa = vcat (view (X, :, :, 1 : 2 : N, :), view (X, :, :, 2 : 2 : N, :))
90+ d = NNlib. batched_mul (Xa, l. ec_d)
91+ s = NNlib. batched_mul (Xa, l. ec_s)
92+ return d, s
93+ end
94+
95+ function even_odd (l:: MWT_CZ1d , X:: AbstractArray{T,4} ) where {T}
96+ bch_sz, N, dims_r... = reverse (size (X))
97+ dims = reverse (dims_r)
98+ @assert dims[1 ] == 2 * l. k
99+ Xₑ = NNlib. batched_mul (X, l. rc_e)
100+ Xₒ = NNlib. batched_mul (X, l. rc_o)
139101# x = torch.zeros(B, N*2, c, self.k,
140102# device = x.device)
141103# x[..., ::2, :, :] = x_e
142104# x[..., 1::2, :, :] = x_o
143- # return x
105+ return X
106+ end
107+
108+ function (l:: MWT_CZ1d )(X:: T ) where {T<: AbstractArray }
109+ bch_sz, N, dims_r... = reverse (size (X))
110+ ns = floor (log2 (N))
111+ stop = ns - l. L
112+
113+ # decompose
114+ Ud = T[]
115+ Us = T[]
116+ for i in 1 : stop
117+ d, X = wavelet_transform (l, X)
118+ push! (Ud, l. A (d)+ l. B (d))
119+ push! (Us, l. C (d))
120+ end
121+ X = l. T0 (X)
122+
123+ # reconstruct
124+ for i in stop: - 1 : 1
125+ X += Us[i]
126+ X = vcat (X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
127+ X = even_odd (l, X)
128+ end
129+ return X
130+ end
0 commit comments