99from pytensor .compile .mode import get_default_mode
1010from pytensor .configdefaults import config
1111from pytensor .gradient import grad , jacobian
12- from pytensor .graph .basic import Constant , equal_computations
12+ from pytensor .graph .basic import Constant , ancestors , equal_computations
1313from pytensor .graph .fg import FunctionGraph
1414from pytensor .graph .replace import clone_replace
1515from pytensor .scan .op import Scan
1616from pytensor .scan .rewriting import ScanInplaceOptimizer , ScanMerge
1717from pytensor .scan .utils import until
1818from pytensor .tensor import stack
19+ from pytensor .tensor .basic import AllocEmpty
1920from pytensor .tensor .blas import Dot22
2021from pytensor .tensor .elemwise import Elemwise
2122from pytensor .tensor .math import Dot , dot , sigmoid , tanh
@@ -1207,7 +1208,7 @@ def test_inplace3(self):
12071208
12081209
12091210class TestSaveMem :
1210- mode = get_default_mode ().including ("scan_save_mem" )
1211+ mode = get_default_mode ().including ("scan_save_mem" ). excluding ( "scan_pushout" )
12111212
12121213 def test_save_mem (self ):
12131214 rng = np .random .default_rng (utt .fetch_seed ())
@@ -1371,7 +1372,7 @@ def test_save_mem_cannot_reduce_constant_number_of_steps(self):
13711372 )
13721373
13731374 def test_save_mem_store_steps (self ):
1374- def f_rnn (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
1375+ def step (u_t , x1_tm1 , x1_tm3 , x2_tm1 , x3tm2 , x3_tm1 , x4_tm1 ):
13751376 return (
13761377 u_t + 1.0 ,
13771378 u_t + 2.0 ,
@@ -1388,7 +1389,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
13881389 x30 = vector ("x30" )
13891390 x40 = scalar ("x40" )
13901391 [x1 , x2 , x3 , x4 , x5 , x6 , x7 ], updates = scan (
1391- f_rnn ,
1392+ step ,
13921393 u ,
13931394 [
13941395 None ,
@@ -1404,7 +1405,7 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
14041405 go_backwards = False ,
14051406 )
14061407
1407- f2 = function (
1408+ f = function (
14081409 [u , x10 , x20 , x30 , x40 ],
14091410 [x1 [- 7 ], x2 [- 3 :- 1 ], x3 [- 6 :], x4 [- 1 ], x5 [- 1 ]],
14101411 updates = updates ,
@@ -1417,13 +1418,51 @@ def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
14171418 v_u = rng .uniform (- 5.0 , 5.0 , size = (20 ,))
14181419
14191420 # compute the output in numpy
1420- tx1 , tx2 , tx3 , tx4 , tx5 = f2 (v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1421-
1422- utt .assert_allclose (tx1 , v_u [- 7 ] + 1.0 )
1423- utt .assert_allclose (tx2 , v_u [- 3 :- 1 ] + 2.0 )
1424- utt .assert_allclose (tx3 , v_u [- 6 :] + 3.0 )
1425- utt .assert_allclose (tx4 , v_u [- 1 ] + 4.0 )
1426- utt .assert_allclose (tx5 , v_u [- 1 ] + 5.0 )
1421+ tx1 , tx2 , tx3 , tx4 , tx5 = f (v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1422+ rtol = 1e-7 if config .floatX == "float64" else 1e-6
1423+ np .testing .assert_allclose (tx1 , v_u [- 7 ] + 1.0 , rtol = rtol )
1424+ np .testing .assert_allclose (tx2 , v_u [- 3 :- 1 ] + 2.0 , rtol = rtol )
1425+ np .testing .assert_allclose (tx3 , v_u [- 6 :] + 3.0 , rtol = rtol )
1426+ np .testing .assert_allclose (tx4 , v_u [- 1 ] + 4.0 , rtol = rtol )
1427+ np .testing .assert_allclose (tx5 , v_u [- 1 ] + 5.0 , rtol = rtol )
1428+
1429+ # Confirm reduction in buffer sizes
1430+ [scan_node ] = [
1431+ node for node in f .maker .fgraph .apply_nodes if isinstance (node .op , Scan )
1432+ ]
1433+ # x6 and x7 are dropped because they are not used
1434+ [n_steps , seq , x4_buffer , x5_buffer , x1_len , x2_len , x3_len ] = scan_node .inputs
1435+ [x4_underlying_alloc ] = [
1436+ var
1437+ for var in ancestors ([x4_buffer ])
1438+ if var .owner and isinstance (var .owner .op , AllocEmpty )
1439+ ]
1440+ [x5_underlying_alloc ] = [
1441+ var
1442+ for var in ancestors ([x5_buffer ])
1443+ if var .owner and isinstance (var .owner .op , AllocEmpty )
1444+ ]
1445+ buffer_lengths = pytensor .function (
1446+ [u , x10 , x20 , x30 , x40 ],
1447+ [
1448+ x1_len ,
1449+ x2_len ,
1450+ x3_len ,
1451+ x4_underlying_alloc .shape [0 ],
1452+ x5_underlying_alloc .shape [0 ],
1453+ ],
1454+ accept_inplace = True ,
1455+ on_unused_input = "ignore" ,
1456+ allow_input_downcast = True ,
1457+ )(v_u , [0 , 0 ], 0 , [0 , 0 ], 0 )
1458+ # ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
1459+ assert [int (i ) for i in buffer_lengths ] == [
1460+ 7 , # entry -7 of a map variable is kept, we need at least that many
1461+ 3 , # entries [-3, -2] of a map variable are kept, we need at least 3
1462+ 6 , # last six entries of a map variable are kept
1463+ 2 + 1 , # last entry of a double tap variable is kept
1464+ 1 + 1 , # last entry of a single tap variable is kept
1465+ ]
14271466
14281467 def test_savemem_does_not_duplicate_number_of_scan_nodes (self ):
14291468 var = pt .ones (())
0 commit comments