@@ -72,8 +72,10 @@ function _rand_product(
7272 end |> collect
7373end
7474
75- @inline function logdensity_def (d:: AbstractProductMeasure , x)
76- mapreduce (logdensity_def, + , marginals (d), x)
75+ for func in [:logdensityof , :logdensity_def ]
76+ @eval @inline function $func (d:: AbstractProductMeasure , x)
77+ mapreduce ($ func, + , marginals (d), x)
78+ end
7779end
7880
7981struct ProductMeasure{M} <: AbstractProductMeasure
@@ -88,27 +90,37 @@ function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple}
8890 Pretty. list_layout (Pretty. tile .([marginals (d)... ]), sep = " ⊗ " )
8991end
9092
91- # For tuples, `mapreduce` has trouble with type inference
92- @inline function logdensity_def (d:: ProductMeasure{T} , x) where {T<: Tuple }
93- ℓs = map (logdensity_def, marginals (d), x)
94- sum (ℓs)
95- end
96-
97- @generated function logdensity_def (d:: ProductMeasure{NamedTuple{N,T}} , x) where {N,T}
93+ @eval @generated function _product_gen_impl (
94+ :: Val{func} ,
95+ d:: ProductMeasure{NamedTuple{N,T}} ,
96+ x,
97+ ) where {func,N,T}
9898 k1 = QuoteNode (first (N))
9999 q = quote
100100 m = marginals (d)
101- ℓ = logdensity_def (getproperty (m, $ k1), getproperty (x, $ k1))
101+ ℓ = $ func (getproperty (m, $ k1), getproperty (x, $ k1))
102102 end
103103 for k in Base. tail (N)
104104 k = QuoteNode (k)
105- qk = :(ℓ += logdensity_def (getproperty (m, $ k), getproperty (x, $ k)))
105+ qk = :(ℓ += $ func (getproperty (m, $ k), getproperty (x, $ k)))
106106 push! (q. args, qk)
107107 end
108108
109109 return q
110110end
111111
112+ for func in [:logdensityof , :logdensity_def ]
113+ # For tuples, `mapreduce` has trouble with type inference
114+ @eval @inline function $func (d:: ProductMeasure{T} , x) where {T<: Tuple }
115+ ℓs = map ($ func, marginals (d), x)
116+ sum (ℓs)
117+ end
118+
119+ @eval function $func (d:: ProductMeasure{NamedTuple{N,T}} , x) where {N,T}
120+ _product_gen_impl (Val ($ func), d, x)
121+ end
122+ end
123+
112124# @generated function basemeasure(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T}
113125# q = quote
114126# m = marginals(d)
0 commit comments