Skip to content

Commit e3c13c7

Browse files
committed
fix: other algorithms are now functional 🎉
1 parent e0c9ef4 commit e3c13c7

File tree

7 files changed

+49
-98
lines changed

7 files changed

+49
-98
lines changed

lib/DataDrivenLux/src/DataDrivenLux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using ConcreteStructs: @concrete
1414
using Setfield: Setfield, @set!
1515

1616
using Optim: Optim, LBFGS
17-
using Optimisers: Optimisers, ADAM
17+
using Optimisers: Optimisers, Adam
1818

1919
using Lux: Lux, logsoftmax, softmax!
2020
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,26 @@
1-
"""
2-
$(TYPEDEF)
1+
@concrete struct RandomSearch <: AbstractDAGSRAlgorithm
2+
options <: CommonAlgOptions
3+
end
34

4-
Performs a random search over the space of possible solutions to the
5-
symbolic regression problem.
5+
"""
6+
$(SIGNATURES)
67
7-
# Fields
8-
$(FIELDS)
8+
Performs a random search over the space of possible solutions to the symbolic regression
9+
problem.
910
"""
10-
@kwdef struct RandomSearch{F, A, L, O} <: AbstractDAGSRAlgorithm
11-
# "The number of candidates to track"
12-
# populationsize::Int = 100
13-
# "The functions to include in the search"
14-
# functions::F = (sin, exp, cos, log, +, -, /, *)
15-
# "The arities of the functions"
16-
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
17-
# "The number of layers"
18-
# n_layers::Int = 1
19-
# "Include skip layers"
20-
# skip::Bool = true
21-
# "Simplex mapping"
22-
# simplex::AbstractSimplex = Softmax()
23-
# "Evaluation function to sort the samples"
24-
# loss::L = aicc
25-
# "The number of candidates to keep in each iteration"
26-
# keep::Union{Real, Int} = 0.1
27-
# "Use protected operators"
28-
# use_protected::Bool = true
29-
# "Use distributed optimization and resampling"
30-
# distributed::Bool = false
31-
# "Use threaded optimization and resampling - not implemented right now."
32-
# threaded::Bool = false
33-
# "Random seed"
34-
# rng::AbstractRNG = Random.default_rng()
35-
# "Optim optimiser"
36-
# optimizer::O = LBFGS()
37-
# "Optim options"
38-
# optim_options::Optim.Options = Optim.Options()
39-
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
40-
# observed::Union{ObservedModel, Nothing} = nothing
41-
# "Field for possible optimiser - no use for Randomsearch"
42-
# optimiser::Nothing = nothing
11+
function RandomSearch(; populationsize = 100, functions = (sin, exp, cos, log, +, -, /, *),
12+
arities = (1, 1, 1, 1, 2, 2, 2, 2), n_layers = 1, skip = true, loss = aicc,
13+
keep = 0.1, use_protected = true, distributed = false, threaded = false,
14+
rng = Random.default_rng(), optimizer = LBFGS(), optim_options = Optim.Options(),
15+
observed = nothing, alpha = 0.999f0)
16+
return RandomSearch(CommonAlgOptions(;
17+
populationsize, functions, arities, n_layers, skip, simplex = Softmax(), loss,
18+
keep, use_protected, distributed, threaded, rng, optimizer,
19+
optim_options, optimiser = nothing, observed, alpha))
4320
end
4421

4522
Base.print(io::IO, ::RandomSearch) = print(io, "RandomSearch")
4623
Base.summary(io::IO, x::RandomSearch) = print(io, x)
4724

4825
# Randomsearch does not do anything
49-
function update_parameters!(::SearchCache)
50-
return
51-
end
26+
update_parameters!(::SearchCache) = nothing

lib/DataDrivenLux/src/algorithms/reinforce.jl

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,35 @@
1+
@concrete struct Reinforce <: AbstractDAGSRAlgorithm
2+
reward
3+
ad_backend <: AD.AbstractBackend
4+
options <: CommonAlgOptions
5+
end
6+
17
"""
2-
$(TYPEDEF)
8+
$(SIGNATURES)
39
4-
Uses the REINFORCE algorithm to search over the space of possible solutions to the
10+
Uses the REINFORCE algorithm to search over the space of possible solutions to the
511
symbolic regression problem.
6-
7-
# Fields
8-
$(FIELDS)
912
"""
10-
@kwdef struct Reinforce{F, A, L, O, R} <: AbstractDAGSRAlgorithm
11-
"Reward function which should convert the loss to a reward."
12-
reward::R = RelativeReward(false)
13-
# "The number of candidates to track"
14-
# populationsize::Int = 100
15-
# "The functions to include in the search"
16-
# functions::F = (sin, exp, cos, log, +, -, /, *)
17-
# "The arities of the functions"
18-
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
19-
# "The number of layers"
20-
# n_layers::Int = 1
21-
# "Include skip layers"
22-
# skip::Bool = true
23-
# "Simplex mapping"
24-
# simplex::AbstractSimplex = Softmax()
25-
# "Evaluation function to sort the samples"
26-
# loss::L = aicc
27-
# "The number of candidates to keep in each iteration"
28-
# keep::Union{Real, Int} = 0.1
29-
# "Use protected operators"
30-
# use_protected::Bool = true
31-
# "Use distributed optimization and resampling"
32-
# distributed::Bool = false
33-
# "Use threaded optimization and resampling - not implemented right now."
34-
# threaded::Bool = false
35-
# "Random seed"
36-
# rng::AbstractRNG = Random.default_rng()
37-
# "Optim optimiser"
38-
# optimizer::O = LBFGS()
39-
# "Optim options"
40-
# optim_options::Optim.Options = Optim.Options()
41-
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
42-
# observed::Union{ObservedModel, Nothing} = nothing
43-
# "AD Backend"
44-
# ad_backend::AD.AbstractBackend = AD.ForwardDiffBackend()
45-
# "Optimiser"
46-
# optimiser::Optimisers.AbstractRule = ADAM()
13+
function Reinforce(reward = RelativeReward(false); populationsize = 100,
14+
functions = (sin, exp, cos, log, +, -, /, *), arities = (1, 1, 1, 1, 2, 2, 2, 2),
15+
n_layers = 1, skip = true, loss = aicc, keep = 0.1, use_protected = true,
16+
distributed = false, threaded = false, rng = Random.default_rng(),
17+
optimizer = LBFGS(), optim_options = Optim.Options(), observed = nothing,
18+
alpha = 0.999f0, optimiser = Adam(), ad_backend = AD.ForwardDiffBackend())
19+
return Reinforce(reward, ad_backend, CommonAlgOptions(;
20+
populationsize, functions, arities, n_layers, skip, simplex = Softmax(), loss,
21+
keep, use_protected, distributed, threaded, rng, optimizer,
22+
optim_options, optimiser, observed, alpha))
4723
end
4824

4925
Base.print(io::IO, ::Reinforce) = print(io, "Reinforce")
5026
Base.summary(io::IO, x::Reinforce) = print(io, x)
5127

5228
function reinforce_loss(candidates, p, alg)
53-
(; loss, reward) = alg
54-
losses = map(loss, candidates)
55-
rewards = reward(losses)
29+
losses = map(alg.options.loss, candidates)
30+
rewards = alg.reward(losses)
5631
# ∇U(θ) = E[∇log(p)*R(t)]
57-
mean(map(enumerate(candidates)) do (i, candidate)
32+
return mean(map(enumerate(candidates)) do (i, candidate)
5833
return rewards[i] * -candidate(p)
5934
end)
6035
end

lib/DataDrivenLux/src/caches/cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212
Base.show(io::IO, cache::SearchCache) = print(io, "SearchCache : $(cache.alg)")
1313

1414
function init_model(x::AbstractDAGSRAlgorithm, basis::Basis, dataset::Dataset, intervals)
15-
(; simplex, n_layers, arities, functions, use_protected, skip) = x
15+
(; simplex, n_layers, arities, functions, use_protected, skip) = x.options
1616

1717
# Get the parameter mapping
1818
variable_mask = map(enumerate(equations(basis))) do (i, eq)

lib/DataDrivenLux/test/randomsearch_solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ dummy_dataset = DataDrivenLux.Dataset(dummy_problem)
2727

2828
@test isempty(dummy_dataset.u_intervals)
2929

30-
for (data, interval) in zip((X, Y, 1:size(X, 2)),
30+
for (data, _interval) in zip((X, Y, 1:size(X, 2)),
3131
(dummy_dataset.x_intervals[1], dummy_dataset.y_intervals[1], dummy_dataset.t_interval))
32-
@test (interval.lo, interval.hi) == extrema(data)
32+
@test isequal_interval(_interval, interval(extrema(data)))
3333
end
3434

3535
# We have 1 Choices in the first layer, 2 in the last

lib/DataDrivenLux/test/reinforce_solve.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Random
66
using Distributions
77
using Test
88
using Optimisers
9+
using Optim
910
using StableRNGs
1011

1112
rng = StableRNG(1234)

lib/DataDrivenLux/test/runtests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@ const GROUP = get(ENV, "GROUP", "All")
99

1010
@time begin
1111
if GROUP == "All" || GROUP == "DataDrivenLux"
12-
@safetestset "Lux" begin
12+
@testset "Lux" begin
1313
@safetestset "Nodes" include("nodes.jl")
1414
@safetestset "Layers" include("layers.jl")
1515
@safetestset "Graphs" include("graphs.jl")
1616
end
1717

18-
@safetestset "Caches" begin
18+
@testset "Caches" begin
1919
@safetestset "Candidate" include("candidate.jl") # FIXME
2020
@safetestset "Cache" include("cache.jl")
2121
end
2222

23-
@safetestset "Algorithms" begin
24-
@safetestset "RandomSearch" include("randomsearch_solve.jl") # FIXME
25-
@safetestset "Reinforce" include("reinforce_solve.jl") # FIXME
26-
@safetestset "CrossEntropy" include("crossentropy_solve.jl") # FIXME
23+
@testset "Algorithms" begin
24+
@safetestset "RandomSearch" include("randomsearch_solve.jl")
25+
@safetestset "Reinforce" include("reinforce_solve.jl")
26+
@safetestset "CrossEntropy" include("crossentropy_solve.jl")
2727
end
2828
end
2929
end

0 commit comments

Comments
 (0)