diff --git a/src/special/misc.jl b/src/special/misc.jl index 21a27b7d..40fe26cc 100644 --- a/src/special/misc.jl +++ b/src/special/misc.jl @@ -1,4 +1,32 @@ @inline Base.:^(v::AbstractSIMD{W,T}, i::Integer) where {W,T} = Base.power_by_squaring(v, i) @inline Base.:^(v::AbstractSIMD{W,T}, i::Integer) where {W,T<:Union{Float32,Float64}} = Base.power_by_squaring(v, i) -@inline relu(x) = (y = zero(x); IfElse.ifelse(x > y, x, y)) +@inline relu(x) = (y = zero(x); ifelse(x > y, x, y)) +@inline Base.fld(x::AbstractSIMD, y::AbstractSIMD) = div(promote(x,y)..., RoundDown) + +@inline function Base.div(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}, ::RoundingMode{:Down}) where {W,T<:Integer} + d = div(x, y) + d - (signbit(x ⊻ y) & (d * y != x)) +end + +@inline Base.mod(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Integer} = + ifelse(y == -1, zero(x), x - fld(x, y) * y) + +@inline Base.mod(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Unsigned} = + rem(x, y) + +@inline Base.mod(i::AbstractSIMD{<:Any,<:Integer}, r::AbstractUnitRange{<:Integer}) = + mod(i-first(r), length(r)) + first(r) + +# avoid ambiguity with clamp(::Missing, lo, hi) in Base.Math at math.jl:1258 +# but who knows what would happen if you called it +for (X,L,H) in Iterators.product(fill([:Any, :Missing, :AbstractSIMD], 3)...) + any(==(:AbstractSIMD), (X,L,H)) || continue + @eval @inline function Base.clamp(x::$X, lo::$L, hi::$H) + x_, lo_, hi_ = promote(x, lo, hi) + ifelse(x_ > hi_, hi_, ifelse(x_ < lo_, lo_, x_)) + end +end + +@inline Base.clamp(x::AbstractSIMD{<:Any,<:Integer}, r::AbstractUnitRange{<:Integer}) = + clamp(x, first(r), last(r)) diff --git a/test/runtests.jl b/test/runtests.jl index d159444a..e078d418 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,7 +45,7 @@ end W = VectorizationBase.pick_vector_width(Float64) @test @inferred(VectorizationBase.pick_integer(Val(W))) == (VectorizationBase.AVX512DQ ? Int64 : Int32) - + @test first(A) === A[1] @test W64S == W64 @testset "Struct-Wrapped Vec" begin @@ -175,7 +175,7 @@ end @test !VectorizationBase.vall(Mask{4}(0xfc)) @test VectorizationBase.vall(Mask{8}(0xff)) @test VectorizationBase.vall(Mask{4}(0xcf)) - + @test VectorizationBase.vany(Mask{8}(0xfc)) @test VectorizationBase.vany(Mask{4}(0xfc)) @test !VectorizationBase.vany(Mask{8}(0x00)) @@ -211,7 +211,7 @@ end @test (Mask{8}(0xac) ⊻ true) === Mask{8}(0x53) @test (false ⊻ Mask{8}(0xac)) === Mask{8}(0xac) @test (true ⊻ Mask{8}(0xac)) === Mask{8}(0x53) - + @test (Mask{4}(0x05) | true) === Mask{4}(0x0f) @test (Mask{4}(0x05) | false) === Mask{4}(0x05) @test (true | Mask{4}(0x05)) === Mask{4}(0x0f) @@ -239,7 +239,7 @@ end # @test VectorizationBase.size_loads(A,2, Val(8)) == eval(VectorizationBase.num_vector_load_expr(@__MODULE__, :((() -> 17)()), 8)) == eval(VectorizationBase.num_vector_load_expr(@__MODULE__, 17, 8)) == divrem(size(A,2), 8) # end - + @testset "vector_width.jl" begin @test all(VectorizationBase.ispow2, 0:1) @test all(i -> !any(VectorizationBase.ispow2, 1+(1 << (i-1)):(1 << i)-1 ) && VectorizationBase.ispow2(1 << i), 2:9) @@ -282,7 +282,7 @@ end @test [vload(stridedpointer(C), (1+w, 2+w, 3)) for w ∈ 1:W64] == getindex.(Ref(C), 1 .+ (1:W64), 2 .+ (1:W64), 3) vstore!(stridedpointer(C), !mtest, ((MM{16})(17), 3, 4)) @test .!v1 == C[17:32,3,4] == tovector(vload(stridedpointer(C), ((MM{16})(17), 3, 4))) - + dims = (41,42,43) .* 3; # dims = (41,42,43); A = reshape(collect(Float64(0):Float64(prod(dims)-1)), dims); @@ -345,7 +345,7 @@ end @test v1 === vu.data[1] @test v2 === vu.data[2] @test v3 === vu.data[3] - + ir = 0:(AV == 1 ? W64-1 : 0); jr = 0:(AV == 2 ? W64-1 : 0); kr = 0:(AV == 3 ? W64-1 : 0) x1 = getindex.(Ref(B), i .+ ir, j .+ jr, k .+ kr) if AU == 1 @@ -364,7 +364,7 @@ end kr = kr .+ length(kr) end x3 = getindex.(Ref(B), i .+ ir, j .+ jr, k .+ kr) - + @test x1 == tovector(vu.data[1]) @test x2 == tovector(vu.data[2]) @test x3 == tovector(vu.data[3]) @@ -398,7 +398,7 @@ end end @test x == 1:100 end - + @testset "Grouped Strided Pointers" begin M, K, N = 4, 5, 6 A = rand(M, K); B = rand(K, N); C = rand(M, N); @@ -426,7 +426,7 @@ end Vec(ntuple(_ -> (randn()), Val(W64))...) )) x = tovector(v) - for f ∈ [-, abs, inv, floor, ceil, trunc, round, sqrt ∘ abs] + for f ∈ [-, abs, inv, floor, ceil, trunc, round, sqrt ∘ abs, VectorizationBase.relu] @test tovector(@inferred(f(v))) == map(f, x) end invtol = VectorizationBase.AVX512F ? 2^-14 : 1.5*2^-12 # moreaccurate with AVX512 @@ -470,7 +470,7 @@ end xi1 = tovector(vi1); xi2 = tovector(vi2); xi3 = mapreduce(tovector, vcat, m1.data); xi4 = mapreduce(tovector, vcat, m2.data); - for f ∈ [+, -, *, ÷, /, %, <<, >>, >>>, ⊻, &, |, VectorizationBase.rotate_left, VectorizationBase.rotate_right, copysign, max, min] + for f ∈ [+, -, *, div, ÷, /, rem, %, <<, >>, >>>, ⊻, &, |, fld, mod, VectorizationBase.rotate_left, VectorizationBase.rotate_right, copysign, max, min] # @show f check_within_limits(tovector(@inferred(f(vi1, vi2))), f.(xi1, xi2)) check_within_limits(tovector(@inferred(f(j, vi2))), f.(j, xi2)) @@ -504,7 +504,24 @@ end @test tovector(@inferred(f(vf1, a))) ≈ f.(xf1, a) @test tovector(@inferred(f(vf2, a))) ≈ f.(xf2, a) end - + + vones, vi2f, vtwos = promote(1.0, vi2, 2f0); # promotes a binary function, right? Even when used with three args? + @test vones === VectorizationBase.VecUnroll((vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0))); + @test vtwos === VectorizationBase.VecUnroll((vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0))); + @test VectorizationBase.vall(vi2f == vi2) + W32 = StaticInt(W64)*StaticInt(2) + vf2 = VectorizationBase.VecUnroll(( + Vec(ntuple(_ -> Core.VecElement(randn(Float32)), W32)), + Vec(ntuple(_ -> Core.VecElement(randn(Float32)), W32)) + )) + vones32, v2f32, vtwos32 = promote(1.0, vf2, 2f0); # promotes a binary function, right? Even when used with three args? + @test vones32 === VectorizationBase.VecUnroll((vbroadcast(W32, 1f0),vbroadcast(W32, 1f0))) + @test vtwos32 === VectorizationBase.VecUnroll((vbroadcast(W32, 2f0),vbroadcast(W32, 2f0))) + @test vf2 === v2f32 + + @test tovector(clamp(m1, 2:i)) == clamp.(tovector(m1), 2, i) + @test tovector(mod(m1, 1:i)) == mod1.(tovector(m1), i) + end @testset "Ternary Functions" begin v1 = Vec(ntuple(_ -> Core.VecElement(randn()), Val(W64))) @@ -515,7 +532,7 @@ end m = Mask{W64}(0xce) mv = tovector(m) for f ∈ [ - muladd, fma, + muladd, fma, clamp, VectorizationBase.vfmadd, VectorizationBase.vfnmadd, VectorizationBase.vfmsub, VectorizationBase.vfnmsub, VectorizationBase.vfmadd231, VectorizationBase.vfnmadd231, VectorizationBase.vfmsub231, VectorizationBase.vfnmsub231 ] @@ -560,7 +577,7 @@ end @test VectorizationBase.vprod(v2) * 3 == VectorizationBase.vprod(VectorizationBase.mulscalar(3, v2)) @test VectorizationBase.vall(v1 + v2 == VectorizationBase.addscalar(v1, v2)) @test 4.0 == VectorizationBase.addscalar(2.0, 2.0) - + v3 = Vec(0, 1, 2, 3); vu3 = VectorizationBase.VecUnroll((v3, v3 - 1)) v4 = Vec(0.0, 1.0, 2.0, 3.0) v5 = Vec(0f0, 1f0, 2f0, 3f0, 4f0, 5f0, 6f0, 7f0) @@ -591,7 +608,7 @@ end @test VectorizationBase.vzero() === VectorizationBase.vzero(W64S, Float64) @test VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(Int64)) === VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(Int32)) @test VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(UInt64)) === VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(UInt32)) - + @test VectorizationBase.vall(VectorizationBase.vbroadcast(W64S, pointer(A)) == vbroadcast(W64S, first(A))) @test VectorizationBase.vbroadcast(W64S, pointer(A,2)) === Vec{W64}(A[2]) === Vec(A[2]) @@ -630,7 +647,7 @@ end @test vtwos32 === VectorizationBase.VecUnroll((vbroadcast(StaticInt(W32), 2f0),vbroadcast(StaticInt(W32), 2f0))) @test vf2 === v2f32 - + vm = if VectorizationBase.AVX512DQ VectorizationBase.VecUnroll(( MM{W64}(rand(Int)),MM{W64}(rand(Int)),MM{W64}(rand(Int)),MM{W64}(rand(Int))