1313import pickle
1414import shutil
1515import sys
16- import timeit
1716from collections import OrderedDict
1817from tempfile import mkdtemp
1918
@@ -2179,15 +2178,13 @@ def scan_fn():
21792178@pytest .mark .skipif (
21802179 not config .cxx , reason = "G++ not available, so we need to skip this test."
21812180)
2182- def test_cython_performance ():
2181+ def test_cython_performance (benchmark ):
21832182
21842183 # This implicitly confirms that the Cython version is being used
21852184 from pytensor .scan import scan_perform_ext # noqa: F401
21862185
21872186 # Python usually out-performs PyTensor below 100 iterations
21882187 N = 200
2189- n_timeit = 50
2190-
21912188 M = - 1 / np .arange (1 , 11 ).astype (config .floatX )
21922189 r = np .arange (N * 10 ).astype (config .floatX ).reshape (N , 10 )
21932190
@@ -2216,17 +2213,11 @@ def f_py():
22162213 # Make sure we're actually computing a `Scan`
22172214 assert any (isinstance (node .op , Scan ) for node in f_cvm .maker .fgraph .apply_nodes )
22182215
2219- cvm_res = f_cvm ( )
2216+ cvm_res = benchmark ( f_cvm )
22202217
22212218 # Make sure the results are the same between the two implementations
22222219 assert np .allclose (cvm_res , py_res )
22232220
2224- python_duration = timeit .timeit (lambda : f_py (), number = n_timeit )
2225- cvm_duration = timeit .timeit (lambda : f_cvm (), number = n_timeit )
2226- print (f"python={ python_duration } , cvm={ cvm_duration } " )
2227-
2228- assert cvm_duration <= python_duration
2229-
22302221
22312222@config .change_flags (mode = "FAST_COMPILE" , compute_test_value = "raise" )
22322223def test_compute_test_values ():
@@ -2662,7 +2653,7 @@ def numpy_implementation(vsample):
26622653 n_result = numpy_implementation (v_vsample )
26632654 utt .assert_allclose (t_result , n_result )
26642655
2665- def test_reordering (self ):
2656+ def test_reordering (self , benchmark ):
26662657 """Test re-ordering of inputs.
26672658
26682659 some rnn with multiple outputs and multiple inputs; other
@@ -2722,14 +2713,14 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
27222713 v_x [i ] = np .dot (v_u1 [i ], vW_in1 ) + v_u2 [i ] * vW_in2 + np .dot (v_x [i - 1 ], vW )
27232714 v_y [i ] = np .dot (v_x [i - 1 ], vWout ) + v_y [i - 1 ]
27242715
2725- (pytensor_dump1 , pytensor_dump2 , pytensor_x , pytensor_y ) = f4 (
2726- v_u1 , v_u2 , v_x0 , v_y0 , vW_in1
2716+ (pytensor_dump1 , pytensor_dump2 , pytensor_x , pytensor_y ) = benchmark (
2717+ f4 , v_u1 , v_u2 , v_x0 , v_y0 , vW_in1
27272718 )
27282719
27292720 utt .assert_allclose (pytensor_x , v_x )
27302721 utt .assert_allclose (pytensor_y , v_y )
27312722
2732- def test_scan_as_tensor_on_gradients (self ):
2723+ def test_scan_as_tensor_on_gradients (self , benchmark ):
27332724 to_scan = dvector ("to_scan" )
27342725 seq = dmatrix ("seq" )
27352726 f1 = dscalar ("f1" )
@@ -2743,7 +2734,12 @@ def scanStep(prev, seq, f1):
27432734 function (inputs = [to_scan , seq , f1 ], outputs = scanned , allow_input_downcast = True )
27442735
27452736 t_grad = grad (scanned .sum (), wrt = [to_scan , f1 ], consider_constant = [seq ])
2746- function (inputs = [to_scan , seq , f1 ], outputs = t_grad , allow_input_downcast = True )
2737+ benchmark (
2738+ function ,
2739+ inputs = [to_scan , seq , f1 ],
2740+ outputs = t_grad ,
2741+ allow_input_downcast = True ,
2742+ )
27472743
27482744 def caching_nsteps_by_scan_op (self ):
27492745 W = matrix ("weights" )
@@ -3060,7 +3056,7 @@ def inner_fn(tap_m3, tap_m2, tap_m1):
30603056 utt .assert_allclose (outputs , expected_outputs )
30613057
30623058 @pytest .mark .slow
3063- def test_hessian_bug_grad_grad_two_scans (self ):
3059+ def test_hessian_bug_grad_grad_two_scans (self , benchmark ):
30643060 # Bug reported by Bitton Tenessi
30653061 # NOTE : The test to reproduce the bug reported by Bitton Tenessi
30663062 # was modified from its original version to be faster to run.
@@ -3094,7 +3090,7 @@ def loss_inner(sum_inner, W):
30943090 H = hessian (cost , W )
30953091 print ("." , file = sys .stderr )
30963092 f = function ([W , n_steps ], H )
3097- f ( np .ones ((8 ,), dtype = "float32" ), 1 )
3093+ benchmark ( f , np .ones ((8 ,), dtype = "float32" ), 1 )
30983094
30993095 def test_grad_connectivity_matrix (self ):
31003096 def inner_fn (x_tm1 , y_tm1 , z_tm1 ):
@@ -3710,7 +3706,7 @@ def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
37103706 utt .assert_allclose (pytensor_x , v_x )
37113707 utt .assert_allclose (pytensor_y , v_y )
37123708
3713- def test_multiple_outs_taps (self ):
3709+ def test_multiple_outs_taps (self , benchmark ):
37143710 l = 5
37153711 rng = np .random .default_rng (utt .fetch_seed ())
37163712
@@ -3805,6 +3801,8 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
38053801 np .testing .assert_almost_equal (res [1 ], ny1 )
38063802 np .testing .assert_almost_equal (res [2 ], ny2 )
38073803
3804+ benchmark (f , v_u1 , v_u2 , v_x0 , v_y0 , vW_in1 )
3805+
38083806 def _grad_mout_helper (self , n_iters , mode ):
38093807 rng = np .random .default_rng (utt .fetch_seed ())
38103808 n_hid = 3
0 commit comments