|
80 | 80 | Builds the `model_info` dictionary from the model's expression. |
81 | 81 | """ |
82 | 82 | function build_model_info(input_expr) |
83 | | - # Extract model name (:name), arguments (:args), (:kwargs) and definition (:body) |
84 | | - modeldef = MacroTools.splitdef(input_expr) |
85 | | - # Function body of the model is empty |
| 83 | + # Break up the model definition and extract its name, arguments, and function body |
| 84 | + modeldef = ExprTools.splitdef(input_expr) |
| 85 | + |
| 86 | + # Print a warning if function body of the model is empty |
86 | 87 | warn_empty(modeldef[:body]) |
87 | | - # Construct model_info dictionary |
| 88 | + |
| 89 | + ## Construct model_info dictionary |
| 90 | + |
| 91 | + # Shortcut if the model does not have any arguments |
| 92 | + if !haskey(modeldef, :args) |
| 93 | + modelinfo = Dict( |
| 94 | + :name => modeldef[:name], |
| 95 | + :main_body => modeldef[:body], |
| 96 | + :arg_syms => [], |
| 97 | + :args_nt => NamedTuple(), |
| 98 | + :defaults_nt => NamedTuple(), |
| 99 | + :args => [], |
| 100 | + :modeldef => modeldef, |
| 101 | + ) |
| 102 | + return modelinfo |
| 103 | + end |
88 | 104 |
|
89 | 105 | # Extracting the argument symbols from the model definition |
90 | 106 | arg_syms = map(modeldef[:args]) do arg |
@@ -158,7 +174,7 @@ function build_model_info(input_expr) |
158 | 174 | :args_nt => args_nt, |
159 | 175 | :defaults_nt => defaults_nt, |
160 | 176 | :args => args, |
161 | | - :whereparams => modeldef[:whereparams] |
| 177 | + :modeldef => modeldef, |
162 | 178 | ) |
163 | 179 |
|
164 | 180 | return model_info |
@@ -318,45 +334,60 @@ hasmissing(T::Type) = false |
318 | 334 | Builds the output expression. |
319 | 335 | """ |
320 | 336 | function build_output(model_info) |
321 | | - # Arguments with default values |
| 337 | + ## Build the anonymous evaluator from the user-provided model definition |
| 338 | + |
| 339 | + # Remove the name and use `function (....)` syntax |
| 340 | + modeldef = model_info[:modeldef] |
| 341 | + delete!(modeldef, :name) |
| 342 | + modeldef[:head] = :function |
| 343 | + |
| 344 | + # Define the input arguments (positional + keyword arguments), without default values |
| 345 | + origargs = map(vcat(get(modeldef, :args, Any[]), get(modeldef, :kwargs, Any[]))) do arg |
| 346 | + Meta.isexpr(arg, :kw) && length(arg.args) >= 1 ? arg.args[1] : arg |
| 347 | + end |
| 348 | + |
| 349 | + # Add our own arguments |
| 350 | + newargs = Any[:(_rng::$(Random.AbstractRNG)), |
| 351 | + :(_model::$(DynamicPPL.Model)), |
| 352 | + :(_varinfo::$(DynamicPPL.AbstractVarInfo)), |
| 353 | + :(_sampler::$(DynamicPPL.AbstractSampler)), |
| 354 | + :(_context::$(DynamicPPL.AbstractContext))] |
| 355 | + combinedargs = vcat(newargs, origargs) |
| 356 | + |
| 357 | + # Delete keyword arguments and update positional arguments |
| 358 | + delete!(modeldef, :kwargs) |
| 359 | + modeldef[:args] = combinedargs |
| 360 | + |
| 361 | + # Replace function body |
| 362 | + modeldef[:body] = model_info[:main_body] |
| 363 | + |
| 364 | + ## Extract other relevant information |
| 365 | + |
| 366 | + # All arguments with default values (if existent) |
322 | 367 | args = model_info[:args] |
323 | | - # Argument symbols without default values |
324 | | - arg_syms = model_info[:arg_syms] |
325 | | - # Arguments namedtuple |
| 368 | + # Named tuple of all arguments |
326 | 369 | args_nt = model_info[:args_nt] |
327 | | - # Default values of the arguments |
328 | | - # Arguments namedtuple |
| 370 | + |
| 371 | + # Named tuple of the default values of the arguments |
329 | 372 | defaults_nt = model_info[:defaults_nt] |
330 | | - # Where parameters |
331 | | - whereparams = model_info[:whereparams] |
332 | | - # Model generator name |
| 373 | + |
| 374 | + # Model name |
333 | 375 | model = model_info[:name] |
334 | | - # Main body of the model |
335 | | - main_body = model_info[:main_body] |
336 | 376 |
|
337 | | - unwrap_data_expr = Expr(:block) |
338 | | - for var in arg_syms |
339 | | - push!(unwrap_data_expr.args, |
340 | | - :($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var))) |
| 377 | + # Define model definition with only keyword arguments |
| 378 | + if isempty(args) |
| 379 | + model_kwform = () |
| 380 | + else |
| 381 | + # All arguments without default values (i.e., only symbols) |
| 382 | + arg_syms = model_info[:arg_syms] |
| 383 | + |
| 384 | + model_kwform = (:($model(; $(args...)) = $model($(arg_syms...))),) |
341 | 385 | end |
342 | 386 |
|
343 | | - model_kwform = isempty(args) ? () : (:($model(;$(args...)) = $model($(arg_syms...))),) |
344 | 387 | @gensym(evaluator) |
345 | | - |
346 | 388 | return quote |
347 | 389 | $(Base).@__doc__ function $model($(args...)) |
348 | | - $evaluator = let |
349 | | - function ( |
350 | | - _rng::$(Random.AbstractRNG), |
351 | | - _model::$(DynamicPPL.Model), |
352 | | - _varinfo::$(DynamicPPL.AbstractVarInfo), |
353 | | - _sampler::$(DynamicPPL.AbstractSampler), |
354 | | - _context::$(DynamicPPL.AbstractContext), |
355 | | - ) |
356 | | - $unwrap_data_expr |
357 | | - $main_body |
358 | | - end |
359 | | - end |
| 390 | + $evaluator = $(ExprTools.combinedef(modeldef)) |
360 | 391 | return $(DynamicPPL.Model)($evaluator, $args_nt, $defaults_nt) |
361 | 392 | end |
362 | 393 | $(model_kwform...) |
|
0 commit comments