Skip to content

Commit dbc7a22

Browse files
committed
workaround Mooncake segfault
1 parent 6d70a9e commit dbc7a22

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/contexts/init.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ abstract type AbstractInitStrategy end
1212
"""
1313
InitValue{T,F}
1414
15-
A wrapper type representing a value of type `T`. The function `F` indicates what transform
16-
to apply to the value to convert it back to the unlinked space. If `value` is already in
17-
unlinked space, then `transform` can be `identity`.
1815
"""
1916
struct InitValue{T,F}
2017
value::T
@@ -26,7 +23,11 @@ end
2623
2724
Generate a new value for a random variable with the given distribution.
2825
29-
This function must return a `InitValue`.
26+
This function must return a tuple of:
27+
28+
- the generated value
29+
- a function that transforms the generated value back to the unlinked space. If the value is
30+
already in unlinked space, then this should be `identity`.
3031
"""
3132
function init end
3233

@@ -37,7 +38,7 @@ Obtain new values by sampling from the prior distribution.
3738
"""
3839
struct InitFromPrior <: AbstractInitStrategy end
3940
function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior)
40-
return InitValue(rand(rng, dist), identity)
41+
return rand(rng, dist), identity
4142
end
4243

4344
"""
@@ -77,7 +78,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
7778
if x isa Array{<:Any,0}
7879
x = x[]
7980
end
80-
return InitValue(x, identity)
81+
return x, identity
8182
end
8283

8384
"""
@@ -128,7 +129,7 @@ function init(
128129
else
129130
# TODO(penelopeysm): Since x is user-supplied, maybe we could also
130131
# check here that the type / size of x matches the dist?
131-
InitValue(x, identity)
132+
x, identity
132133
end
133134
else
134135
p.fallback === nothing && error("No value was provided for the variable `$(vn)`.")
@@ -201,7 +202,7 @@ function init(
201202
else
202203
from_vec_transform(dist)
203204
end
204-
return InitValue((@view vr.vect[range_and_linked.range]), transform)
205+
return (@view vr.vect[range_and_linked.range]), transform
205206
end
206207

207208
"""
@@ -233,8 +234,8 @@ function tilde_assume!!(
233234
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
234235
)
235236
in_varinfo = haskey(vi, vn)
236-
init_val = init(ctx.rng, vn, dist, ctx.strategy)
237-
x, inv_logjac = with_logabsdet_jacobian(init_val.transform, init_val.value)
237+
val, transform = init(ctx.rng, vn, dist, ctx.strategy)
238+
x, inv_logjac = with_logabsdet_jacobian(transform, val)
238239
# Determine whether to insert a transformed value into the VarInfo.
239240
# If the VarInfo alrady had a value for this variable, we will
240241
# keep the same linked status as in the original VarInfo. If not, we

0 commit comments

Comments
 (0)