From b8562878080334395644f6fb1046e029983e521b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 30 Aug 2023 16:08:08 +0300 Subject: [PATCH 1/2] Allow N-dimensional arrays in sorting rules --- src/rulesets/Base/sort.jl | 4 ++-- test/rulesets/Base/sort.jl | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/sort.jl b/src/rulesets/Base/sort.jl index 0805da91f..2ebccbd84 100644 --- a/src/rulesets/Base/sort.jl +++ b/src/rulesets/Base/sort.jl @@ -25,12 +25,12 @@ function rrule(::typeof(partialsort), xs::AbstractVector, k::Union{Integer,Ordin return ys, partialsort_pullback end -function frule((_, ẋs), ::typeof(sort), xs::AbstractVector; kw...) +function frule((_, ẋs), ::typeof(sort), xs::AbstractArray; kw...) inds = sortperm(xs; kw...) return xs[inds], ẋs[inds] end -function rrule(::typeof(sort), xs::AbstractVector; kwargs...) +function rrule(::typeof(sort), xs::AbstractArray; kwargs...) inds = sortperm(xs; kwargs...) ys = xs[inds] diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 052045d1e..564817feb 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -7,6 +7,16 @@ # rev test_rrule(sort, a) test_rrule(sort, a; fkwargs=(;rev=true)) + + a = rand(5, 4) + for dims in (1, 2) + # fwd + test_frule(sort, a; fkwargs=(;dims)) + test_frule(sort, a; fkwargs=(;dims, rev=true)) + # rev + test_rrule(sort, a; fkwargs=(;dims)) + test_rrule(sort, a; fkwargs=(;dims, rev=true)) + end end @testset "partialsort" begin a = rand(10) From 5f4814b8e255752884b85dd4fb29ae8c2253b37b Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Wed, 30 Aug 2023 20:13:11 +0300 Subject: [PATCH 2/2] Restrict Nd tests to 1.9 --- test/rulesets/Base/sort.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/sort.jl b/test/rulesets/Base/sort.jl index 564817feb..d06067bd2 100644 --- a/test/rulesets/Base/sort.jl +++ b/test/rulesets/Base/sort.jl @@ -8,14 +8,16 @@ test_rrule(sort, a) test_rrule(sort, a; fkwargs=(;rev=true)) - a = rand(5, 4) - for dims in (1, 2) - # fwd - test_frule(sort, a; fkwargs=(;dims)) - test_frule(sort, a; fkwargs=(;dims, rev=true)) - # rev - test_rrule(sort, a; fkwargs=(;dims)) - test_rrule(sort, a; fkwargs=(;dims, rev=true)) + if VERSION ≥ v"1.9" + a = rand(5, 4) + for dims in (1, 2) + # fwd + test_frule(sort, a; fkwargs=(;dims)) + test_frule(sort, a; fkwargs=(;dims, rev=true)) + # rev + test_rrule(sort, a; fkwargs=(;dims)) + test_rrule(sort, a; fkwargs=(;dims, rev=true)) + end end end @testset "partialsort" begin