@@ -1049,7 +1049,7 @@ def test_car_moment(self, mu, size, expected):
10491049 )
10501050 def test_mvstudentt_moment (self , nu , mu , cov , size , expected ):
10511051 with pm .Model () as model :
1052- x = pm .MvStudentT ("x" , nu = nu , mu = mu , cov = cov , size = size )
1052+ x = pm .MvStudentT ("x" , nu = nu , mu = mu , scale = cov , size = size )
10531053
10541054 # MvStudentT logp is only impemented for up to 2D variables
10551055 assert_moment_is_expected (model , expected , check_finite_logp = x .ndim < 3 )
@@ -1369,28 +1369,28 @@ def test_issue_3706(self):
13691369
13701370
13711371class TestMvStudentTCov (BaseTestDistributionRandom ):
1372- def mvstudentt_rng_fn (self , size , nu , mu , cov , rng ):
1373- mv_samples = rng .multivariate_normal (np .zeros_like (mu ), cov , size = size )
1372+ def mvstudentt_rng_fn (self , size , nu , mu , scale , rng ):
1373+ mv_samples = rng .multivariate_normal (np .zeros_like (mu ), scale , size = size )
13741374 chi2_samples = rng .chisquare (nu , size = size )
13751375 return (mv_samples / np .sqrt (chi2_samples [:, None ] / nu )) + mu
13761376
13771377 pymc_dist = pm .MvStudentT
13781378 pymc_dist_params = {
13791379 "nu" : 5 ,
13801380 "mu" : np .array ([1.0 , 2.0 ]),
1381- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1381+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
13821382 }
13831383 expected_rv_op_params = {
13841384 "nu" : 5 ,
13851385 "mu" : np .array ([1.0 , 2.0 ]),
1386- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1386+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
13871387 }
13881388 sizes_to_check = [None , (1 ), (2 , 3 )]
13891389 sizes_expected = [(2 ,), (1 , 2 ), (2 , 3 , 2 )]
13901390 reference_dist_params = {
13911391 "nu" : 5 ,
13921392 "mu" : np .array ([1.0 , 2.0 ]),
1393- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1393+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
13941394 }
13951395 reference_dist = lambda self : ft .partial (self .mvstudentt_rng_fn , rng = self .get_random_state ())
13961396 checks_to_run = [
@@ -1409,29 +1409,29 @@ def check_errors(self):
14091409 "mvstudentt" ,
14101410 nu = np .array ([1 , 2 ]),
14111411 mu = np .ones (2 ),
1412- cov = np .full ((2 , 2 ), np .ones (2 )),
1412+ scale = np .full ((2 , 2 ), np .ones (2 )),
14131413 )
14141414
14151415 def check_mu_broadcast_helper (self ):
14161416 """Test that mu is broadcasted to the shape of cov"""
1417- x = pm .MvStudentT .dist (nu = 4 , mu = 1 , cov = np .eye (3 ))
1417+ x = pm .MvStudentT .dist (nu = 4 , mu = 1 , scale = np .eye (3 ))
14181418 mu = x .owner .inputs [4 ]
14191419 assert mu .eval ().shape == (3 ,)
14201420
1421- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), cov = np .eye (3 ))
1421+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), scale = np .eye (3 ))
14221422 mu = x .owner .inputs [4 ]
14231423 assert mu .eval ().shape == (3 ,)
14241424
1425- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1425+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), scale = np .eye (3 ))
14261426 mu = x .owner .inputs [4 ]
14271427 assert mu .eval ().shape == (1 , 3 )
14281428
1429- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1429+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), scale = np .eye (3 ))
14301430 mu = x .owner .inputs [4 ]
14311431 assert mu .eval ().shape == (10 , 3 )
14321432
14331433 # Cov is artificually limited to being 2D
1434- # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov =np.full((2, 3, 3), np.eye(3)))
1434+ # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), scale =np.full((2, 3, 3), np.eye(3)))
14351435 # mu = x.owner.inputs[4]
14361436 # assert mu.eval().shape == (10, 2, 3)
14371437
@@ -1446,7 +1446,7 @@ class TestMvStudentTChol(BaseTestDistributionRandom):
14461446 expected_rv_op_params = {
14471447 "nu" : 5 ,
14481448 "mu" : np .array ([1.0 , 2.0 ]),
1449- "cov " : quaddist_matrix (chol = pymc_dist_params ["chol" ]).eval (),
1449+ "scale " : quaddist_matrix (chol = pymc_dist_params ["chol" ]).eval (),
14501450 }
14511451 checks_to_run = ["check_pymc_params_match_rv_op" ]
14521452
0 commit comments