Skip to content

Commit ca682a3

Browse files
optim working!
1 parent 55de24c commit ca682a3

File tree

4 files changed

+76
-4
lines changed

4 files changed

+76
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "1.0.1"
77
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1213
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -27,4 +28,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2728

2829
[targets]
2930
test = ["PyCall", "SymPy", "Test"]
30-

src/SymbolicNumericIntegration.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include("homotopy.jl")
1818

1919
include("numeric_utils.jl")
2020
include("sparse.jl")
21+
include("optim.jl")
2122
include("integral.jl")
2223

2324
export integrate, generate_basis

src/integral.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 5
3636
radius = 1.0,
3737
show_basis = false, opt = STLSQ(exp.(-10:1:0)), bypass = false,
3838
symbolic = true, max_basis = 100, verbose = false, complex_plane = true,
39-
homotopy = true)
39+
homotopy = true, use_optim=false)
4040
eq = expand(eq)
4141
eq = apply_div_rule(eq)
4242

@@ -61,7 +61,7 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 5
6161

6262
s, u, ϵ = integrate_sum(eq, x, l; bypass, abstol, num_trials, num_steps,
6363
radius, show_basis, opt, symbolic,
64-
max_basis, verbose, complex_plane, homotopy)
64+
max_basis, verbose, complex_plane, homotopy, use_optim)
6565
return simplify(s), u, ϵ
6666
end
6767

@@ -234,7 +234,14 @@ end
234234
integral, error
235235
"""
236236
function try_integrate(T, eq, x, basis, radius; kwargs...)
237+
args = Dict(kwargs)
238+
use_optim = args[:use_optim]
237239
basis = basis[2:end] # remove 1 from the beginning
238-
return solve_sparse(T, eq, x, basis, radius; kwargs...)
240+
241+
if use_optim
242+
return solve_optim(T, eq, x, basis, radius; kwargs...)
243+
else
244+
return solve_sparse(T, eq, x, basis, radius; kwargs...)
245+
end
239246
end
240247

src/optim.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Optim
2+
3+
function solve_optim(T, eq, x, basis, radius, rounds=2; kwargs...)
4+
args = Dict(kwargs)
5+
abstol, complex_plane, verbose = args[:abstol], args[:complex_plane], args[:verbose]
6+
n = length(basis)
7+
λ = 1e-9
8+
9+
A = zeros(Complex{T}, (2n, n))
10+
X = zeros(Complex{T}, 2n)
11+
init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol)
12+
# modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol)
13+
14+
l = find_independent_subset(A; abstol)
15+
# l .&= rand(length(l)) .> 0.5
16+
A, basis = A[:, l], basis[l]
17+
q, ϵ = find_minimizer(A, λ)
18+
19+
if ϵ > abstol
20+
return 0, ϵ
21+
end
22+
23+
qa = q
24+
μ = maximum(abs.(qa))
25+
26+
for ρ in exp10.(-1:-1:-8)
27+
l = abs.(qa) .> ρ * μ
28+
q, ϵ = find_minimizer(A[:, l], λ)
29+
if ϵ < abstol
30+
q = nice_parameter.(q)
31+
basis = basis[l]
32+
y = sum(q[i] * expr(basis[i]) for i=1:length(basis))
33+
return y, ϵ
34+
end
35+
end
36+
37+
return 0, ϵ
38+
end
39+
40+
41+
function find_minimizer(A, λ)
42+
n, m = size(A)
43+
B = real.(A' * A)
44+
b = real.(A' * ones(n))
45+
46+
f = function(q)
47+
l = λ * sum(abs.(q)) # L1 norm
48+
# l = λ * sqrt(sum(q .^ 2)) # L2 norm
49+
l += sum(abs2.(A * q .- 1))
50+
return l
51+
end
52+
53+
g! = function(G, q)
54+
G .= 2 * (B * q .- b) .+ λ * sign.(q)
55+
end
56+
57+
q0 = randn(m)
58+
res = Optim.optimize(f, g!, q0, LBFGS())
59+
q = Optim.minimizer(res)
60+
ϵ = Optim.minimum(res)
61+
62+
return q, ϵ
63+
end
64+

0 commit comments

Comments
 (0)