|
1 | | -##### |
2 | | -##### getindex(::Tuple) |
3 | | -##### |
4 | | - |
5 | | -function frule((_, ẋ), ::typeof(getindex), x::Tuple, i::Integer) |
6 | | - return x[i], ẋ[i] |
7 | | -end |
8 | | - |
9 | | -function frule((_, ẋ), ::typeof(getindex), x::Tuple, i) |
10 | | - y = x[i] |
11 | | - return y, Tangent{typeof(y)}(ẋ[i]...) |
| 1 | +# Int rather than Int64/Integer is intentional |
| 2 | +function frule((_, ẋ), ::typeof(getfield), x::Tuple, i::Int) |
| 3 | + return x.i, ẋ.i |
12 | 4 | end |
13 | 5 |
|
14 | 6 | "for a given tuple type, returns a Val{N} where N is the length of the tuple" |
|
77 | 69 | """ |
78 | 70 | ∇getindex(x, dy, inds...) |
79 | 71 |
|
80 | | -For the `rrule` of `y = x[inds...]`, this function is roughly |
| 72 | +For the `rrule` of `y = x[inds...]`, this function is roughly |
81 | 73 | `setindex(zero(x), dy, inds...)`, returning the array `dx`. |
82 | 74 | Differentiable. Includes `ProjectTo(x)(dx)`. |
83 | 75 | """ |
84 | | -function ∇getindex(x::AbstractArray, dy, inds...) |
| 76 | +function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N} |
85 | 77 | # `to_indices` removes any logical indexing, colons, CartesianIndex etc, |
86 | 78 | # leaving just Int / AbstractVector of Int |
87 | 79 | plain_inds = Base.to_indices(x, inds) |
88 | | - dx = _setindex_zero(x, dy, plain_inds...) |
89 | | - ∇getindex!(dx, dy, plain_inds...) |
| 80 | + dx = if plain_inds isa NTuple{N, Int} && T<:Number |
| 81 | + # scalar indexing |
| 82 | + OneElement(dy, plain_inds, axes(x)) |
| 83 | + else # some from slicing (potentially noncontigous) |
| 84 | + dx = _setindex_zero(x, dy, plain_inds...) |
| 85 | + ∇getindex!(dx, dy, plain_inds...) |
| 86 | + end |
90 | 87 | return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules |
91 | 88 | end |
92 | 89 | ∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z |
93 | 90 |
|
| 91 | +""" |
| 92 | + OneElement(val, ind, axes) <: AbstractArray |
| 93 | +
|
| 94 | +Extremely simple `struct` used for the gradient of scalar `getindex`. |
| 95 | +""" |
| 96 | +struct OneElement{T,N,I,A} <: AbstractArray{T,N} |
| 97 | + val::T |
| 98 | + ind::I |
| 99 | + axes::A |
| 100 | + OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes) |
| 101 | +end |
| 102 | +Base.size(A::OneElement) = map(length, A.axes) |
| 103 | +Base.axes(A::OneElement) = A.axes |
| 104 | +Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) |
| 105 | + |
| 106 | +function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N} |
| 107 | + if !ChainRulesCore.is_inplaceable_destination(xs) |
| 108 | + xs = collect(xs) |
| 109 | + end |
| 110 | + xs[oe.ind...] += oe.val |
| 111 | + return xs |
| 112 | +end |
| 113 | + |
| 114 | +Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe) |
| 115 | +Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe) |
| 116 | +Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) |
| 117 | + |
94 | 118 | """ |
95 | 119 | _setindex_zero(x, dy, inds...) |
96 | 120 |
|
@@ -159,29 +183,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds...) |
159 | 183 | return dx |
160 | 184 | end |
161 | 185 |
|
162 | | -##### |
163 | | -##### first, tail |
164 | | -##### |
165 | | - |
166 | | -function frule((_, ẋ), ::typeof(first), x::Tuple) |
167 | | - return first(x), first(ẋ) |
168 | | -end |
169 | | - |
170 | | -function rrule(::typeof(first), x::T) where {T<:Tuple} |
171 | | - first_back(dy) = (NoTangent(), Tangent{T}(ntuple(j -> j == 1 ? dy : NoTangent(), _tuple_N(T))...)) |
172 | | - return first(x), first_back |
173 | | -end |
174 | | - |
175 | | -function frule((_, ẋ), ::typeof(Base.tail), x::Tuple) |
176 | | - y = Base.tail(x) |
177 | | - return y, Tangent{typeof(y)}(Base.tail(ẋ)...) |
178 | | -end |
179 | | - |
180 | | -function rrule(::typeof(Base.tail), x::T) where {T<:Tuple} |
181 | | - tail_pullback(dy) = (NoTangent(), Tangent{T}(NoTangent(), dy...)) |
182 | | - return Base.tail(x), tail_pullback |
183 | | -end |
184 | | - |
185 | 186 | ##### |
186 | 187 | ##### view |
187 | 188 | ##### |
|
0 commit comments