@@ -310,6 +310,41 @@ def test_rop(self, cls_ofg):
310310 dvval2 = fn (xval , Wval , duval )
311311 np .testing .assert_array_almost_equal (dvval2 , dvval , 4 )
312312
313+ def test_rop_multiple_outputs (self ):
314+ a = vector ()
315+ M = matrix ()
316+ b = dot (a , M )
317+ op_matmul = OpFromGraph ([a , M ], [b , - b ])
318+
319+ x = vector ()
320+ W = matrix ()
321+ du = vector ()
322+
323+ xval = np .random .random ((16 ,)).astype (config .floatX )
324+ Wval = np .random .random ((16 , 16 )).astype (config .floatX )
325+ duval = np .random .random ((16 ,)).astype (config .floatX )
326+
327+ y = op_matmul (x , W )[0 ]
328+ dv = Rop (y , x , du )
329+ fn = function ([x , W , du ], dv )
330+ result_dvval = fn (xval , Wval , duval )
331+ expected_dvval = np .dot (duval , Wval )
332+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
333+
334+ y = op_matmul (x , W )[1 ]
335+ dv = Rop (y , x , du )
336+ fn = function ([x , W , du ], dv )
337+ result_dvval = fn (xval , Wval , duval )
338+ expected_dvval = - np .dot (duval , Wval )
339+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
340+
341+ y = pt .add (* op_matmul (x , W ))
342+ dv = Rop (y , x , du )
343+ fn = function ([x , W , du ], dv )
344+ result_dvval = fn (xval , Wval , duval )
345+ expected_dvval = np .zeros_like (np .dot (duval , Wval ))
346+ np .testing .assert_array_almost_equal (result_dvval , expected_dvval , 4 )
347+
313348 @pytest .mark .parametrize (
314349 "cls_ofg" , [OpFromGraph , partial (OpFromGraph , inline = True )]
315350 )
0 commit comments