Skip to content

Commit c373425

Browse files
committed
test: more tests are now working
1 parent 9af1a5f commit c373425

File tree

8 files changed

+92
-129
lines changed

8 files changed

+92
-129
lines changed

lib/DataDrivenLux/src/DataDrivenLux.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using CommonSolve: CommonSolve, solve!
1313
using ConcreteStructs: @concrete
1414
using Setfield: Setfield, @set!
1515

16+
# TODO: Get rid of Optim and Optimisers in favor of Optimization.jl
1617
using Optim: Optim, LBFGS
1718
using Optimisers: Optimisers, ADAM
1819

@@ -64,17 +65,20 @@ export AdditiveError, MultiplicativeError
6465
export ObservedModel
6566

6667
# Simplex
67-
include("./lux/simplex.jl")
68+
include("lux/simplex.jl")
6869
export Softmax, GumbelSoftmax, DirectSimplex
6970

7071
# Nodes and Layers
71-
include("./lux/path_state.jl")
72+
include("lux/path_state.jl")
7273
export PathState
73-
include("./lux/node.jl")
74+
75+
include("lux/node.jl")
7476
export FunctionNode
75-
include("./lux/layer.jl")
77+
78+
include("lux/layer.jl")
7679
export FunctionLayer
77-
include("./lux/graph.jl")
80+
81+
include("lux/graph.jl")
7882
export LayeredDAG
7983

8084
include("caches/dataset.jl")

lib/DataDrivenLux/src/caches/cache.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ struct SearchCache{ALG, PTYPE, O} <: AbstractAlgorithmCache
99
optimiser_state::O
1010
end
1111

12-
function Base.show(io::IO, cache::SearchCache)
13-
print(io, "SearchCache : $(cache.alg)")
14-
return
15-
end
12+
Base.show(io::IO, cache::SearchCache) = print(io, "SearchCache : $(cache.alg)")
1613

1714
function init_model(x::AbstractDAGSRAlgorithm, basis::Basis, dataset::Dataset, intervals)
1815
(; simplex, n_layers, arities, functions, use_protected, skip) = x
@@ -116,7 +113,7 @@ function update_cache!(cache::SearchCache)
116113
sortperm!(cache.sorting, cache.candidates, by = loss)
117114
permute!(cache.candidates, cache.sorting)
118115
loss_quantile = quantile(losses, keep, sorted = true)
119-
cache.keeps .= (losses .<= loss_quantile)
116+
@. cache.keeps = losses loss_quantile
120117
end
121118

122119
return
@@ -158,7 +155,6 @@ function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(2)}, p = cache.p
158155
end
159156

160157
# Distributed
161-
162158
function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(3)}, p = cache.p)
163159
(; optimizer, optim_options) = cache.alg
164160

@@ -176,4 +172,4 @@ function optimize_cache!(cache::SearchCache{<:Any, __PROCESSUSE(3)}, p = cache.p
176172
return
177173
end
178174

179-
function convert_to_basis(cache::SearchCache) end
175+
function convert_to_basis(::SearchCache) end

lib/DataDrivenLux/src/caches/candidate.jl

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ StatsBase.nullloglikelihood(stats::PathStatistics) = getfield(stats, :nullloglik
2222
StatsBase.dof(stats::PathStatistics) = getfield(stats, :dof)
2323
StatsBase.r2(c::PathStatistics) = r2(c, :CoxSnell)
2424

25-
struct ComponentModel{B, M}
26-
basis::B
27-
model::M
25+
@concrete struct ComponentModel
26+
basis
27+
model
2828
end
2929

30-
function (c::ComponentModel)(dataset::Dataset{T}, ps, st::NamedTuple{fieldnames},
31-
p::AbstractVector{T}) where {T, fieldnames}
30+
function (c::ComponentModel)(dataset::Dataset{T}, ps, st::NamedTuple,
31+
p::AbstractVector{T}) where {T}
3232
return first(c.model(c.basis(dataset, p), ps, st))
3333
end
34-
function (c::ComponentModel)(ps, st::NamedTuple{fieldnames},
35-
paths::Vector{<:AbstractPathState}) where {fieldnames}
34+
function (c::ComponentModel)(ps, st::NamedTuple, paths::Vector{<:AbstractPathState})
3635
return get_loglikelihood(c.model, ps, st, paths)
3736
end
3837

@@ -45,29 +44,29 @@ to the symbolic regression problem.
4544
# Fields
4645
$(FIELDS)
4746
"""
48-
struct Candidate{S <: NamedTuple} <: StatsBase.StatisticalModel
47+
@concrete struct Candidate <: StatsBase.StatisticalModel
4948
"Random seed"
50-
rng::AbstractRNG
49+
rng <: AbstractRNG
5150
"The current state"
52-
st::S
51+
st <: NamedTuple
5352
"The current parameters"
54-
ps::AbstractVector
53+
ps <: AbstractVector
5554
"Incoming paths"
56-
incoming_path::Vector{AbstractPathState}
55+
incoming_path <: Vector{<:AbstractPathState}
5756
"Outgoing path"
58-
outgoing_path::Vector{AbstractPathState}
57+
outgoing_path <: Vector{<:AbstractPathState}
5958
"Statistics"
60-
statistics::PathStatistics
59+
statistics <: PathStatistics
6160
"The observed model"
62-
observed::ObservedModel
61+
observed <: ObservedModel
6362
"The parameter distribution"
64-
parameterdist::ParameterDistributions
63+
parameterdist <: ParameterDistributions
6564
"The optimal scales"
66-
scales::AbstractVector
65+
scales <: AbstractVector
6766
"The optimal parameters"
68-
parameters::AbstractVector
67+
parameters <: AbstractVector
6968
"The component model"
70-
model::ComponentModel
69+
model <: ComponentModel
7170
end
7271

7372
function (c::Candidate)(dataset::Dataset{T}, ps = c.ps, p = c.parameters) where {T}
@@ -89,12 +88,9 @@ StatsBase.r2(c::Candidate) = r2(c, :CoxSnell)
8988
get_parameters(c::Candidate) = transform_parameter(c.parameterdist, c.parameters)
9089
get_scales(c::Candidate) = transform_scales(c.observed, c.scales)
9190

92-
function Candidate(rng, model, basis, dataset; observed = ObservedModel(dataset.y),
93-
parameterdist = ParameterDistributions(basis), ptype = Float32)
94-
(; y, x) = dataset
95-
96-
T = eltype(dataset)
97-
91+
function Candidate(
92+
rng, model, basis, dataset::Dataset{T}; observed = ObservedModel(dataset.y),
93+
parameterdist = ParameterDistributions(basis), ptype = Float32) where {T}
9894
# Create the initial state and path
9995
dataset_intervals = interval_eval(basis, dataset, get_interval(parameterdist))
10096

@@ -110,21 +106,21 @@ function Candidate(rng, model, basis, dataset; observed = ObservedModel(dataset.
110106

111107
ŷ, _ = model(basis(dataset, transform_parameter(parameterdist, parameters)), ps, st)
112108

113-
lls = logpdf(observed, y, ŷ, scales)
109+
lls = logpdf(observed, dataset.y, ŷ, scales)
114110
lls += logpdf(parameterdist, parameters)
115111

116-
rss = sum(abs2, y .- ŷ)
112+
rss = sum(abs2, dataset.y .- ŷ)
117113
dof_ = get_dof(outgoing_path)
118114

119-
= vec(mean(y, dims = 2))
115+
= vec(mean(dataset.y; dims = 2))
120116

121-
null_ll = logpdf(observed, y, ȳ, scales) + logpdf(parameterdist, parameters)
117+
null_ll = logpdf(observed, dataset.y, ȳ, scales) + logpdf(parameterdist, parameters)
122118

123-
stats = PathStatistics(rss, lls, null_ll, dof_, prod(size(y)))
119+
stats = PathStatistics(rss, lls, null_ll, dof_, prod(size(dataset.y)))
124120

125-
return Candidate{typeof(st)}(
126-
Lux.replicate(rng), st, ComponentVector(ps), incoming_path, outgoing_path, stats,
127-
observed, parameterdist, scales, parameters, ComponentModel(basis, model))
121+
return Candidate(Lux.replicate(rng), st, ComponentVector(ps), incoming_path,
122+
outgoing_path, stats, observed, parameterdist, scales, parameters,
123+
ComponentModel(basis, model))
128124
end
129125

130126
function update_values!(c::Candidate, ps, dataset)
@@ -136,34 +132,24 @@ function update_values!(c::Candidate, ps, dataset)
136132
dataloglikelihood = logpdf(observed, y, ŷ, scales) + logpdf(parameterdist, parameters)
137133
rss = sum(abs2, y .- ŷ)
138134
dof = get_dof(outgoing_path)
139-
= vec(mean(y, dims = 2))
135+
= vec(mean(y; dims = 2))
140136
nullloglikelihood = logpdf(observed, y, ȳ, scales) + logpdf(parameterdist, parameters)
141137
update_stats!(statistics, rss, dataloglikelihood, nullloglikelihood, dof)
142138
return
143139
end
144140

145141
@views function Distributions.logpdf(
146142
c::Candidate, p::ComponentVector, dataset::Dataset{T}, ps = c.ps) where {T}
147-
(; observed, parameterdist) = c
148-
(; scales, parameters) = p
149-
(; y) = dataset
150-
151-
= c(dataset, ps, parameters)
152-
return logpdf(c, p, y, ŷ)
143+
= c(dataset, ps, p.parameters)
144+
return logpdf(c, p, dataset.y, ŷ)
153145
end
154146

155147
function Distributions.logpdf(c::Candidate, p::AbstractVector, y::AbstractMatrix{T},
156148
::AbstractMatrix{T}) where {T}
157-
(; scales, parameters) = p
158-
(; observed, parameterdist) = c
159-
160-
return logpdf(observed, y, ŷ, scales) + logpdf(parameterdist, parameters)
149+
return logpdf(c.observed, y, ŷ, p.scales) + logpdf(c.parameterdist, p.parameters)
161150
end
162151

163-
function initial_values(c::Candidate)
164-
(; scales, parameters) = c
165-
return ComponentVector((; scales = scales, parameters = parameters))
166-
end
152+
initial_values(c::Candidate) = ComponentVector(; c.scales, c.parameters)
167153

168154
function optimize_candidate!(
169155
c::Candidate, dataset::Dataset{T}, ps = c.ps; optimizer = Optim.LBFGS(),
@@ -195,16 +181,10 @@ function optimize_candidate!(
195181
return
196182
end
197183

198-
function check_intervals(paths::AbstractArray{<:AbstractPathState})::Bool
199-
@inbounds for path in paths
200-
check_intervals(path) || return false
201-
end
202-
return true
203-
end
184+
check_intervals(paths::AbstractArray{<:AbstractPathState}) = all(check_intervals, paths)
204185

205186
function sample(c::Candidate, ps, i = 0, max_sample = 10)
206-
(; incoming_path, st) = c
207-
return sample(c.model.model, incoming_path, ps, st, i, max_sample)
187+
return sample(c.model.model, c.incoming_path, ps, c.st, i, max_sample)
208188
end
209189

210190
function sample(model, incoming, ps, st, i = 0, max_sample = 10)
Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
struct Dataset{T}
2-
x::AbstractMatrix{T}
3-
y::AbstractMatrix{T}
4-
u::AbstractMatrix{T}
5-
t::AbstractVector{T}
6-
x_intervals::AbstractVector{Interval{T}}
7-
y_intervals::AbstractVector{Interval{T}}
8-
u_intervals::AbstractVector{Interval{T}}
9-
t_interval::Interval{T}
1+
@concrete struct Dataset{T}
2+
x <: AbstractMatrix{T}
3+
y <: AbstractMatrix{T}
4+
u <: AbstractMatrix{T}
5+
t <: AbstractVector{T}
6+
x_intervals <: AbstractVector{Interval{T}}
7+
y_intervals <: AbstractVector{Interval{T}}
8+
u_intervals <: AbstractVector{Interval{T}}
9+
t_interval <: Interval{T}
1010
end
1111

1212
Base.eltype(::Dataset{T}) where {T} = T
@@ -20,10 +20,10 @@ function Dataset(X::AbstractMatrix, Y::AbstractMatrix,
2020
U = convert.(T, U)
2121
t = convert.(T, t)
2222
t = isempty(t) ? convert.(T, LinRange(0, size(Y, 2) - 1, size(Y, 2))) : convert.(T, t)
23-
x_intervals = Interval.(map(extrema, eachrow(X)))
24-
y_intervals = Interval.(map(extrema, eachrow(Y)))
25-
u_intervals = Interval.(map(extrema, eachrow(U)))
26-
t_intervals = isempty(t) ? Interval{T}(zero(T), zero(T)) : Interval(extrema(t))
23+
x_intervals = interval.(map(extrema, eachrow(X)))
24+
y_intervals = interval.(map(extrema, eachrow(Y)))
25+
u_intervals = interval.(map(extrema, eachrow(U)))
26+
t_intervals = isempty(t) ? Interval{T}(zero(T), zero(T)) : interval(extrema(t))
2727
return Dataset{T}(X, Y, U, t, x_intervals, y_intervals, u_intervals, t_intervals)
2828
end
2929

@@ -35,50 +35,40 @@ end
3535

3636
function (b::Basis{false, false})(d::Dataset{T}, p::P) where {T, P}
3737
f = DataDrivenDiffEq.get_f(b)
38-
(; x, t) = d
39-
return f(x, p, t)
38+
return f(d.x, p, d.t)
4039
end
4140

4241
function (b::Basis{false, true})(d::Dataset{T}, p::P) where {T, P}
4342
f = DataDrivenDiffEq.get_f(b)
44-
(; x, t, u) = d
45-
return f(x, p, t, u)
43+
return f(d.x, p, d.t, d.u)
4644
end
4745

4846
function (b::Basis{true, false})(d::Dataset{T}, p::P) where {T, P}
4947
f = DataDrivenDiffEq.get_f(b)
50-
(; y, x, t) = d
51-
return f(y, x, p, t)
48+
return f(d.y, d.x, p, d.t)
5249
end
5350

5451
function (b::Basis{true, true})(d::Dataset{T}, p::P) where {T, P}
5552
f = DataDrivenDiffEq.get_f(b)
56-
(; y, x, t, u) = d
57-
return f(y, x, p, t, u)
53+
return f(d.y, d.x, p, d.t, d.u)
5854
end
5955

60-
##
61-
6256
function interval_eval(b::Basis{false, false}, d::Dataset{T}, p::P) where {T, P}
6357
f = DataDrivenDiffEq.get_f(b)
64-
(; x_intervals, t_interval) = d
65-
return f(x_intervals, p, t_interval)
58+
return f(d.x_intervals, p, d.t_interval)
6659
end
6760

6861
function interval_eval(b::Basis{false, true}, d::Dataset{T}, p::P) where {T, P}
6962
f = DataDrivenDiffEq.get_f(b)
70-
(; x_intervals, t_interval, u_intervals) = d
71-
return f(x_intervals, p, t_interval, u_intervals)
63+
return f(d.x_intervals, p, d.t_interval, d.u_intervals)
7264
end
7365

7466
function interval_eval(b::Basis{true, false}, d::Dataset{T}, p::P) where {T, P}
7567
f = DataDrivenDiffEq.get_f(b)
76-
(; y_intervals, x_intervals, t_interval) = d
77-
return f(y_intervals, x_intervals, p, t_interval)
68+
return f(d.y_intervals, d.x_intervals, p, d.t_interval)
7869
end
7970

8071
function interval_eval(b::Basis{true, true}, d::Dataset{T}, p::P) where {T, P}
8172
f = DataDrivenDiffEq.get_f(b)
82-
(; y_intervals, x_intervals, t_interval, u_intervals) = d
83-
return f(y_intervals, x_intervals, p, t_interval, u_intervals)
73+
return f(d.y_intervals, d.x_intervals, p, d.t_interval, d.u_intervals)
8474
end

lib/DataDrivenLux/src/custom_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function ParameterDistribution(
147147
upper_t = isinf(upper) ? TransformVariables.∞ : upper
148148
transform = as(Real, lower_t, upper_t)
149149
init = convert.(T, TransformVariables.inverse(transform, init))
150-
return ParameterDistribution(d, Interval(lower, upper), transform, init)
150+
return ParameterDistribution(d, interval(lower, upper), transform, init)
151151
end
152152

153153
function Base.summary(io::IO, p::ParameterDistribution)

lib/DataDrivenLux/src/lux/path_state.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
abstract type AbstractPathState end
22

3-
@concrete struct PathState{T} <: AbstractPathState
3+
struct PathState{T, PO <: Tuple, PI <: Tuple} <: AbstractPathState
44
"Accumulated loglikelihood of the state"
55
path_interval::Interval{T}
66
"All the operators of the path"
7-
path_operators <: Tuple
7+
path_operators::PO
88
"The unique identifier of nodes in the path"
9-
path_ids <: Tuple
9+
path_ids::PI
10+
11+
function PathState{T}(
12+
interval::Interval{T}, path_operators::PO, path_ids::PI) where {T, PO, PI}
13+
return new{T, PO, PI}(interval, path_operators, path_ids)
14+
end
15+
function PathState{T}(
16+
interval::Interval, path_operators::PO, path_ids::PI) where {T, PO, PI}
17+
return new{T, PO, PI}(Interval{T}(interval), path_operators, path_ids)
18+
end
1019
end
1120

1221
function PathState(interval::Interval{T}, id::Tuple{Int, Int} = (1, 1)) where {T}

lib/DataDrivenLux/test/candidate.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ using StableRNGs
2626

2727
@test DataDrivenLux.get_scales(candidate) ones(Float64, 1)
2828
@test isempty(DataDrivenLux.get_parameters(candidate))
29-
@test_nowarn DataDrivenLux.optimize_candidate!(
30-
candidate, dataset; options = Optim.Options())
29+
@test_nowarn DataDrivenLux.optimize_candidate!(candidate, dataset)
3130
end
3231

3332
@testset "Candidate with parametes" begin

0 commit comments

Comments
 (0)