@@ -218,3 +218,96 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
218218 dx[i_zero] += p_rest * dy
219219 return
220220end
221+
222+ # ####
223+ # #### `cumprod`
224+ # ####
225+
226+ function rrule (:: typeof (cumprod), x:: AbstractVector{<:Real} ; dims:: Integer = 1 )
227+ y = cumprod (x; dims= dims) # does nothing unless dims == 1
228+ project_x = ProjectTo (x)
229+ function cumprod_pullback_1 (dy_raw)
230+ dy = unthunk (dy_raw)
231+ dx_thunk = InplaceableThunk (
232+ dx -> if dims == 1
233+ ∇cumprod! (dx, x, dy, y)
234+ else
235+ dx .+ = dy
236+ end
237+ ,
238+ @thunk project_x (if dims == 1
239+ ∇cumprod (x, dy, y)
240+ else
241+ dy
242+ end )
243+ )
244+ return (NoTangent (), dx_thunk)
245+ end
246+ return y, cumprod_pullback_1
247+ end
248+
249+ function rrule (:: typeof (cumprod), x:: AbstractArray{<:Real} ; dims:: Integer )
250+ y = cumprod (x; dims= dims)
251+ project_x = ProjectTo (x)
252+ function cumprod_pullback_2 (dy_raw)
253+ dy = unthunk (dy_raw)
254+ dx_thunk = InplaceableThunk (
255+ dx -> if dims <= ndims (x)
256+ vald = Val (Int (dims))
257+ ∇cumprod_dim! (dx, vald, x, dy, y)
258+ else
259+ dx .+ = dy
260+ end
261+ ,
262+ @thunk project_x (if dims <= ndims (x)
263+ vald = Val (Int (dims))
264+ ∇cumprod_dim (vald, x, dy, y)
265+ else
266+ dy
267+ end )
268+ )
269+ return (NoTangent (), dx_thunk)
270+ end
271+ return y, cumprod_pullback_2
272+ end
273+
274+ function ∇cumprod_dim (vald:: Val{dim} , x:: AbstractArray , dy= fill! (zero (x),1 ), y= cumprod (x; dims= dim)) where {dim}
275+ T = promote_type (eltype (x), eltype (dy))
276+ dx = fill! (similar (x, T, axes (x)), zero (T))
277+ ∇cumprod_dim! (dx, vald, x, dy, y)
278+ return dx
279+ end
280+
281+ @inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
282+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
283+ for ind in Iterators. product (iters... )
284+ @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
285+ end
286+ return dx
287+ end
288+
289+ function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
290+ T = promote_type (eltype (x), eltype (dy)) # really needs to allow dy * y / x
291+ dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
292+ ∇cumprod! (dx, x, dy, y)
293+ return dx
294+ end
295+
296+ @inline function ∇cumprod! (dx:: AbstractVector , x:: AbstractVector , dy, y)
297+ lo, hi = firstindex (x), lastindex (x)
298+ z = something (findfirst (iszero, x), hi+ 1 )
299+ acc = zero (eltype (dy))
300+ @inbounds for k in z- 1 : - 1 : lo
301+ acc += y[k] * dy[k]
302+ dx[k] += acc / x[k]
303+ end
304+ @inbounds if z != hi+ 1
305+ yk = z== 1 ? one (eltype (y)) : y[z- 1 ] # will be prod(x[j] for j=1:k if j!=z)
306+ dx[z] += yk * dy[z]
307+ for k in (z+ 1 ): hi
308+ yk *= x[k]
309+ dx[z] += yk * dy[k]
310+ end
311+ end
312+ return dx
313+ end
0 commit comments