@@ -5,6 +5,16 @@ const NO_DEFAULT = NoDefault()
55# A short-hand for a type commonly used in type signatures for VarInfo methods.
66VarNameTuple = NTuple{N,VarName} where {N}
77
8+ # TODO (mhauru) This is currently used in the transformation functions of NoDist,
9+ # ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in
10+ # SimpleVarInfo and maybe other places.
11+ """
12+ The type for all log probability variables.
13+
14+ This is Float64 on 64-bit systems and Float32 on 32-bit systems.
15+ """
16+ const LogProbType = float (Real)
17+
818"""
919 @addlogprob!(ex)
1020
@@ -252,12 +262,16 @@ function (f::UnwrapSingletonTransform)(x)
252262 return only (x)
253263end
254264
255- Bijectors. with_logabsdet_jacobian (f:: UnwrapSingletonTransform , x) = (f (x), 0 )
265+ function Bijectors. with_logabsdet_jacobian (f:: UnwrapSingletonTransform , x)
266+ return f (x), zero (LogProbType)
267+ end
268+
256269function Bijectors. with_logabsdet_jacobian (
257270 inv_f:: Bijectors.Inverse{<:UnwrapSingletonTransform} , x
258271)
259272 f = inv_f. orig
260- return (reshape ([x], f. input_size), 0 )
273+ result = reshape ([x], f. input_size)
274+ return result, zero (LogProbType)
261275end
262276
263277"""
@@ -306,18 +320,26 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
306320 return inverse (x)
307321end
308322
309- Bijectors. with_logabsdet_jacobian (f:: ReshapeTransform , x) = (f (x), 0 )
323+ function Bijectors. with_logabsdet_jacobian (f:: ReshapeTransform , x)
324+ return f (x), zero (LogProbType)
325+ end
310326
311327function Bijectors. with_logabsdet_jacobian (inv_f:: Bijectors.Inverse{<:ReshapeTransform} , x)
312- return ( inv_f (x), 0 )
328+ return inv_f (x), zero (LogProbType )
313329end
314330
315331struct ToChol <: Bijectors.Bijector
316332 uplo:: Char
317333end
318334
319- Bijectors. with_logabsdet_jacobian (f:: ToChol , x) = (Cholesky (Matrix (x), f. uplo, 0 ), 0 )
320- Bijectors. with_logabsdet_jacobian (:: Bijectors.Inverse{<:ToChol} , y:: Cholesky ) = (y. UL, 0 )
335+ function Bijectors. with_logabsdet_jacobian (f:: ToChol , x)
336+ return Cholesky (Matrix (x), f. uplo, 0 ), zero (LogProbType)
337+ end
338+
339+ function Bijectors. with_logabsdet_jacobian (:: Bijectors.Inverse{<:ToChol} , y:: Cholesky )
340+ return y. UL, zero (LogProbType)
341+ end
342+
321343function Bijectors. with_logabsdet_jacobian (:: Bijectors.Inverse{<:ToChol} , y)
322344 return error (
323345 " Inverse{ToChol} is only defined for Cholesky factorizations. " *
0 commit comments