@@ -629,3 +629,47 @@ def test_data_container():
629629
630630 ip = marginal_m .initial_point ()
631631 np .testing .assert_allclose (logp_fn (ip ), ref_logp_fn (ip ))
632+
633+
634+ @pytest .mark .parametrize ("univariate" , (True , False ))
635+ def test_vector_univariate_mixture (univariate ):
636+
637+ with MarginalModel () as m :
638+ idx = pm .Bernoulli ("idx" , p = 0.5 , shape = (2 ,) if univariate else ())
639+
640+ def dist (idx , size ):
641+ return pm .math .switch (
642+ pm .math .eq (idx , 0 ),
643+ pm .Normal .dist ([- 10 , - 10 ], 1 ),
644+ pm .Normal .dist ([10 , 10 ], 1 ),
645+ )
646+
647+ pm .CustomDist ("norm" , idx , dist = dist )
648+
649+ m .marginalize (idx )
650+ logp_fn = m .compile_logp ()
651+
652+ if univariate :
653+ with pm .Model () as ref_m :
654+ pm .NormalMixture ("norm" , w = [0.5 , 0.5 ], mu = [[- 10 , 10 ], [- 10 , 10 ]], shape = (2 ,))
655+ else :
656+ with pm .Model () as ref_m :
657+ pm .Mixture (
658+ "norm" ,
659+ w = [0.5 , 0.5 ],
660+ comp_dists = [
661+ pm .MvNormal .dist ([- 10 , - 10 ], np .eye (2 )),
662+ pm .MvNormal .dist ([10 , 10 ], np .eye (2 )),
663+ ],
664+ shape = (2 ,),
665+ )
666+ ref_logp_fn = ref_m .compile_logp ()
667+
668+ for test_value in (
669+ [- 10 , - 10 ],
670+ [10 , 10 ],
671+ [- 10 , 10 ],
672+ [- 10 , 10 ],
673+ ):
674+ pt = {"norm" : test_value }
675+ np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
0 commit comments