@@ -22,17 +22,16 @@ StatsBase.nullloglikelihood(stats::PathStatistics) = getfield(stats, :nullloglik
2222StatsBase. dof (stats:: PathStatistics ) = getfield (stats, :dof )
2323StatsBase. r2 (c:: PathStatistics ) = r2 (c, :CoxSnell )
2424
25- struct ComponentModel{B, M}
26- basis:: B
27- model:: M
25+ @concrete struct ComponentModel
26+ basis
27+ model
2828end
2929
30- function (c:: ComponentModel )(dataset:: Dataset{T} , ps, st:: NamedTuple{fieldnames} ,
31- p:: AbstractVector{T} ) where {T, fieldnames }
30+ function (c:: ComponentModel )(dataset:: Dataset{T} , ps, st:: NamedTuple ,
31+ p:: AbstractVector{T} ) where {T}
3232 return first (c. model (c. basis (dataset, p), ps, st))
3333end
34- function (c:: ComponentModel )(ps, st:: NamedTuple{fieldnames} ,
35- paths:: Vector{<:AbstractPathState} ) where {fieldnames}
34+ function (c:: ComponentModel )(ps, st:: NamedTuple , paths:: Vector{<:AbstractPathState} )
3635 return get_loglikelihood (c. model, ps, st, paths)
3736end
3837
@@ -45,29 +44,29 @@ to the symbolic regression problem.
4544# Fields
4645$(FIELDS)
4746"""
48- struct Candidate{S <: NamedTuple } <: StatsBase.StatisticalModel
47+ @concrete struct Candidate <: StatsBase.StatisticalModel
4948 " Random seed"
50- rng:: AbstractRNG
49+ rng <: AbstractRNG
5150 " The current state"
52- st:: S
51+ st <: NamedTuple
5352 " The current parameters"
54- ps:: AbstractVector
53+ ps <: AbstractVector
5554 " Incoming paths"
56- incoming_path:: Vector{AbstractPathState}
55+ incoming_path <: Vector{<: AbstractPathState}
5756 " Outgoing path"
58- outgoing_path:: Vector{AbstractPathState}
57+ outgoing_path <: Vector{<: AbstractPathState}
5958 " Statistics"
60- statistics:: PathStatistics
59+ statistics <: PathStatistics
6160 " The observed model"
62- observed:: ObservedModel
61+ observed <: ObservedModel
6362 " The parameter distribution"
64- parameterdist:: ParameterDistributions
63+ parameterdist <: ParameterDistributions
6564 " The optimal scales"
66- scales:: AbstractVector
65+ scales <: AbstractVector
6766 " The optimal parameters"
68- parameters:: AbstractVector
67+ parameters <: AbstractVector
6968 " The component model"
70- model:: ComponentModel
69+ model <: ComponentModel
7170end
7271
7372function (c:: Candidate )(dataset:: Dataset{T} , ps = c. ps, p = c. parameters) where {T}
@@ -89,12 +88,9 @@ StatsBase.r2(c::Candidate) = r2(c, :CoxSnell)
8988get_parameters (c:: Candidate ) = transform_parameter (c. parameterdist, c. parameters)
9089get_scales (c:: Candidate ) = transform_scales (c. observed, c. scales)
9190
92- function Candidate (rng, model, basis, dataset; observed = ObservedModel (dataset. y),
93- parameterdist = ParameterDistributions (basis), ptype = Float32)
94- (; y, x) = dataset
95-
96- T = eltype (dataset)
97-
91+ function Candidate (
92+ rng, model, basis, dataset:: Dataset{T} ; observed = ObservedModel (dataset. y),
93+ parameterdist = ParameterDistributions (basis), ptype = Float32) where {T}
9894 # Create the initial state and path
9995 dataset_intervals = interval_eval (basis, dataset, get_interval (parameterdist))
10096
@@ -110,21 +106,21 @@ function Candidate(rng, model, basis, dataset; observed = ObservedModel(dataset.
110106
111107 ŷ, _ = model (basis (dataset, transform_parameter (parameterdist, parameters)), ps, st)
112108
113- lls = logpdf (observed, y, ŷ, scales)
109+ lls = logpdf (observed, dataset . y, ŷ, scales)
114110 lls += logpdf (parameterdist, parameters)
115111
116- rss = sum (abs2, y .- ŷ)
112+ rss = sum (abs2, dataset . y .- ŷ)
117113 dof_ = get_dof (outgoing_path)
118114
119- ȳ = vec (mean (y, dims = 2 ))
115+ ȳ = vec (mean (dataset . y; dims = 2 ))
120116
121- null_ll = logpdf (observed, y, ȳ, scales) + logpdf (parameterdist, parameters)
117+ null_ll = logpdf (observed, dataset . y, ȳ, scales) + logpdf (parameterdist, parameters)
122118
123- stats = PathStatistics (rss, lls, null_ll, dof_, prod (size (y)))
119+ stats = PathStatistics (rss, lls, null_ll, dof_, prod (size (dataset . y)))
124120
125- return Candidate {typeof(st)} (
126- Lux . replicate (rng), st, ComponentVector (ps), incoming_path, outgoing_path, stats ,
127- observed, parameterdist, scales, parameters, ComponentModel (basis, model))
121+ return Candidate (Lux . replicate (rng), st, ComponentVector (ps), incoming_path,
122+ outgoing_path, stats, observed, parameterdist, scales, parameters ,
123+ ComponentModel (basis, model))
128124end
129125
130126function update_values! (c:: Candidate , ps, dataset)
@@ -136,34 +132,24 @@ function update_values!(c::Candidate, ps, dataset)
136132 dataloglikelihood = logpdf (observed, y, ŷ, scales) + logpdf (parameterdist, parameters)
137133 rss = sum (abs2, y .- ŷ)
138134 dof = get_dof (outgoing_path)
139- ȳ = vec (mean (y, dims = 2 ))
135+ ȳ = vec (mean (y; dims = 2 ))
140136 nullloglikelihood = logpdf (observed, y, ȳ, scales) + logpdf (parameterdist, parameters)
141137 update_stats! (statistics, rss, dataloglikelihood, nullloglikelihood, dof)
142138 return
143139end
144140
145141@views function Distributions. logpdf (
146142 c:: Candidate , p:: ComponentVector , dataset:: Dataset{T} , ps = c. ps) where {T}
147- (; observed, parameterdist) = c
148- (; scales, parameters) = p
149- (; y) = dataset
150-
151- ŷ = c (dataset, ps, parameters)
152- return logpdf (c, p, y, ŷ)
143+ ŷ = c (dataset, ps, p. parameters)
144+ return logpdf (c, p, dataset. y, ŷ)
153145end
154146
155147function Distributions. logpdf (c:: Candidate , p:: AbstractVector , y:: AbstractMatrix{T} ,
156148 ŷ:: AbstractMatrix{T} ) where {T}
157- (; scales, parameters) = p
158- (; observed, parameterdist) = c
159-
160- return logpdf (observed, y, ŷ, scales) + logpdf (parameterdist, parameters)
149+ return logpdf (c. observed, y, ŷ, p. scales) + logpdf (c. parameterdist, p. parameters)
161150end
162151
163- function initial_values (c:: Candidate )
164- (; scales, parameters) = c
165- return ComponentVector ((; scales = scales, parameters = parameters))
166- end
152+ initial_values (c:: Candidate ) = ComponentVector (; c. scales, c. parameters)
167153
168154function optimize_candidate! (
169155 c:: Candidate , dataset:: Dataset{T} , ps = c. ps; optimizer = Optim. LBFGS (),
@@ -195,16 +181,10 @@ function optimize_candidate!(
195181 return
196182end
197183
198- function check_intervals (paths:: AbstractArray{<:AbstractPathState} ):: Bool
199- @inbounds for path in paths
200- check_intervals (path) || return false
201- end
202- return true
203- end
184+ check_intervals (paths:: AbstractArray{<:AbstractPathState} ) = all (check_intervals, paths)
204185
205186function sample (c:: Candidate , ps, i = 0 , max_sample = 10 )
206- (; incoming_path, st) = c
207- return sample (c. model. model, incoming_path, ps, st, i, max_sample)
187+ return sample (c. model. model, c. incoming_path, ps, c. st, i, max_sample)
208188end
209189
210190function sample (model, incoming, ps, st, i = 0 , max_sample = 10 )
0 commit comments