22Base. sum (xs:: AbstractArray , weights:: AbstractArray ) = dot (xs, weights)
33struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
44
5- @testset " Maps and Reductions" begin
5+ const CFG = ChainRulesTestUtils. ADviaRuleConfig ()
6+
7+ @testset " Reductions" begin
8+ @testset " sum(::Tuple)" begin
9+ test_frule (sum, Tuple (rand (5 )))
10+ end
611 @testset " sum(x; dims=$dims )" for dims in (:, 2 , (1 ,3 ))
712 # Forward
813 test_frule (sum, rand (5 ); fkwargs= (;dims= dims))
@@ -79,12 +84,11 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
7984 test_rrule (sum, inv, transpose (view (x, 1 , :)))
8085
8186 # Make sure we preserve type for StaticArrays
82- ADviaRuleConfig = ChainRulesTestUtils. ADviaRuleConfig
83- _, pb = rrule (ADviaRuleConfig (), sum, abs, @SVector [1.0 , - 3.0 ])
87+ _, pb = rrule (CFG, sum, abs, @SVector [1.0 , - 3.0 ])
8488 @test pb (1.0 ) isa Tuple{NoTangent, NoTangent, SVector{2 , Float64}}
8589
8690 # make sure we preserve type for Diagonal
87- _, pb = rrule (ADviaRuleConfig () , sum, abs, Diagonal ([1.0 , - 3.0 ]))
91+ _, pb = rrule (CFG , sum, abs, Diagonal ([1.0 , - 3.0 ]))
8892 @test pb (1.0 )[3 ] isa Diagonal
8993
9094 # Boolean -- via @non_differentiable, test that this isn't ambiguous
@@ -173,7 +177,64 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
173177 @test unthunk (rrule (prod, v)[2 ](1f0 )[2 ]) == zeros (4 )
174178 test_rrule (prod, v)
175179 end
176- end # prod
180+ end # prod
181+
182+ @testset " foldl(f, ::Array)" begin
183+ # Simple
184+ y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init= 1 )
185+ @test y1 == 6
186+ b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
187+
188+ y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
189+ @test y2 == 0
190+ b2 (8 ) == (NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
191+
192+ # Test execution order
193+ c5 = Counter ()
194+ y5, b5 = rrule (CFG, foldl, c5, [5 , 7 , 11 ])
195+ @test c5 == Counter (2 )
196+ @test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
197+ @test b5 (1 ) == (NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
198+ @test c5 == Counter (42 )
199+
200+ c6 = Counter ()
201+ y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init= 3 )
202+ @test c6 == Counter (3 )
203+ @test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
204+ @test b6 (1 ) == (NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
205+ @test c6 == Counter (63 )
206+
207+ # Test gradient of function
208+ y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
209+ @test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
210+ @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
211+
212+ y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init= 3 )
213+ @test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
214+ @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
215+ # To find these numbers:
216+ # ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217+ # ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218+
219+ # Finite differencing
220+ test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
221+ test_rrule (foldl, * , rand (ComplexF64,3 ,4 ); fkwargs= (; init= rand (ComplexF64)))
222+ test_rrule (foldl, + , rand (ComplexF64,7 ); fkwargs= (; init= rand (ComplexF64)))
223+ test_rrule (foldl, max, rand (3 ); fkwargs= (; init= 999 ))
224+ end
225+ VERSION >= v " 1.5" && @testset " foldl(f, ::Tuple)" begin
226+ y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init= 1 )
227+ @test y1 == 6
228+ b1 (7 ) == (NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ))
229+
230+ y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
231+ @test y2 == 0
232+ b2 (8 ) == (NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ))
233+
234+ # Finite differencing
235+ test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
236+ test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
237+ end
177238end
178239
179240@testset " Accumulations" begin
@@ -188,14 +249,14 @@ end
188249 @testset " higher dimensions, dims=$dims " for dims in (1 ,2 ,3 )
189250 m = round .(10 .* randn (4 ,5 ), sigdigits= 3 )
190251 test_rrule (cumprod, m; fkwargs= (;dims= dims), atol= 0.1 )
191- m[2 ,2 ] = 0
192- m[2 ,4 ] = 0
252+ m[2 , 2 ] = 0
253+ m[2 , 4 ] = 0
193254 test_rrule (cumprod, m; fkwargs= (;dims= dims))
194255
195256 t = round .(10 .* randn (3 ,3 ,3 ), sigdigits= 3 )
196257 test_rrule (cumprod, t; fkwargs= (;dims= dims))
197- t[2 ,2 , 2 ] = 0
198- t[2 ,3 , 3 ] = 0
258+ t[2 , 2 , 2 ] = 0
259+ t[2 , 3 , 3 ] = 0
199260 test_rrule (cumprod, t; fkwargs= (;dims= dims))
200261 end
201262
211272 back = rrule (cumprod, Diagonal ([1 , 2 ]); dims= 1 )[2 ]
212273 @test unthunk (back (fill (0.5 , 2 , 2 ))[2 ]) ≈ [1 / 2 0 ; 0 0 ] # ProjectTo'd to Diagonal now
213274 end
275+ end # cumprod
276+
277+ @testset " accumulate(f, ::Array)" begin
278+ # Simple
279+ y1, b1 = rrule (CFG, accumulate, * , [1 , 2 , 3 , 4 ]; init= 1 )
280+ @test y1 == [1 , 2 , 6 , 24 ]
281+ @test b1 ([1 , 1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [33 , 16 , 10 , 6 ])
282+
283+ if VERSION >= v " 1.5"
284+ y2, b2 = rrule (CFG, accumulate, / , [1 2 ; 3 4 ])
285+ @test y2 ≈ accumulate (/ , [1 2 ; 3 4 ])
286+ @test b2 (ones (2 , 2 ))[3 ] ≈ [1.5416666 - 0.104166664 ; - 0.18055555 - 0.010416667 ] atol= 1e-6
287+ end
288+
289+ # Test execution order
290+ c3 = Counter ()
291+ y3, b3 = rrule (CFG, accumulate, c3, [5 , 7 , 11 ]; init= 3 )
292+ @test c3 == Counter (3 )
293+ @test y3 == [8 , 30 , 123 ] == accumulate (Counter (), [5 , 7 , 11 ]; init= 3 )
294+ @test b3 ([1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [29169 , 602 , 23 ]) # the 23 is clear!
295+
296+ c4 = Counter ()
297+ y4, b4 = rrule (CFG, accumulate, c4, [5 , 7 , 11 ])
298+ @test c4 == Counter (2 )
299+ @test y4 == [5 , (5 + 7 )* 1 , ((5 + 7 )* 1 + 11 )* 2 ] == accumulate (Counter (), [5 , 7 , 11 ])
300+ @test b4 ([1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [417 , 42 * (1 + 12 ), 22 ])
301+
302+ # Test gradient of function
303+ y7, b7 = rrule (CFG, accumulate, Multiplier (3 ), [5 , 7 , 11 ])
304+ @test y7 == accumulate ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
305+ @test b7 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2345 ,), [715 , 510 , 315 ])
306+
307+ y8, b8 = rrule (CFG, accumulate, Multiplier (13 ), [5 , 7 , 11 ], init= 3 )
308+ @test y8 == [195 , 17745 , 2537535 ] == accumulate ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
309+ @test b8 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 588330 ,), [511095 , 365040 , 230685 ])
310+ # To find these numbers:
311+ # ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
312+ # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
313+
314+ # Finite differencing
315+ test_rrule (accumulate, * , randn (5 ); fkwargs= (; init= rand ()))
316+ if VERSION >= v " 1.5"
317+ test_rrule (accumulate, / , 1 .+ rand (3 , 4 ))
318+ test_rrule (accumulate, ^ , 1 .+ rand (2 , 3 ); fkwargs= (; init= rand ()))
319+ end
320+ end
321+ VERSION >= v " 1.5" && @testset " accumulate(f, ::Tuple)" begin
322+ # Simple
323+ y1, b1 = rrule (CFG, accumulate, * , (1 , 2 , 3 , 4 ); init= 1 )
324+ @test y1 == (1 , 2 , 6 , 24 )
325+ @test b1 ((1 , 1 , 1 , 1 )) == (NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (33 , 16 , 10 , 6 ))
326+
327+ # Finite differencing
328+ test_rrule (accumulate, * , Tuple (randn (5 )); fkwargs= (; init= rand ()))
329+ test_rrule (accumulate, / , Tuple (1 .+ rand (5 )); check_inferred= false )
214330 end
215331end
0 commit comments