@@ -72,30 +72,74 @@ function rrule(
7272end
7373
7474function rrule (
75- config:: RuleConfig{>:HasReverseMode} , :: typeof (sum), f, xs:: AbstractArray ; dims= :
76- )
77- fx_and_pullbacks = map (x-> rrule_via_ad (config, f, x), xs)
78- y = sum (first, fx_and_pullbacks; dims= dims)
75+ config:: RuleConfig{>:HasReverseMode} ,
76+ :: typeof (sum),
77+ f:: F ,
78+ xs:: AbstractArray{T} ;
79+ dims = :,
80+ ) where {F,T}
81+ project = ProjectTo (xs)
7982
80- pullbacks = last .(fx_and_pullbacks)
83+ if _uses_input_only (f, T)
84+ # Then we can compute the forward pass as usual, save nothing but `xs`:
85+ function sum_pullback_f1 (dy)
86+ dxs = broadcast (unthunk (dy), xs) do dyₖ, xᵢ
87+ ∂yₖ∂xᵢ = only (only (derivatives_given_output (nothing , f, xᵢ)))
88+ dyₖ * conj (∂yₖ∂xᵢ)
89+ end
90+ return (NoTangent (), NoTangent (), project (dxs))
91+ end
92+ return sum (f, xs; dims), sum_pullback_f1
93+ end
8194
82- project = ProjectTo (xs)
95+ # (There is an intermediate case, where `derivatives_given_output` needs to
96+ # see `f.(xs)` but we don't need the pullbacks. Not implemented at present.)
97+
98+ # In the general case, we need to save all the pullbacks:
99+ fx_and_pullbacks = map (xᵢ -> rrule_via_ad (config, f, xᵢ), xs)
100+ y = sum (first, fx_and_pullbacks; dims)
101+
102+ function sum_pullback_f2 (dy)
103+ # For arrays of arrays, we ought to protect the element against broadcasting:
104+ broadcast_dy = dims isa Colon ? Ref (unthunk (dy)) : unthunk (dy)
105+ if Base. issingletontype (F)
106+ # Then at least `f` has no gradient.
107+ # Broadcasting here gets the shape right with or without `dims` keyword.
108+ dxs = broadcast (fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
109+ unthunk (last (pbᵢ (dyₖ)))
110+ end
111+ return (NoTangent (), NoTangent (), project (dxs))
83112
84- function sum_pullback (ȳ)
85- call (f, x) = f (x)
86- # if dims is :, then need only left-handed only broadcast
87- broadcast_ȳ = dims isa Colon ? (ȳ,) : ȳ
88- f̄_and_x̄s = call .(pullbacks, broadcast_ȳ)
89- # no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
90- f̄ = if fieldcount (typeof (f)) === 0 # Then don't need to worry about derivative wrt f
91- NoTangent ()
92113 else
93- sum (first, f̄_and_x̄s)
114+ # Most general case. If `f` were stateful, we would need to reverse the order
115+ # of iteration here, but since this function makes no guarantee, even the primal
116+ # result is then ill-defined.
117+ df_and_dxs = broadcast (fx_and_pullbacks, broadcast_dy) do (_, pbᵢ), dyₖ
118+ pbᵢ (dyₖ)
119+ end
120+ df = sum (first, df_and_dxs)
121+ dxs = map (unthunk ∘ last, df_and_dxs)
122+ return (NoTangent (), df, project (dxs))
94123 end
95- x̄s = map (unthunk ∘ last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
96- return NoTangent (), f̄, project (x̄s)
97124 end
98- return y, sum_pullback
125+ return y, sum_pullback_f2
126+ end
127+
128+ """
129+ _uses_input_only(f, xT::Type)
130+
131+ Returns `true` if it can prove that `derivatives_given_output` will work using only the input
132+ of the given type. Thus there is no need to store the output `y = f(x::xT)`, allowing us to take
133+ a fast path in the `rrule` for `sum(f, xs)`.
134+
135+ Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can be inferred.
136+ The method of `derivatives_given_output` usually comes from `@scalar_rule`.
137+ """
138+ function _uses_input_only (f:: F , :: Type{xT} ) where {F,xT}
139+ gT = Core. Compiler. _return_type (derivatives_given_output, Tuple{Nothing, F, xT})
140+ # Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`:
141+ # ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),)
142+ return isconcretetype (gT) && gT <: Tuple{Tuple{Number}}
99143end
100144
101145# https://github.com/JuliaDiff/ChainRules.jl/issues/522
@@ -228,6 +272,7 @@ function ∇prod_dims(vald::Val{dims}, x, dy, y=prod(x; dims=dims)) where {dims}
228272 ∇prod_dims! (dx, vald, x, dy, y)
229273 return dx
230274end
275+ ∇prod_dims (:: Val , x, dy:: AbstractZero , y= 0 ) = dy
231276
232277function ∇prod_dims! (dx, :: Val{dims} , x, dy, y) where {dims}
233278 iters = ntuple (d -> d in dims ? tuple (:) : axes (x,d), ndims (x)) # Without Val(dims) this is a serious type instability
@@ -244,6 +289,7 @@ function ∇prod(x, dy::Number=1, y::Number=prod(x))
244289 ∇prod! (dx, x, dy, y)
245290 return dx
246291end
292+ ∇prod (x, dy:: AbstractZero , y:: Number = 0 ) = dy
247293
248294function ∇prod! (dx, x, dy:: Number = 1 , y:: Number = prod (x))
249295 numzero = iszero (y) ? count (iszero, x) : 0
@@ -326,7 +372,8 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
326372 dx = fill! (similar (x, T, axes (x)), zero (T))
327373 ∇cumprod_dim! (dx, vald, x, dy, y)
328374 return dx
329- end
375+ end
376+ ∇cumprod_dim (vald:: Val , x:: AbstractArray , dy:: AbstractZero , y= 0 ) = dy
330377
331378@inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
332379 iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
@@ -342,6 +389,7 @@ function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
342389 ∇cumprod! (dx, x, dy, y)
343390 return dx
344391end
392+ ∇cumprod (x:: AbstractVector , dy:: AbstractZero , y= 0 ) = dy
345393
346394@inline function ∇cumprod! (dx:: AbstractVector , x:: AbstractVector , dy, y)
347395 lo, hi = firstindex (x), lastindex (x)
0 commit comments