1- """
2- $(TYPEDEF)
1+ @concrete struct CrossEntropy <: AbstractDAGSRAlgorithm
2+ options <: CommonAlgOptions
3+ end
34
4- Uses the crossentropy method for discrete optimization to search the space of possible solutions.
5+ """
6+ $(SIGNATURES)
57
6- # Fields
7- $(FIELDS)
8+ Uses the crossentropy method for discrete optimization to search the space of possible
9+ solutions.
810"""
9- @with_kw struct CrossEntropy{F, A, L, O} <: AbstractDAGSRAlgorithm
10- " The number of candidates to track"
11- populationsize:: Int = 100
12- " The functions to include in the search"
13- functions:: F = (sin, exp, cos, log, + , - , / , * )
14- " The arities of the functions"
15- arities:: A = (1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 )
16- " The number of layers"
17- n_layers:: Int = 1
18- " Include skip layers"
19- skip:: Bool = true
20- " Evaluation function to sort the samples"
21- loss:: L = aicc
22- " The number of candidates to keep in each iteration"
23- keep:: Union{Real, Int} = 0.1
24- " Use protected operators"
25- use_protected:: Bool = true
26- " Use distributed optimization and resampling"
27- distributed:: Bool = false
28- " Use threaded optimization and resampling - not implemented right now."
29- threaded:: Bool = false
30- " Random seed"
31- rng:: Random.AbstractRNG = Random. default_rng ()
32- " Optim optimiser"
33- optimizer:: O = LBFGS ()
34- " Optim options"
35- optim_options:: Optim.Options = Optim. Options ()
36- " Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
37- observed:: Union{ObservedModel, Nothing} = nothing
38- " Field for possible optimiser - no use for CrossEntropy"
39- optimiser:: Nothing = nothing
40- " Update parameter for smoothness"
41- alpha:: Real = 0.999f0
11+ function CrossEntropy (; populationsize = 100 , functions = (sin, exp, cos, log, + , - , / , * ),
12+ arities = (1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ), n_layers = 1 , skip = true , loss = aicc,
13+ keep = 0.1 , use_protected = true , distributed = false , threaded = false ,
14+ rng = Random. default_rng (), optimizer = LBFGS (), optim_options = Optim. Options (),
15+ observed = nothing , alpha = 0.999f0 )
16+ return CrossEntropy (CommonAlgOptions (;
17+ populationsize, functions, arities, n_layers, skip, simplex = DirectSimplex (), loss,
18+ keep, use_protected, distributed, threaded, rng, optimizer,
19+ optim_options, optimiser = nothing , observed, alpha))
4220end
4321
44- Base. print (io:: IO , :: CrossEntropy ) = print (io, " CrossEntropy" )
22+ Base. print (io:: IO , :: CrossEntropy ) = print (io, " CrossEntropy() " )
4523Base. summary (io:: IO , x:: CrossEntropy ) = print (io, x)
4624
4725function init_model (x:: CrossEntropy , basis:: Basis , dataset:: Dataset , intervals)
48- @unpack n_layers, arities, functions, use_protected, skip = x
49-
50- # We enforce the direct simplex here!
51- simplex = DirectSimplex ()
26+ (; n_layers, arities, functions, use_protected, skip) = x. options
5227
5328 # Get the parameter mapping
5429 variable_mask = map (enumerate (equations (basis))) do (i, eq)
55- any (ModelingToolkit. isvariable, ModelingToolkit. get_variables (eq. rhs)) &&
56- IntervalArithmetic. iscommon (intervals[i])
30+ return any (ModelingToolkit. isvariable, ModelingToolkit. get_variables (eq. rhs)) &&
31+ IntervalArithmetic. iscommon (intervals[i])
5732 end
5833
5934 variable_mask = Any[variable_mask... ]
@@ -63,15 +38,14 @@ function init_model(x::CrossEntropy, basis::Basis, dataset::Dataset, intervals)
6338 end
6439
6540 return LayeredDAG (length (basis), size (dataset. y, 1 ), n_layers, arities, functions;
66- skip = skip , input_functions = variable_mask, simplex = simplex)
41+ skip, input_functions = variable_mask, x . options . simplex)
6742end
6843
6944function update_parameters! (cache:: SearchCache{<:CrossEntropy} )
70- @unpack candidates, keeps, p, alg = cache
71- @unpack alpha = alg
72- p̄ = mean (map (candidates[keeps]) do candidate
73- ComponentVector (get_configuration (candidate. model. model, p, candidate. st))
45+ p̄ = mean (map (cache. candidates[cache. keeps]) do candidate
46+ return ComponentVector (get_configuration (candidate. model. model, cache. p, candidate. st))
7447 end )
75- cache. p .= alpha * p + (one (alpha) - alpha) .* p̄
48+ alpha = cache. alg. options. alpha
49+ @. cache. p = alpha * cache. p + (true - alpha) * p̄
7650 return
7751end
0 commit comments