|
| 1 | +@concrete struct Reinforce <: AbstractDAGSRAlgorithm |
| 2 | + reward |
| 3 | + ad_backend <: AD.AbstractBackend |
| 4 | + options <: CommonAlgOptions |
| 5 | +end |
| 6 | + |
1 | 7 | """ |
2 | | -$(TYPEDEF) |
| 8 | +$(SIGNATURES) |
3 | 9 |
|
4 | | -Uses the REINFORCE algorithm to search over the space of possible solutions to the |
| 10 | +Uses the REINFORCE algorithm to search over the space of possible solutions to the |
5 | 11 | symbolic regression problem. |
6 | | -
|
7 | | -# Fields |
8 | | -$(FIELDS) |
9 | 12 | """ |
10 | | -@kwdef struct Reinforce{F, A, L, O, R} <: AbstractDAGSRAlgorithm |
11 | | - "Reward function which should convert the loss to a reward." |
12 | | - reward::R = RelativeReward(false) |
13 | | - # "The number of candidates to track" |
14 | | - # populationsize::Int = 100 |
15 | | - # "The functions to include in the search" |
16 | | - # functions::F = (sin, exp, cos, log, +, -, /, *) |
17 | | - # "The arities of the functions" |
18 | | - # arities::A = (1, 1, 1, 1, 2, 2, 2, 2) |
19 | | - # "The number of layers" |
20 | | - # n_layers::Int = 1 |
21 | | - # "Include skip layers" |
22 | | - # skip::Bool = true |
23 | | - # "Simplex mapping" |
24 | | - # simplex::AbstractSimplex = Softmax() |
25 | | - # "Evaluation function to sort the samples" |
26 | | - # loss::L = aicc |
27 | | - # "The number of candidates to keep in each iteration" |
28 | | - # keep::Union{Real, Int} = 0.1 |
29 | | - # "Use protected operators" |
30 | | - # use_protected::Bool = true |
31 | | - # "Use distributed optimization and resampling" |
32 | | - # distributed::Bool = false |
33 | | - # "Use threaded optimization and resampling - not implemented right now." |
34 | | - # threaded::Bool = false |
35 | | - # "Random seed" |
36 | | - # rng::AbstractRNG = Random.default_rng() |
37 | | - # "Optim optimiser" |
38 | | - # optimizer::O = LBFGS() |
39 | | - # "Optim options" |
40 | | - # optim_options::Optim.Options = Optim.Options() |
41 | | - # "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed." |
42 | | - # observed::Union{ObservedModel, Nothing} = nothing |
43 | | - # "AD Backend" |
44 | | - # ad_backend::AD.AbstractBackend = AD.ForwardDiffBackend() |
45 | | - # "Optimiser" |
46 | | - # optimiser::Optimisers.AbstractRule = ADAM() |
| 13 | +function Reinforce(reward = RelativeReward(false); populationsize = 100, |
| 14 | + functions = (sin, exp, cos, log, +, -, /, *), arities = (1, 1, 1, 1, 2, 2, 2, 2), |
| 15 | + n_layers = 1, skip = true, loss = aicc, keep = 0.1, use_protected = true, |
| 16 | + distributed = false, threaded = false, rng = Random.default_rng(), |
| 17 | + optimizer = LBFGS(), optim_options = Optim.Options(), observed = nothing, |
| 18 | + alpha = 0.999f0, optimiser = Adam(), ad_backend = AD.ForwardDiffBackend()) |
| 19 | + return Reinforce(reward, ad_backend, CommonAlgOptions(; |
| 20 | + populationsize, functions, arities, n_layers, skip, simplex = Softmax(), loss, |
| 21 | + keep, use_protected, distributed, threaded, rng, optimizer, |
| 22 | + optim_options, optimiser, observed, alpha)) |
47 | 23 | end |
48 | 24 |
|
49 | 25 | Base.print(io::IO, ::Reinforce) = print(io, "Reinforce") |
50 | 26 | Base.summary(io::IO, x::Reinforce) = print(io, x) |
51 | 27 |
|
52 | 28 | function reinforce_loss(candidates, p, alg) |
53 | | - (; loss, reward) = alg |
54 | | - losses = map(loss, candidates) |
55 | | - rewards = reward(losses) |
| 29 | + losses = map(alg.options.loss, candidates) |
| 30 | + rewards = alg.reward(losses) |
56 | 31 | # ∇U(θ) = E[∇log(p)*R(t)] |
57 | | - mean(map(enumerate(candidates)) do (i, candidate) |
| 32 | + return mean(map(enumerate(candidates)) do (i, candidate) |
58 | 33 | return rewards[i] * -candidate(p) |
59 | 34 | end) |
60 | 35 | end |
|
0 commit comments