Skip to content

Commit c96d33b

Browse files
Merge pull request #46 from shahriariravanian/main
rules updated and Optim.jl based minimizer implemented (experimental)
2 parents 6133194 + c218865 commit c96d33b

File tree

8 files changed

+304
-204
lines changed

8 files changed

+304
-204
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "SymbolicNumericIntegration"
22
uuid = "78aadeae-fbc0-11eb-17b6-c7ec0477ba9e"
33
authors = ["Shahriar Iravanian <siravan@svtsim.com>"]
4-
version = "1.0.1"
4+
version = "1.0.2"
55

66
[deps]
77
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1213
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -27,4 +28,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2728

2829
[targets]
2930
test = ["PyCall", "SymPy", "Test"]
30-

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ julia> integrate(exp(x^2))
8484
- `opt` (default `STLSQ(exp.(-10:1:0))`): the optimizer passed to `sparse_regression!`.
8585
- `max_basis` (default `110`): the maximum number of expression in the basis.
8686
- `complex_plane` (default `true`): random test points are generated on the complex plane (only over the real axis if `complex_plane` is `false`).
87+
- `use_optim` (default `false`): use Optim.jl `minimize` function instead of the STLSQ algorithm (**experimental**)
8788

8889
## Testing
8990

src/SymbolicNumericIntegration.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ include("candidates.jl")
1717
include("homotopy.jl")
1818

1919
include("numeric_utils.jl")
20+
include("sparse.jl")
21+
include("optim.jl")
2022
include("integral.jl")
2123

2224
export integrate, generate_basis

src/cache.jl

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const DEBUG_CACHE = true
2-
31
mutable struct ExprCache
42
eq::Any# the primary expression
53
f::Any# compiled eq
@@ -13,46 +11,37 @@ cache(eq) = ExprCache(eq, nothing, nothing, nothing)
1311
expr(c::ExprCache) = c.eq
1412
expr(c) = c
1513

16-
function deriv!(c::ExprCache, x)
14+
function deriv!(c::ExprCache, xs...)
1715
if c.δeq == nothing
18-
c.δeq = expand_derivatives(Differential(x)(expr(c)))
16+
c.δeq = expand_derivatives(Differential(xs[1])(expr(c)))
1917
end
2018
return c.δeq
2119
end
2220

23-
function deriv!(c, x)
24-
if DEBUG_CACHE
25-
error("ExprCache object expected")
26-
end
27-
return expand_derivatives(Differential(x)(c))
21+
function deriv!(c, xs...)
22+
return expand_derivatives(Differential(xs[1])(c))
2823
end
2924

30-
function fun!(c::ExprCache, x)
25+
function fun!(c::ExprCache, xs...)
3126
if c.f == nothing
32-
c.f = build_function(expr(c), x; expression = false)
27+
c.f = build_function(expr(c), xs...; expression = false)
3328
end
3429
return c.f
3530
end
3631

37-
function fun!(c, x)
38-
if DEBUG_CACHE
39-
error("ExprCache object expected")
40-
end
41-
return build_function(c, x; expression = false)
32+
function fun!(c, xs...)
33+
return build_function(c, xs...; expression = false)
4234
end
4335

44-
function deriv_fun!(c::ExprCache, x)
36+
function deriv_fun!(c::ExprCache, xs...)
4537
if c.δf == nothing
46-
c.δf = build_function(deriv!(c, x), x; expression = false)
38+
c.δf = build_function(deriv!(c, xs...), xs...; expression = false)
4739
end
4840
return c.δf
4941
end
5042

51-
function deriv_fun!(c, x)
52-
if DEBUG_CACHE
53-
error("ExprCache object expected")
54-
end
55-
return build_function(deriv!(c, x), x; expression = false)
43+
function deriv_fun!(c, xs...)
44+
return build_function(deriv!(c, xs), xs...; expression = false)
5645
end
5746

5847
Base.show(io::IO, c::ExprCache) = show(expr(c))

src/homotopy.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ function generate_homotopy(eq, x)
7878
S += expand((1 + h₁) * (1 + h₂))
7979
end
8080

81+
S = simplify(S)
82+
8183
unique([one(x); [equivalent(t, x) for t in terms(S)]])
8284
end
8385

@@ -88,70 +90,69 @@ function ∂(x)
8890
return isequal(d, 0) ? 1 : d
8991
end
9092

91-
partial_int_rules = [@rule 𝛷(sin(~x)) => (cos(~x), (~x))
93+
partial_int_rules = [
94+
# trigonometric functions
95+
@rule 𝛷(sin(~x)) => (cos(~x), (~x))
9296
@rule 𝛷(cos(~x)) => (sin(~x), (~x))
9397
@rule 𝛷(tan(~x)) => (log(cos(~x)), (~x))
9498
@rule 𝛷(csc(~x)) => (log(csc(~x) + cot(~x)), (~x))
9599
@rule 𝛷(sec(~x)) => (log(sec(~x) + tan(~x)), (~x))
96100
@rule 𝛷(cot(~x)) => (log(sin(~x)), (~x))
101+
# hyperbolic functions
97102
@rule 𝛷(sinh(~x)) => (cosh(~x), (~x))
98103
@rule 𝛷(cosh(~x)) => (sinh(~x), (~x))
99104
@rule 𝛷(tanh(~x)) => (log(cosh(~x)), (~x))
100105
@rule 𝛷(csch(~x)) => (log(tanh(~x / 2)), (~x))
101106
@rule 𝛷(sech(~x)) => (atan(sinh(~x)), (~x))
102107
@rule 𝛷(coth(~x)) => (log(sinh(~x)), (~x))
108+
# 1/trigonometric functions
103109
@rule 𝛷(^(sin(~x), -1)) => (log(csc(~x) + cot(~x)), (~x))
104110
@rule 𝛷(^(cos(~x), -1)) => (log(sec(~x) + tan(~x)), (~x))
105111
@rule 𝛷(^(tan(~x), -1)) => (log(sin(~x)), (~x))
106112
@rule 𝛷(^(csc(~x), -1)) => (cos(~x), (~x))
107113
@rule 𝛷(^(sec(~x), -1)) => (sin(~x), (~x))
108114
@rule 𝛷(^(cot(~x), -1)) => (log(cos(~x)), (~x))
115+
# 1/hyperbolic functions
109116
@rule 𝛷(^(sinh(~x), -1)) => (log(tanh(~x / 2)), (~x))
110117
@rule 𝛷(^(cosh(~x), -1)) => (atan(sinh(~x)), (~x))
111118
@rule 𝛷(^(tanh(~x), -1)) => (log(sinh(~x)), (~x))
112119
@rule 𝛷(^(csch(~x), -1)) => (cosh(~x), (~x))
113120
@rule 𝛷(^(sech(~x), -1)) => (sinh(~x), (~x))
114121
@rule 𝛷(^(coth(~x), -1)) => (log(cosh(~x)), (~x))
115-
116-
# @rule 𝛷(^(sin(~x), ~k::is_neg)) => 𝛷(^(csc(~x), -~k))
117-
# @rule 𝛷(^(cos(~x), ~k::is_neg)) => 𝛷(^(sec(~x), -~k))
118-
# @rule 𝛷(^(tan(~x), ~k::is_neg)) => 𝛷(^(cot(~x), -~k))
119-
# @rule 𝛷(^(csc(~x), ~k::is_neg)) => 𝛷(^(sin(~x), -~k))
120-
# @rule 𝛷(^(sec(~x), ~k::is_neg)) => 𝛷(^(cos(~x), -~k))
121-
# @rule 𝛷(^(cot(~x), ~k::is_neg)) => 𝛷(^(tan(~x), -~k))
122-
# @rule 𝛷(^(sinh(~x), ~k::is_neg)) => 𝛷(^(csch(~x), -~k))
123-
# @rule 𝛷(^(cosh(~x), ~k::is_neg)) => 𝛷(^(sech(~x), -~k))
124-
# @rule 𝛷(^(tanh(~x), ~k::is_neg)) => 𝛷(^(coth(~x), -~k))
125-
# @rule 𝛷(^(csch(~x), ~k::is_neg)) => 𝛷(^(sinh(~x), -~k))
126-
# @rule 𝛷(^(sech(~x), ~k::is_neg)) => 𝛷(^(cosh(~x), -~k))
127-
# @rule 𝛷(^(coth(~x), ~k::is_neg)) => 𝛷(^(tanh(~x), -~k))
128-
122+
# inverse trigonometric functions
129123
@rule 𝛷(asin(~x)) => (~x * asin(~x) + sqrt(1 - ~x * ~x), (~x))
130124
@rule 𝛷(acos(~x)) => (~x * acos(~x) + sqrt(1 - ~x * ~x), (~x))
131125
@rule 𝛷(atan(~x)) => (~x * atan(~x) + log(~x * ~x + 1), (~x))
132-
@rule 𝛷(acsc(~x)) => (~x * acsc(~x) + acosh(~x), (~x)) # needs an abs inside acosh
133-
@rule 𝛷(asec(~x)) => (~x * asec(~x) + acosh(~x), (~x)) # needs an abs inside acosh
126+
@rule 𝛷(acsc(~x)) => (~x * acsc(~x) + atanh(1 - ^(~x, -2)), (~x))
127+
@rule 𝛷(asec(~x)) => (~x * asec(~x) + acosh(~x), (~x))
134128
@rule 𝛷(acot(~x)) => (~x * acot(~x) + log(~x * ~x + 1), (~x))
129+
# inverse hyperbolic functions
135130
@rule 𝛷(asinh(~x)) => (~x * asinh(~x) + sqrt(~x * ~x + 1), (~x))
136131
@rule 𝛷(acosh(~x)) => (~x * acosh(~x) + sqrt(~x * ~x - 1), (~x))
137132
@rule 𝛷(atanh(~x)) => (~x * atanh(~x) + log(~x + 1), (~x))
138133
@rule 𝛷(acsch(~x)) => (acsch(~x), (~x))
139134
@rule 𝛷(asech(~x)) => (asech(~x), (~x))
140135
@rule 𝛷(acoth(~x)) => (~x * acot(~x) + log(~x + 1), (~x))
141-
@rule 𝛷(log(~x)) => (~x + ~x * log(~x), (~x))
136+
# logarithmic and exponential functions
137+
@rule 𝛷(log(~x)) => (~x + ~x * log(~x) +
138+
sum(candidate_pow_minus(~x, -1); init = one(~x)),
139+
(~x))
140+
@rule 𝛷(exp(~x)) => (exp(~x), (~x))
141+
@rule 𝛷(^(exp(~x), ~k::is_neg)) => (^(exp(-~x), -~k), (~x))
142+
# square-root functions
142143
@rule 𝛷(^(~x, ~k::is_abs_half)) => (sum(candidate_sqrt(~x, ~k);
143144
init = one(~x)), 1);
145+
@rule 𝛷(sqrt(~x)) => (sum(candidate_sqrt(~x, 0.5); init = one(~x)), 1);
146+
@rule 𝛷(^(sqrt(~x), -1)) => 𝛷(^(~x, -0.5))
147+
# rational functions
144148
@rule 𝛷(^(~x::is_poly, ~k::is_neg)) => (sum(candidate_pow_minus(~x,
145149
~k);
146150
init = one(~x)), 1)
147-
@rule 𝛷(sqrt(~x)) => (sum(candidate_sqrt(~x, 0.5); init = one(~x)), 1);
148-
@rule 𝛷(^(sqrt(~x), -1)) => 𝛷(^(~x, -0.5))
149151
@rule 𝛷(^(~x, -1)) => (log(~x), (~x))
150152
@rule 𝛷(^(~x, ~k::is_neg_int)) => (sum(^(~x, i) for i in (~k + 1):-1),
151153
(~x))
152154
@rule 𝛷(1 / ~x) => 𝛷(^(~x, -1))
153155
@rule 𝛷(^(~x, ~k)) => (^(~x, ~k + 1), (~x))
154-
@rule 𝛷(exp(~x)) => (exp(~x), (~x))
155156
@rule 𝛷(1) => (𝑥, 1)
156157
@rule 𝛷(~x) => ((~x + ^(~x, 2)), (~x))]
157158

src/integral.jl

Lines changed: 7 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Base.signbit(x::SymbolicUtils.Sym{Number, Nothing}) = false
2424
verbose: print a detailed report
2525
complex_plane: generate random test points on the complex plane (if false, the points will be on real axis)
2626
homotopy: use the homotopy algorithm to generate the basis
27+
use_optim: use Optim.jl `minimize` function instead of the STLSQ algorithm (**experimental**)
2728
2829
output:
2930
-------
@@ -36,7 +37,7 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 5
3637
radius = 1.0,
3738
show_basis = false, opt = STLSQ(exp.(-10:1:0)), bypass = false,
3839
symbolic = true, max_basis = 100, verbose = false, complex_plane = true,
39-
homotopy = true)
40+
homotopy = true, use_optim = false)
4041
eq = expand(eq)
4142
eq = apply_div_rule(eq)
4243

@@ -61,7 +62,7 @@ function integrate(eq, x = nothing; abstol = 1e-6, num_steps = 2, num_trials = 5
6162

6263
s, u, ϵ = integrate_sum(eq, x, l; bypass, abstol, num_trials, num_steps,
6364
radius, show_basis, opt, symbolic,
64-
max_basis, verbose, complex_plane, homotopy)
65+
max_basis, verbose, complex_plane, homotopy, use_optim)
6566
return simplify(s), u, ϵ
6667
end
6768

@@ -235,164 +236,12 @@ end
235236
"""
236237
function try_integrate(T, eq, x, basis, radius; kwargs...)
237238
args = Dict(kwargs)
238-
abstol, opt, complex_plane, verbose = args[:abstol], args[:opt], args[:complex_plane],
239-
args[:verbose]
240-
239+
use_optim = args[:use_optim]
241240
basis = basis[2:end] # remove 1 from the beginning
242-
n = length(basis)
243-
244-
# A is an nxn matrix holding the values of the fragments at n random points
245-
A = zeros(Complex{T}, (n, n))
246-
X = zeros(Complex{T}, n)
247-
248-
init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol)
249-
250-
y₁, ϵ₁ = sparse_fit(T, A, x, basis, opt; abstol)
251-
if ϵ₁ < abstol
252-
return y₁, ϵ₁
253-
end
254-
255-
y₂, ϵ₂ = find_singlet(T, A, basis; abstol)
256-
if ϵ₂ < abstol
257-
return y₂, ϵ₂
258-
end
259-
260-
if n < 8 # here, 8 is arbitrary and signifies a small basis
261-
y₃, ϵ₃ = find_dense(T, A, basis; abstol)
262-
if ϵ₃ < abstol
263-
return y₃, ϵ₃
264-
end
265-
end
266-
267-
# moving toward the poles
268-
modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol)
269-
y₄, ϵ₄ = sparse_fit(T, A, x, basis, opt; abstol)
270-
271-
if ϵ₄ < abstol || ϵ₄ < ϵ₁
272-
return y₄, ϵ₄
273-
else
274-
return y₁, ϵ₁
275-
end
276-
end
277-
278-
function init_basis_matrix!(T, A, X, x, eq, basis, radius, complex_plane; abstol = 1e-6)
279-
n = size(A, 1)
280-
k = 1
281-
i = 1
282-
283-
eq_fun = fun!(eq, x)
284-
Δbasis_fun = deriv_fun!.(basis, x)
285-
286-
while k <= n
287-
try
288-
x₀ = test_point(complex_plane, radius)
289-
X[k] = x₀ # move_toward_roots_poles(x₀, x, eq)
290-
b₀ = eq_fun(X[k])
291-
292-
if is_proper(b₀)
293-
for j in 1:n
294-
A[k, j] = Δbasis_fun[j](X[k]) / b₀
295-
end
296-
if all(is_proper, A[k, :])
297-
k += 1
298-
end
299-
end
300-
catch e
301-
println("Error from init_basis_matrix!: ", e)
302-
end
303-
end
304-
end
305-
306-
function move_toward_roots_poles(z, x, eq; n = 1, max_r = 100.0)
307-
eq_fun = fun!(eq, x)
308-
Δeq_fun = deriv_fun!(eq, x)
309-
is_root = rand() < 0.5
310-
z₀ = z
311-
for i in 1:n
312-
dz = eq_fun(z) / Δeq_fun(z)
313-
if is_root
314-
z -= dz
315-
else
316-
z += dz
317-
end
318-
if abs(z) > max_r
319-
return z₀
320-
end
321-
end
322-
return z
323-
end
324-
325-
function modify_basis_matrix!(T, A, X, x, eq, basis, radius; abstol = 1e-6)
326-
n = size(A, 1)
327-
eq_fun = fun!(eq, x)
328-
Δeq_fun = deriv_fun!(eq, x)
329-
Δbasis_fun = deriv_fun!.(basis, x)
330-
331-
for k in 1:n
332-
# One Newton iteration toward the poles
333-
# note the + sign instead of the usual - in Newton-Raphson's method. This is
334-
# because we are moving toward the poles and not zeros.
335-
336-
X[k] += eq_fun(X[k]) / Δeq_fun(X[k])
337-
b₀ = eq_fun(X[k])
338-
for j in 1:n
339-
A[k, j] = Δbasis_fun[j](X[k]) / b₀
340-
end
341-
end
342-
end
343-
344-
function sparse_fit(T, A, x, basis, opt; abstol = 1e-6)
345-
n = length(basis)
346-
# find a linearly independent subset of the basis
347-
l = find_independent_subset(A; abstol)
348-
A, basis, n = A[l, l], basis[l], sum(l)
349-
350-
try
351-
b = ones(n)
352-
# q₀ = A \ b
353-
q₀ = DataDrivenDiffEq.init(opt, A, b)
354-
@views sparse_regression!(q₀, A, permutedims(b)', opt, maxiter = 1000)
355-
ϵ = rms(A * q₀ - b)
356-
q = nice_parameter.(q₀)
357-
if sum(iscomplex.(q)) > 2
358-
return nothing, Inf
359-
end # eliminating complex coefficients
360-
return sum(q[i] * expr(basis[i]) for i in 1:length(basis) if q[i] != 0;
361-
init = zero(x)),
362-
abs(ϵ)
363-
catch e
364-
println("Error from sparse_fit", e)
365-
return nothing, Inf
366-
end
367-
end
368241

369-
function find_singlet(T, A, basis; abstol)
370-
σ = vec(std(A; dims = 1))
371-
μ = vec(mean(A; dims = 1))
372-
l =.< abstol) .* (abs.(μ) .> abstol)
373-
if sum(l) == 1
374-
k = findfirst(l)
375-
return nice_parameter(1 / μ[k]) * expr(basis[k]), σ[k]
242+
if use_optim
243+
return solve_optim(T, eq, x, basis, radius; kwargs...)
376244
else
377-
return nothing, Inf
378-
end
379-
end
380-
381-
function find_dense(T, A, basis; abstol = 1e-6)
382-
n = size(A, 1)
383-
b = ones(T, n)
384-
385-
try
386-
q = A \ b
387-
if minimum(abs.(q)) > abstol
388-
ϵ = maximum(abs.(A * q .- b))
389-
if ϵ < abstol
390-
y = sum(nice_parameter.(q) .* expr.(basis))
391-
return y, ϵ
392-
end
393-
end
394-
catch e
395-
#
245+
return solve_sparse(T, eq, x, basis, radius; kwargs...)
396246
end
397-
return nothing, Inf
398247
end

0 commit comments

Comments
 (0)