@@ -837,245 +837,6 @@ end
837837# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
838838infer_nested_eltype (:: Type{<:AbstractDict{<:Any,ET}} ) where {ET} = infer_nested_eltype (ET)
839839
840- """
841- varname_leaves(vn::VarName, val)
842-
843- Return an iterator over all varnames that are represented by `vn` on `val`.
844-
845- # Examples
846- ```jldoctest
847- julia> using DynamicPPL: varname_leaves
848-
849- julia> foreach(println, varname_leaves(@varname(x), rand(2)))
850- x[1]
851- x[2]
852-
853- julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
854- x[1:2][1]
855- x[1:2][2]
856-
857- julia> x = (y = 1, z = [[2.0], [3.0]]);
858-
859- julia> foreach(println, varname_leaves(@varname(x), x))
860- x.y
861- x.z[1][1]
862- x.z[2][1]
863- ```
864- """
865- varname_leaves (vn:: VarName , :: Real ) = [vn]
866- function varname_leaves (vn:: VarName , val:: AbstractArray{<:Union{Real,Missing}} )
867- return (
868- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)) for
869- I in CartesianIndices (val)
870- )
871- end
872- function varname_leaves (vn:: VarName , val:: AbstractArray )
873- return Iterators. flatten (
874- varname_leaves (
875- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), val[I]
876- ) for I in CartesianIndices (val)
877- )
878- end
879- function varname_leaves (vn:: VarName , val:: NamedTuple )
880- iter = Iterators. map (keys (val)) do k
881- optic = Accessors. PropertyLens {k} ()
882- varname_leaves (VarName {getsym(vn)} (optic ∘ getoptic (vn)), optic (val))
883- end
884- return Iterators. flatten (iter)
885- end
886-
887- """
888- varname_and_value_leaves(vn::VarName, val)
889-
890- Return an iterator over all varname-value pairs that are represented by `vn` on `val`.
891-
892- # Examples
893- ```jldoctest varname-and-value-leaves
894- julia> using DynamicPPL: varname_and_value_leaves
895-
896- julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2))
897- (x[1], 1)
898- (x[2], 2)
899-
900- julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2))
901- (x[1:2][1], 1)
902- (x[1:2][2], 2)
903-
904- julia> x = (y = 1, z = [[2.0], [3.0]]);
905-
906- julia> foreach(println, varname_and_value_leaves(@varname(x), x))
907- (x.y, 1)
908- (x.z[1][1], 2.0)
909- (x.z[2][1], 3.0)
910- ```
911-
912- There is also some special handling for certain types:
913-
914- ```jldoctest varname-and-value-leaves
915- julia> using LinearAlgebra
916-
917- julia> x = reshape(1:4, 2, 2);
918-
919- julia> # `LowerTriangular`
920- foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x)))
921- (x[1, 1], 1)
922- (x[2, 1], 2)
923- (x[2, 2], 4)
924-
925- julia> # `UpperTriangular`
926- foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x)))
927- (x[1, 1], 1)
928- (x[1, 2], 3)
929- (x[2, 2], 4)
930-
931- julia> # `Cholesky` with lower-triangular
932- foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0)))
933- (x.L[1, 1], 1.0)
934- (x.L[2, 1], 0.0)
935- (x.L[2, 2], 1.0)
936-
937- julia> # `Cholesky` with upper-triangular
938- foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0)))
939- (x.U[1, 1], 1.0)
940- (x.U[1, 2], 0.0)
941- (x.U[2, 2], 1.0)
942- ```
943- """
944- function varname_and_value_leaves (vn:: VarName , x)
945- return Iterators. map (value, Iterators. flatten (varname_and_value_leaves_inner (vn, x)))
946- end
947-
948- """
949- varname_and_value_leaves(container)
950-
951- Return an iterator over all varname-value pairs that are represented by `container`.
952-
953- This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container
954- containing multiple varnames.
955-
956- See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref).
957-
958- # Examples
959- ```jldoctest varname-and-value-leaves-container
960- julia> using DynamicPPL: varname_and_value_leaves
961-
962- julia> # With an `OrderedDict`
963- dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]);
964-
965- julia> foreach(println, varname_and_value_leaves(dict))
966- (y, 1)
967- (z[1][1], 2.0)
968- (z[2][1], 3.0)
969-
970- julia> # With a `NamedTuple`
971- nt = (y = 1, z = [[2.0], [3.0]]);
972-
973- julia> foreach(println, varname_and_value_leaves(nt))
974- (y, 1)
975- (z[1][1], 2.0)
976- (z[2][1], 3.0)
977- ```
978- """
979- function varname_and_value_leaves (container:: OrderedDict )
980- return Iterators. flatten (varname_and_value_leaves (k, v) for (k, v) in container)
981- end
982- function varname_and_value_leaves (container:: NamedTuple )
983- return Iterators. flatten (
984- varname_and_value_leaves (VarName {k} (), v) for (k, v) in pairs (container)
985- )
986- end
987-
988- """
989- Leaf{T}
990-
991- A container that represents the leaf of a nested structure, implementing
992- `iterate` to return itself.
993-
994- This is particularly useful in conjunction with `Iterators.flatten` to
995- prevent flattening of nested structures.
996- """
997- struct Leaf{T}
998- value:: T
999- end
1000-
1001- Leaf (xs... ) = Leaf (xs)
1002-
1003- # Allow us to treat `Leaf` as an iterator containing a single element.
1004- # Something like an `[x]` would also be an iterator with a single element,
1005- # but when we call `flatten` on this, it would also iterate over `x`,
1006- # unflattening that too. By making `Leaf` a single-element iterator, which
1007- # returns itself, we can call `iterate` on this as many times as we like
1008- # without causing any change. The result is that `Iterators.flatten`
1009- # will _not_ unflatten `Leaf`s.
1010- # Note that this is similar to how `Base.iterate` is implemented for `Real`::
1011- #
1012- # julia> iterate(1)
1013- # (1, nothing)
1014- #
1015- # One immediate example where this becomes in our scenario is that we might
1016- # have `missing` values in our data, which does _not_ have an `iterate`
1017- # implemented. Calling `Iterators.flatten` on this would cause an error.
1018- Base. iterate (leaf:: Leaf ) = leaf, nothing
1019- Base. iterate (:: Leaf , _) = nothing
1020-
1021- # Convenience.
1022- value (leaf:: Leaf ) = leaf. value
1023-
1024- # Leaf-types.
1025- varname_and_value_leaves_inner (vn:: VarName , x:: Real ) = [Leaf (vn, x)]
1026- function varname_and_value_leaves_inner (
1027- vn:: VarName , val:: AbstractArray{<:Union{Real,Missing}}
1028- )
1029- return (
1030- Leaf (
1031- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ AbstractPPL. getoptic (vn)),
1032- val[I],
1033- ) for I in CartesianIndices (val)
1034- )
1035- end
1036- # Containers.
1037- function varname_and_value_leaves_inner (vn:: VarName , val:: AbstractArray )
1038- return Iterators. flatten (
1039- varname_and_value_leaves_inner (
1040- VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ AbstractPPL. getoptic (vn)),
1041- val[I],
1042- ) for I in CartesianIndices (val)
1043- )
1044- end
1045- function varname_and_value_leaves_inner (vn:: VarName , val:: NamedTuple )
1046- iter = Iterators. map (keys (val)) do k
1047- optic = Accessors. PropertyLens {k} ()
1048- varname_and_value_leaves_inner (
1049- VarName {getsym(vn)} (optic ∘ getoptic (vn)), optic (val)
1050- )
1051- end
1052-
1053- return Iterators. flatten (iter)
1054- end
1055- # Special types.
1056- function varname_and_value_leaves_inner (vn:: VarName , x:: Cholesky )
1057- # TODO : Or do we use `PDMat` here?
1058- return if x. uplo == ' L'
1059- varname_and_value_leaves_inner (Accessors. PropertyLens {:L} () ∘ vn, x. L)
1060- else
1061- varname_and_value_leaves_inner (Accessors. PropertyLens {:U} () ∘ vn, x. U)
1062- end
1063- end
1064- function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.LowerTriangular )
1065- return (
1066- Leaf (VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), x[I])
1067- # Iteration over the lower-triangular indices.
1068- for I in CartesianIndices (x) if I[1 ] >= I[2 ]
1069- )
1070- end
1071- function varname_and_value_leaves_inner (vn:: VarName , x:: LinearAlgebra.UpperTriangular )
1072- return (
1073- Leaf (VarName {getsym(vn)} (Accessors. IndexLens (Tuple (I)) ∘ getoptic (vn)), x[I])
1074- # Iteration over the upper-triangular indices.
1075- for I in CartesianIndices (x) if I[1 ] <= I[2 ]
1076- )
1077- end
1078-
1079840broadcast_safe (x) = x
1080841broadcast_safe (x:: Distribution ) = (x,)
1081842broadcast_safe (x:: AbstractContext ) = (x,)
0 commit comments