@@ -20,7 +20,7 @@ function 🦋generate_random!(A, ::Val{SEED} = Val(888)) where {SEED}
2020 (uv,)
2121end
2222
23- function 🦋workspace (A, B:: Matrix{T} , U:: Adjoint{T, Matrix{T}} , V:: Matrix{T} , :: Val{SEED} = Val (888 )) where {T, SEED}
23+ function 🦋workspace (A, b, B:: Matrix{T} , U:: Adjoint{T, Matrix{T}} , V:: Matrix{T} , thread , :: Val{SEED} = Val (888 )) where {T, SEED}
2424 M = size (A, 1 )
2525 if (M % 4 != 0 )
2626 A = pad! (A)
@@ -29,9 +29,10 @@ function 🦋workspace(A, B::Matrix{T}, U::Adjoint{T, Matrix{T}}, V::Matrix{T},
2929 ws = 🦋generate_random! (copyto! (B, A))
3030 🦋mul! (copyto! (B, A), ws)
3131 U, V = materializeUV (B, ws)
32- F = RecursiveFactorization. lu! (B, Val (false ))
32+ F = RecursiveFactorization. lu! (B, thread)
33+ out = similar (b, M)
3334
34- U, V, F
35+ U, V, F, out
3536end
3637
3738const butterfly_workspace = 🦋workspace;
@@ -41,14 +42,12 @@ function 🦋mul_level!(A, u, v)
4142 @assert M == length (u) && N == length (v)
4243 Mh = M >>> 1
4344 Nh = N >>> 1
44- M2 = M - Mh
45- N2 = N - Nh
4645 @turbo for n in 1 : Nh
4746 for m in 1 : Mh
4847 A11 = A[m, n]
49- A21 = A[m + M2 , n]
50- A12 = A[m, n + N2 ]
51- A22 = A[m + M2 , n + N2 ]
48+ A21 = A[m + Mh , n]
49+ A12 = A[m, n + Nh ]
50+ A22 = A[m + Mh , n + Nh ]
5251
5352 T1 = A11 + A12
5453 T2 = A21 + A22
@@ -60,36 +59,16 @@ function 🦋mul_level!(A, u, v)
6059 C22 = T3 - T4
6160
6261 u1 = u[m]
63- u2 = u[m + M2 ]
62+ u2 = u[m + Mh ]
6463 v1 = v[n]
65- v2 = v[n + N2 ]
64+ v2 = v[n + Nh ]
6665
6766 A[m, n] = u1 * C11 * v1
68- A[m + M2 , n] = u2 * C21 * v1
69- A[m, n + N2 ] = u1 * C12 * v2
70- A[m + M2 , n + N2 ] = u2 * C22 * v2
67+ A[m + Mh , n] = u2 * C21 * v1
68+ A[m, n + Nh ] = u1 * C12 * v2
69+ A[m + Mh , n + Nh ] = u2 * C22 * v2
7170 end
7271 end
73- #=
74- if (N % 2 == 1) # N odd
75- n = N2
76- for m in 1:M
77- A[m, n] = u[m] * A[m, n] * v[n]
78- end
79- end
80-
81- if (M % 2 == 1) # M odd
82- m = M2
83- for n in 1:N
84- A[m, n] = u[m] * A[m, n] * v[n]
85- end
86- end
87-
88- if (M % 2 == 1) && (N % 2 == 1)
89- m = M2
90- n = N2
91- A[m, n] /= (u[m] * v[n])
92- end =#
9372end
9473
9574function 🦋mul! (A, (uv,))
@@ -98,8 +77,8 @@ function 🦋mul!(A, (uv,))
9877 Mh = M >>> 1
9978
10079 U₁ = @view (uv[1 : Mh])
101- V₁ = @view (uv[(Mh + 1 ): (2 * Mh )])
102- U₂ = @view (uv[(1 + 2 * Mh ): (M + Mh)])
80+ V₁ = @view (uv[(Mh + 1 ): (M )])
81+ U₂ = @view (uv[(1 + M ): (M + Mh)])
10382 V₂ = @view (uv[(1 + M + Mh): (2 * M)])
10483
10584 🦋mul_level! (@view (A[1 : Mh, 1 : Mh]), U₁, V₁)
0 commit comments