@@ -488,7 +488,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
488488
489489 # Construct a chain with 'sampled values' of β
490490 ground_truth_β = 2
491- β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
491+ β_chain = MCMCChains. Chains (
492+ rand (Normal (ground_truth_β, 0.002 ), 1000 ),
493+ [:β ];
494+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
495+ )
492496
493497 # Generate predictions from that chain
494498 xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
@@ -534,7 +538,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
534538 @testset " prediction from multiple chains" begin
535539 # Normal linreg model
536540 multiple_β_chain = MCMCChains. Chains (
537- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
541+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ),
542+ [:β ];
543+ info= (; varname_to_symbol= Dict (@varname (β) => :β )),
538544 )
539545 predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
540546 @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
0 commit comments