6161
6262function Model (;kwargs... )
6363 for (i, node) in enumerate (values (kwargs))
64- @assert typeof ( node) <: Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
64+ @assert node isa Tuple{Union{Array{Float64}, Float64}, Function, Symbol} " Check input order for node $(i) matches Tuple(value, function, kind)"
6565 end
66- vals = getvals (NamedTuple (kwargs))
66+ node_keys = keys (kwargs)
67+ vals = [getvals (NamedTuple (kwargs))... ]
68+ vals[1 ] = Tuple ([Ref (val) for val in vals[1 ]])
6769 args = [argnames (f) for f in vals[2 ]]
68- A, sorted_vertices = dag (NamedTuple {keys(kwargs)} (args))
69- modelinputs = NamedTuple {Tuple(sorted_vertices)} .([Tuple .(args), vals... ])
70- Model (GraphInfo (modelinputs... , A, sorted_vertices))
70+ A, sorted_inds = dag (NamedTuple {node_keys} (args))
71+ sorted_vertices = node_keys[sorted_inds]
72+ model_inputs = NamedTuple {node_keys} .([Tuple .(args), vals... ])
73+ sorted_model_inputs = [NamedTuple {sorted_vertices} (m) for m in model_inputs]
74+ Model (GraphInfo (sorted_model_inputs... , A, [sorted_vertices... ]))
7175end
7276
7377"""
@@ -78,11 +82,10 @@ and returns the implied adjacency matrix and topologically ordered
7882vertex list.
7983"""
8084function dag (inputs)
81- input_names = Symbol[keys (inputs)... ]
8285 A = adjacency_matrix (inputs)
8386 sorted_vertices = topological_sort_by_dfs (A)
8487 sorted_A = permute (A, collect (1 : length (inputs)), sorted_vertices)
85- sorted_A, input_names[ sorted_vertices]
88+ sorted_A, sorted_vertices
8689end
8790
8891"""
@@ -95,7 +98,7 @@ input, eval and kind, as required by the GraphInfo type.
9598@generated function getvals (nt:: NamedTuple{T} ) where T
9699 values = [:(nt[$ i][$ j]) for i in 1 : length (T), j in 1 : 3 ]
97100 m = [:($ (values[:,i]. .. ), ) for i in 1 : 3 ]
98- return Expr (:tuple , m... ) # :($(m...),)
101+ return Expr (:tuple , m... )
99102end
100103
101104"""
@@ -180,6 +183,7 @@ function topological_sort_by_dfs(A)
180183 return reverse (verts)
181184end
182185
186+ # getters and setters
183187"""
184188 Base.getindex(m::Model, vn::VarName{p})
185189
@@ -217,39 +221,116 @@ function Base.getindex(m::Model, vn::VarName)
217221 return m. g[vn]
218222end
219223
220- function Base. show (io:: IO , m:: Model )
221- print (io, " Nodes: \n " )
222- for node in nodes (m)
223- print (io, " $node = " , m[VarName {node} ()], " \n " )
224- end
224+ """
225+ set_node_value!(m::Model, ind::VarName, value::T) where Takes
226+
227+ Change the value of the node.
228+
229+ # Examples
230+
231+ ```jl-doctest
232+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
233+ μ = (1.0, () -> 1.0, :Logical),
234+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
235+ Nodes:
236+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#38#41"(), kind = :Logical)
237+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#37#40"(), kind = :Stochastic)
238+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#39#42"(), kind = :Stochastic)
239+
240+
241+ julia> set_node_value!(m, @varname(s2), 1.0)
242+ 1.0
243+
244+ julia> get_node_value(m, @varname s2)
245+ 1.0
246+ ```
247+ """
248+ function set_node_value! (m:: Model , ind:: VarName , value:: T ) where T
249+ @assert typeof (m[ind]. value[]) == T
250+ m[ind]. value[] = value
225251end
226252
253+ """
254+ get_node_value(m::Model, ind::VarName)
227255
228- function Base. iterate (m:: Model , state= 1 )
229- state > length (nodes (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
256+ Retrieve the value of a particular node, indexed by a VarName.
257+
258+ # Examples
259+
260+ julia> m = Model( s2 = (0.0, () -> InverseGamma(2.0,3.0), :Stochastic),
261+ μ = (1.0, () -> 1.0, :Logical),
262+ y = (0.0, (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic))
263+ Nodes:
264+ μ = (input = (), value = Base.RefValue{Float64}(1.0), eval = var"#44#47"(), kind = :Logical)
265+ s2 = (input = (), value = Base.RefValue{Float64}(0.0), eval = var"#43#46"(), kind = :Stochastic)
266+ y = (input = (:μ, :s2), value = Base.RefValue{Float64}(0.0), eval = var"#45#48"(), kind = :Stochastic)
267+
268+
269+ julia> get_node_value(m, @varname s2)
270+ 0.0
271+ """
272+
273+ function get_node_value (m:: Model , ind:: VarName )
274+ v = getproperty (m[ind], :value )
275+ v[]
230276end
277+ # Base.get(m::Model, ind::VarName, field::Symbol) = field==:value ? getvalue(m, ind) : getproperty(m[ind],field)
231278
232- Base . eltype (m :: Model ) = NamedTuple{ fieldnames (GraphInfo)[ 1 : 4 ]}
233- Base . IteratorEltype (m:: Model ) = HasEltype ( )
279+ """
280+ get_node_input (m::Model, ind::VarName )
234281
235- Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
236- Base. values (m:: Model ) = Base. Generator (identity, m)
237- Base. length (m:: Model ) = length (nodes (m))
238- Base. keytype (m:: Model ) = eltype (keys (m))
239- Base. valtype (m:: Model ) = eltype (m)
282+ Retrieve the inputs/parents of a node, as given by model DAG.
283+ """
284+ get_node_input (m:: Model , ind:: VarName ) = getproperty (m[ind], :input )
240285
286+ """
287+ get_node_input(m::Model, ind::VarName)
241288
289+ Retrieve the evaluation function for a node.
242290"""
243- dag(m::Model)
291+ get_node_eval (m:: Model , ind:: VarName ) = getproperty (m[ind], :eval )
292+
293+ """
294+ get_nodekind(m::Model, ind::VarName)
295+
296+ Retrieve the type of the node, i.e. stochastic or logical.
297+ """
298+ get_nodekind (m:: Model , ind:: VarName ) = getproperty (m[ind], :kind )
299+
300+ """
301+ get_dag(m::Model)
244302
245303Returns the adjacency matrix of the model as a SparseArray.
246304"""
247305get_dag (m:: Model ) = m. g. A
248306
249307"""
250- nodes (m::Model)
308+ get_sorted_vertices (m::Model)
251309
252310Returns a `Vector{Symbol}` containing the sorted vertices
253311of the DAG.
254312"""
255- nodes (m:: Model ) = m. g. sorted_vertices
313+ get_sorted_vertices (m:: Model ) = getproperty (m. g, :sorted_vertices )
314+
315+ # iterators
316+
317+ function Base. iterate (m:: Model , state= 1 )
318+ state > length (get_sorted_vertices (m)) ? nothing : (m[VarName {m.g.sorted_vertices[state]} ()], state+ 1 )
319+ end
320+
321+ Base. eltype (m:: Model ) = NamedTuple{fieldnames (GraphInfo)[1 : 4 ]}
322+ Base. IteratorEltype (m:: Model ) = HasEltype ()
323+
324+ Base. keys (m:: Model ) = (VarName {n} () for n in m. g. sorted_vertices)
325+ Base. values (m:: Model ) = Base. Generator (identity, m)
326+ Base. length (m:: Model ) = length (get_sorted_vertices (m))
327+ Base. keytype (m:: Model ) = eltype (keys (m))
328+ Base. valtype (m:: Model ) = eltype (m)
329+
330+ # show methods
331+ function Base. show (io:: IO , m:: Model )
332+ print (io, " Nodes: \n " )
333+ for node in get_sorted_vertices (m)
334+ print (io, " $node = " , m[VarName {node} ()], " \n " )
335+ end
336+ end
0 commit comments