@@ -50,8 +50,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
5050 raise AssertionError ()
5151
5252
53- def OpKeyPatternNodeRewriter (p1 , p2 , ign = False ):
54- return OpKeyGraphRewriter (PatternNodeRewriter (p1 , p2 ), ignore_newtrees = ign )
53+ def OpKeyPatternNodeRewriter (p1 , p2 , allow_multiple_clients = False , ign = False ):
54+ return OpKeyGraphRewriter (
55+ PatternNodeRewriter (p1 , p2 , allow_multiple_clients = allow_multiple_clients ),
56+ ignore_newtrees = ign ,
57+ )
5558
5659
5760def WalkingPatternNodeRewriter (p1 , p2 , ign = True ):
@@ -207,13 +210,70 @@ def constraint(r):
207210 assert str (g ) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
208211
209212 def test_allow_multiple_clients (self ):
210- x , y , z = MyVariable ("x" ), MyVariable ("y" ), MyVariable ("z" )
211- e0 = op1 (x , y )
212- # `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
213- e = op3 (op4 (e0 ), e0 )
214- g = FunctionGraph ([x , y , z ], [e ])
215- OpKeyPatternNodeRewriter ((op4 , (op1 , "x" , "y" )), (op3 , "x" , "y" )).rewrite (g )
216- assert str (g ) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
213+ x , y , z = inputs = MyVariable ("x" ), MyVariable ("y" ), MyVariable ("z" )
214+ w = op1 (x , y )
215+ # `w` has multiple clients (i.e. the `op4` and `op3` nodes)
216+ e = op3 (op4 (w ), w )
217+
218+ # By default, allow_multiple_clients is False
219+ # So the replacement should fail
220+ outputs = [e ]
221+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
222+ OpKeyPatternNodeRewriter (
223+ (op4 , (op1 , "x" , "y" )),
224+ (op3 , "x" , "y" ),
225+ ).rewrite (g )
226+ assert equal_computations (g .outputs , outputs )
227+
228+ # Now it should be fine
229+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
230+ OpKeyPatternNodeRewriter (
231+ (op4 , (op1 , "x" , "y" )),
232+ (op3 , "x" , "y" ),
233+ allow_multiple_clients = True ,
234+ ).rewrite (g )
235+ assert equal_computations (g .outputs , [op3 (op3 (x , y ), w )])
236+
237+ # The fact that the inputs of the pattern have multiple clients should not matter
238+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
239+ OpKeyPatternNodeRewriter (
240+ (op3 , (op4 , "w" ), "w" ),
241+ (op3 , "w" , "w" ),
242+ allow_multiple_clients = False ,
243+ ).rewrite (g )
244+ assert equal_computations (g .outputs , [op3 (w , w )])
245+
246+ # The fact that are multiple clients above the inputs of the pattern should not matter
247+ v = op4 (e )
248+ e1 = op4 (v )
249+ e2 = op1 (x , x ) # Irrelevant reuse of x that should not block rewrite either
250+ e3 = op1 (v , v ) # Relevant reuse of v that should block rewrite
251+
252+ outputs = [e1 , e2 ]
253+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
254+ OpKeyPatternNodeRewriter (
255+ (op4 , (op4 , "e" )),
256+ "e" ,
257+ allow_multiple_clients = False ,
258+ ).rewrite (g )
259+ assert equal_computations (g .outputs , [e , e2 ])
260+
261+ outputs = [e1 , e3 ]
262+ g = FunctionGraph ([x , y , z ], outputs , copy_inputs = False )
263+ OpKeyPatternNodeRewriter (
264+ (op4 , (op4 , "e" )),
265+ "e" ,
266+ allow_multiple_clients = False ,
267+ ).rewrite (g )
268+ assert equal_computations (g .outputs , outputs )
269+
270+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
271+ OpKeyPatternNodeRewriter (
272+ (op4 , (op4 , "e" )),
273+ "e" ,
274+ allow_multiple_clients = True ,
275+ ).rewrite (g )
276+ assert equal_computations (g .outputs , [e , e3 ])
217277
218278 def test_eq (self ):
219279 # replacing the whole graph
0 commit comments