-
Notifications
You must be signed in to change notification settings - Fork 37
Closed
Description
In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).
It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand for a conditioned model:
StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)Metadata
Metadata
Assignees
Labels
No labels