@@ -62,7 +62,12 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform
6262 sz = Bijectors. output_size (b, size (dist))
6363 y = rand (rng, Uniform (u. lower, u. upper), sz)
6464 b_inv = Bijectors. inverse (b)
65- return b_inv (y)
65+ x = b_inv (y)
66+ # https://github.com/TuringLang/Bijectors.jl/issues/398
67+ if x isa Array{<: Any ,0 }
68+ x = x[]
69+ end
70+ return x
6671end
6772
6873"""
@@ -134,12 +139,14 @@ function tilde_assume(
134139 # `init()` always returns values in original space, i.e. possibly
135140 # constrained
136141 x = init (ctx. rng, vn, dist, ctx. strategy)
137- # There is a function `to_maybe_linked_internal_transform` that does this,
138- # but unfortunately it uses `istrans(vi, vn)` which fails if vn is not in
139- # vi, so we have to manually check. By default we will insert an unlinked
140- # value into the varinfo.
141- is_transformed = in_varinfo ? istrans (vi, vn) : false
142- f = if is_transformed
142+ # Determine whether to insert a transformed value into the VarInfo.
143+ # If the VarInfo alrady had a value for this variable, we will
144+ # keep the same linked status as in the original VarInfo. If not, we
145+ # check the rest of the VarInfo to see if other variables are linked.
146+ # istrans(vi) returns true if vi is nonempty and all variables in vi
147+ # are linked.
148+ insert_transformed_value = in_varinfo ? istrans (vi, vn) : istrans (vi)
149+ f = if insert_transformed_value
143150 to_linked_internal_transform (vi, vn, dist)
144151 else
145152 to_internal_transform (vi, vn, dist)
@@ -150,7 +157,7 @@ function tilde_assume(
150157 # always converts x to a vector, i.e., if dist is univariate, f(x) will be
151158 # a vector of length 1. It would be nice if we could unify these.
152159 y = f (x)
153- logjac = logabsdetjac (is_transformed ? Bijectors . bijector (dist) : identity, x)
160+ logjac = logabsdetjac (insert_transformed_value ? link_transform (dist) : identity, x)
154161 # Add the new value to the VarInfo. `push!!` errors if the value already
155162 # exists, hence the need for setindex!!
156163 if in_varinfo
0 commit comments