Skip to content

Commit 3ecbf27

Browse files
authored
Merge pull request #28 from mcabbott/modclamp
Allow `mod` and `clamp` to accept `Vec`
2 parents 87989d2 + 67eb26b commit 3ecbf27

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

src/special/misc.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,32 @@
11
@inline Base.:^(v::AbstractSIMD{W,T}, i::Integer) where {W,T} = Base.power_by_squaring(v, i)
22
@inline Base.:^(v::AbstractSIMD{W,T}, i::Integer) where {W,T<:Union{Float32,Float64}} = Base.power_by_squaring(v, i)
3-
@inline relu(x) = (y = zero(x); IfElse.ifelse(x > y, x, y))
3+
@inline relu(x) = (y = zero(x); ifelse(x > y, x, y))
44

5+
@inline Base.fld(x::AbstractSIMD, y::AbstractSIMD) = div(promote(x,y)..., RoundDown)
6+
7+
@inline function Base.div(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}, ::RoundingMode{:Down}) where {W,T<:Integer}
8+
d = div(x, y)
9+
d - (signbit(x y) & (d * y != x))
10+
end
11+
12+
@inline Base.mod(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Integer} =
13+
ifelse(y == -1, zero(x), x - fld(x, y) * y)
14+
15+
@inline Base.mod(x::AbstractSIMD{W,T}, y::AbstractSIMD{W,T}) where {W,T<:Unsigned} =
16+
rem(x, y)
17+
18+
@inline Base.mod(i::AbstractSIMD{<:Any,<:Integer}, r::AbstractUnitRange{<:Integer}) =
19+
mod(i-first(r), length(r)) + first(r)
20+
21+
# avoid ambiguity with clamp(::Missing, lo, hi) in Base.Math at math.jl:1258
22+
# but who knows what would happen if you called it
23+
for (X,L,H) in Iterators.product(fill([:Any, :Missing, :AbstractSIMD], 3)...)
24+
any(==(:AbstractSIMD), (X,L,H)) || continue
25+
@eval @inline function Base.clamp(x::$X, lo::$L, hi::$H)
26+
x_, lo_, hi_ = promote(x, lo, hi)
27+
ifelse(x_ > hi_, hi_, ifelse(x_ < lo_, lo_, x_))
28+
end
29+
end
30+
31+
@inline Base.clamp(x::AbstractSIMD{<:Any,<:Integer}, r::AbstractUnitRange{<:Integer}) =
32+
clamp(x, first(r), last(r))

test/runtests.jl

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
W = VectorizationBase.pick_vector_width(Float64)
4646
@test @inferred(VectorizationBase.pick_integer(Val(W))) == (VectorizationBase.AVX512DQ ? Int64 : Int32)
4747

48-
48+
4949
@test first(A) === A[1]
5050
@test W64S == W64
5151
@testset "Struct-Wrapped Vec" begin
@@ -175,7 +175,7 @@ end
175175
@test !VectorizationBase.vall(Mask{4}(0xfc))
176176
@test VectorizationBase.vall(Mask{8}(0xff))
177177
@test VectorizationBase.vall(Mask{4}(0xcf))
178-
178+
179179
@test VectorizationBase.vany(Mask{8}(0xfc))
180180
@test VectorizationBase.vany(Mask{4}(0xfc))
181181
@test !VectorizationBase.vany(Mask{8}(0x00))
@@ -211,7 +211,7 @@ end
211211
@test (Mask{8}(0xac) true) === Mask{8}(0x53)
212212
@test (false Mask{8}(0xac)) === Mask{8}(0xac)
213213
@test (true Mask{8}(0xac)) === Mask{8}(0x53)
214-
214+
215215
@test (Mask{4}(0x05) | true) === Mask{4}(0x0f)
216216
@test (Mask{4}(0x05) | false) === Mask{4}(0x05)
217217
@test (true | Mask{4}(0x05)) === Mask{4}(0x0f)
@@ -239,7 +239,7 @@ end
239239
# @test VectorizationBase.size_loads(A,2, Val(8)) == eval(VectorizationBase.num_vector_load_expr(@__MODULE__, :((() -> 17)()), 8)) == eval(VectorizationBase.num_vector_load_expr(@__MODULE__, 17, 8)) == divrem(size(A,2), 8)
240240
# end
241241

242-
242+
243243
@testset "vector_width.jl" begin
244244
@test all(VectorizationBase.ispow2, 0:1)
245245
@test all(i -> !any(VectorizationBase.ispow2, 1+(1 << (i-1)):(1 << i)-1 ) && VectorizationBase.ispow2(1 << i), 2:9)
@@ -282,7 +282,7 @@ end
282282
@test [vload(stridedpointer(C), (1+w, 2+w, 3)) for w 1:W64] == getindex.(Ref(C), 1 .+ (1:W64), 2 .+ (1:W64), 3)
283283
vstore!(stridedpointer(C), !mtest, ((MM{16})(17), 3, 4))
284284
@test .!v1 == C[17:32,3,4] == tovector(vload(stridedpointer(C), ((MM{16})(17), 3, 4)))
285-
285+
286286
dims = (41,42,43) .* 3;
287287
# dims = (41,42,43);
288288
A = reshape(collect(Float64(0):Float64(prod(dims)-1)), dims);
@@ -345,7 +345,7 @@ end
345345
@test v1 === vu.data[1]
346346
@test v2 === vu.data[2]
347347
@test v3 === vu.data[3]
348-
348+
349349
ir = 0:(AV == 1 ? W64-1 : 0); jr = 0:(AV == 2 ? W64-1 : 0); kr = 0:(AV == 3 ? W64-1 : 0)
350350
x1 = getindex.(Ref(B), i .+ ir, j .+ jr, k .+ kr)
351351
if AU == 1
@@ -364,7 +364,7 @@ end
364364
kr = kr .+ length(kr)
365365
end
366366
x3 = getindex.(Ref(B), i .+ ir, j .+ jr, k .+ kr)
367-
367+
368368
@test x1 == tovector(vu.data[1])
369369
@test x2 == tovector(vu.data[2])
370370
@test x3 == tovector(vu.data[3])
@@ -398,7 +398,7 @@ end
398398
end
399399
@test x == 1:100
400400
end
401-
401+
402402
@testset "Grouped Strided Pointers" begin
403403
M, K, N = 4, 5, 6
404404
A = rand(M, K); B = rand(K, N); C = rand(M, N);
@@ -426,7 +426,7 @@ end
426426
Vec(ntuple(_ -> (randn()), Val(W64))...)
427427
))
428428
x = tovector(v)
429-
for f [-, abs, inv, floor, ceil, trunc, round, sqrt abs]
429+
for f [-, abs, inv, floor, ceil, trunc, round, sqrt abs, VectorizationBase.relu]
430430
@test tovector(@inferred(f(v))) == map(f, x)
431431
end
432432
invtol = VectorizationBase.AVX512F ? 2^-14 : 1.5*2^-12 # moreaccurate with AVX512
@@ -470,7 +470,7 @@ end
470470
xi1 = tovector(vi1); xi2 = tovector(vi2);
471471
xi3 = mapreduce(tovector, vcat, m1.data);
472472
xi4 = mapreduce(tovector, vcat, m2.data);
473-
for f [+, -, *, ÷, /, %, <<, >>, >>>, , &, |, VectorizationBase.rotate_left, VectorizationBase.rotate_right, copysign, max, min]
473+
for f [+, -, *, div, ÷, /, rem, %, <<, >>, >>>, , &, |, fld, mod, VectorizationBase.rotate_left, VectorizationBase.rotate_right, copysign, max, min]
474474
# @show f
475475
check_within_limits(tovector(@inferred(f(vi1, vi2))), f.(xi1, xi2))
476476
check_within_limits(tovector(@inferred(f(j, vi2))), f.(j, xi2))
@@ -504,7 +504,24 @@ end
504504
@test tovector(@inferred(f(vf1, a))) f.(xf1, a)
505505
@test tovector(@inferred(f(vf2, a))) f.(xf2, a)
506506
end
507-
507+
508+
vones, vi2f, vtwos = promote(1.0, vi2, 2f0); # promotes a binary function, right? Even when used with three args?
509+
@test vones === VectorizationBase.VecUnroll((vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0),vbroadcast(Val(W64), 1.0)));
510+
@test vtwos === VectorizationBase.VecUnroll((vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0),vbroadcast(Val(W64), 2.0)));
511+
@test VectorizationBase.vall(vi2f == vi2)
512+
W32 = StaticInt(W64)*StaticInt(2)
513+
vf2 = VectorizationBase.VecUnroll((
514+
Vec(ntuple(_ -> Core.VecElement(randn(Float32)), W32)),
515+
Vec(ntuple(_ -> Core.VecElement(randn(Float32)), W32))
516+
))
517+
vones32, v2f32, vtwos32 = promote(1.0, vf2, 2f0); # promotes a binary function, right? Even when used with three args?
518+
@test vones32 === VectorizationBase.VecUnroll((vbroadcast(W32, 1f0),vbroadcast(W32, 1f0)))
519+
@test vtwos32 === VectorizationBase.VecUnroll((vbroadcast(W32, 2f0),vbroadcast(W32, 2f0)))
520+
@test vf2 === v2f32
521+
522+
@test tovector(clamp(m1, 2:i)) == clamp.(tovector(m1), 2, i)
523+
@test tovector(mod(m1, 1:i)) == mod1.(tovector(m1), i)
524+
508525
end
509526
@testset "Ternary Functions" begin
510527
v1 = Vec(ntuple(_ -> Core.VecElement(randn()), Val(W64)))
@@ -515,7 +532,7 @@ end
515532
m = Mask{W64}(0xce)
516533
mv = tovector(m)
517534
for f [
518-
muladd, fma,
535+
muladd, fma, clamp,
519536
VectorizationBase.vfmadd, VectorizationBase.vfnmadd, VectorizationBase.vfmsub, VectorizationBase.vfnmsub,
520537
VectorizationBase.vfmadd231, VectorizationBase.vfnmadd231, VectorizationBase.vfmsub231, VectorizationBase.vfnmsub231
521538
]
@@ -560,7 +577,7 @@ end
560577
@test VectorizationBase.vprod(v2) * 3 == VectorizationBase.vprod(VectorizationBase.mulscalar(3, v2))
561578
@test VectorizationBase.vall(v1 + v2 == VectorizationBase.addscalar(v1, v2))
562579
@test 4.0 == VectorizationBase.addscalar(2.0, 2.0)
563-
580+
564581
v3 = Vec(0, 1, 2, 3); vu3 = VectorizationBase.VecUnroll((v3, v3 - 1))
565582
v4 = Vec(0.0, 1.0, 2.0, 3.0)
566583
v5 = Vec(0f0, 1f0, 2f0, 3f0, 4f0, 5f0, 6f0, 7f0)
@@ -591,7 +608,7 @@ end
591608
@test VectorizationBase.vzero() === VectorizationBase.vzero(W64S, Float64)
592609
@test VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(Int64)) === VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(Int32))
593610
@test VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(UInt64)) === VectorizationBase.vbroadcast(StaticInt(2)*W64S, one(UInt32))
594-
611+
595612
@test VectorizationBase.vall(VectorizationBase.vbroadcast(W64S, pointer(A)) == vbroadcast(W64S, first(A)))
596613
@test VectorizationBase.vbroadcast(W64S, pointer(A,2)) === Vec{W64}(A[2]) === Vec(A[2])
597614

@@ -630,7 +647,7 @@ end
630647
@test vtwos32 === VectorizationBase.VecUnroll((vbroadcast(StaticInt(W32), 2f0),vbroadcast(StaticInt(W32), 2f0)))
631648
@test vf2 === v2f32
632649

633-
650+
634651
vm = if VectorizationBase.AVX512DQ
635652
VectorizationBase.VecUnroll((
636653
MM{W64}(rand(Int)),MM{W64}(rand(Int)),MM{W64}(rand(Int)),MM{W64}(rand(Int))

0 commit comments

Comments
 (0)