Skip to content

Commit b8b48f6

Browse files
Merge branch 'optim'
2 parents b15fb88 + 47512d7 commit b8b48f6

File tree

8 files changed

+303
-252
lines changed

8 files changed

+303
-252
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.1.0"
55

66
[deps]
77
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
8+
DataDrivenSparse = "5b588203-7d8b-4fab-a537-c31a7f73f46b"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
@@ -14,11 +15,11 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1415
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1516

1617
[compat]
17-
DataDrivenDiffEq = "0.8"
18+
DataDrivenDiffEq = "1.0"
1819
DataStructures = "0.18"
1920
Optim = "1"
20-
SymbolicUtils = "0.19"
21-
Symbolics = "4"
21+
SymbolicUtils = "1"
22+
Symbolics = "5"
2223
julia = "1.6"
2324

2425
[extras]

src/SymbolicNumericIntegration.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using SymbolicUtils: istree, operation, arguments
55
using Symbolics
66
using Symbolics: value, get_variables, expand_derivatives
77
using SymbolicUtils.Rewriters
8+
using SymbolicUtils: exprtype, BasicSymbolic
89

9-
using DataDrivenDiffEq
10+
using DataDrivenDiffEq, DataDrivenSparse
1011

1112
include("utils.jl")
1213
include("special.jl")

src/candidates.jl

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using DataStructures
22

33
# this is the main heurisctic used to find the test fragments
4-
function generate_basis(eq, x, try_kernel = true; homotopy = true)
5-
if homotopy && !try_kernel
4+
function generate_basis(eq, x, try_kernel = true)
5+
if !try_kernel
66
S = sum(generate_homotopy(expr(eq), x))
77
return cache.(unique([one(x); [equivalent(t, x) for t in terms(S)]]))
88
end
@@ -13,33 +13,35 @@ function generate_basis(eq, x, try_kernel = true; homotopy = true)
1313
for t in terms(eq)
1414
q = equivalent(t, x)
1515
f = kernel(q)
16+
p = q / f
1617

17-
C₁ = closure(f, x)
18-
p = q * inverse(f)
19-
20-
if homotopy
21-
if isdependent(p, x)
22-
C₂ = generate_homotopy(p, x)
23-
else
24-
C₂ = 1
25-
end
18+
if isdependent(p, x)
19+
C₂ = generate_homotopy(p, x)
2620
else
27-
C₂ = find_candidates(p, x)
21+
C₂ = 1
2822
end
29-
23+
24+
C₁ = closure(f, x)
25+
3026
S += sum(c₁ * c₂ for c₁ in C₁ for c₂ in C₂)
3127
end
3228
return cache.(unique([one(x); [equivalent(t, x) for t in terms(S)]]))
3329
end
3430

35-
function expand_basis(basis, x)
36-
b = sum(expr.(basis))
37-
δb = sum(deriv!.(basis, x))
31+
function expand_basis(basis, x; Kmax=1000)
32+
b = sum(expr.(basis))
33+
34+
Kb = complexity(b) # Kolmogorov complexity
35+
if Kb > Kmax
36+
return basis, false
37+
end
38+
39+
δb = sum(deriv!.(basis, x))
3840
eq = (1 + x) * (b + δb)
3941
eq = expand(eq)
4042
S = Set{Any}()
4143
enqueue_expr!(S, eq, x)
42-
return cache.([one(x); [s for s in S]])
44+
return cache.([one(x); [s for s in S]]), true
4345
end
4446

4547
function closure(eq, x; max_terms = 50)
@@ -58,22 +60,6 @@ function closure(eq, x; max_terms = 50)
5860
unique([one(x); [s for s in S]; [s * x for s in S]])
5961
end
6062

61-
function find_candidates(eq, x)
62-
eq = apply_d_rules(eq)
63-
D = Differential(x)
64-
65-
S = Set{Any}()
66-
q = Queue{Any}()
67-
enqueue_expr!(S, q, eq, x)
68-
69-
for y in q
70-
∂y = expand_derivatives(D(y))
71-
enqueue_expr!(S, ∂y, x)
72-
end
73-
74-
return unique([one(x); [s for s in S]])
75-
end
76-
7763
function candidate_pow_minus(p, k; abstol = 1e-8)
7864
if isnan(poly_deg(p))
7965
return [p^k, p^(k + 1), log(p)]
@@ -142,27 +128,31 @@ end
142128

143129
###############################################################################
144130

145-
function enqueue_expr!(S, q, eq::SymbolicUtils.Add, x)
131+
enqueue_expr!(S, q, eq, x) = enqueue_expr!!(S, q, ops(eq)..., x)
132+
133+
function enqueue_expr!!(S, q, ::Add, eq, x)
146134
for t in arguments(eq)
147135
enqueue_expr!(S, q, t, x)
148136
end
149137
end
150138

151-
function enqueue_expr!(S, q, eq, x)
139+
function enqueue_expr!!(S, q, ::Any, eq, x)
152140
y = eq / coef(eq, x)
153141
if y S && isdependent(y, x)
154142
enqueue!(q, y)
155143
push!(S, y)
156144
end
157145
end
158146

159-
function enqueue_expr!(S, eq::SymbolicUtils.Add, x)
147+
enqueue_expr!(S, eq, x) = enqueue_expr!!(S, ops(eq)..., x)
148+
149+
function enqueue_expr!!(S, ::Add, eq, x)
160150
for t in arguments(eq)
161151
enqueue_expr!(S, t, x)
162152
end
163153
end
164154

165-
function enqueue_expr!(S, eq, x)
155+
function enqueue_expr!!(S, ::Any, eq, x)
166156
y = eq / coef(eq, x)
167157
if y S && isdependent(y, x)
168158
push!(S, y)

src/homotopy.jl

Lines changed: 82 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,60 @@
11
@syms 𝑥
22
@syms u[20]
33

4-
mutable struct Transform
5-
k::Int
6-
sub::Dict
7-
end
4+
transformer(eq) = transformer(ops(eq)...)
85

9-
function next_variable!(f, eq)
10-
μ = u[f.k]
11-
f.k += 1
12-
f.sub[μ] = eq
13-
return μ
6+
function transformer(::Mul, eq)
7+
return vcat([transformer(t) for t in arguments(eq)]...)
148
end
159

16-
function transformer(eq::SymbolicUtils.Add, f)
17-
return sum(transformer(t, f) for t in arguments(eq); init = 0)
18-
end
19-
function transformer(eq::SymbolicUtils.Mul, f)
20-
return prod(transformer(t, f) for t in arguments(eq); init = 1)
21-
end
22-
function transformer(eq::SymbolicUtils.Div, f)
23-
return transformer(arguments(eq)[1], f) * transformer(arguments(eq)[2]^-1, f)
10+
function transformer(::Div, eq)
11+
a = transformer(arguments(eq)[1])
12+
b = transformer(arguments(eq)[2])
13+
b = [(1/q, k) for (q, k) in b]
14+
return [a; b]
2415
end
2516

26-
function transformer(eq::SymbolicUtils.Pow, f)
17+
function transformer(::Pow, eq)
2718
y, k = arguments(eq)
28-
29-
if is_pos_int(k)
30-
μ = next_variable!(f, y)
31-
return μ^k
32-
elseif is_neg_int(k)
33-
μ = next_variable!(f, inv(y))
34-
return μ^-k
19+
if is_number(k)
20+
r = nice_parameter(k)
21+
if denominator(r) == 1
22+
return [(y, k)]
23+
else
24+
return [(y^(1/denominator(r)), numerator(r))]
25+
end
3526
else
36-
return next_variable!(f, y^k)
37-
end
27+
return [(eq, 1)]
28+
end
3829
end
3930

40-
function transformer(eq, f)
41-
if isdependent(eq, 𝑥)
42-
return next_variable!(f, eq)
43-
else
44-
return 1
45-
end
31+
function transformer(::Any, eq)
32+
return [(eq, 1)]
4633
end
4734

4835
function transform(eq, x)
4936
eq = substitute(eq, Dict(x => 𝑥))
50-
f = Transform(1, Dict())
51-
q = transformer(eq, f)
52-
if !any(is_poly, values(f.sub))
53-
q *= next_variable!(f, 1)
54-
end
55-
return q, f.sub
37+
p = transformer(eq)
38+
p = p[isdependent.(first.(p), 𝑥)]
39+
40+
return p
41+
end
42+
43+
function rename_factors(p)
44+
n = length(p)
45+
46+
q = 1
47+
sub = Dict()
48+
ks = Int[]
49+
50+
for (i,(y,k)) in enumerate(p)
51+
μ = u[i]
52+
q *= μ ^ k
53+
sub[μ] = y
54+
push!(ks, k)
55+
end
56+
57+
return q, sub, ks
5658
end
5759

5860
##############################################################################
@@ -73,29 +75,34 @@ Symbolics.derivative(::typeof(Li), args::NTuple{1, Any}, ::Val{1}) = 1 / log(arg
7375

7476
function substitute_x(eq, x, sub)
7577
eq = substitute(eq, sub)
76-
substitute(eq, Dict(𝑥 => x))
78+
return substitute(eq, Dict(𝑥 => x))
7779
end
7880

81+
guard_zero(x) = isequal(x, 0) ? one(x) : x
82+
7983
function generate_homotopy(eq, x)
8084
eq = eq isa Num ? eq.val : eq
8185
x = x isa Num ? x.val : x
8286

83-
q, sub = transform(eq, x)
87+
p = transform(eq, x)
88+
q, sub, ks = rename_factors(p)
8489
S = 0
8590

8691
for i in 1:length(sub)
87-
μ = u[i]
88-
h₁, ∂h₁ = apply_partial_int_rules(sub[μ])
89-
h₁ = substitute(h₁, Dict(si => Si, ci => Ci, ei => Ei, li => Li))
90-
h₂ = expand_derivatives(Differential(μ)(q))
91-
92-
h₁ = substitute_x(h₁, x, sub)
93-
h₂ = substitute_x(h₂ * ∂h₁^-1, x, sub)
94-
95-
S += expand((1 + h₁) * (1 + h₂))
96-
end
97-
98-
unique([one(x); [equivalent(t, x) for t in terms(S)]])
92+
μ = u[i]
93+
h₁, ∂h₁ = apply_partial_int_rules(sub[μ])
94+
h₁ = substitute(h₁, Dict(si => Si, ci => Ci, ei => Ei, li => Li))
95+
h₁ = substitute_x(h₁, x, sub)
96+
97+
for j = 1:ks[i]
98+
h₂ = substitute_x((q / μ^j) / ∂h₁, x, sub)
99+
S += expand((1 + h₁) * guard_zero(1 + h₂))
100+
end
101+
end
102+
103+
ζ = [x^k for k=1:maximum(ks)+1]
104+
105+
unique([one(x); ζ; [equivalent(t, x) for t in terms(S)]])
99106
end
100107

101108
##############################################################################
@@ -105,6 +112,8 @@ function ∂(x)
105112
return isequal(d, 0) ? 1 : d
106113
end
107114

115+
@syms 𝛷(x)
116+
108117
partial_int_rules = [
109118
# trigonometric functions
110119
@rule 𝛷(sin(~x)) => (cos(~x) + si(~x), (~x))
@@ -121,19 +130,19 @@ partial_int_rules = [
121130
@rule 𝛷(sech(~x)) => (atan(sinh(~x)), (~x))
122131
@rule 𝛷(coth(~x)) => (log(sinh(~x)), (~x))
123132
# 1/trigonometric functions
124-
@rule 𝛷(^(sin(~x), -1)) => (log(csc(~x) + cot(~x)), (~x))
125-
@rule 𝛷(^(cos(~x), -1)) => (log(sec(~x) + tan(~x)), (~x))
126-
@rule 𝛷(^(tan(~x), -1)) => (log(sin(~x)), (~x))
127-
@rule 𝛷(^(csc(~x), -1)) => (cos(~x), (~x))
128-
@rule 𝛷(^(sec(~x), -1)) => (sin(~x), (~x))
129-
@rule 𝛷(^(cot(~x), -1)) => (log(cos(~x)), (~x))
133+
@rule 𝛷(1 / sin(~x)) => (log(csc(~x) + cot(~x)) + log(sin(~x)), (~x))
134+
@rule 𝛷(1 / cos(~x)) => (log(sec(~x) + tan(~x)) + log(cos(~x)), (~x))
135+
@rule 𝛷(1 / tan(~x)) => (log(sin(~x)) + log(tan(~x)), (~x))
136+
@rule 𝛷(1 / csc(~x)) => (cos(~x) + log(csc(~x)), (~x))
137+
@rule 𝛷(1 / sec(~x)) => (sin(~x) + log(sec(~x)), (~x))
138+
@rule 𝛷(1 / cot(~x)) => (log(cos(~x)) + log(cot(~x)), (~x))
130139
# 1/hyperbolic functions
131-
@rule 𝛷(^(sinh(~x), -1)) => (log(tanh(~x / 2)), (~x))
132-
@rule 𝛷(^(cosh(~x), -1)) => (atan(sinh(~x)), (~x))
133-
@rule 𝛷(^(tanh(~x), -1)) => (log(sinh(~x)), (~x))
134-
@rule 𝛷(^(csch(~x), -1)) => (cosh(~x), (~x))
135-
@rule 𝛷(^(sech(~x), -1)) => (sinh(~x), (~x))
136-
@rule 𝛷(^(coth(~x), -1)) => (log(cosh(~x)), (~x))
140+
@rule 𝛷(1 / sinh(~x)) => (log(tanh(~x / 2)) + log(sinh(~x)), (~x))
141+
@rule 𝛷(1 / cosh(~x)) => (atan(sinh(~x)) + log(cosh(~x)), (~x))
142+
@rule 𝛷(1 / tanh(~x)) => (log(sinh(~x)) + log(tanh(~x)), (~x))
143+
@rule 𝛷(1 / csch(~x)) => (cosh(~x) + log(csch(~x)), (~x))
144+
@rule 𝛷(1 / sech(~x)) => (sinh(~x) + log(sech(~x)), (~x))
145+
@rule 𝛷(1 / coth(~x)) => (log(cosh(~x)) + log(coth(~x)), (~x))
137146
# inverse trigonometric functions
138147
@rule 𝛷(asin(~x)) => (~x * asin(~x) + sqrt(1 - ~x * ~x), (~x))
139148
@rule 𝛷(acos(~x)) => (~x * acos(~x) + sqrt(1 - ~x * ~x), (~x))
@@ -152,23 +161,24 @@ partial_int_rules = [
152161
@rule 𝛷(log(~x)) => (~x + ~x * log(~x) +
153162
sum(candidate_pow_minus(~x, -1); init = one(~x)),
154163
(~x))
155-
@rule 𝛷(^(log(~x), -1)) => (log(log(~x)) + li(~x), (~x))
164+
@rule 𝛷(1 / log(~x)) => (log(log(~x)) + li(~x), (~x))
156165
@rule 𝛷(exp(~x)) => (exp(~x) + ei(~x), (~x))
157166
@rule 𝛷(^(exp(~x), ~k::is_neg)) => (^(exp(-~x), -~k), (~x))
158167
# square-root functions
159168
@rule 𝛷(^(~x, ~k::is_abs_half)) => (sum(candidate_sqrt(~x, ~k);
160169
init = one(~x)), 1);
161-
@rule 𝛷(sqrt(~x)) => (sum(candidate_sqrt(~x, 0.5); init = one(~x)), 1);
162-
@rule 𝛷(^(sqrt(~x), -1)) => 𝛷(^(~x, -0.5))
170+
@rule 𝛷(sqrt(~x)) => (sum(candidate_sqrt(~x, 0.5); init = one(~x)), (~x));
171+
@rule 𝛷(1 / sqrt(~x)) => (sum(candidate_sqrt(~x, -0.5); init = one(~x)), (~x));
163172
# rational functions
164-
@rule 𝛷(^(~x::is_poly, ~k::is_neg)) => (sum(candidate_pow_minus(~x,
165-
~k);
173+
@rule 𝛷(1 / ^(~x::is_poly, ~k::is_pos_int)) => (sum(candidate_pow_minus(~x, -~k);
174+
init = one(~x)), 1)
175+
@rule 𝛷(1 / ~x::is_poly) => (sum(candidate_pow_minus(~x, -1);
166176
init = one(~x)), 1)
167177
@rule 𝛷(^(~x, -1)) => (log(~x), (~x))
168178
@rule 𝛷(^(~x, ~k::is_neg_int)) => (sum(^(~x, i) for i in (~k + 1):-1),
169179
(~x))
170-
@rule 𝛷(1 / ~x) => 𝛷(^(~x, -1))
171-
@rule 𝛷(^(~x, ~k)) => (^(~x, ~k + 1), (~x))
180+
@rule 𝛷(1 / ~x) => (log(~x), (~x))
181+
@rule 𝛷(^(~x, ~k::is_pos_int)) => (sum(^(~x, i+1) for i=1:~k+1), (~x))
172182
@rule 𝛷(1) => (𝑥, 1)
173183
@rule 𝛷(~x) => ((~x + ^(~x, 2)), (~x))]
174184

0 commit comments

Comments
 (0)