1212from pytensor .compile .mode import Mode , get_default_mode
1313from pytensor .configdefaults import config
1414from pytensor .gradient import grad
15- from pytensor .graph .basic import Constant
15+ from pytensor .graph .basic import Constant , equal_computations
1616from pytensor .graph .fg import FunctionGraph
1717from pytensor .graph .rewriting .basic import check_stack_trace , out2in
1818from pytensor .graph .rewriting .db import RewriteDatabaseQuery
@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
8686
8787class TestDimshuffleLift :
8888 def test_double_transpose (self ):
89- x , y , z = inputs ()
89+ x , * _ = inputs ()
9090 e = ds (ds (x , (1 , 0 )), (1 , 0 ))
91- g = FunctionGraph ([x ], [e ])
92- # TODO FIXME: Construct these graphs and compare them.
93- assert (
94- str (g ) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
95- )
91+ g = FunctionGraph ([x ], [e ], clone = False )
92+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
9693 dimshuffle_lift .rewrite (g )
97- assert str ( g ) == "FunctionGraph(x)"
94+ assert g . outputs [ 0 ] is x
9895 # no need to check_stack_trace as graph is supposed to be empty
9996
10097 def test_merge2 (self ):
101- x , y , z = inputs ()
98+ x , * _ = inputs ()
10299 e = ds (ds (x , (1 , "x" , 0 )), (2 , 0 , "x" , 1 ))
103- g = FunctionGraph ([x ], [e ])
104- # TODO FIXME: Construct these graphs and compare them.
105- assert (
106- str (g )
107- == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
108- ), str (g )
100+ g = FunctionGraph ([x ], [e ], clone = False )
101+ assert len (g .apply_nodes ) == 2
109102 dimshuffle_lift .rewrite (g )
110- assert str ( g ) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x)) " , str ( g )
103+ assert equal_computations ( g . outputs , [ x . dimshuffle ( 0 , 1 , "x " , "x" )] )
111104 # Check stacktrace was copied over correctly after rewrite was applied
112105 assert check_stack_trace (g , ops_to_check = "all" )
113106
114107 def test_elim3 (self ):
115108 x , y , z = inputs ()
116109 e = ds (ds (ds (x , (0 , "x" , 1 )), (2 , 0 , "x" , 1 )), (1 , 0 ))
117- g = FunctionGraph ([x ], [e ])
118- # TODO FIXME: Construct these graphs and compare them.
119- assert str (g ) == (
120- "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
121- "(InplaceDimShuffle{0,x,1}(x))))"
122- ), str (g )
110+ g = FunctionGraph ([x ], [e ], clone = False )
111+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
123112 dimshuffle_lift .rewrite (g )
124- assert str ( g ) == "FunctionGraph(x)" , str ( g )
113+ assert g . outputs [ 0 ] is x
125114 # no need to check_stack_trace as graph is supposed to be empty
126115
127116 def test_lift (self ):
128117 x , y , z = inputs ([False ] * 1 , [False ] * 2 , [False ] * 3 )
129118 e = x + y + z
130- g = FunctionGraph ([x , y , z ], [e ])
131-
132- # TODO FIXME: Construct these graphs and compare them.
133- # It does not really matter if the DimShuffles are inplace
134- # or not.
135- init_str_g_inplace = (
136- "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
137- "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
138- )
139- init_str_g_noinplace = (
140- "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
141- "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
142- )
143- assert str (g ) in (init_str_g_inplace , init_str_g_noinplace ), str (g )
144-
145- rewrite_str_g_inplace = (
146- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
147- "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
148- )
149- rewrite_str_g_noinplace = (
150- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
151- "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
152- )
119+ g = FunctionGraph ([x , y , z ], [e ], clone = False )
153120 dimshuffle_lift .rewrite (g )
154- assert str (g ) in (rewrite_str_g_inplace , rewrite_str_g_noinplace ), str (g )
121+ assert equal_computations (
122+ g .outputs ,
123+ [(x .dimshuffle ("x" , "x" , 0 ) + y .dimshuffle ("x" , 0 , 1 )) + z ],
124+ )
155125 # Check stacktrace was copied over correctly after rewrite was applied
156126 assert check_stack_trace (g , ops_to_check = "all" )
157127
158128 def test_recursive_lift (self ):
159- v = vector (dtype = "float64" )
160- m = matrix (dtype = "float64" )
129+ v = vector ("v" , dtype = "float64" )
130+ m = matrix ("m" , dtype = "float64" )
161131 out = ((v + 42 ) * (m + 84 )).T
162- g = FunctionGraph ([v , m ], [out ])
163- # TODO FIXME: Construct these graphs and compare them.
164- init_str_g = (
165- "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
166- "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
167- "(<TensorType(float64, (?,))>, "
168- "InplaceDimShuffle{x}(TensorConstant{42}))), "
169- "Elemwise{add,no_inplace}"
170- "(<TensorType(float64, (?, ?))>, "
171- "InplaceDimShuffle{x,x}(TensorConstant{84})))))"
172- )
173- assert str (g ) == init_str_g
174- new_out = local_dimshuffle_lift .transform (g , g .outputs [0 ].owner )[0 ]
175- new_g = FunctionGraph (g .inputs , [new_out ])
176- rewrite_str_g = (
177- "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
178- "(InplaceDimShuffle{0,x}(<TensorType(float64, (?,))>), "
179- "InplaceDimShuffle{x,x}(TensorConstant{42})), "
180- "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
181- "(<TensorType(float64, (?, ?))>), "
182- "InplaceDimShuffle{x,x}(TensorConstant{84}))))"
132+ g = FunctionGraph ([v , m ], [out ], clone = False )
133+ new_out = local_dimshuffle_lift .transform (g , g .outputs [0 ].owner )
134+ assert equal_computations (
135+ new_out ,
136+ [(v .dimshuffle (0 , "x" ) + 42 ) * (m .T + 84 )],
183137 )
184- assert str (new_g ) == rewrite_str_g
185138 # Check stacktrace was copied over correctly after rewrite was applied
139+ new_g = FunctionGraph (g .inputs , new_out , clone = False )
186140 assert check_stack_trace (new_g , ops_to_check = "all" )
187141
188142 def test_useless_dimshuffle (self ):
189- x , _ , _ = inputs ()
143+ x , * _ = inputs ()
190144 e = ds (x , (0 , 1 ))
191- g = FunctionGraph ([x ], [e ])
192- # TODO FIXME: Construct these graphs and compare them.
193- assert str (g ) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
145+ g = FunctionGraph ([x ], [e ], clone = False )
146+ assert isinstance (g .outputs [0 ].owner .op , DimShuffle )
194147 dimshuffle_lift .rewrite (g )
195- assert str ( g ) == "FunctionGraph(x)"
148+ assert g . outputs [ 0 ] is x
196149 # Check stacktrace was copied over correctly after rewrite was applied
197150 assert hasattr (g .outputs [0 ].tag , "trace" )
198151
@@ -203,17 +156,10 @@ def test_dimshuffle_on_broadcastable(self):
203156 ds_y = ds (y , (2 , 1 , 0 )) # useless
204157 ds_z = ds (z , (2 , 1 , 0 )) # useful
205158 ds_u = ds (u , ("x" )) # useful
206- g = FunctionGraph ([x , y , z , u ], [ds_x , ds_y , ds_z , ds_u ])
207- # TODO FIXME: Construct these graphs and compare them.
208- assert (
209- str (g )
210- == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
211- )
159+ g = FunctionGraph ([x , y , z , u ], [ds_x , ds_y , ds_z , ds_u ], clone = False )
160+ assert len (g .apply_nodes ) == 4
212161 dimshuffle_lift .rewrite (g )
213- assert (
214- str (g )
215- == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
216- )
162+ assert equal_computations (g .outputs , [x , y , z .T , u .dimshuffle ("x" )])
217163 # Check stacktrace was copied over correctly after rewrite was applied
218164 assert hasattr (g .outputs [0 ].tag , "trace" )
219165
@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
237183 reshape_dimshuffle_row ,
238184 reshape_dimshuffle_col ,
239185 ],
186+ clone = False ,
240187 )
241-
242- # TODO FIXME: Construct these graphs and compare them.
243- assert str (g ) == (
244- "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
245- "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
246- "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
247- "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
248- )
188+ assert len (g .apply_nodes ) == 4 * 3
249189 useless_dimshuffle_in_reshape = out2in (local_useless_dimshuffle_in_reshape )
250190 useless_dimshuffle_in_reshape .rewrite (g )
251- assert str (g ) == (
252- "FunctionGraph(Reshape{1}(vector, Shape(vector)), "
253- "Reshape{2}(mat, Shape(mat)), "
254- "Reshape{2}(row, Shape(row)), "
255- "Reshape{2}(col, Shape(col)))"
191+ assert equal_computations (
192+ g .outputs ,
193+ [
194+ reshape (vec , vec .shape ),
195+ reshape (mat , mat .shape ),
196+ reshape (row , row .shape ),
197+ reshape (col , col .shape ),
198+ ],
256199 )
257-
258200 # Check stacktrace was copied over correctly after rewrite was applied
259201 assert check_stack_trace (g , ops_to_check = "all" )
260202
261203 # Check that the rewrite does not get applied when the order
262204 # of dimensions has changed.
263205 reshape_dimshuffle_mat2 = reshape (mat .dimshuffle ("x" , 1 , "x" , 0 ), mat .shape )
264- h = FunctionGraph ([mat ], [reshape_dimshuffle_mat2 ])
265- str_h = str ( h )
206+ h = FunctionGraph ([mat ], [reshape_dimshuffle_mat2 ], clone = False )
207+ assert len ( h . apply_nodes ) == 3
266208 useless_dimshuffle_in_reshape .rewrite (h )
267- assert str (h ) == str_h
209+ assert equal_computations (
210+ h .outputs , [reshape (mat .dimshuffle ("x" , 1 , "x" , 0 ), mat .shape )]
211+ )
268212
269213
270214class TestFusion :
0 commit comments