1+ struct SparseKernel1d{T,S}
2+ k:: Int
3+ conv_blk:: S
4+ out_weight:: T
5+ end
6+
7+ function SparseKernel1d (k:: Int , c:: Int = 1 ; init= Flux. glorot_uniform)
8+ input_dim = c* k
9+ emb_dim = 128
10+ conv = Conv ((3 ,), input_dim=> emb_dim, relu; stride= 1 , pad= 1 , init= init)
11+ W_out = Dense (emb_dim, input_dim; init= init)
12+ return SparseKernel1d (k, conv, W_out)
13+ end
14+
15+ function (l:: SparseKernel1d )(X:: AbstractArray )
16+ X_ = l. conv_blk (batched_transpose (X))
17+ Y = l. out_weight (batched_transpose (X_))
18+ return Y
19+ end
20+
21+
22+ # class MWT_CZ1d(nn.Module):
23+ # def __init__(self,
24+ # k = 3, alpha = 5,
25+ # L = 0, c = 1,
26+ # base = 'legendre',
27+ # initializer = None,
28+ # **kwargs):
29+ # super(MWT_CZ1d, self).__init__()
30+
31+ # self.k = k
32+ # self.L = L
33+ # H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
34+ # H0r = H0@PHI0
35+ # G0r = G0@PHI0
36+ # H1r = H1@PHI1
37+ # G1r = G1@PHI1
38+
39+ # H0r[np.abs(H0r)<1e-8]=0
40+ # H1r[np.abs(H1r)<1e-8]=0
41+ # G0r[np.abs(G0r)<1e-8]=0
42+ # G1r[np.abs(G1r)<1e-8]=0
43+
44+ # self.A = sparseKernelFT1d(k, alpha, c)
45+ # self.B = sparseKernelFT1d(k, alpha, c)
46+ # self.C = sparseKernelFT1d(k, alpha, c)
47+
48+ # self.T0 = nn.Linear(k, k)
49+
50+ # self.register_buffer('ec_s', torch.Tensor(
51+ # np.concatenate((H0.T, H1.T), axis=0)))
52+ # self.register_buffer('ec_d', torch.Tensor(
53+ # np.concatenate((G0.T, G1.T), axis=0)))
54+
55+ # self.register_buffer('rc_e', torch.Tensor(
56+ # np.concatenate((H0r, G0r), axis=0)))
57+ # self.register_buffer('rc_o', torch.Tensor(
58+ # np.concatenate((H1r, G1r), axis=0)))
59+
60+
61+ # def forward(self, x):
62+
63+ # B, N, c, ich = x.shape # (B, N, k)
64+ # ns = math.floor(np.log2(N))
65+
66+ # Ud = torch.jit.annotate(List[Tensor], [])
67+ # Us = torch.jit.annotate(List[Tensor], [])
68+ # # decompose
69+ # for i in range(ns-self.L):
70+ # d, x = self.wavelet_transform(x)
71+ # Ud += [self.A(d) + self.B(x)]
72+ # Us += [self.C(d)]
73+ # x = self.T0(x) # coarsest scale transform
74+
75+ # # reconstruct
76+ # for i in range(ns-1-self.L,-1,-1):
77+ # x = x + Us[i]
78+ # x = torch.cat((x, Ud[i]), -1)
79+ # x = self.evenOdd(x)
80+ # return x
81+
82+
83+ # def wavelet_transform(self, x):
84+ # xa = torch.cat([x[:, ::2, :, :],
85+ # x[:, 1::2, :, :],
86+ # ], -1)
87+ # d = torch.matmul(xa, self.ec_d)
88+ # s = torch.matmul(xa, self.ec_s)
89+ # return d, s
90+
91+
92+ # def evenOdd(self, x):
93+
94+ # B, N, c, ich = x.shape # (B, N, c, k)
95+ # assert ich == 2*self.k
96+ # x_e = torch.matmul(x, self.rc_e)
97+ # x_o = torch.matmul(x, self.rc_o)
98+
99+ # x = torch.zeros(B, N*2, c, self.k,
100+ # device = x.device)
101+ # x[..., ::2, :, :] = x_e
102+ # x[..., 1::2, :, :] = x_o
103+ # return x
0 commit comments