From 8281539c7f32e4db85ea401e02b7198fbd5ba780 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 5 Aug 2021 08:42:11 -0400 Subject: [PATCH 1/6] unique, take 1 --- src/rulesets/Base/sort.jl | 54 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index cf3eadc4b..c60223798 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -1,3 +1,7 @@ +##### +##### `sort` +##### + function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,OrdinalRange}; kwargs...) inds = partialsortperm(xs, k; kwargs...) ys = xs[inds] @@ -33,3 +37,53 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) end return ys, sort_pullback end + +##### +##### `unique` +##### + +function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) + axes_x = axes(x) + project = ProjectTo(x) + y = unique(x; dims=dims) # accepts only dims=: or dims::Integer + if dims isa Colon + xs, ys = vec(x), y + else + xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) + end + mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) + mask .= (mask .== cumsum(mask, dims=1) .== true) + keep = map(I -> I[1], findall(mask)) + function unique_pullback(dy_raw) + dy = unthunk(dy_raw) + if dims isa Colon + # The function `_zerolike_writeat` is defined near `maximum`, allows + # second derivatives. Should perhaps eventually be shared with `getindex`. + dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) + else + inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) + dx = _zerolike_writeat(x, dy, (), inds...) + end + return (NoTangent(), project(dx)) + end + return y, unique_pullback +end + +function _zerolike_writeat(x, dy, dims, ind...) + # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't + # allow `eltype(dy)`, nor does it work for many structured matrices. + dx = fill!(similar(x, eltype(dy), axes(x)), false) + view(dx, ind...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray + dx +end + +#= + +rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30]) +rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30]) + +rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60] + +rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0] + +=# From 912ae5c5873ceb0f8dc1ab436f0c3791036ce1a0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 27 Aug 2021 16:27:30 -0400 Subject: [PATCH 2/6] add shortcut, and tests --- src/rulesets/Base/sort.jl | 37 +++++++++++++++---------------------- test/rulesets/Base/sort.jl | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index c60223798..a077abb03 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -44,18 +44,22 @@ end function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) axes_x = axes(x) - project = ProjectTo(x) y = unique(x; dims=dims) # accepts only dims=: or dims::Integer - if dims isa Colon - xs, ys = vec(x), y - else - xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) - end - mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) - mask .= (mask .== cumsum(mask, dims=1) .== true) - keep = map(I -> I[1], findall(mask)) function unique_pullback(dy_raw) dy = unthunk(dy_raw) + if length(x) == length(y) + # Short-circuit for the case of all unique, since `mask` is fairly expensive: + dx = reshape(dy, axes_x) + return (NoTangent(), ProjectTo(x)(dx)) + end + if dims isa Colon + xs, ys = vec(x), y + else + xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) + end + mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) + mask .= (mask .== cumsum(mask, dims=1) .== true) + keep = map(I -> I[1], findall(mask)) if dims isa Colon # The function `_zerolike_writeat` is defined near `maximum`, allows # second derivatives. Should perhaps eventually be shared with `getindex`. @@ -63,8 +67,8 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) else inds = ntuple(d -> d==dims ? keep : (:), length(axes_x)) dx = _zerolike_writeat(x, dy, (), inds...) - end - return (NoTangent(), project(dx)) + end + return (NoTangent(), ProjectTo(x)(dx)) end return y, unique_pullback end @@ -76,14 +80,3 @@ function _zerolike_writeat(x, dy, dims, ind...) view(dx, ind...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray dx end - -#= - -rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30]) -rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30]) - -rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60] - -rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0] - -=# diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 5f3bc4213..cf5b40b85 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -12,4 +12,25 @@ test_rrule(partialsort, a, 4, fkwargs=(;rev=true)) end + + @testset "unique" begin + # Trivial case, all unique: + test_rrule(unique, rand(5)) + test_rrule(unique, rand(3,4)) + test_rrule(unique, rand(3,4); fkwargs=(; dims=2)) + + # Not all unique: + @test rrule(unique, [1,1,2,3])[1] == [1,2,3] + @test rrule(unique, [1,1,2,3])[2]([10,20,30]) == (NoTangent(), [10, 0, 20, 30]) + + @test rrule(unique, [1 2; 1 4])[1] == [1,2,4] + @test rrule(unique, [1 2; 1 4])[2]([10,20,30]) == (NoTangent(), [10 20; 0 30]) + + @test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[1] == [1 2 2; 1 2 4] + @test rrule(unique, [1 2 1 2; 1 2 1 4], dims=2)[2]([10 20 30; 40 50 60])[2] == [10 20 0 30; 40 50 0 60] + + @test rrule(unique, Diagonal([1,2,3]))[1] == [1,0,2,3] + @test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] == [10.0 0.0 0.0; 0.0 30.0 0.0; 0.0 0.0 40.0] + @test rrule(unique, Diagonal([1,2,3]))[2]([10 20 30 40])[2] isa Diagonal + end end From 6c959917d7e64de2af0bf77e3b1d7ee8c04d3f2f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 17 Nov 2021 09:45:25 -0500 Subject: [PATCH 3/6] sortslices, too --- src/rulesets/Base/sort.jl | 16 ++++++++++++++++ test/rulesets/Base/sort.jl | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index a077abb03..ec38782d2 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -38,6 +38,22 @@ function rrule(::typeof(sort), xs::AbstractVector; kwargs...) return ys, sort_pullback end +##### +##### `sortslices` +##### + +function rrule(::typeof(sortslices), x::AbstractArray{<:Number}; dims::Integer, kw...) + p = sortperm(collect(eachslice(x; dims=dims)); kw...) + inds = ntuple(d -> d == dims ? p : (:), ndims(x)) + function sortslices_pullback(dy) + # No actual need to zero this, and if you didn't, then you could widen eltype + # Also, you could use similar(dy) here not x, same size? + dx = _zerolike_writeat(x, unthunk(dy), (), inds...) + return (NoTangent(), ProjectTo(x)(dx)) + end + return x[inds...], sortslices_pullback +end + ##### ##### `unique` ##### diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index cf5b40b85..40ab8d46a 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -13,6 +13,14 @@ test_rrule(partialsort, a, 4, fkwargs=(;rev=true)) end + @testset "sortslices" begin + test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) + test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) + test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum)) + + @test_throws Exception sortslices(Diagonal(1:3), dims=1) + end + @testset "unique" begin # Trivial case, all unique: test_rrule(unique, rand(5)) From cc9bd436a3fcc3391aa5b0a4915b3996ddf41bd1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 24 Nov 2021 11:43:34 -0500 Subject: [PATCH 4/6] fixup --- src/rulesets/Base/sort.jl | 8 -------- test/rulesets/Base/sort.jl | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index ec38782d2..f377dde4a 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -88,11 +88,3 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) end return y, unique_pullback end - -function _zerolike_writeat(x, dy, dims, ind...) - # It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't - # allow `eltype(dy)`, nor does it work for many structured matrices. - dx = fill!(similar(x, eltype(dy), axes(x)), false) - view(dx, ind...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray - dx -end diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 40ab8d46a..f76109586 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -16,7 +16,7 @@ @testset "sortslices" begin test_rrule(sortslices, rand(3,4); fkwargs=(; dims=2)) test_rrule(sortslices, rand(5,4); fkwargs=(; dims=1, rev=true, by=last)) - test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum)) + test_rrule(sortslices, rand(3,4,5); fkwargs=(; dims=3, by=sum), check_inferred=false) @test_throws Exception sortslices(Diagonal(1:3), dims=1) end From 3f4481015a097abc9d3d139562f270afd691c7fd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 30 Nov 2021 16:00:14 -0500 Subject: [PATCH 5/6] Apply 3 suggestions Co-authored-by: Lyndon White --- src/rulesets/Base/sort.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index f377dde4a..75e2eb0e2 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -42,7 +42,7 @@ end ##### `sortslices` ##### -function rrule(::typeof(sortslices), x::AbstractArray{<:Number}; dims::Integer, kw...) +function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) p = sortperm(collect(eachslice(x; dims=dims)); kw...) inds = ntuple(d -> d == dims ? p : (:), ndims(x)) function sortslices_pullback(dy) @@ -68,13 +68,14 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) dx = reshape(dy, axes_x) return (NoTangent(), ProjectTo(x)(dx)) end + if dims isa Colon xs, ys = vec(x), y else xs, ys = collect(eachslice(x; dims=dims)), collect(eachslice(y; dims=dims)) end mask = isequal.(permutedims(ys), xs) # unique([0.0, -0.0, NaN, NaN]) - mask .= (mask .== cumsum(mask, dims=1) .== true) + mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) keep = map(I -> I[1], findall(mask)) if dims isa Colon # The function `_zerolike_writeat` is defined near `maximum`, allows From 213987ebb0567b4c3d26bf4ac1cf4ab2f20f1664 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 30 Nov 2021 17:04:44 -0500 Subject: [PATCH 6/6] comment Co-authored-by: Lyndon White --- src/rulesets/Base/sort.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 75e2eb0e2..be7840c8c 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -78,8 +78,8 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:) mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1) keep = map(I -> I[1], findall(mask)) if dims isa Colon - # The function `_zerolike_writeat` is defined near `maximum`, allows - # second derivatives. Should perhaps eventually be shared with `getindex`. + # The function `_zerolike_writeat` allows second derivatives. + # Should perhaps eventually be shared with `getindex`. dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x) else inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))