Skip to content

Commit e0c9ef4

Browse files
committed
fix: CrossEntropy is now functional 🎉
1 parent c373425 commit e0c9ef4

File tree

9 files changed

+127
-133
lines changed

9 files changed

+127
-133
lines changed

lib/DataDrivenLux/src/DataDrivenLux.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ using DataDrivenDiffEq: AbstractBasis, AbstractDataDrivenAlgorithm,
88
ABSTRACT_CONT_PROB, ABSTRACT_DISCRETE_PROB,
99
InternalDataDrivenProblem, is_implicit, is_controlled
1010

11-
using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF
11+
using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF, SIGNATURES
1212
using CommonSolve: CommonSolve, solve!
1313
using ConcreteStructs: @concrete
1414
using Setfield: Setfield, @set!
1515

16-
# TODO: Get rid of Optim and Optimisers in favor of Optimization.jl
1716
using Optim: Optim, LBFGS
1817
using Optimisers: Optimisers, ADAM
1918

@@ -93,6 +92,8 @@ export SearchCache
9392
include("algorithms/rewards.jl")
9493
export RelativeReward, AbsoluteReward
9594

95+
include("algorithms/common.jl")
96+
9697
include("algorithms/randomsearch.jl")
9798
export RandomSearch
9899

@@ -104,4 +105,4 @@ export CrossEntropy
104105

105106
include("solve.jl")
106107

107-
end # module DataDrivenLux
108+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@kwdef @concrete struct CommonAlgOptions
2+
populationsize::Int = 100
3+
functions = (sin, exp, cos, log, +, -, /, *)
4+
arities = (1, 1, 1, 1, 2, 2, 2, 2)
5+
n_layers::Int = 1
6+
skip::Bool = true
7+
simplex <: AbstractSimplex = Softmax()
8+
loss = aicc
9+
keep <: Union{Real, Int} = 0.1
10+
use_protected::Bool = true
11+
distributed::Bool = false
12+
threaded::Bool = false
13+
rng <: AbstractRNG = Random.default_rng()
14+
optimizer = LBFGS()
15+
optim_options <: Optim.Options = Optim.Options()
16+
optimiser <: Union{Nothing, Optimisers.AbstractRule} = nothing
17+
observed <: Union{ObservedModel, Nothing} = nothing
18+
alpha::Real = 0.999f0
19+
end
Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,29 @@
1-
"""
2-
$(TYPEDEF)
1+
@concrete struct CrossEntropy <: AbstractDAGSRAlgorithm
2+
options <: CommonAlgOptions
3+
end
34

4-
Uses the crossentropy method for discrete optimization to search the space of possible solutions.
5+
"""
6+
$(SIGNATURES)
57
6-
# Fields
7-
$(FIELDS)
8+
Uses the crossentropy method for discrete optimization to search the space of possible
9+
solutions.
810
"""
9-
@kwdef struct CrossEntropy{F, A, L, O} <: AbstractDAGSRAlgorithm
10-
"The number of candidates to track"
11-
populationsize::Int = 100
12-
"The functions to include in the search"
13-
functions::F = (sin, exp, cos, log, +, -, /, *)
14-
"The arities of the functions"
15-
arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
16-
"The number of layers"
17-
n_layers::Int = 1
18-
"Include skip layers"
19-
skip::Bool = true
20-
"Evaluation function to sort the samples"
21-
loss::L = aicc
22-
"The number of candidates to keep in each iteration"
23-
keep::Union{Real, Int} = 0.1
24-
"Use protected operators"
25-
use_protected::Bool = true
26-
"Use distributed optimization and resampling"
27-
distributed::Bool = false
28-
"Use threaded optimization and resampling - not implemented right now."
29-
threaded::Bool = false
30-
"Random seed"
31-
rng::AbstractRNG = Random.default_rng()
32-
"Optim optimiser"
33-
optimizer::O = LBFGS()
34-
"Optim options"
35-
optim_options::Optim.Options = Optim.Options()
36-
"Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
37-
observed::Union{ObservedModel, Nothing} = nothing
38-
"Field for possible optimiser - no use for CrossEntropy"
39-
optimiser::Nothing = nothing
40-
"Update parameter for smoothness"
41-
alpha::Real = 0.999f0
11+
function CrossEntropy(; populationsize = 100, functions = (sin, exp, cos, log, +, -, /, *),
12+
arities = (1, 1, 1, 1, 2, 2, 2, 2), n_layers = 1, skip = true, loss = aicc,
13+
keep = 0.1, use_protected = true, distributed = false, threaded = false,
14+
rng = Random.default_rng(), optimizer = LBFGS(), optim_options = Optim.Options(),
15+
observed = nothing, alpha = 0.999f0)
16+
return CrossEntropy(CommonAlgOptions(;
17+
populationsize, functions, arities, n_layers, skip, simplex = DirectSimplex(), loss,
18+
keep, use_protected, distributed, threaded, rng, optimizer,
19+
optim_options, optimiser = nothing, observed, alpha))
4220
end
4321

44-
Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy")
22+
Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy()")
4523
Base.summary(io::IO, x::CrossEntropy) = print(io, x)
4624

4725
function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
48-
(; n_layers, arities, functions, use_protected, skip) = x
49-
50-
# We enforce the direct simplex here!
51-
simplex = DirectSimplex()
26+
(; n_layers, arities, functions, use_protected, skip) = x.options
5227

5328
# Get the parameter mapping
5429
variable_mask = map(enumerate(equations(basis))) do (i, eq)
@@ -63,15 +38,14 @@ function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
6338
end
6439

6540
return LayeredDAG(length(basis), size(dataset.y, 1), n_layers, arities, functions;
66-
skip = skip, input_functions = variable_mask, simplex = simplex)
41+
skip, input_functions = variable_mask, x.options.simplex)
6742
end
6843

6944
function update_parameters!(cache::SearchCache{<:CrossEntropy})
70-
(; candidates, keeps, p, alg) = cache
71-
(; alpha) = alg
72-
= mean(map(candidates[keeps]) do candidate
73-
return ComponentVector(get_configuration(candidate.model.model, p, candidate.st))
45+
= mean(map(cache.candidates[cache.keeps]) do candidate
46+
return ComponentVector(get_configuration(candidate.model.model, cache.p, candidate.st))
7447
end)
75-
cache.p .= alpha * p + (one(alpha) - alpha) .*
48+
alpha = cache.alg.options.alpha
49+
@. cache.p = alpha * cache.p + (true - alpha) *
7650
return
7751
end

lib/DataDrivenLux/src/algorithms/randomsearch.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,38 @@ symbolic regression problem.
88
$(FIELDS)
99
"""
1010
@kwdef struct RandomSearch{F, A, L, O} <: AbstractDAGSRAlgorithm
11-
"The number of candidates to track"
12-
populationsize::Int = 100
13-
"The functions to include in the search"
14-
functions::F = (sin, exp, cos, log, +, -, /, *)
15-
"The arities of the functions"
16-
arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
17-
"The number of layers"
18-
n_layers::Int = 1
19-
"Include skip layers"
20-
skip::Bool = true
21-
"Simplex mapping"
22-
simplex::AbstractSimplex = Softmax()
23-
"Evaluation function to sort the samples"
24-
loss::L = aicc
25-
"The number of candidates to keep in each iteration"
26-
keep::Union{Real, Int} = 0.1
27-
"Use protected operators"
28-
use_protected::Bool = true
29-
"Use distributed optimization and resampling"
30-
distributed::Bool = false
31-
"Use threaded optimization and resampling - not implemented right now."
32-
threaded::Bool = false
33-
"Random seed"
34-
rng::AbstractRNG = Random.default_rng()
35-
"Optim optimiser"
36-
optimizer::O = LBFGS()
37-
"Optim options"
38-
optim_options::Optim.Options = Optim.Options()
39-
"Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
40-
observed::Union{ObservedModel, Nothing} = nothing
41-
"Field for possible optimiser - no use for Randomsearch"
42-
optimiser::Nothing = nothing
11+
# "The number of candidates to track"
12+
# populationsize::Int = 100
13+
# "The functions to include in the search"
14+
# functions::F = (sin, exp, cos, log, +, -, /, *)
15+
# "The arities of the functions"
16+
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
17+
# "The number of layers"
18+
# n_layers::Int = 1
19+
# "Include skip layers"
20+
# skip::Bool = true
21+
# "Simplex mapping"
22+
# simplex::AbstractSimplex = Softmax()
23+
# "Evaluation function to sort the samples"
24+
# loss::L = aicc
25+
# "The number of candidates to keep in each iteration"
26+
# keep::Union{Real, Int} = 0.1
27+
# "Use protected operators"
28+
# use_protected::Bool = true
29+
# "Use distributed optimization and resampling"
30+
# distributed::Bool = false
31+
# "Use threaded optimization and resampling - not implemented right now."
32+
# threaded::Bool = false
33+
# "Random seed"
34+
# rng::AbstractRNG = Random.default_rng()
35+
# "Optim optimiser"
36+
# optimizer::O = LBFGS()
37+
# "Optim options"
38+
# optim_options::Optim.Options = Optim.Options()
39+
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
40+
# observed::Union{ObservedModel, Nothing} = nothing
41+
# "Field for possible optimiser - no use for Randomsearch"
42+
# optimiser::Nothing = nothing
4343
end
4444

4545
Base.print(io::IO, ::RandomSearch) = print(io, "RandomSearch")

lib/DataDrivenLux/src/algorithms/reinforce.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,40 @@ $(FIELDS)
1010
@kwdef struct Reinforce{F, A, L, O, R} <: AbstractDAGSRAlgorithm
1111
"Reward function which should convert the loss to a reward."
1212
reward::R = RelativeReward(false)
13-
"The number of candidates to track"
14-
populationsize::Int = 100
15-
"The functions to include in the search"
16-
functions::F = (sin, exp, cos, log, +, -, /, *)
17-
"The arities of the functions"
18-
arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
19-
"The number of layers"
20-
n_layers::Int = 1
21-
"Include skip layers"
22-
skip::Bool = true
23-
"Simplex mapping"
24-
simplex::AbstractSimplex = Softmax()
25-
"Evaluation function to sort the samples"
26-
loss::L = aicc
27-
"The number of candidates to keep in each iteration"
28-
keep::Union{Real, Int} = 0.1
29-
"Use protected operators"
30-
use_protected::Bool = true
31-
"Use distributed optimization and resampling"
32-
distributed::Bool = false
33-
"Use threaded optimization and resampling - not implemented right now."
34-
threaded::Bool = false
35-
"Random seed"
36-
rng::AbstractRNG = Random.default_rng()
37-
"Optim optimiser"
38-
optimizer::O = LBFGS()
39-
"Optim options"
40-
optim_options::Optim.Options = Optim.Options()
41-
"Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
42-
observed::Union{ObservedModel, Nothing} = nothing
43-
"AD Backend"
44-
ad_backend::AD.AbstractBackend = AD.ForwardDiffBackend()
45-
"Optimiser"
46-
optimiser::Optimisers.AbstractRule = ADAM()
13+
# "The number of candidates to track"
14+
# populationsize::Int = 100
15+
# "The functions to include in the search"
16+
# functions::F = (sin, exp, cos, log, +, -, /, *)
17+
# "The arities of the functions"
18+
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
19+
# "The number of layers"
20+
# n_layers::Int = 1
21+
# "Include skip layers"
22+
# skip::Bool = true
23+
# "Simplex mapping"
24+
# simplex::AbstractSimplex = Softmax()
25+
# "Evaluation function to sort the samples"
26+
# loss::L = aicc
27+
# "The number of candidates to keep in each iteration"
28+
# keep::Union{Real, Int} = 0.1
29+
# "Use protected operators"
30+
# use_protected::Bool = true
31+
# "Use distributed optimization and resampling"
32+
# distributed::Bool = false
33+
# "Use threaded optimization and resampling - not implemented right now."
34+
# threaded::Bool = false
35+
# "Random seed"
36+
# rng::AbstractRNG = Random.default_rng()
37+
# "Optim optimiser"
38+
# optimizer::O = LBFGS()
39+
# "Optim options"
40+
# optim_options::Optim.Options = Optim.Options()
41+
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
42+
# observed::Union{ObservedModel, Nothing} = nothing
43+
# "AD Backend"
44+
# ad_backend::AD.AbstractBackend = AD.ForwardDiffBackend()
45+
# "Optimiser"
46+
# optimiser::Optimisers.AbstractRule = ADAM()
4747
end
4848

4949
Base.print(io::IO, ::Reinforce) = print(io, "Reinforce")

lib/DataDrivenLux/src/algorithms/rewards.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ struct AbsoluteReward{risk} <: AbstractRewardScale{risk} end
2525

2626
AbsoluteReward(risk_seeking = true) = AbsoluteReward{risk_seeking}()
2727

28-
function (::AbsoluteReward)(losses::Vector{T}) where {T <: Number}
29-
return exp.(-losses)
30-
end
28+
(::AbsoluteReward)(losses::Vector{T}) where {T <: Number} = exp.(-losses)
3129

3230
function (::AbsoluteReward{true})(losses::Vector{T}) where {T <: Number}
3331
r = exp.(-losses)

lib/DataDrivenLux/src/caches/cache.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
function init_cache(x::X where {X <: AbstractDAGSRAlgorithm},
3434
basis::Basis, problem::DataDrivenProblem; kwargs...)
35-
(; rng, keep, observed, populationsize, optimizer, optim_options, optimiser, loss) = x
35+
(; rng, keep, observed, populationsize, optimizer, optim_options, optimiser, loss) = x.options
3636
# Derive the model
3737
dataset = Dataset(problem)
3838
TData = eltype(dataset)
@@ -75,9 +75,9 @@ function init_cache(x::X where {X <: AbstractDAGSRAlgorithm},
7575
end
7676

7777
# Distributed always goes first here
78-
if x.distributed
78+
if x.options.distributed
7979
ptype = __PROCESSUSE(3)
80-
elseif x.threaded
80+
elseif x.options.threaded
8181
ptype = __PROCESSUSE(2)
8282
else
8383
ptype = __PROCESSUSE(1)
@@ -94,7 +94,7 @@ function init_cache(x::X where {X <: AbstractDAGSRAlgorithm},
9494
end
9595

9696
function update_cache!(cache::SearchCache)
97-
(; keep, loss, optimizer, optim_options) = cache.alg
97+
(; keep, loss) = cache.alg.options
9898

9999
# Update the parameters based on the current results
100100
update_parameters!(cache)
@@ -109,6 +109,7 @@ function update_cache!(cache::SearchCache)
109109
cache.keeps[1:keep] .= true
110110
else
111111
losses = map(loss, cache.candidates)
112+
@. losses = ifelse(isnan(losses), Inf, losses)
112113
# TODO Maybe weight by age or loss here
113114
sortperm!(cache.sorting, cache.candidates, by = loss)
114115
permute!(cache.candidates, cache.sorting)
@@ -123,7 +124,7 @@ end
123124

124125
# Serial
125126
function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(1)}, p = cache.p)
126-
(; optimizer, optim_options) = cache.alg
127+
(; optimizer, optim_options) = cache.alg.options
127128
map(enumerate(cache.candidates)) do (i, candidate)
128129
if cache.keeps[i]
129130
cache.ages[i] += 1
@@ -140,7 +141,7 @@ end
140141

141142
# Threaded
142143
function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(2)}, p = cache.p)
143-
(; optimizer, optim_options) = cache.alg
144+
(; optimizer, optim_options) = cache.alg.options
144145
# Update all
145146
Threads.@threads for i in 1:length(cache.keeps)
146147
if cache.keeps[i]
@@ -156,7 +157,7 @@ end
156157

157158
# Distributed
158159
function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(3)}, p = cache.p)
159-
(; optimizer, optim_options) = cache.alg
160+
(; optimizer, optim_options) = cache.alg.options
160161

161162
successes = pmap(1:length(cache.keeps)) do i
162163
if cache.keeps[i]

lib/DataDrivenLux/src/solve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ function DataDrivenDiffEq.get_fit_targets(::A, prob::AbstractDataDrivenProblem,
33
return prob.X, DataDrivenDiffEq.get_implicit_data(prob)
44
end
55

6-
struct DataDrivenLuxResult <: DataDrivenDiffEq.AbstractDataDrivenResult
7-
candidate::Candidate
8-
retcode::DDReturnCode
6+
@concrete struct DataDrivenLuxResult <: DataDrivenDiffEq.AbstractDataDrivenResult
7+
candidate <: Candidate
8+
retcode <: DDReturnCode
99
end
1010

1111
function CommonSolve.solve!(prob::InternalDataDrivenProblem{A}) where {A <:
@@ -19,7 +19,7 @@ function CommonSolve.solve!(prob::InternalDataDrivenProblem{A}) where {A <:
1919
_showvalues = let cache = cache
2020
(iter) -> begin
2121
shows = min(5, sum(cache.keeps))
22-
losses = map(alg.loss, cache.candidates[cache.keeps])
22+
losses = map(alg.options.loss, cache.candidates[cache.keeps])
2323
min_, max_ = extrema(losses)
2424
[(:Iterations, iter),
2525
(:RSS, map(StatsBase.rss, cache.candidates[cache.keeps][1:shows])),
@@ -43,7 +43,7 @@ function CommonSolve.solve!(prob::InternalDataDrivenProblem{A}) where {A <:
4343
end
4444

4545
# Create the optimal basis
46-
sort!(cache.candidates, by = alg.loss)
46+
sort!(cache.candidates, by = alg.options.loss)
4747
best_cache = first(cache.candidates)
4848

4949
new_basis = convert_to_basis(best_cache, cache.p, options)

0 commit comments

Comments
 (0)