Skip to content

Commit 986259b

Browse files
Merge pull request #522 from SciML/ap/lux
feat: update to support Lux 1.0
2 parents 08d630d + c2e033e commit 986259b

31 files changed

+556
-741
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
style = "sciml"
22
format_markdown = true
3-
format_docstrings = true
3+
format_docstrings = true
4+
annotate_untyped_fields_with_any = false

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ jobs:
2323
- name: CompatHelper.main()
2424
env:
2525
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
26-
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR"])'
26+
run: julia -e 'using CompatHelper; CompatHelper.main(;subdirs=["", "docs", "lib/DataDrivenDMD", "lib/DataDrivenSparse", "lib/DataDrivenSR", "lib/DataDrivenLux"])'

lib/DataDrivenDMD/src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ end
9595

9696
function (algorithm::AbstractKoopmanAlgorithm)(prob::InternalDataDrivenProblem;
9797
control_input = nothing, kwargs...)
98-
@unpack traindata, testdata, control_idx, options = prob
99-
@unpack abstol = options
98+
(; traindata, testdata, control_idx, options) = prob
99+
(; abstol) = options
100100
# Preprocess control idx, indicates if any control is active in a single basis atom
101101
control_idx = map(any, eachrow(control_idx))
102102
no_controls = .!control_idx

lib/DataDrivenLux/Project.toml

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,58 @@
11
name = "DataDrivenLux"
22
uuid = "47881146-99d0-492a-8425-8f2f33327637"
33
authors = ["JuliusMartensen <julius.martensen@gmail.com>"]
4-
version = "0.1.1"
4+
version = "0.2.0"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
910
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
11+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1012
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
1113
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1214
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1315
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
16+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1417
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1518
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
1619
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1720
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1821
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1922
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
20-
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
23+
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
2124
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
2225
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2326
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2427
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
25-
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
28+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
29+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2630
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
31+
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
2732

2833
[compat]
29-
AbstractDifferentiation = "0.4"
34+
AbstractDifferentiation = "0.6"
3035
ChainRulesCore = "1.15"
31-
ComponentArrays = "0.13"
36+
CommonSolve = "0.2.4"
37+
ComponentArrays = "0.15"
38+
ConcreteStructs = "0.2.3"
3239
DataDrivenDiffEq = "1"
3340
Distributions = "0.25"
3441
DistributionsAD = "0.6"
42+
DocStringExtensions = "0.9.3"
3543
ForwardDiff = "0.10"
36-
IntervalArithmetic = "0.20"
44+
IntervalArithmetic = "0.22"
3745
InverseFunctions = "0.1"
38-
Lux = "0.4"
39-
NNlib = "0.8"
46+
Lux = "1"
47+
LuxCore = "1"
4048
Optim = "1.7"
41-
Optimisers = "0.2"
49+
Optimisers = "0.3"
4250
ProgressMeter = "1.7"
43-
Reexport = "1.2"
44-
TransformVariables = "0.7"
45-
julia = "1.6"
51+
Setfield = "1"
52+
StatsBase = "0.34.3"
53+
TransformVariables = "0.8"
54+
WeightInitializers = "1"
55+
julia = "1.10"
4656

4757
[extras]
4858
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

lib/DataDrivenLux/src/DataDrivenLux.jl

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,43 @@ module DataDrivenLux
33
using DataDrivenDiffEq
44

55
# Load specific (abstract) types
6-
using DataDrivenDiffEq: AbstractBasis
7-
using DataDrivenDiffEq: AbstractDataDrivenAlgorithm
8-
using DataDrivenDiffEq: AbstractDataDrivenResult
9-
using DataDrivenDiffEq: AbstractDataDrivenProblem
10-
using DataDrivenDiffEq: DDReturnCode, ABSTRACT_CONT_PROB, ABSTRACT_DISCRETE_PROB
11-
using DataDrivenDiffEq: InternalDataDrivenProblem
12-
using DataDrivenDiffEq: is_implicit, is_controlled
13-
14-
using DataDrivenDiffEq.DocStringExtensions
15-
using DataDrivenDiffEq.CommonSolve
16-
using DataDrivenDiffEq.CommonSolve: solve!
17-
using DataDrivenDiffEq.StatsBase
18-
using DataDrivenDiffEq.Parameters
19-
using DataDrivenDiffEq.Setfield
20-
21-
using Reexport
22-
@reexport using Optim
23-
using Lux
24-
25-
using InverseFunctions
26-
using TransformVariables
27-
using NNlib
28-
using Distributions
29-
using DistributionsAD
30-
31-
using ChainRulesCore
32-
using ComponentArrays
33-
34-
using IntervalArithmetic
35-
using Random
36-
using Distributed
37-
using ProgressMeter
38-
using Logging
39-
using AbstractDifferentiation, ForwardDiff
40-
using Optimisers
6+
using DataDrivenDiffEq: AbstractBasis, AbstractDataDrivenAlgorithm,
7+
AbstractDataDrivenResult, AbstractDataDrivenProblem, DDReturnCode,
8+
ABSTRACT_CONT_PROB, ABSTRACT_DISCRETE_PROB,
9+
InternalDataDrivenProblem, is_implicit, is_controlled
10+
11+
using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF, SIGNATURES
12+
using CommonSolve: CommonSolve, solve!
13+
using ConcreteStructs: @concrete
14+
using Setfield: Setfield, @set!
15+
16+
using Optim: Optim, LBFGS
17+
using Optimisers: Optimisers, Adam
18+
19+
using Lux: Lux, logsoftmax, softmax!
20+
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
21+
using WeightInitializers: WeightInitializers, ones32, zeros32
22+
23+
using InverseFunctions: InverseFunctions, NoInverse
24+
using TransformVariables: TransformVariables, as, transform_logdensity
25+
using Distributions: Distributions, Distribution, Normal, Uniform, Univariate, dof,
26+
loglikelihood, logpdf, mean, mode, quantile, scale, truncated
27+
using DistributionsAD: DistributionsAD
28+
using StatsBase: StatsBase, aicc, nobs, nullloglikelihood, r2, rss, sum, weights
29+
30+
using ChainRulesCore: @ignore_derivatives
31+
using ComponentArrays: ComponentArrays, ComponentVector
32+
33+
using IntervalArithmetic: IntervalArithmetic, Interval, interval, isempty
34+
using ProgressMeter: ProgressMeter
35+
using AbstractDifferentiation: AbstractDifferentiation
36+
using ForwardDiff: ForwardDiff
37+
38+
using Logging: Logging, NullLogger, with_logger
39+
using Random: Random, AbstractRNG
40+
using Distributed: Distributed, pmap
41+
42+
const AD = AbstractDifferentiation
4143

4244
abstract type AbstractAlgorithmCache <: AbstractDataDrivenResult end
4345
abstract type AbstractDAGSRAlgorithm <: AbstractDataDrivenAlgorithm end
@@ -62,17 +64,20 @@ export AdditiveError, MultiplicativeError
6264
export ObservedModel
6365

6466
# Simplex
65-
include("./lux/simplex.jl")
67+
include("lux/simplex.jl")
6668
export Softmax, GumbelSoftmax, DirectSimplex
6769

6870
# Nodes and Layers
69-
include("./lux/path_state.jl")
71+
include("lux/path_state.jl")
7072
export PathState
71-
include("./lux/node.jl")
73+
74+
include("lux/node.jl")
7275
export FunctionNode
73-
include("./lux/layer.jl")
76+
77+
include("lux/layer.jl")
7478
export FunctionLayer
75-
include("./lux/graph.jl")
79+
80+
include("lux/graph.jl")
7681
export LayeredDAG
7782

7883
include("caches/dataset.jl")
@@ -87,6 +92,8 @@ export SearchCache
8792
include("algorithms/rewards.jl")
8893
export RelativeReward, AbsoluteReward
8994

95+
include("algorithms/common.jl")
96+
9097
include("algorithms/randomsearch.jl")
9198
export RandomSearch
9299

@@ -98,4 +105,4 @@ export CrossEntropy
98105

99106
include("solve.jl")
100107

101-
end # module DataDrivenLux
108+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@kwdef @concrete struct CommonAlgOptions
2+
populationsize::Int = 100
3+
functions = (sin, exp, cos, log, +, -, /, *)
4+
arities = (1, 1, 1, 1, 2, 2, 2, 2)
5+
n_layers::Int = 1
6+
skip::Bool = true
7+
simplex <: AbstractSimplex = Softmax()
8+
loss = aicc
9+
keep <: Union{Real, Int} = 0.1
10+
use_protected::Bool = true
11+
distributed::Bool = false
12+
threaded::Bool = false
13+
rng <: AbstractRNG = Random.default_rng()
14+
optimizer = LBFGS()
15+
optim_options <: Optim.Options = Optim.Options()
16+
optimiser <: Union{Nothing, Optimisers.AbstractRule} = nothing
17+
observed <: Union{ObservedModel, Nothing} = nothing
18+
alpha::Real = 0.999f0
19+
end
Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,34 @@
1-
"""
2-
$(TYPEDEF)
1+
@concrete struct CrossEntropy <: AbstractDAGSRAlgorithm
2+
options <: CommonAlgOptions
3+
end
34

4-
Uses the crossentropy method for discrete optimization to search the space of possible solutions.
5+
"""
6+
$(SIGNATURES)
57
6-
# Fields
7-
$(FIELDS)
8+
Uses the crossentropy method for discrete optimization to search the space of possible
9+
solutions.
810
"""
9-
@with_kw struct CrossEntropy{F, A, L, O} <: AbstractDAGSRAlgorithm
10-
"The number of candidates to track"
11-
populationsize::Int = 100
12-
"The functions to include in the search"
13-
functions::F = (sin, exp, cos, log, +, -, /, *)
14-
"The arities of the functions"
15-
arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
16-
"The number of layers"
17-
n_layers::Int = 1
18-
"Include skip layers"
19-
skip::Bool = true
20-
"Evaluation function to sort the samples"
21-
loss::L = aicc
22-
"The number of candidates to keep in each iteration"
23-
keep::Union{Real, Int} = 0.1
24-
"Use protected operators"
25-
use_protected::Bool = true
26-
"Use distributed optimization and resampling"
27-
distributed::Bool = false
28-
"Use threaded optimization and resampling - not implemented right now."
29-
threaded::Bool = false
30-
"Random seed"
31-
rng::Random.AbstractRNG = Random.default_rng()
32-
"Optim optimiser"
33-
optimizer::O = LBFGS()
34-
"Optim options"
35-
optim_options::Optim.Options = Optim.Options()
36-
"Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
37-
observed::Union{ObservedModel, Nothing} = nothing
38-
"Field for possible optimiser - no use for CrossEntropy"
39-
optimiser::Nothing = nothing
40-
"Update parameter for smoothness"
41-
alpha::Real = 0.999f0
11+
function CrossEntropy(; populationsize = 100, functions = (sin, exp, cos, log, +, -, /, *),
12+
arities = (1, 1, 1, 1, 2, 2, 2, 2), n_layers = 1, skip = true, loss = aicc,
13+
keep = 0.1, use_protected = true, distributed = false, threaded = false,
14+
rng = Random.default_rng(), optimizer = LBFGS(), optim_options = Optim.Options(),
15+
observed = nothing, alpha = 0.999f0)
16+
return CrossEntropy(CommonAlgOptions(;
17+
populationsize, functions, arities, n_layers, skip, simplex = DirectSimplex(), loss,
18+
keep, use_protected, distributed, threaded, rng, optimizer,
19+
optim_options, optimiser = nothing, observed, alpha))
4220
end
4321

44-
Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy")
22+
Base.print(io::IO, ::CrossEntropy) = print(io, "CrossEntropy()")
4523
Base.summary(io::IO, x::CrossEntropy) = print(io, x)
4624

4725
function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
48-
@unpack n_layers, arities, functions, use_protected, skip = x
49-
50-
# We enforce the direct simplex here!
51-
simplex = DirectSimplex()
26+
(; n_layers, arities, functions, use_protected, skip) = x.options
5227

5328
# Get the parameter mapping
5429
variable_mask = map(enumerate(equations(basis))) do (i, eq)
55-
any(ModelingToolkit.isvariable, ModelingToolkit.get_variables(eq.rhs)) &&
56-
IntervalArithmetic.iscommon(intervals[i])
30+
return any(ModelingToolkit.isvariable, ModelingToolkit.get_variables(eq.rhs)) &&
31+
IntervalArithmetic.iscommon(intervals[i])
5732
end
5833

5934
variable_mask = Any[variable_mask...]
@@ -63,15 +38,14 @@ function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
6338
end
6439

6540
return LayeredDAG(length(basis), size(dataset.y, 1), n_layers, arities, functions;
66-
skip = skip, input_functions = variable_mask, simplex = simplex)
41+
skip, input_functions = variable_mask, x.options.simplex)
6742
end
6843

6944
function update_parameters!(cache::SearchCache{<:CrossEntropy})
70-
@unpack candidates, keeps, p, alg = cache
71-
@unpack alpha = alg
72-
= mean(map(candidates[keeps]) do candidate
73-
ComponentVector(get_configuration(candidate.model.model, p, candidate.st))
45+
= mean(map(cache.candidates[cache.keeps]) do candidate
46+
return ComponentVector(get_configuration(candidate.model.model, cache.p, candidate.st))
7447
end)
75-
cache.p .= alpha * p + (one(alpha) - alpha) .*
48+
alpha = cache.alg.options.alpha
49+
@. cache.p = alpha * cache.p + (true - alpha) *
7650
return
7751
end

0 commit comments

Comments
 (0)