diff --git a/benchmark/bench_sort.jl b/benchmark/bench_sort.jl new file mode 100644 index 00000000..a8327897 --- /dev/null +++ b/benchmark/bench_sort.jl @@ -0,0 +1,89 @@ +module BenchSort + +using BenchmarkTools +using Random: rand! +using StaticArrays +using StaticArrays: BitonicSort + +const SUITE = BenchmarkGroup() + +# 1 second is sufficient for reasonably consistent timings. +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1 + +const LEN = 1000 + +const Floats = (Float16, Float32, Float64) +const Ints = (Int8, Int16, Int32, Int64, Int128) +const UInts = (UInt8, UInt16, UInt32, UInt64, UInt128) + +map_sort!(vs; kwargs...) = map!(v -> sort(v; kwargs...), vs, vs) + +addgroup!(SUITE, "BitonicSort") + +g = addgroup!(SUITE["BitonicSort"], "SVector") +for lt in (isless, <) + n = 1 + while (n = nextprod([2, 3], n + 1)) <= 24 + for T in (Floats..., Ints..., UInts...) + (lt === <) && (T <: Integer) && continue # For Integers, isless is <. + vs = Vector{SVector{n, T}}(undef, LEN) + g[lt, n, T] = @benchmarkable( + map_sort!($vs; alg=BitonicSort, lt=$lt), + evals=1, # Redundant on @benchmarkable as of BenchmarkTools 1.1.3. + # We need evals=1 so that setup runs before every eval. But PkgBenchmark + # always `tunes!` benchmarks before running, which overrides this. As a + # workaround, use the unhygienic symbol `__params` to set evals just before + # execution at + # https://github.com/JuliaCI/BenchmarkTools.jl/blob/v1.1.3/src/execution.jl#L482 + # See also: https://github.com/JuliaCI/PkgBenchmark.jl/issues/120 + setup=(__params.evals = 1; rand!($vs)), + ) + end + end +end + +g = addgroup!(SUITE["BitonicSort"], "MVector") +for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64)) + vs = Vector{MVector{n, T}}(undef, LEN) + g[lt, n, T] = @benchmarkable( + map_sort!($vs; alg=BitonicSort, lt=$lt), + evals=1, + setup=(__params.evals = 1; rand!($vs)), + ) +end + +g = addgroup!(SUITE["BitonicSort"], "SizedVector") +for (lt, n, T) in ((isless, 16, Int64), (isless, 16, Float64), (<, 16, Float64)) + vs = Vector{SizedVector{n, T, Vector{T}}}(undef, LEN) + g[lt, n, T] = @benchmarkable( + map_sort!($vs; alg=BitonicSort, lt=$lt), + evals=1, + setup=(__params.evals = 1; rand!($vs)), + ) +end + +function map_floats_nans!(vs::Vector{SVector{N, T}}, p) where {N, T} + @inline _rand(_) = ifelse(rand(Float32) < p, T(NaN), rand(T)) + for i in eachindex(vs) + @inbounds vs[i] = SVector(ntuple(_rand, Val(N))) + end + return vs +end + +g = addgroup!(SUITE["BitonicSort"], "NaNs") +for p in (0.001, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0) + (lt, n, T) = (isless, 16, Float64) + vs = Vector{SVector{n, T}}(undef, LEN) + g[lt, n, T, p] = @benchmarkable( + map_sort!($vs; alg=BitonicSort, lt=$lt), + evals=1, + setup=(__params.evals = 1; map_floats_nans!($vs, $p)), + ) +end + +end # module BenchSort + +# Allow PkgBenchmark.benchmarkpkg to call this file directly. +@isdefined(SUITE) || (SUITE = BenchSort.SUITE) + +BenchSort.SUITE diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 568a1be0..1f34e77b 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -124,6 +124,7 @@ include("indexing.jl") include("broadcast.jl") include("mapreduce.jl") include("sort.jl") +using .Sort include("arraymath.jl") include("linalg.jl") include("matrix_multiply_add.jl") diff --git a/src/sort.jl b/src/sort.jl index 477bca83..bf985054 100644 --- a/src/sort.jl +++ b/src/sort.jl @@ -1,9 +1,30 @@ -import Base.Order: Forward, Ordering, Perm, ord -import Base.Sort: Algorithm, lt, sort, sortperm +module Sort +import Base: sort, sortperm + +using ..StaticArrays +using Base: @_inline_meta +using Base.Order: Forward, Ordering, Perm, Reverse, ord +using Base.Sort: Algorithm, lt + +export BitonicSort struct BitonicSortAlg <: Algorithm end +# For consistency with Julia Base, track their *Sort docstring text in base/sort.jl. +""" + StaticArrays.BitonicSort + +Indicate that a sorting function should use a bitonic sorting network, which is *not* +stable. By default, `StaticVector`s with at most 20 elements are sorted with `BitonicSort`. + +Characteristics: + * *not stable*: does not preserve the ordering of elements which compare equal (e.g. "a" + and "A" in a sort of letters which ignores case). + * *in-place* in memory. + * *good performance* for small collections. + * compilation time increases dramatically with the number of elements. +""" const BitonicSort = BitonicSortAlg() @@ -19,8 +40,7 @@ defalg(a::StaticVector) = rev::Union{Bool,Nothing} = nothing, order::Ordering = Forward) length(a) <= 1 && return a - ordr = ord(lt, by, rev, order) - return _sort(a, alg, ordr) + return _sort(a, alg, lt, by, rev, order) end @inline function sortperm(a::StaticVector; @@ -33,21 +53,73 @@ end length(a) <= 1 && return SVector{length(a),Int}(p) ordr = Perm(ord(lt, by, rev, order), a) - return SVector{length(a),Int}(_sort(p, alg, ordr)) + return SVector{length(a),Int}(_sort(p, alg, isless, identity, nothing, ordr)) end +@inline _sort(a::StaticVector, alg, lt, by, rev, order) = + similar_type(a)(sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order)) + +@inline _sort(a::StaticVector, alg::BitonicSortAlg, lt, by, rev, order) = + similar_type(a)(_sort(Tuple(a), alg, lt, by, rev, order)) + +@inline _sort(a::NTuple, alg, lt, by, rev, order) = + sort!(Base.copymutable(a); alg=alg, lt=lt, by=by, rev=rev, order=order) + +@inline _sort(a::NTuple, ::BitonicSortAlg, lt, by, rev, order) = + _bitonic_sort(a, ord(lt, by, rev, order)) + +# For better performance sorting floats under the isless relation, apply an order-preserving +# bijection to sort them as integers. +@inline function _sort( + a::NTuple{N, <:Base.IEEEFloat}, + ::BitonicSortAlg, + lt::typeof(isless), + by::Union{typeof.((identity, +, -))...}, + rev::Union{Bool, Nothing}, + order, +) where N + # Exclude N == 2 to avoid a performance regression on AArch64. + if N > 2 && (order === Forward || order === Reverse) + _rev = xor(by === -, rev === true, order === Reverse) + return _intfp.(_bitonic_sort(_fpint.(a), ord(isless, identity, _rev, Forward))) + end + return _bitonic_sort(a, ord(lt, by, rev, order)) +end -@inline _sort(a::StaticVector, alg, order) = - similar_type(a)(sort!(Base.copymutable(a); alg=alg, order=order)) - -@inline _sort(a::StaticVector, alg::BitonicSortAlg, order) = - similar_type(a)(_sort(Tuple(a), alg, order)) +_inttype(::Type{Float64}) = Int64 +_inttype(::Type{Float32}) = Int32 +_inttype(::Type{Float16}) = Int16 + +_floattype(::Type{Int64}) = Float64 +_floattype(::Type{Int32}) = Float32 +_floattype(::Type{Int16}) = Float16 + +# Modified from the _fpint function added to base/float.jl in Julia 1.7. This is a strictly +# increasing function with respect to the isless relation. `isless` is trichotomous with the +# isequal relation and treats every NaN as identical. This function on the other hand +# distinguishes between NaNs with different payloads and signs, but this difference is +# inconsequential for unstable sorting. The `offset` is necessary because NaNs (in +# particular, those with the sign bit set) must be mapped to the greatest Ints, which is +# Julia-specific. +@inline function _fpint(x::F) where F + I = _inttype(F) + offset = Base.significand_mask(F) % I + n = reinterpret(I, x) + return ifelse(n < zero(I), n ⊻ typemax(I), n) - offset +end -_sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order) +# Inverse of _fpint. +@inline function _intfp(n::I) where I + F = _floattype(I) + offset = Base.significand_mask(F) % I + n += offset + n = ifelse(n < zero(I), n ⊻ typemax(I), n) + return reinterpret(F, n) +end # Implementation loosely following # https://www.inf.hs-flensburg.de/lang/algorithmen/sortieren/bitonic/oddn.htm -@generated function _sort(a::NTuple{N}, ::BitonicSortAlg, order) where N +@generated function _bitonic_sort(a::NTuple{N}, order) where N function swap_expr(i, j, rev) ai = Symbol('a', i) aj = Symbol('a', j) @@ -87,3 +159,5 @@ _sort(a::NTuple, alg, order) = sort!(Base.copymutable(a); alg=alg, order=order) return ($(symlist...),) end end + +end # module Sort diff --git a/test/sort.jl b/test/sort.jl index ded0703e..36db9d6a 100644 --- a/test/sort.jl +++ b/test/sort.jl @@ -1,4 +1,8 @@ +module SortTests + using StaticArrays, Test +using StaticArrays.Sort: _inttype +using Base.Order: Forward, Reverse @testset "sort" begin @@ -30,4 +34,88 @@ using StaticArrays, Test @test sortperm(SA[1, 1, 1, 0]) == SA[4, 1, 2, 3] end -end + @testset "NaNs" begin + # Return an SVector with floats and NaNs that have random sign and payload bits. + function floats_randnans(::Type{SVector{N, T}}, p) where {N, T} + float_or(x, y) = reinterpret(T, |(reinterpret.(_inttype(T), (x, y))...)) + @inline function _rand(_) + r = rand(T) + # The bitwise or of any T with T(Inf) is either ±T(Inf) or a NaN. + ifelse(rand(Float32) < p, float_or(typemax(T), r - T(0.5)), r) + end + return SVector(ntuple(_rand, Val(N))) + end + + # Sort floats and arbitrary NaNs. + for T in (Float16, Float32, Float64) + buffer = Vector{T}(undef, 16) + @test all(floats_randnans(SVector{16, T}, 0.5) for _ in 1:10_000) do a + copyto!(buffer, a) + isequal(sort(a), sort!(buffer)) + end + end + + # Sort signed Infs, signed zeros, and signed NaNs with extremal payloads. + for T in (Float16, Float32, Float64) + U = _inttype(T) + small_nan = reinterpret(T, reinterpret(U, typemax(T)) + one(U)) + large_nan = reinterpret(T, typemax(U)) + nans = (small_nan, large_nan, T(NaN), -small_nan, -large_nan, -T(NaN)) + (a, b, c, d) = (-T(Inf), -zero(T), zero(T), T(Inf)) + sorted = [a, b, c, d, nans..., nans...] + @test isequal(sorted, sort(SA[nans..., d, c, b, a, nans...])) + @test isequal(sorted, sort(SA[d, c, nans..., nans..., b, a])) + end + end + + # These tests are selected and modified from Julia's test/ordering.jl and test/sorting.jl. + @testset "Base tests" begin + # This testset partially fails on Julia versions < 1.5 because order could be + # discarded: https://github.com/JuliaLang/julia/pull/34719 + if VERSION >= v"1.5" + @testset "ordering" begin + for T in (Int, Float64) + for (s1, rev) in enumerate([nothing, true, false]) + for (s2, lt) in enumerate([>, <, (a, b) -> a - b > 0, (a, b) -> a - b < 0]) + for (s3, by) in enumerate([-, +]) + for (s4, order) in enumerate([Reverse, Forward]) + if isodd(s1 + s2 + s3 + s4) + target = T.(SA[1, 2, 3]) + else + target = T.(SA[3, 2, 1]) + end + @test target == sort(T.(SA[2, 3, 1]), rev=rev, lt=lt, by=by, order=order) + end + end + end + end + end + + @test SA[1 => 3, 2 => 5, 3 => 1] == + sort(SA[1 => 3, 2 => 5, 3 => 1]) == + sort(SA[1 => 3, 2 => 5, 3 => 1], by=first) == + sort(SA[1 => 3, 2 => 5, 3 => 1], rev=true, order=Reverse) == + sort(SA[1 => 3, 2 => 5, 3 => 1], lt= >, order=Reverse) + + @test SA[3 => 1, 1 => 3, 2 => 5] == + sort(SA[1 => 3, 2 => 5, 3 => 1], by=last) == + sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, rev=true, order=Reverse) == + sort(SA[1 => 3, 2 => 5, 3 => 1], by=last, lt= >, order=Reverse) + end + end + + @testset "sort" begin + for T in (Int, Float64) + @test sort(T.(SA[2,3,1])) == T.(SA[1,2,3]) == sort(T.(SA[2,3,1]); order=Forward) + @test sort(T.(SA[2,3,1]), rev=true) == T.(SA[3,2,1]) == sort(T.(SA[2,3,1]), order=Reverse) + end + @test sort(SA['z':-1:'a'...]) == SA['a':'z'...] + @test sort(SA['a':'z'...], rev=true) == SA['z':-1:'a'...] + end + + @test sortperm(SA[2,3,1]) == SA[3,1,2] + end + +end # @testset "sort" + +end # module SortTests