Skip to content

Commit ff47411

Browse files
optim and rules improved
1 parent e048f4f commit ff47411

File tree

7 files changed

+103
-84
lines changed

7 files changed

+103
-84
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicNumericIntegration"
22
uuid = "78aadeae-fbc0-11eb-17b6-c7ec0477ba9e"
33
authors = ["Shahriar Iravanian <siravan@svtsim.com>"]
4-
version = "1.0.1"
4+
version = "1.0.2"
55

66
[deps]
77
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ julia> integrate(exp(x^2))
8484
- `opt` (default `STLSQ(exp.(-10:1:0))`): the optimizer passed to `sparse_regression!`.
8585
- `max_basis` (default `110`): the maximum number of expression in the basis.
8686
- `complex_plane` (default `true`): random test points are generated on the complex plane (only over the real axis if `complex_plane` is `false`).
87+
- `use_optim` (default `false`): use Optim.jl `minimize` function instead of the STLSQ algorithm (**experimental**)
8788

8889
## Testing
8990

src/cache.jl

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const DEBUG_CACHE = true
2-
31
mutable struct ExprCache
42
eq::Any# the primary expression
53
f::Any# compiled eq
@@ -13,46 +11,37 @@ cache(eq) = ExprCache(eq, nothing, nothing, nothing)
1311
expr(c::ExprCache) = c.eq
1412
expr(c) = c
1513

16-
function deriv!(c::ExprCache, x)
14+
function deriv!(c::ExprCache, xs...)
1715
if c.δeq == nothing
18-
c.δeq = expand_derivatives(Differential(x)(expr(c)))
16+
c.δeq = expand_derivatives(Differential(xs[1])(expr(c)))
1917
end
2018
return c.δeq
2119
end
2220

23-
function deriv!(c, x)
24-
if DEBUG_CACHE
25-
error("ExprCache object expected")
26-
end
27-
return expand_derivatives(Differential(x)(c))
21+
function deriv!(c, xs...)
22+
return expand_derivatives(Differential(xs[1])(c))
2823
end
2924

30-
function fun!(c::ExprCache, x)
25+
function fun!(c::ExprCache, xs...)
3126
if c.f == nothing
32-
c.f = build_function(expr(c), x; expression = false)
27+
c.f = build_function(expr(c), xs...; expression = false)
3328
end
3429
return c.f
3530
end
3631

37-
function fun!(c, x)
38-
if DEBUG_CACHE
39-
error("ExprCache object expected")
40-
end
41-
return build_function(c, x; expression = false)
32+
function fun!(c, xs...)
33+
return build_function(c, xs...; expression = false)
4234
end
4335

44-
function deriv_fun!(c::ExprCache, x)
36+
function deriv_fun!(c::ExprCache, xs...)
4537
if c.δf == nothing
46-
c.δf = build_function(deriv!(c, x), x; expression = false)
38+
c.δf = build_function(deriv!(c, xs...), xs...; expression = false)
4739
end
4840
return c.δf
4941
end
5042

51-
function deriv_fun!(c, x)
52-
if DEBUG_CACHE
53-
error("ExprCache object expected")
54-
end
55-
return build_function(deriv!(c, x), x; expression = false)
43+
function deriv_fun!(c, xs...)
44+
return build_function(deriv!(c, xs), xs...; expression = false)
5645
end
5746

5847
Base.show(io::IO, c::ExprCache) = show(expr(c))

src/homotopy.jl

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ function generate_homotopy(eq, x)
7777

7878
S += expand((1 + h₁) * (1 + h₂))
7979
end
80+
81+
S = simplify(S)
8082

8183
unique([one(x); [equivalent(t, x) for t in terms(S)]])
8284
end
@@ -88,70 +90,63 @@ function ∂(x)
8890
return isequal(d, 0) ? 1 : d
8991
end
9092

91-
partial_int_rules = [@rule 𝛷(sin(~x)) => (cos(~x), (~x))
93+
partial_int_rules = [
94+
# trigonometric functions
95+
@rule 𝛷(sin(~x)) => (cos(~x), (~x))
9296
@rule 𝛷(cos(~x)) => (sin(~x), (~x))
9397
@rule 𝛷(tan(~x)) => (log(cos(~x)), (~x))
9498
@rule 𝛷(csc(~x)) => (log(csc(~x) + cot(~x)), (~x))
9599
@rule 𝛷(sec(~x)) => (log(sec(~x) + tan(~x)), (~x))
96100
@rule 𝛷(cot(~x)) => (log(sin(~x)), (~x))
101+
# hyperbolic functions
97102
@rule 𝛷(sinh(~x)) => (cosh(~x), (~x))
98103
@rule 𝛷(cosh(~x)) => (sinh(~x), (~x))
99104
@rule 𝛷(tanh(~x)) => (log(cosh(~x)), (~x))
100105
@rule 𝛷(csch(~x)) => (log(tanh(~x / 2)), (~x))
101106
@rule 𝛷(sech(~x)) => (atan(sinh(~x)), (~x))
102107
@rule 𝛷(coth(~x)) => (log(sinh(~x)), (~x))
108+
# 1/trigonometric functions
103109
@rule 𝛷(^(sin(~x), -1)) => (log(csc(~x) + cot(~x)), (~x))
104110
@rule 𝛷(^(cos(~x), -1)) => (log(sec(~x) + tan(~x)), (~x))
105111
@rule 𝛷(^(tan(~x), -1)) => (log(sin(~x)), (~x))
106112
@rule 𝛷(^(csc(~x), -1)) => (cos(~x), (~x))
107113
@rule 𝛷(^(sec(~x), -1)) => (sin(~x), (~x))
108114
@rule 𝛷(^(cot(~x), -1)) => (log(cos(~x)), (~x))
115+
# 1/hyperbolic functions
109116
@rule 𝛷(^(sinh(~x), -1)) => (log(tanh(~x / 2)), (~x))
110117
@rule 𝛷(^(cosh(~x), -1)) => (atan(sinh(~x)), (~x))
111118
@rule 𝛷(^(tanh(~x), -1)) => (log(sinh(~x)), (~x))
112119
@rule 𝛷(^(csch(~x), -1)) => (cosh(~x), (~x))
113120
@rule 𝛷(^(sech(~x), -1)) => (sinh(~x), (~x))
114121
@rule 𝛷(^(coth(~x), -1)) => (log(cosh(~x)), (~x))
115-
116-
# @rule 𝛷(^(sin(~x), ~k::is_neg)) => 𝛷(^(csc(~x), -~k))
117-
# @rule 𝛷(^(cos(~x), ~k::is_neg)) => 𝛷(^(sec(~x), -~k))
118-
# @rule 𝛷(^(tan(~x), ~k::is_neg)) => 𝛷(^(cot(~x), -~k))
119-
# @rule 𝛷(^(csc(~x), ~k::is_neg)) => 𝛷(^(sin(~x), -~k))
120-
# @rule 𝛷(^(sec(~x), ~k::is_neg)) => 𝛷(^(cos(~x), -~k))
121-
# @rule 𝛷(^(cot(~x), ~k::is_neg)) => 𝛷(^(tan(~x), -~k))
122-
# @rule 𝛷(^(sinh(~x), ~k::is_neg)) => 𝛷(^(csch(~x), -~k))
123-
# @rule 𝛷(^(cosh(~x), ~k::is_neg)) => 𝛷(^(sech(~x), -~k))
124-
# @rule 𝛷(^(tanh(~x), ~k::is_neg)) => 𝛷(^(coth(~x), -~k))
125-
# @rule 𝛷(^(csch(~x), ~k::is_neg)) => 𝛷(^(sinh(~x), -~k))
126-
# @rule 𝛷(^(sech(~x), ~k::is_neg)) => 𝛷(^(cosh(~x), -~k))
127-
# @rule 𝛷(^(coth(~x), ~k::is_neg)) => 𝛷(^(tanh(~x), -~k))
128-
122+
# inverse trigonometric functions
129123
@rule 𝛷(asin(~x)) => (~x * asin(~x) + sqrt(1 - ~x * ~x), (~x))
130124
@rule 𝛷(acos(~x)) => (~x * acos(~x) + sqrt(1 - ~x * ~x), (~x))
131125
@rule 𝛷(atan(~x)) => (~x * atan(~x) + log(~x * ~x + 1), (~x))
132-
@rule 𝛷(acsc(~x)) => (~x * acsc(~x) + acosh(~x), (~x)) # needs an abs inside acosh
133-
@rule 𝛷(asec(~x)) => (~x * asec(~x) + acosh(~x), (~x)) # needs an abs inside acosh
126+
@rule 𝛷(acsc(~x)) => (~x * acsc(~x) + atanh(1 - ^(~x,-2)), (~x))
127+
@rule 𝛷(asec(~x)) => (~x * asec(~x) + acosh(~x), (~x))
134128
@rule 𝛷(acot(~x)) => (~x * acot(~x) + log(~x * ~x + 1), (~x))
129+
# inverse hyperbolic functions
135130
@rule 𝛷(asinh(~x)) => (~x * asinh(~x) + sqrt(~x * ~x + 1), (~x))
136131
@rule 𝛷(acosh(~x)) => (~x * acosh(~x) + sqrt(~x * ~x - 1), (~x))
137132
@rule 𝛷(atanh(~x)) => (~x * atanh(~x) + log(~x + 1), (~x))
138133
@rule 𝛷(acsch(~x)) => (acsch(~x), (~x))
139134
@rule 𝛷(asech(~x)) => (asech(~x), (~x))
140135
@rule 𝛷(acoth(~x)) => (~x * acot(~x) + log(~x + 1), (~x))
141-
@rule 𝛷(log(~x)) => (~x + ~x * log(~x), (~x))
142-
@rule 𝛷(^(~x, ~k::is_abs_half)) => (sum(candidate_sqrt(~x, ~k);
143-
init = one(~x)), 1);
144-
@rule 𝛷(^(~x::is_poly, ~k::is_neg)) => (sum(candidate_pow_minus(~x,
145-
~k);
146-
init = one(~x)), 1)
136+
# logarithmic and exponential functions
137+
@rule 𝛷(log(~x)) => (~x + ~x * log(~x) + sum(candidate_pow_minus(~x, -1); init = one(~x)), (~x))
138+
@rule 𝛷(exp(~x)) => (exp(~x), (~x))
139+
@rule 𝛷(^(exp(~x), ~k::is_neg)) => (^(exp(-~x), -~k), (~x))
140+
# square-root functions
141+
@rule 𝛷(^(~x, ~k::is_abs_half)) => (sum(candidate_sqrt(~x, ~k); init = one(~x)), 1);
147142
@rule 𝛷(sqrt(~x)) => (sum(candidate_sqrt(~x, 0.5); init = one(~x)), 1);
148-
@rule 𝛷(^(sqrt(~x), -1)) => 𝛷(^(~x, -0.5))
143+
@rule 𝛷(^(sqrt(~x), -1)) => 𝛷(^(~x, -0.5))
144+
# rational functions
145+
@rule 𝛷(^(~x::is_poly, ~k::is_neg)) => (sum(candidate_pow_minus(~x, ~k); init = one(~x)), 1)
149146
@rule 𝛷(^(~x, -1)) => (log(~x), (~x))
150-
@rule 𝛷(^(~x, ~k::is_neg_int)) => (sum(^(~x, i) for i in (~k + 1):-1),
151-
(~x))
147+
@rule 𝛷(^(~x, ~k::is_neg_int)) => (sum(^(~x, i) for i in (~k + 1):-1), (~x))
152148
@rule 𝛷(1 / ~x) => 𝛷(^(~x, -1))
153-
@rule 𝛷(^(~x, ~k)) => (^(~x, ~k + 1), (~x))
154-
@rule 𝛷(exp(~x)) => (exp(~x), (~x))
149+
@rule 𝛷(^(~x, ~k)) => (^(~x, ~k + 1), (~x))
155150
@rule 𝛷(1) => (𝑥, 1)
156151
@rule 𝛷(~x) => ((~x + ^(~x, 2)), (~x))]
157152

src/integral.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Base.signbit(x::SymbolicUtils.Sym{Number, Nothing}) = false
2424
verbose: print a detailed report
2525
complex_plane: generate random test points on the complex plane (if false, the points will be on real axis)
2626
homotopy: use the homotopy algorithm to generate the basis
27+
use_optim: use Optim.jl `minimize` function instead of the STLSQ algorithm (**experimental**)
2728
2829
output:
2930
-------

src/optim.jl

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Optim
33
function solve_optim(T, eq, x, basis, radius, rounds=2; kwargs...)
44
args = Dict(kwargs)
55
abstol, complex_plane, verbose = args[:abstol], args[:complex_plane], args[:verbose]
6+
67
n = length(basis)
78
λ = 1e-9
89

@@ -12,49 +13,80 @@ function solve_optim(T, eq, x, basis, radius, rounds=2; kwargs...)
1213
# modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol)
1314

1415
l = find_independent_subset(A; abstol)
15-
# l .&= rand(length(l)) .> 0.5
16-
A, basis = A[:, l], basis[l]
17-
q, ϵ = find_minimizer(A, λ)
16+
A, basis, n = A[:, l], basis[l], sum(l)
17+
p = rank_basis(A, basis)
18+
19+
# q, ϵ = find_minimizer(A, λ)
1820

19-
if ϵ > abstol
20-
return 0, ϵ
21-
end
21+
# if ϵ > abstol
22+
# return 0, ϵ
23+
# end
2224

23-
qa = q
24-
μ = maximum(abs.(qa))
25+
# qa = q
26+
# μ = maximum(abs.(qa))
2527

26-
for ρ in exp10.(-1:-1:-8)
27-
l = abs.(qa) .> ρ * μ
28-
q, ϵ = find_minimizer(A[:, l], λ)
29-
if ϵ < abstol
30-
q = nice_parameter.(q)
31-
basis = basis[l]
32-
y = sum(q[i] * expr(basis[i]) for i=1:length(basis))
33-
return y, ϵ
28+
# for ρ in exp10.(-1:-1:-8)
29+
# l = abs.(qa) .> ρ * μ
30+
31+
qm = zeros(n)
32+
ϵm = 1e6
33+
lm = qm .> 0
34+
35+
for i = 1:n
36+
l = (1:n .<= i)
37+
q, ϵ = find_minimizer(A[:, l], λ)
38+
nz = sum(abs.(q) .> abstol)
39+
40+
# println(i, '\t', ϵ, '\t', nz)
41+
42+
if ϵ*nz < ϵm
43+
ϵm = ϵ*nz
44+
qm = q
45+
lm = l
3446
end
35-
end
47+
end
3648

37-
return 0, ϵ
49+
if ϵm < abstol
50+
return reconstruct(qm, basis[lm]), ϵm
51+
else
52+
return 0, ϵm
53+
end
54+
end
55+
56+
function reconstruct(q, basis)
57+
q = nice_parameter.(q)
58+
y = sum(q[i] * expr(basis[i]) for i=1:length(basis))
59+
return y
60+
end
61+
62+
# returns a vector of indices of basis elems from the most important to the least
63+
function rank_basis(A, basis)
64+
n, m = size(A)
65+
q = A \ ones(n)
66+
w = [abs(q[i])*norm(A[:,i]) for i=1:m]
67+
p = sortperm(w; rev=true)
68+
return p
3869
end
3970

71+
clamp_down(x, η) = abs(x) < η ? 0 : x
4072

4173
function find_minimizer(A, λ)
4274
n, m = size(A)
43-
B = real.(A' * A)
44-
b = real.(A' * ones(n))
4575

4676
f = function(q)
47-
l = λ * sum(abs.(q)) # L1 norm
48-
# l = λ * sqrt(sum(q .^ 2)) # L2 norm
49-
l += sum(abs2.(A * q .- 1))
77+
q .= clamp_down.(q, maximum(abs.(q))*1e-6)
78+
t = A*q .- 1
79+
l = sum(t' * t)
80+
l += λ * norm(q, 1)
5081
return l
5182
end
5283

53-
g! = function(G, q)
54-
G .= 2 * (B * q .- b) .+ λ * sign.(q)
84+
g! = function(G, q)
85+
t = A*q .- 1
86+
G .= 2 * real(A' * t)
5587
end
5688

57-
q0 = randn(m)
89+
q0 = A \ ones(n) # randn(m)
5890
res = Optim.optimize(f, g!, q0, LBFGS())
5991
q = Optim.minimizer(res)
6092
ϵ = Optim.minimum(res)

src/sparse.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ function solve_sparse(T, eq, x, basis, radius; kwargs...)
44
args = Dict(kwargs)
55
abstol, opt, complex_plane, verbose = args[:abstol], args[:opt], args[:complex_plane],
66
args[:verbose]
7+
78
n = length(basis)
89

910
# A is an nxn matrix holding the values of the fragments at n random points
@@ -108,17 +109,17 @@ end
108109

109110

110111
function sparse_fit(T, A, x, basis, opt; abstol = 1e-6)
111-
n = length(basis)
112112
# find a linearly independent subset of the basis
113113
l = find_independent_subset(A; abstol)
114-
A, basis, n = A[l, l], basis[l], sum(l)
114+
A, basis = A[l, l], basis[l]
115+
n, m = size(A)
115116

116117
try
117-
b = ones(n)
118-
# q₀ = A \ b
118+
b = ones((n,1))
119119
q₀ = DataDrivenDiffEq.init(opt, A, b)
120-
@views sparse_regression!(q₀, A, permutedims(b)', opt, maxiter = 1000)
121-
ϵ = rms(A * q₀ - b)
120+
# @views sparse_regression!(q₀, A, permutedims(b)', opt, maxiter = 1000)
121+
@views sparse_regression!(q₀, A, b, opt, maxiter = 1000)
122+
ϵ = rms(A * q₀ .- b)
122123
q = nice_parameter.(q₀)
123124
if sum(iscomplex.(q)) > 2
124125
return nothing, Inf

0 commit comments

Comments
 (0)