Skip to content

Commit 6b2880b

Browse files
committed
Add and test shuffle and fill intrinsics
1 parent c5b425d commit 6b2880b

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

src/device/intrinsics/simd.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
2-
simd_shuffle_down, simd_shuffle_up
2+
simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
33

44
using Core: LLVMPtr
55

@@ -104,6 +104,15 @@ for (jltype, suffix) in simd_shuffle_map
104104
@device_function simd_shuffle_up(data::$jltype, delta::Integer) =
105105
ccall($"extern air.simd_shuffle_up.$suffix",
106106
llvmcall, $jltype, ($jltype, Int16), data, delta)
107+
108+
# TODO: Emulate or disallow on M1 (Apple7)
109+
@device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
110+
ccall($"extern air.simd_shuffle_and_fill_down.$suffix",
111+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
112+
113+
@device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
114+
ccall($"extern air.simd_shuffle_and_fill_up.$suffix",
115+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
107116
end
108117
end
109118

@@ -134,3 +143,39 @@ modify the lower `delta` lanes of `data` because it doesn't wrap values around t
134143
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135144
"""
136145
simd_shuffle_up
146+
147+
@doc """
148+
simd_shuffle_and_fill_down(data::T, filling_data::T, delta::Integer, [modulo::Integer])
149+
150+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
151+
difference from the caller's SIMD lane ID minus `delta`.
152+
153+
If the difference is negative, the operation copies values from the upper `delta` lanes of
154+
`filling_data` to the lower `delta` lanes of `data`.
155+
156+
The value of `delta` needs to be the same for all threads in a SIMD-group.
157+
158+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
159+
and must be 2, 4, 8, 16, or 32.
160+
161+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
162+
"""
163+
simd_shuffle_and_fill_down
164+
165+
@doc """
166+
simd_shuffle_and_fill_up(data::T, filling_data::T, delta::Integer, [modulo::Integer])
167+
168+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
169+
sum of the caller's SIMD lane ID and `delta`.
170+
171+
If the sum is greater than `modulo`, the function copies values from the lower `delta` lanes of
172+
`filling_data` into the upper `delta` lanes of `data`.
173+
174+
The value of `delta` needs to be the same for all threads in a SIMD-group.
175+
176+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
177+
and must be 2, 4, 8, 16, or 32.
178+
179+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
180+
"""
181+
simd_shuffle_and_fill_up

test/device/intrinsics.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,41 @@ end
659659
@test sum(a) b[res_idx]
660660
end
661661

662+
@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
663+
function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}) where T
664+
idx = thread_position_in_grid_1d()
665+
idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
666+
simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
667+
668+
temp_data = MtlThreadGroupArray(T, 16)
669+
temp_data[idx] = data[idx]
670+
temp_filling_data = MtlThreadGroupArray(T, 16)
671+
temp_filling_data[idx] = filling_data[idx]
672+
simdgroup_barrier(Metal.MemoryFlagThreadGroup)
673+
674+
if simd_idx == 1
675+
dat_value = temp_data[idx_in_simd]
676+
dat_fil_value = temp_filling_data[idx_in_simd]
677+
678+
value = f(dat_value, dat_fil_value, abs(nshift), length(data))
679+
680+
data[idx] = value
681+
end
682+
return
683+
end
684+
685+
dev_a = Metal.zeros(typ, 16; storage=Metal.SharedStorage)
686+
dev_b = Metal.zeros(typ, 16; storage=Metal.SharedStorage)
687+
# GC.@preserve dev_a dev_b begin
688+
a = unsafe_wrap(Array{typ}, dev_a, 16)
689+
b = unsafe_wrap(Array{typ}, dev_b, 16)
690+
691+
a .= 1:16
692+
b .= 1:16
693+
694+
Metal.@sync @metal threads=16 kernel_mod(dev_a, dev_b)
695+
@test a == circshift(b,nshift)
696+
end
662697
@testset "matrix functions" begin
663698
@testset "load_store($typ)" for typ in [Float16, Float32]
664699
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},

0 commit comments

Comments
 (0)