-
Notifications
You must be signed in to change notification settings - Fork 44
Add and test shuffle and fill intrinsics #555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I think I mentioned it in a previous PR, but it would be good to add a compile-time error to GPUCompiler such that we can emit that here (from a Lacking that, it's probably fine to drop a TODO in the code, document the restriction, and only activate the tests on supported devices. |
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl
index 83386aee..5f098af8 100644
--- a/src/device/intrinsics/simd.jl
+++ b/src/device/intrinsics/simd.jl
@@ -1,5 +1,5 @@
export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
- simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
+ simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
using Core: LLVMPtr
@@ -106,13 +106,17 @@ for (jltype, suffix) in simd_shuffle_map
llvmcall, $jltype, ($jltype, Int16), data, delta)
# TODO: Emulate or disallow on M1 (Apple7)
- @device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
- ccall($"extern air.simd_shuffle_and_fill_down.$suffix",
- llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
-
- @device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
- ccall($"extern air.simd_shuffle_and_fill_up.$suffix",
- llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
+ @device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer = threads_per_simdgroup()) =
+ ccall(
+ $"extern air.simd_shuffle_and_fill_down.$suffix",
+ llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo
+ )
+
+ @device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer = threads_per_simdgroup()) =
+ ccall(
+ $"extern air.simd_shuffle_and_fill_up.$suffix",
+ llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo
+ )
end
end
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 625ed21c..fb0fd36f 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -659,44 +659,44 @@ end
@test sum(a) ≈ b[res_idx]
end
-@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
- function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where T
- idx = thread_position_in_grid_1d()
- idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
- simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
+ @testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f, nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
+ function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where {T}
+ idx = thread_position_in_grid_1d()
+ idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
+ simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
- temp_data = MtlThreadGroupArray(T, 16)
- temp_data[idx] = data[idx]
- temp_filling_data = MtlThreadGroupArray(T, 16)
- temp_filling_data[idx] = filling_data[idx]
- simdgroup_barrier(Metal.MemoryFlagThreadGroup)
+ temp_data = MtlThreadGroupArray(T, 16)
+ temp_data[idx] = data[idx]
+ temp_filling_data = MtlThreadGroupArray(T, 16)
+ temp_filling_data[idx] = filling_data[idx]
+ simdgroup_barrier(Metal.MemoryFlagThreadGroup)
- if simd_idx == 1
- dat_value = temp_data[idx_in_simd]
- dat_fil_value = temp_filling_data[idx_in_simd]
+ if simd_idx == 1
+ dat_value = temp_data[idx_in_simd]
+ dat_fil_value = temp_filling_data[idx_in_simd]
- value = f(dat_value, dat_fil_value, abs(nshift), modulo)
+ value = f(dat_value, dat_fil_value, abs(nshift), modulo)
- data[idx] = value
+ data[idx] = value
+ end
+ return
end
- return
- end
- N = 16
- midN = N ÷ 2
+ N = 16
+ midN = N ÷ 2
- a = Array{typ}(1:N)
- mtla = MtlArray(a)
- mtlb = MtlArray(a)
+ a = Array{typ}(1:N)
+ mtla = MtlArray(a)
+ mtlb = MtlArray(a)
- Metal.@sync @metal threads=N kernel_mod(mtla, mtlb, N)
- @test Array(mtla) == circshift(a,nshift)
+ Metal.@sync @metal threads = N kernel_mod(mtla, mtlb, N)
+ @test Array(mtla) == circshift(a, nshift)
- mtlc = MtlArray(a)
+ mtlc = MtlArray(a)
- Metal.@sync @metal threads=N kernel_mod(mtlc, mtlb, midN)
- @test Array(mtlc) == [circshift(a[1:midN],nshift);circshift(a[midN+1:end],nshift)]
-end
+ Metal.@sync @metal threads = N kernel_mod(mtlc, mtlb, midN)
+ @test Array(mtlc) == [circshift(a[1:midN], nshift);circshift(a[(midN + 1):end], nshift)]
+ end
@testset "matrix functions" begin
@testset "load_store($typ)" for typ in [Float16, Float32]
function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
Benchmark suite | Current: 003f119 | Previous: c5b425d | Ratio |
---|---|---|---|
private array/construct |
27256.833333333332 ns |
25996.5 ns |
1.05 |
private array/broadcast |
460042 ns |
454083 ns |
1.01 |
private array/random/randn/Float32 |
926895.5 ns |
816583 ns |
1.14 |
private array/random/randn!/Float32 |
594625 ns |
654333 ns |
0.91 |
private array/random/rand!/Int64 |
541958 ns |
581334 ns |
0.93 |
private array/random/rand!/Float32 |
550792 ns |
617834 ns |
0.89 |
private array/random/rand/Int64 |
830000 ns |
755437.5 ns |
1.10 |
private array/random/rand/Float32 |
787625 ns |
621500 ns |
1.27 |
private array/copyto!/gpu_to_gpu |
568729.5 ns |
671750 ns |
0.85 |
private array/copyto!/cpu_to_gpu |
600792 ns |
824500 ns |
0.73 |
private array/copyto!/gpu_to_cpu |
627375 ns |
666750 ns |
0.94 |
private array/accumulate/1d |
1417791 ns |
1345792 ns |
1.05 |
private array/accumulate/2d |
1538584 ns |
1388792 ns |
1.11 |
private array/iteration/findall/int |
2333333 ns |
2075334 ns |
1.12 |
private array/iteration/findall/bool |
2068167 ns |
1842667 ns |
1.12 |
private array/iteration/findfirst/int |
1834187.5 ns |
1681104 ns |
1.09 |
private array/iteration/findfirst/bool |
1776709 ns |
1666937 ns |
1.07 |
private array/iteration/scalar |
2434833 ns |
3620500 ns |
0.67 |
private array/iteration/logical |
3627042 ns |
3205125 ns |
1.13 |
private array/iteration/findmin/1d |
1869750 ns |
1760646 ns |
1.06 |
private array/iteration/findmin/2d |
1455167 ns |
1349500.5 ns |
1.08 |
private array/reductions/reduce/1d |
1036604 ns |
1038250 ns |
1.00 |
private array/reductions/reduce/2d |
709042 ns |
659791 ns |
1.07 |
private array/reductions/mapreduce/1d |
1006250 ns |
1022833 ns |
0.98 |
private array/reductions/mapreduce/2d |
703791 ns |
660750 ns |
1.07 |
private array/permutedims/4d |
2626542 ns |
2543458 ns |
1.03 |
private array/permutedims/2d |
1119353.5 ns |
1022104 ns |
1.10 |
private array/permutedims/3d |
1840020.5 ns |
1717542 ns |
1.07 |
private array/copy |
835375 ns |
567521 ns |
1.47 |
latency/precompile |
9198154000 ns |
9072804750 ns |
1.01 |
latency/ttfp |
3726296084 ns |
3675978250 ns |
1.01 |
latency/import |
1262449979 ns |
1242875750 ns |
1.02 |
integration/metaldevrt |
762459 ns |
713042 ns |
1.07 |
integration/byval/slices=1 |
1661104.5 ns |
1568958 ns |
1.06 |
integration/byval/slices=3 |
20233833 ns |
9347292 ns |
2.16 |
integration/byval/reference |
1651041 ns |
1561458.5 ns |
1.06 |
integration/byval/slices=2 |
2842708 ns |
2616271 ns |
1.09 |
kernel/indexing |
467042 ns |
470084 ns |
0.99 |
kernel/indexing_checked |
464709 ns |
470000 ns |
0.99 |
kernel/launch |
7917 ns |
36947.75 ns |
0.21 |
metal/synchronization/stream |
15000 ns |
14667 ns |
1.02 |
metal/synchronization/context |
15375 ns |
15167 ns |
1.01 |
shared array/construct |
26434 ns |
24225.75 ns |
1.09 |
shared array/broadcast |
465541.5 ns |
460333 ns |
1.01 |
shared array/random/randn/Float32 |
892208 ns |
802562.5 ns |
1.11 |
shared array/random/randn!/Float32 |
586083 ns |
634792 ns |
0.92 |
shared array/random/rand!/Int64 |
550042 ns |
581167 ns |
0.95 |
shared array/random/rand!/Float32 |
548937.5 ns |
607417 ns |
0.90 |
shared array/random/rand/Int64 |
893958 ns |
782875 ns |
1.14 |
shared array/random/rand/Float32 |
805229 ns |
603583 ns |
1.33 |
shared array/copyto!/gpu_to_gpu |
80166 ns |
83209 ns |
0.96 |
shared array/copyto!/cpu_to_gpu |
85916 ns |
82792 ns |
1.04 |
shared array/copyto!/gpu_to_cpu |
80937.5 ns |
83000 ns |
0.98 |
shared array/accumulate/1d |
1465084 ns |
1356250 ns |
1.08 |
shared array/accumulate/2d |
1526917 ns |
1392520.5 ns |
1.10 |
shared array/iteration/findall/int |
2037167 ns |
1792562.5 ns |
1.14 |
shared array/iteration/findall/bool |
1762333 ns |
1606375 ns |
1.10 |
shared array/iteration/findfirst/int |
1523021 ns |
1393125 ns |
1.09 |
shared array/iteration/findfirst/bool |
1464166 ns |
1378083.5 ns |
1.06 |
shared array/iteration/scalar |
160875 ns |
158000 ns |
1.02 |
shared array/iteration/logical |
3362083 ns |
2996041.5 ns |
1.12 |
shared array/iteration/findmin/1d |
1591916 ns |
1456500 ns |
1.09 |
shared array/iteration/findmin/2d |
1465541.5 ns |
1371375 ns |
1.07 |
shared array/reductions/reduce/1d |
718208 ns |
729667 ns |
0.98 |
shared array/reductions/reduce/2d |
719916 ns |
673417 ns |
1.07 |
shared array/reductions/mapreduce/1d |
715062.5 ns |
733562 ns |
0.97 |
shared array/reductions/mapreduce/2d |
704583 ns |
676771 ns |
1.04 |
shared array/permutedims/4d |
2671396 ns |
2545500 ns |
1.05 |
shared array/permutedims/2d |
1109958 ns |
1024875 ns |
1.08 |
shared array/permutedims/3d |
1836000 ns |
1581604 ns |
1.16 |
shared array/copy |
209459 ns |
239270.5 ns |
0.88 |
This comment was automatically generated by workflow using github-action-benchmark.
Not actually ready for review, I marked as such to see if it would fail. Not sure why they don't. |
486ecc7
to
71db25a
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #555 +/- ##
==========================================
- Coverage 80.66% 80.47% -0.19%
==========================================
Files 61 61
Lines 2679 2679
==========================================
- Hits 2161 2156 -5
- Misses 518 523 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Not supported on M1 which this draft conveniently ignores.
Anyone have any idea how and where this compatibility checking should be handled?