Skip to content

Commit 2d80405

Browse files
committed
Add and test shuffle and fill intrinsics
1 parent 145dc33 commit 2d80405

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

src/device/intrinsics/simd.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
2-
simd_shuffle_down, simd_shuffle_up
2+
simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
33

44
using Core: LLVMPtr
55

@@ -104,6 +104,15 @@ for (jltype, suffix) in simd_shuffle_map
104104
@device_function simd_shuffle_up(data::$jltype, delta::Integer) =
105105
ccall($"extern air.simd_shuffle_up.$suffix",
106106
llvmcall, $jltype, ($jltype, Int16), data, delta)
107+
108+
# TODO: Emulate or disallow on M1 (Apple7)
109+
@device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
110+
ccall($"extern air.simd_shuffle_and_fill_down.$suffix",
111+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
112+
113+
@device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
114+
ccall($"extern air.simd_shuffle_and_fill_up.$suffix",
115+
llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
107116
end
108117
end
109118

@@ -134,3 +143,45 @@ modify the lower `delta` lanes of `data` because it doesn't wrap values around t
134143
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
135144
"""
136145
simd_shuffle_up
146+
147+
@doc """
148+
simd_shuffle_and_fill_down(data::T, filling_data::T, delta::Integer, [modulo::Integer])
149+
150+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
151+
difference from the caller's SIMD lane ID minus `delta`.
152+
153+
If the difference is negative, the operation copies values from the upper `delta` lanes of
154+
`filling_data` to the lower `delta` lanes of `data`.
155+
156+
The value of `delta` needs to be the same for all threads in a SIMD-group.
157+
158+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
159+
and must be 2, 4, 8, 16, or 32.
160+
161+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
162+
163+
!!! note
164+
`simd_shuffle_and_fill_down` is only available on Apple8+ GPUs (M2 and newer)
165+
"""
166+
simd_shuffle_and_fill_down
167+
168+
@doc """
169+
simd_shuffle_and_fill_up(data::T, filling_data::T, delta::Integer, [modulo::Integer])
170+
171+
Returns `data` or `filling_data` for each vector from the thread whose SIMD lane ID is the
172+
sum of the caller's SIMD lane ID and `delta`.
173+
174+
If the sum is greater than `modulo`, the function copies values from the lower `delta` lanes of
175+
`filling_data` into the upper `delta` lanes of `data`.
176+
177+
The value of `delta` needs to be the same for all threads in a SIMD-group.
178+
179+
The `modulo` parameter defines the vector width that splits the SIMD-group into separate vectors
180+
and must be 2, 4, 8, 16, or 32.
181+
182+
T must be one of the following: Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, or UInt8
183+
184+
!!! note
185+
`simd_shuffle_and_fill_up` is only available on Apple8+ GPUs (M2 and newer)
186+
"""
187+
simd_shuffle_and_fill_up

test/device/intrinsics/simd.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,44 @@
3434
Metal.@sync @metal threads=32 kernel(dev_a, dev_b)
3535
@test sum(a) b[res_idx]
3636
end
37+
@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
38+
function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where T
39+
idx = thread_position_in_grid_1d()
40+
idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
41+
simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
42+
43+
temp_data = MtlThreadGroupArray(T, 16)
44+
temp_data[idx] = data[idx]
45+
temp_filling_data = MtlThreadGroupArray(T, 16)
46+
temp_filling_data[idx] = filling_data[idx]
47+
simdgroup_barrier(Metal.MemoryFlagThreadGroup)
48+
49+
if simd_idx == 1
50+
dat_value = temp_data[idx_in_simd]
51+
dat_fil_value = temp_filling_data[idx_in_simd]
52+
53+
value = f(dat_value, dat_fil_value, abs(nshift), modulo)
3754

55+
data[idx] = value
56+
end
57+
return
58+
end
59+
60+
N = 16
61+
midN = N ÷ 2
62+
63+
a = Array{typ}(1:N)
64+
mtla = MtlArray(a)
65+
mtlb = MtlArray(a)
66+
67+
Metal.@sync @metal threads=N kernel_mod(mtla, mtlb, N)
68+
@test Array(mtla) == circshift(a,nshift)
69+
70+
mtlc = MtlArray(a)
71+
72+
Metal.@sync @metal threads=N kernel_mod(mtlc, mtlb, midN)
73+
@test Array(mtlc) == [circshift(a[1:midN],nshift);circshift(a[midN+1:end],nshift)]
74+
end
3875
@testset "matrix functions" begin
3976
@testset "load_store($typ)" for typ in [Float16, Float32]
4077
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},

0 commit comments

Comments
 (0)