@@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon)
5252 return x, getindex_back_4
5353end
5454
55-
5655# ####
57- # #### getindex
56+ # #### getindex(::AbstractArray)
5857# ####
5958
6059function frule ((_, ẋ), :: typeof (getindex), x:: AbstractArray , inds... )
6160 return x[inds... ], ẋ[inds... ]
6261end
6362
64- function rrule (:: typeof (getindex), x:: Array{<:Number} , inds... )
65- # removes any logical indexing, CartesianIndex etc
66- # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
63+ function rrule (:: typeof (getindex), x:: AbstractArray , inds... )
64+ function getindex_pullback (dy)
65+ nots = map (Returns (NoTangent ()), inds)
66+ return (NoTangent (), thunked_∇getindex (x, dy, inds... ), nots... )
67+ end
68+ return x[inds... ], getindex_pullback
69+ end
70+
71+ function thunked_∇getindex (x, dy, inds... )
72+ return InplaceableThunk (
73+ dx -> ∇getindex! (dx, unthunk (dy), Base. to_indices (x, inds)... ),
74+ @thunk (∇getindex (x, unthunk (dy), inds... )),
75+ )
76+ end
77+
78+ """
79+ ∇getindex(x, dy, inds...)
80+
81+ For the `rrule` of `y = x[inds...]`, this function is roughly
82+ `setindex(zero(x), dy, inds...)`, returning the array `dx`.
83+ Differentiable. Includes `ProjectTo(x)(dx)`.
84+ """
85+ function ∇getindex (x:: AbstractArray , dy, inds... )
86+ # `to_indices` removes any logical indexing, colons, CartesianIndex etc,
87+ # leaving just Int / AbstractVector of Int
6788 plain_inds = Base. to_indices (x, inds)
68- y = getindex (x, plain_inds... )
69- function getindex_pullback (ȳ)
70- function getindex_add! (Δ)
71- # this a optimizes away for simple cases
72- for (ȳ_ii, ii) in zip (ȳ, Iterators. product (plain_inds... ))
73- Δ[ii... ] += ȳ_ii
74- end
75- return Δ
76- end
89+ dx = _setindex_zero (x, dy, plain_inds... )
90+ ∇getindex! (dx, dy, plain_inds... )
91+ return ProjectTo (x)(dx) # since we have x, may as well do this inside, not in rules
92+ end
93+
94+ """
95+ _setindex_zero(x, dy, inds...)
7796
78- x̄ = InplaceableThunk (
79- getindex_add!,
80- @thunk (getindex_add! (zero (x))),
81- )
82- īnds = broadcast (Returns (NoTangent ()), inds)
83- return (NoTangent (), x̄, īnds... )
97+ This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
98+ and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
99+ `∇getindex` does next.
100+
101+ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
102+ allow `eltype(dy)`, nor does it work for many structured matrices.
103+ """
104+ _setindex_zero (x:: AbstractArray{<:Number} , dy, inds:: Integer... ) = fill! (similar (x, typeof (dy), axes (x)), ZeroTangent ())
105+ _setindex_zero (x:: AbstractArray{<:Number} , dy, inds... ) = fill! (similar (x, eltype (dy), axes (x)), ZeroTangent ())
106+ function _setindex_zero (x:: AbstractArray , dy, inds:: Integer... )
107+ # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
108+ # but always makes an abstract type. TODO : make it infer concrete type for e.g. vectors of SVectors
109+ T = Union{typeof (dy), ZeroTangent}
110+ return fill! (similar (x, T, axes (x)), ZeroTangent ())
111+ end
112+ function _setindex_zero (x:: AbstractArray , dy, inds... )
113+ T = Union{eltype (dy), ZeroTangent}
114+ return fill! (similar (x, T, axes (x)), ZeroTangent ())
115+ end
116+ ChainRules. @non_differentiable _setindex_zero (x:: AbstractArray , dy:: Any , inds:: Any... )
117+
118+ function ∇getindex! (dx:: AbstractArray , dy, inds:: Integer... )
119+ view (dx, inds... ) .+ = Ref (dy)
120+ return dx
121+ end
122+ function ∇getindex! (dx:: AbstractArray , dy, inds... )
123+ view (dx, inds... ) .+ = dy
124+ return dx
125+ end
126+
127+ # Allow for second derivatives, by writing rules for `∇getindex`:
128+
129+ function frule ((_, _, dẏ), :: typeof (∇getindex), x, dy, inds... )
130+ return ∇getindex (x, dy, inds... ), ∇getindex (x, dẏ, inds... )
131+ end
132+
133+ function rrule (:: typeof (∇getindex), x, dy, inds... )
134+ z = ∇getindex (x, dy, inds... )
135+ function ∇getindex_pullback (dz)
136+ d2y = getindex (unthunk (dz), inds... )
137+ nots = map (Returns (NoTangent ()), inds)
138+ return (NoTangent (), NoTangent (), ProjectTo (dy)(d2y), nots... )
84139 end
140+ return z, ∇getindex_pullback
141+ end
142+
143+ # Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
144+ # To avoid this, copy everything back to the CPU.
145+ # But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
85146
86- return y, getindex_pullback
147+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds:: Integer... )
148+ view (dx, inds... ) .+ = Ref (dy)
149+ return dx
150+ end
151+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds:: Union{Integer, AbstractUnitRange, Base.Slice} ...)
152+ view (dx, inds... ) .+ = dy
153+ return dx
154+ end
155+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds... )
156+ dx_cpu = adapt (Array, dx)
157+ view (dx_cpu, adapt (Array, inds)... ) .+ = adapt (Array, dy)
158+ copyto! (dx, dx_cpu)
159+ return dx
87160end
88161
89162# ####
@@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
117190 return view (x, inds... ), view (ẋ, inds... )
118191end
119192
193+ function rrule (:: typeof (view), x:: AbstractArray , inds... )
194+ function view_pullback (dy)
195+ nots = map (Returns (NoTangent ()), inds)
196+ return (NoTangent (), thunked_∇getindex (x, dy, inds... ), nots... )
197+ end
198+ return view (x, inds... ), view_pullback
199+ end
200+
201+ function rrule (:: typeof (view), x:: AbstractArray , i:: Integer , jkl:: Integer... )
202+ # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex:
203+ function view_pullback_0 (dy)
204+ nots = map (Returns (NoTangent ()), (i, jkl... ))
205+ return (NoTangent (), thunked_∇getindex (x, dy, i: i, jkl... ), nots... )
206+ end
207+ return view (x, i, jkl... ), view_pullback_0
208+ end
209+
120210# ####
121211# #### setindex!
122212# ####
@@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
125215 return setindex! (x, v, inds... ), setindex! (ẋ, v̇, inds... )
126216end
127217
218+ # ####
219+ # #### unsafe_getindex
220+ # ####
221+
222+ # This is called by e.g. `iterate(1:0.1:2)`,
223+ # and fixes https://github.com/FluxML/Zygote.jl/issues/1247
224+ # Only needs to accept AbstractRange, but AbstractVector makes testing easier.
225+
226+ function frule ((_, ẋ), :: typeof (Base. unsafe_getindex), x:: AbstractVector , i:: Integer )
227+ return Base. unsafe_getindex (x, i), getindex (ẋ, i)
228+ end
229+
230+ function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (Base. unsafe_getindex), x:: AbstractVector , i:: Integer )
231+ return rrule_via_ad (cfg, getindex, x, i)
232+ end
128233
129234# ####
130235# #### `eachslice` and friends
0 commit comments