@@ -410,15 +410,15 @@ def test_not_supported_marginalized_deterministic_and_potential():
410410 (None , does_not_warn ()),
411411 (UNSET , does_not_warn ()),
412412 (transforms .log , does_not_warn ()),
413- (transforms .Chain ([transforms .log , transforms .logodds ]), does_not_warn ()),
413+ (transforms .Chain ([transforms .logodds , transforms .log ]), does_not_warn ()),
414414 (
415- transforms .Interval (0 , 1 ),
415+ transforms .Interval (0 , 2 ),
416416 pytest .warns (
417417 UserWarning , match = "which depends on the marginalized idx may no longer work"
418418 ),
419419 ),
420420 (
421- transforms .Chain ([transforms .log , transforms .Interval (0 , 1 )]),
421+ transforms .Chain ([transforms .log , transforms .Interval (- 1 , 1 )]),
422422 pytest .warns (
423423 UserWarning , match = "which depends on the marginalized idx may no longer work"
424424 ),
@@ -428,7 +428,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
428428def test_marginalized_transforms (transform , expected_warning ):
429429 w = [0.1 , 0.3 , 0.6 ]
430430 data = [0 , 5 , 10 ]
431- initval = 0.5 # Value that will be negative on the unconstrained space
431+ initval = 0.7 # Value that will be negative on the unconstrained space
432432
433433 with pm .Model () as m_ref :
434434 sigma = pm .Mixture (
@@ -467,7 +467,7 @@ def test_marginalized_transforms(transform, expected_warning):
467467 transform_name = "log"
468468 else :
469469 transform_name = transform .name
470- assert f"sigma_{ transform_name } __" in ip
470+ assert - np . inf < ip [ f"sigma_{ transform_name } __" ] < 0.0
471471 np .testing .assert_allclose (m .compile_logp ()(ip ), m_ref .compile_logp ()(ip ))
472472
473473
0 commit comments