@@ -148,32 +148,35 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
148148 assert np .all (idata12 ["x" ] != idata22 ["x" ])
149149 assert np .all (idata13 ["x" ] != idata23 ["x" ])
150150
151- def test_sample_init (self ):
151+ @pytest .mark .parametrize (
152+ "init" ,
153+ (
154+ "advi" ,
155+ "advi_map" ,
156+ "map" ,
157+ "adapt_diag" ,
158+ "jitter+adapt_diag" ,
159+ "jitter+adapt_diag_grad" ,
160+ "adapt_full" ,
161+ "jitter+adapt_full" ,
162+ ),
163+ )
164+ def test_sample_init (self , init ):
152165 with self .model :
153- for init in (
154- "advi" ,
155- "advi_map" ,
156- "map" ,
157- "adapt_diag" ,
158- "jitter+adapt_diag" ,
159- "jitter+adapt_diag_grad" ,
160- "adapt_full" ,
161- "jitter+adapt_full" ,
162- ):
163- kwargs = {
164- "init" : init ,
165- "tune" : 120 ,
166- "n_init" : 1000 ,
167- "draws" : 50 ,
168- "random_seed" : 20160911 ,
169- }
170- with warnings .catch_warnings (record = True ) as rec :
171- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
172- if init .endswith ("adapt_full" ):
173- with pytest .warns (UserWarning , match = "experimental feature" ):
174- pm .sample (** kwargs )
175- else :
176- pm .sample (** kwargs )
166+ kwargs = {
167+ "init" : init ,
168+ "tune" : 120 ,
169+ "n_init" : 1000 ,
170+ "draws" : 50 ,
171+ "random_seed" : 20160911 ,
172+ }
173+ with warnings .catch_warnings (record = True ) as rec :
174+ warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
175+ if init .endswith ("adapt_full" ):
176+ with pytest .warns (UserWarning , match = "experimental feature" ):
177+ pm .sample (** kwargs , cores = 1 )
178+ else :
179+ pm .sample (** kwargs , cores = 1 )
177180
178181 def test_sample_args (self ):
179182 with self .model :
0 commit comments