@@ -342,3 +342,165 @@ function rrule(::typeof(fill), x::Any, dims...)
342342 fill_pullback (Ȳ) = (NoTangent (), project (sum (Ȳ)), nots... )
343343 return fill (x, dims... ), fill_pullback
344344end
345+
346+ # ####
347+ # #### `findmax`, `maximum`, etc.
348+ # ####
349+
350+ for findm in (:findmin , :findmax )
351+ findm_pullback = Symbol (findm, :_pullback )
352+
353+ @eval function frule ((_, xdot), :: typeof ($ findm), x; dims= :)
354+ y, ind = $ findm (x; dims= dims)
355+ return (y, ind), Tangent {typeof((y, ind))} (xdot[ind], NoTangent ())
356+ end
357+
358+ @eval function rrule (:: typeof ($ findm), x:: AbstractArray ; dims= :)
359+ y, ind = $ findm (x; dims= dims)
360+ project = ProjectTo (x)
361+ # This pullback is a lot like the one for getindex. Ideally they would probably be combined?
362+ function $findm_pullback ((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
363+ dy isa AbstractZero && return (NoTangent (), NoTangent ())
364+ x_thunk = @thunk project (_zerolike_writeat (x, unthunk (dy), dims, ind))
365+ x_ithunk = InplaceableThunk (x_thunk) do dx
366+ if dims isa Colon
367+ view (dx, ind) .= view (dx, ind) .+ Ref (unthunk (dy))
368+ else
369+ view (dx, ind) .= view (dx, ind) .+ unthunk (dy) # this could be .+=, but not on Julia 1.0
370+ end
371+ dx
372+ end
373+ return (NoTangent (), x_ithunk)
374+ end
375+ return (y, ind), $ findm_pullback
376+ end
377+ end
378+
379+ # This function is roughly `setindex!(zero(x), dy, inds...)`:
380+
381+ function _zerolike_writeat (x:: AbstractArray{<:Number} , dy, dims, inds... )
382+ # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
383+ # allow `eltype(dy)`, nor does it work for many structured matrices.
384+ dx = fill! (similar (x, eltype (dy), axes (x)), 0 )
385+ view (dx, inds... ) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
386+ dx
387+ end
388+ function _zerolike_writeat (x:: AbstractArray , dy, dims, inds... )
389+ # Since we have `x`, we can also handle arrays of arrays.
390+ dx = map (zero, x)
391+ if dims isa Colon
392+ view (dx, inds... ) .= Ref (dy)
393+ else
394+ view (dx, inds... ) .= dy
395+ end
396+ dx
397+ end
398+
399+ # Allow for second derivatives, by writing rules for `_zerolike_writeat`;
400+ # these rules are the reason it takes a `dims` argument.
401+
402+ function frule ((_, _, dydot), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
403+ return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dydot, dims, inds... )
404+ end
405+
406+ function rrule (:: typeof (_zerolike_writeat), x, dy, dims, inds... )
407+ z = _zerolike_writeat (x, dy, dims, inds... )
408+ function _zerolike_writeat_pullback (dz)
409+ dx = sum (view (unthunk (dz), inds... ); dims= dims)
410+ nots = map (_ -> NoTangent (), inds)
411+ return (NoTangent (), NoTangent (), dx, NoTangent (), nots... )
412+ end
413+ return z, _zerolike_writeat_pullback
414+ end
415+
416+ # These rules for `maximum` pick the same subgradient as `findmax`:
417+
418+ function frule ((_, xdot), :: typeof (maximum), x; dims= :)
419+ y, ind = findmax (x; dims= dims)
420+ return y, xdot[ind]
421+ end
422+
423+ function rrule (:: typeof (maximum), x:: AbstractArray ; dims= :)
424+ (y, _), back = rrule (findmax, x; dims= dims)
425+ maximum_pullback (dy) = back ((dy, nothing ))
426+ return y, maximum_pullback
427+ end
428+
429+ function frule ((_, xdot), :: typeof (minimum), x; dims= :)
430+ y, ind = findmin (x; dims= dims)
431+ return y, xdot[ind]
432+ end
433+
434+ function rrule (:: typeof (minimum), x:: AbstractArray ; dims= :)
435+ (y, _), back = rrule (findmin, x; dims= dims)
436+ minimum_pullback (dy) = back ((dy, nothing ))
437+ return y, minimum_pullback
438+ end
439+
440+ # ####
441+ # #### `extrema`
442+ # ####
443+
444+ function rrule (:: typeof (extrema), x:: AbstractArray{<:Number} ; dims= :)
445+ if dims isa Colon
446+ return _extrema_colon (x)
447+ else
448+ return _extrema_dims (x, dims)
449+ end
450+ end
451+
452+ function _extrema_colon (x)
453+ ylo, ilo = findmin (x)
454+ yhi, ihi = findmax (x)
455+ project = ProjectTo (x)
456+ function extrema_pullback ((dylo, dyhi)) # accepts Tangent
457+ if (dylo, dyhi) isa Tuple{AbstractZero, AbstractZero}
458+ return (NoTangent (), NoTangent ())
459+ end
460+ # One argument may be AbstractZero here. Use promote_op because
461+ # promote_type allows for * as well as +, hence gives Any.
462+ T = Base. promote_op (+ , typeof (dylo), typeof (dyhi))
463+ x_nothunk = let
464+ # x_thunk = @thunk begin # this doesn't infer
465+ dx = fill! (similar (x, T, axes (x)), false )
466+ view (dx, ilo) .= dylo
467+ view (dx, ihi) .= view (dx, ihi) .+ dyhi
468+ project (dx)
469+ end
470+ # x_ithunk = InplaceableThunk(x_thunk) do dx
471+ # view(dx, ilo) .= view(dx, ilo) .+ dylo
472+ # view(dx, ihi) .= view(dx, ihi) .+ dyhi
473+ # dx
474+ # end
475+ return (NoTangent (), x_nothunk)
476+ end
477+ return (ylo, yhi), extrema_pullback
478+ end
479+
480+ function _extrema_dims (x, dims)
481+ ylo, ilo = findmin (x; dims= dims)
482+ yhi, ihi = findmax (x; dims= dims)
483+ y = similar (ylo, Tuple{eltype (ylo), eltype (yhi)})
484+ map! (tuple, y, ylo, yhi) # this is a GPU-friendly version of collect(zip(ylo, yhi))
485+ project = ProjectTo (x)
486+ function extrema_pullback_dims (dy_raw)
487+ dy = unthunk (dy_raw)
488+ @assert dy isa AbstractArray{<: Tuple{Any,Any} }
489+ # Can we actually get Array{Tuple{Float64,ZeroTangent}} here? Not sure.
490+ T = Base. promote_op (+ , eltype (dy). parameters... )
491+ x_nothunk = let
492+ # x_thunk = @thunk begin # this doesn't infer
493+ dx = fill! (similar (x, T, axes (x)), false )
494+ view (dx, ilo) .= first .(dy)
495+ view (dx, ihi) .= view (dx, ihi) .+ last .(dy)
496+ project (dx)
497+ end
498+ # x_ithunk = InplaceableThunk(x_thunk) do dx
499+ # view(dx, ilo) .= first.(dy)
500+ # view(dx, ihi) .= view(dx, ihi) .+ last.(dy)
501+ # dx
502+ # end
503+ return (NoTangent (), x_nothunk)
504+ end
505+ return y, extrema_pullback_dims
506+ end
0 commit comments