169169OptimLogDensity(model; adtype=adtype)
170170```
171171
172+ Here, `ctx` must be a context that contains an `OptimizationContext` as its
173+ leaf.
174+
172175If not specified, `adtype` defaults to `AutoForwardDiff()`.
173176
174177An OptimLogDensity does not, in itself, obey the LogDensityProblems interface.
@@ -189,24 +192,40 @@ optim_ld(z) # returns -logp
189192```
190193"""
191194struct OptimLogDensity{
192- M<: DynamicPPL.Model ,
193- F<: Function ,
194- V<: DynamicPPL.AbstractVarInfo ,
195- C<: DynamicPPL.AbstractContext ,
196- AD<: ADTypes.AbstractADType ,
195+ M<: DynamicPPL.Model ,F<: Function ,V<: DynamicPPL.AbstractVarInfo ,AD<: ADTypes.AbstractADType
197196}
198- ldf:: DynamicPPL.LogDensityFunction{M,F,V,C,AD}
199- end
197+ ldf:: DynamicPPL.LogDensityFunction{M,F,V,AD}
200198
201- function OptimLogDensity (
202- model :: DynamicPPL.Model ,
203- getlogdensity :: Function ,
204- vi :: DynamicPPL.AbstractVarInfo = DynamicPPL . ldf_default_varinfo (model, getlogdensity);
205- adtype = AutoForwardDiff () ,
206- )
207- return OptimLogDensity (
208- DynamicPPL . LogDensityFunction (model, getlogdensity, vi; adtype= adtype)
199+ # Inner constructors enforce that the model has an OptimizationContext as
200+ # its leaf context.
201+ function OptimLogDensity (
202+ model :: DynamicPPL.Model ,
203+ getlogdensity :: Function ,
204+ vi :: DynamicPPL.VarInfo ,
205+ ctx :: OptimizationContext ;
206+ adtype:: ADTypes.AbstractADType = Turing . DEFAULT_ADTYPE,
209207 )
208+ new_context = DynamicPPL. setleafcontext (model, ctx)
209+ new_model = contextualize (model, new_context)
210+ return new {typeof(new_model),typeof(getlogdensity),typeof(vi),typeof(adtype)} (
211+ DynamicPPL. LogDensityFunction (new_model, getlogdensity, vi; adtype= adtype)
212+ )
213+ end
214+ function OptimLogDensity (
215+ model:: DynamicPPL.Model ,
216+ getlogdensity:: Function ,
217+ ctx:: OptimizationContext ;
218+ adtype:: ADTypes.AbstractADType = Turing. DEFAULT_ADTYPE,
219+ )
220+ # No varinfo
221+ return OptimLogDensity (
222+ model,
223+ getlogdensity,
224+ DynamicPPL. ldf_default_varinfo (model, getlogdensity),
225+ ctx;
226+ adtype= adtype,
227+ )
228+ end
210229end
211230
212231"""
0 commit comments