1414from pytensor .graph import (
1515 Constant ,
1616 FunctionGraph ,
17+ Op ,
1718 RewriteDatabaseQuery ,
1819 Type ,
1920 rewrite_graph ,
2324from pytensor .printing import debugprint
2425from pytensor .tensor import (
2526 add ,
27+ dvector ,
2628 exp ,
2729 iscalar ,
2830 iscalars ,
3941from pytensor .tensor .basic import MakeVector , concatenate , expand_dims , make_vector
4042from pytensor .tensor .blas import Dot22 , Gemv
4143from pytensor .tensor .blas_c import CGemv
44+ from pytensor .tensor .blockwise import Blockwise
4245from pytensor .tensor .elemwise import DimShuffle , Elemwise
4346from pytensor .tensor .math import sum as pt_sum
4447from pytensor .tensor .rewriting .subtensor_lift import (
4548 local_subtensor_make_vector ,
46- local_subtensor_of_elemwise ,
49+ local_subtensor_of_batch_dims ,
4750 local_subtensor_shape_constant ,
4851)
4952from pytensor .tensor .shape import SpecifyShape , _shape
6063NO_OPTIMIZATION_MODE = Mode (linker = "py" , optimizer = None )
6164
6265
63- class TestLocalSubtensorOfElemwise :
66+ class TestLocalSubtensorOfBatchDims :
6467 def test_unary_multiple_clients (self ):
6568 # as test0, but we reuse the output of the elemwise
6669 # So we should not lift the subtensor
@@ -146,7 +149,7 @@ def test_multinary_multiple_clients(self):
146149 ),
147150 ],
148151 )
149- def test_local_subtensor_of_elemwise (self , original_fn , expected_fn ):
152+ def test_elemwise (self , original_fn , expected_fn ):
150153 rng = np .random .default_rng (257 )
151154 x = pt .matrix ("x" , shape = (5 , 3 ))
152155 y = pt .matrix ("y" , shape = (5 , 3 ))
@@ -165,19 +168,56 @@ def test_local_subtensor_of_elemwise(self, original_fn, expected_fn):
165168 out .eval ({x : x_test , y : y_test }, ** eval_kwargs ),
166169 )
167170
168- def test_local_subtensor_of_elemwise_multiple_clients (self ):
171+ def test_elemwise_multiple_clients (self ):
169172 x = pt .matrix ("x" , shape = (5 , 3 ))
170173 y = pt .matrix ("y" , shape = (5 , 3 ))
171174 out1 = add (x , y )
172175 out2 = out1 [0 ]
173176
174177 # Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
175178 fgraph = FunctionGraph ([x , y ], [out1 , out2 ], clone = False )
176- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is None
179+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is None
177180
178181 # Otherwise it should work
179182 fgraph .remove_output (0 )
180- assert local_subtensor_of_elemwise .transform (fgraph , out2 .owner ) is not None
183+ assert local_subtensor_of_batch_dims .transform (fgraph , out2 .owner ) is not None
184+
185+ def test_blockwise (self ):
186+ class CoreTestOp (Op ):
187+ itypes = [dvector , dvector ]
188+ otypes = [dvector ]
189+
190+ def perform (self , node , inputs , output_storage ):
191+ output_storage [0 ][0 ] = np .convolve (* inputs , mode = "valid" )
192+
193+ core_test_op = CoreTestOp ()
194+ block_test_op = Blockwise (core_test_op , signature = "(a),(b)->(c)" )
195+
196+ x = tensor3 ("x" , shape = (7 , 5 , 11 ), dtype = "float64" )
197+ y = tensor ("y" , shape = (7 , 33 ), dtype = "float64" )
198+ out = block_test_op (x , y [:, None , :])
199+ assert isinstance (out .owner .op , Blockwise )
200+
201+ out_sliced = out [2 :][:, 3 :]
202+ rewritten_out_sliced = rewrite_graph (out_sliced )
203+ expected_out_sliced = block_test_op (x [2 :, 3 :], y [2 :][:, None , :])
204+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
205+
206+ rng = np .random .default_rng (191 )
207+ x_test = rng .normal (size = x .type .shape ).astype (x .type .dtype )
208+ y_test = rng .normal (size = y .type .shape ).astype (y .type .dtype )
209+ np .testing .assert_allclose (
210+ rewritten_out_sliced .eval (
211+ {x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE
212+ ),
213+ out_sliced .eval ({x : x_test , y : y_test }, mode = NO_OPTIMIZATION_MODE ),
214+ )
215+
216+ # Check slice on core dims
217+ out_sliced = out [2 :][:, 0 ][:, 4 :]
218+ rewritten_out_sliced = rewrite_graph (out_sliced )
219+ expected_out_sliced = block_test_op (x [2 :, 0 ], y [2 :])[:, 4 :]
220+ assert equal_computations ([rewritten_out_sliced ], [expected_out_sliced ])
181221
182222
183223def test_local_subtensor_of_dot ():
0 commit comments