Skip to content

Commit a0b761a

Browse files
sparse functions simplified
1 parent c10ca98 commit a0b761a

File tree

2 files changed

+31
-58
lines changed

2 files changed

+31
-58
lines changed

src/integral.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ function integrate_term(eq, x, l; kwargs...)
189189
for j in 1:num_trials
190190
basis = isodd(j) ? basis1 : basis2
191191
r = radius #*sqrt(2)^j
192-
y, ϵ = try_integrate(Float64, eq, x, basis, r; kwargs...)
192+
y, ϵ = try_integrate(eq, x, basis, r; kwargs...)
193193

194194
ϵ = accept_solution(eq, x, y, r)
195195
if ϵ < abstol
@@ -240,29 +240,28 @@ end
240240
-------
241241
integral, error
242242
"""
243-
function try_integrate(T, eq, x, basis, radius; kwargs...)
243+
function try_integrate(eq, x, basis, radius; kwargs...)
244244
args = Dict(kwargs)
245245
use_optim = args[:use_optim]
246246
basis = basis[2:end] # remove 1 from the beginning
247247

248248
if use_optim
249-
return solve_optim(T, eq, x, basis, radius; kwargs...)
249+
return solve_optim(eq, x, basis, radius; kwargs...)
250250
else
251-
return solve_sparse(T, eq, x, basis, radius; kwargs...)
251+
return solve_sparse(eq, x, basis, radius; kwargs...)
252252
end
253253
end
254254

255255
#################################################################################
256256

257+
# integrate_basis is used for debugging and should not be called in the course of normal execution
257258
function integrate_basis(eq, x = var(eq); abstol = 1e-6, radius = 1.0, complex_plane = true)
258259
eq = expand(eq)
259260
eq = apply_div_rule(eq)
260261
eq = cache(eq)
261262
basis = generate_basis(eq, x, false)
262263
n = length(basis)
263-
A = zeros(Complex{Float64}, (n, n))
264-
X = zeros(Complex{Float64}, n)
265-
init_basis_matrix!(Float64, A, X, x, eq, basis, radius, complex_plane; abstol)
264+
A, X = init_basis_matrix(eq, x, basis, radius, complex_plane; abstol)
266265
return basis, A, X
267266
end
268267

src/sparse.jl

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,37 @@
11

2-
function solve_sparse(T, eq, x, basis, radius; kwargs...)
2+
function solve_sparse(eq, x, basis, radius; kwargs...)
33
args = Dict(kwargs)
44
abstol, opt, complex_plane, verbose = args[:abstol], args[:opt], args[:complex_plane],
55
args[:verbose]
66

7-
n = length(basis)
8-
9-
# A is an nxn matrix holding the values of the fragments at n random points
10-
A = zeros(Complex{T}, (n, n))
11-
X = zeros(Complex{T}, n)
12-
13-
init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol)
7+
A, X = init_basis_matrix(eq, x, basis, radius, complex_plane; abstol)
148

159
# find a linearly independent subset of the basis
1610
l = find_independent_subset(A; abstol)
1711
A, basis = A[l, l], basis[l]
1812

19-
y₁, ϵ₁ = sparse_fit(T, A, x, basis, opt; abstol)
13+
y₁, ϵ₁ = sparse_fit(A, basis, opt; abstol)
2014
if ϵ₁ < abstol
2115
return y₁, ϵ₁
2216
end
2317

2418
rank = sum(l)
2519

2620
if rank == 1
27-
y₂, ϵ₂ = find_singlet(T, A, basis; abstol)
21+
y₂, ϵ₂ = find_singlet(A, basis; abstol)
2822
if ϵ₂ < abstol
2923
return y₂, ϵ₂
3024
end
3125
elseif rank < 8
32-
y₃, ϵ₃ = find_dense(T, A, basis; abstol)
26+
y₃, ϵ₃ = find_dense(A, basis; abstol)
3327
if ϵ₃ < abstol
3428
return y₃, ϵ₃
3529
end
3630
end
3731

3832
# moving toward the poles
39-
modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol)
40-
y₄, ϵ₄ = sparse_fit(T, A, x, basis, opt; abstol)
33+
modify_basis_matrix!(A, X, eq, x, basis, radius; abstol)
34+
y₄, ϵ₄ = sparse_fit(A, basis, opt; abstol)
4135

4236
if ϵ₄ < abstol || ϵ₄ < ϵ₁
4337
return y₄, ϵ₄
@@ -46,24 +40,27 @@ function solve_sparse(T, eq, x, basis, radius; kwargs...)
4640
end
4741
end
4842

49-
function init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol = 1e-6)
50-
n, m = size(A)
51-
k = 1
52-
i = 1
43+
function init_basis_matrix(eq, x, basis, radius, complex_plane; abstol = 1e-6)
44+
n = length(basis)
45+
46+
# A is an nxn matrix holding the values of the fragments at n random points
47+
A = zeros(Complex{Float64}, (n, n))
48+
X = zeros(Complex{Float64}, n)
5349

5450
eq_fun = fun!(eq, x)
5551
Δbasis_fun = deriv_fun!.(basis, x)
5652

53+
k = 1
5754
l = 10*n # max attempt
5855

59-
while k <= n
56+
while k <= n && l > 0
6057
try
6158
x₀ = test_point(complex_plane, radius)
62-
X[k] = x₀ # move_toward_roots_poles(x₀, x, eq)
59+
X[k] = x₀
6360
b₀ = eq_fun(X[k])
6461

6562
if is_proper(b₀)
66-
for j in 1:m
63+
for j in 1:n
6764
A[k, j] = Δbasis_fun[j](X[k]) / b₀
6865
end
6966
if all(is_proper, A[k, :])
@@ -73,33 +70,13 @@ function init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol
7370
catch e
7471
println("Error from init_basis_matrix!: ", e)
7572
end
76-
if l == 0
77-
return
78-
end
7973
l -= 1
8074
end
75+
76+
return A, X
8177
end
8278

83-
function move_toward_roots_poles(z, x, eq; n = 1, max_r = 100.0)
84-
eq_fun = fun!(eq, x)
85-
Δeq_fun = deriv_fun!(eq, x)
86-
is_root = rand() < 0.5
87-
z₀ = z
88-
for i in 1:n
89-
dz = eq_fun(z) / Δeq_fun(z)
90-
if is_root
91-
z -= dz
92-
else
93-
z += dz
94-
end
95-
if abs(z) > max_r
96-
return z₀
97-
end
98-
end
99-
return z
100-
end
101-
102-
function modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol = 1e-6)
79+
function modify_basis_matrix!(A, X, eq, x, basis, radius; abstol = 1e-6)
10380
n, m = size(A)
10481
eq_fun = fun!(eq, x)
10582
Δeq_fun = deriv_fun!(eq, x)
@@ -122,14 +99,11 @@ DataDrivenSparse.active_set!(idx::BitMatrix, p::SoftThreshold, x::Matrix{Complex
12299
DataDrivenSparse.active_set!(idx, p, abs.(x), λ)
123100

124101

125-
function sparse_fit(T, A, x, basis, opt; abstol = 1e-6)
102+
function sparse_fit(A, basis, opt; abstol = 1e-6)
126103
n, m = size(A)
127104

128105
try
129106
b = ones((1, n))
130-
# q₀ = DataDrivenDiffEq.init(opt, A, b)
131-
# @views sparse_regression!(q₀, A, permutedims(b)', opt, maxiter = 1000)
132-
# @views sparse_regression!(q₀, A, b, opt, maxiter = 1000)
133107
solver = SparseLinearSolver(opt, options = DataDrivenCommonOptions(verbose=false, maxiters=1000))
134108
res, _... = solver(A', b)
135109
q₀ = DataDrivenSparse.coef(first(res))
@@ -140,15 +114,15 @@ function sparse_fit(T, A, x, basis, opt; abstol = 1e-6)
140114
return nothing, Inf
141115
end # eliminating complex coefficients
142116
return sum(q[i] * expr(basis[i]) for i in 1:length(basis) if q[i] != 0;
143-
init = zero(x)),
117+
init = 0),
144118
abs(ϵ)
145119
catch e
146120
println("Error from sparse_fit", e)
147121
return nothing, Inf
148122
end
149123
end
150124

151-
function find_singlet(T, A, basis; abstol)
125+
function find_singlet(A, basis; abstol)
152126
σ = vec(std(A; dims = 1))
153127
μ = vec(mean(A; dims = 1))
154128
l =.< abstol) .* (abs.(μ) .> abstol)
@@ -160,9 +134,9 @@ function find_singlet(T, A, basis; abstol)
160134
end
161135
end
162136

163-
function find_dense(T, A, basis; abstol = 1e-6)
137+
function find_dense(A, basis; abstol = 1e-6)
164138
n = size(A, 1)
165-
b = ones(T, n)
139+
b = ones(n)
166140

167141
try
168142
q = A \ b

0 commit comments

Comments
 (0)