Skip to content

Add and test shuffle and fill intrinsics #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Add and test shuffle and fill intrinsics #555

wants to merge 1 commit into from

Conversation

christiangnrd
Copy link
Member

Not supported on M1 which this draft conveniently ignores.

Anyone have any idea how and where this compatibility checking should be handled?

@maleadt
Copy link
Member

maleadt commented Mar 3, 2025

Anyone have any idea how and where this compatibility checking should be handled?

I think I mentioned it in a previous PR, but it would be good to add a compile-time error to GPUCompiler such that we can emit that here (from a sv""-based branch). That would result in the error getting optimized out on supported platforms, only errorring when used statically on an unsupported device.

Lacking that, it's probably fine to drop a TODO in the code, document the restriction, and only activate the tests on supported devices.

@christiangnrd christiangnrd marked this pull request as ready for review March 3, 2025 22:19
Copy link
Contributor

github-actions bot commented Mar 3, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl
index 83386aee..5f098af8 100644
--- a/src/device/intrinsics/simd.jl
+++ b/src/device/intrinsics/simd.jl
@@ -1,5 +1,5 @@
 export simdgroup_load, simdgroup_store, simdgroup_multiply, simdgroup_multiply_accumulate,
-        simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
+    simd_shuffle_down, simd_shuffle_up, simd_shuffle_and_fill_down, simd_shuffle_and_fill_up
 
 using Core: LLVMPtr
 
@@ -106,13 +106,17 @@ for (jltype, suffix) in simd_shuffle_map
                 llvmcall, $jltype, ($jltype, Int16), data, delta)
 
         # TODO: Emulate or disallow on M1 (Apple7)
-        @device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
-            ccall($"extern air.simd_shuffle_and_fill_down.$suffix",
-                llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
-
-        @device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer=threads_per_simdgroup()) =
-            ccall($"extern air.simd_shuffle_and_fill_up.$suffix",
-                llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo)
+        @device_function simd_shuffle_and_fill_down(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer = threads_per_simdgroup()) =
+            ccall(
+            $"extern air.simd_shuffle_and_fill_down.$suffix",
+            llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo
+        )
+
+        @device_function simd_shuffle_and_fill_up(data::$jltype, filling_data::$jltype, delta::Integer, modulo::Integer = threads_per_simdgroup()) =
+            ccall(
+            $"extern air.simd_shuffle_and_fill_up.$suffix",
+            llvmcall, $jltype, ($jltype, $jltype, Int16, Int16), data, filling_data, delta, modulo
+        )
     end
 end
 
diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl
index 625ed21c..fb0fd36f 100644
--- a/test/device/intrinsics.jl
+++ b/test/device/intrinsics.jl
@@ -659,44 +659,44 @@ end
     @test sum(a) ≈ b[res_idx]
 end
 
-@testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f,nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
-   function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where T
-        idx = thread_position_in_grid_1d()
-        idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
-        simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
+    @testset "$f($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8], (f, nshift) in [(simd_shuffle_and_fill_down, -4), (simd_shuffle_and_fill_up, 2)]
+        function kernel_mod(data::MtlDeviceVector{T}, filling_data::MtlDeviceVector{T}, modulo) where {T}
+            idx = thread_position_in_grid_1d()
+            idx_in_simd = thread_index_in_simdgroup() #simd_lane_id
+            simd_idx = simdgroup_index_in_threadgroup() #simd_group_id
 
-        temp_data = MtlThreadGroupArray(T, 16)
-        temp_data[idx] = data[idx]
-        temp_filling_data = MtlThreadGroupArray(T, 16)
-        temp_filling_data[idx] = filling_data[idx]
-        simdgroup_barrier(Metal.MemoryFlagThreadGroup)
+            temp_data = MtlThreadGroupArray(T, 16)
+            temp_data[idx] = data[idx]
+            temp_filling_data = MtlThreadGroupArray(T, 16)
+            temp_filling_data[idx] = filling_data[idx]
+            simdgroup_barrier(Metal.MemoryFlagThreadGroup)
 
-        if simd_idx == 1
-            dat_value = temp_data[idx_in_simd]
-            dat_fil_value = temp_filling_data[idx_in_simd]
+            if simd_idx == 1
+                dat_value = temp_data[idx_in_simd]
+                dat_fil_value = temp_filling_data[idx_in_simd]
 
-            value = f(dat_value, dat_fil_value, abs(nshift), modulo)
+                value = f(dat_value, dat_fil_value, abs(nshift), modulo)
 
-            data[idx] = value
+                data[idx] = value
+            end
+            return
         end
-        return
-    end
 
-    N = 16
-    midN = N ÷ 2
+        N = 16
+        midN = N ÷ 2
 
-    a = Array{typ}(1:N)
-    mtla = MtlArray(a)
-    mtlb = MtlArray(a)
+        a = Array{typ}(1:N)
+        mtla = MtlArray(a)
+        mtlb = MtlArray(a)
 
-    Metal.@sync @metal threads=N kernel_mod(mtla, mtlb, N)
-    @test Array(mtla) == circshift(a,nshift)
+        Metal.@sync @metal threads = N kernel_mod(mtla, mtlb, N)
+        @test Array(mtla) == circshift(a, nshift)
 
-    mtlc = MtlArray(a)
+        mtlc = MtlArray(a)
 
-    Metal.@sync @metal threads=N kernel_mod(mtlc, mtlb, midN)
-    @test Array(mtlc) == [circshift(a[1:midN],nshift);circshift(a[midN+1:end],nshift)]
-end
+        Metal.@sync @metal threads = N kernel_mod(mtlc, mtlb, midN)
+        @test Array(mtlc) == [circshift(a[1:midN], nshift);circshift(a[(midN + 1):end], nshift)]
+    end
 @testset "matrix functions" begin
     @testset "load_store($typ)" for typ in [Float16, Float32]
         function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T},

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: 003f119 Previous: c5b425d Ratio
private array/construct 27256.833333333332 ns 25996.5 ns 1.05
private array/broadcast 460042 ns 454083 ns 1.01
private array/random/randn/Float32 926895.5 ns 816583 ns 1.14
private array/random/randn!/Float32 594625 ns 654333 ns 0.91
private array/random/rand!/Int64 541958 ns 581334 ns 0.93
private array/random/rand!/Float32 550792 ns 617834 ns 0.89
private array/random/rand/Int64 830000 ns 755437.5 ns 1.10
private array/random/rand/Float32 787625 ns 621500 ns 1.27
private array/copyto!/gpu_to_gpu 568729.5 ns 671750 ns 0.85
private array/copyto!/cpu_to_gpu 600792 ns 824500 ns 0.73
private array/copyto!/gpu_to_cpu 627375 ns 666750 ns 0.94
private array/accumulate/1d 1417791 ns 1345792 ns 1.05
private array/accumulate/2d 1538584 ns 1388792 ns 1.11
private array/iteration/findall/int 2333333 ns 2075334 ns 1.12
private array/iteration/findall/bool 2068167 ns 1842667 ns 1.12
private array/iteration/findfirst/int 1834187.5 ns 1681104 ns 1.09
private array/iteration/findfirst/bool 1776709 ns 1666937 ns 1.07
private array/iteration/scalar 2434833 ns 3620500 ns 0.67
private array/iteration/logical 3627042 ns 3205125 ns 1.13
private array/iteration/findmin/1d 1869750 ns 1760646 ns 1.06
private array/iteration/findmin/2d 1455167 ns 1349500.5 ns 1.08
private array/reductions/reduce/1d 1036604 ns 1038250 ns 1.00
private array/reductions/reduce/2d 709042 ns 659791 ns 1.07
private array/reductions/mapreduce/1d 1006250 ns 1022833 ns 0.98
private array/reductions/mapreduce/2d 703791 ns 660750 ns 1.07
private array/permutedims/4d 2626542 ns 2543458 ns 1.03
private array/permutedims/2d 1119353.5 ns 1022104 ns 1.10
private array/permutedims/3d 1840020.5 ns 1717542 ns 1.07
private array/copy 835375 ns 567521 ns 1.47
latency/precompile 9198154000 ns 9072804750 ns 1.01
latency/ttfp 3726296084 ns 3675978250 ns 1.01
latency/import 1262449979 ns 1242875750 ns 1.02
integration/metaldevrt 762459 ns 713042 ns 1.07
integration/byval/slices=1 1661104.5 ns 1568958 ns 1.06
integration/byval/slices=3 20233833 ns 9347292 ns 2.16
integration/byval/reference 1651041 ns 1561458.5 ns 1.06
integration/byval/slices=2 2842708 ns 2616271 ns 1.09
kernel/indexing 467042 ns 470084 ns 0.99
kernel/indexing_checked 464709 ns 470000 ns 0.99
kernel/launch 7917 ns 36947.75 ns 0.21
metal/synchronization/stream 15000 ns 14667 ns 1.02
metal/synchronization/context 15375 ns 15167 ns 1.01
shared array/construct 26434 ns 24225.75 ns 1.09
shared array/broadcast 465541.5 ns 460333 ns 1.01
shared array/random/randn/Float32 892208 ns 802562.5 ns 1.11
shared array/random/randn!/Float32 586083 ns 634792 ns 0.92
shared array/random/rand!/Int64 550042 ns 581167 ns 0.95
shared array/random/rand!/Float32 548937.5 ns 607417 ns 0.90
shared array/random/rand/Int64 893958 ns 782875 ns 1.14
shared array/random/rand/Float32 805229 ns 603583 ns 1.33
shared array/copyto!/gpu_to_gpu 80166 ns 83209 ns 0.96
shared array/copyto!/cpu_to_gpu 85916 ns 82792 ns 1.04
shared array/copyto!/gpu_to_cpu 80937.5 ns 83000 ns 0.98
shared array/accumulate/1d 1465084 ns 1356250 ns 1.08
shared array/accumulate/2d 1526917 ns 1392520.5 ns 1.10
shared array/iteration/findall/int 2037167 ns 1792562.5 ns 1.14
shared array/iteration/findall/bool 1762333 ns 1606375 ns 1.10
shared array/iteration/findfirst/int 1523021 ns 1393125 ns 1.09
shared array/iteration/findfirst/bool 1464166 ns 1378083.5 ns 1.06
shared array/iteration/scalar 160875 ns 158000 ns 1.02
shared array/iteration/logical 3362083 ns 2996041.5 ns 1.12
shared array/iteration/findmin/1d 1591916 ns 1456500 ns 1.09
shared array/iteration/findmin/2d 1465541.5 ns 1371375 ns 1.07
shared array/reductions/reduce/1d 718208 ns 729667 ns 0.98
shared array/reductions/reduce/2d 719916 ns 673417 ns 1.07
shared array/reductions/mapreduce/1d 715062.5 ns 733562 ns 0.97
shared array/reductions/mapreduce/2d 704583 ns 676771 ns 1.04
shared array/permutedims/4d 2671396 ns 2545500 ns 1.05
shared array/permutedims/2d 1109958 ns 1024875 ns 1.08
shared array/permutedims/3d 1836000 ns 1581604 ns 1.16
shared array/copy 209459 ns 239270.5 ns 0.88

This comment was automatically generated by workflow using github-action-benchmark.

@christiangnrd
Copy link
Member Author

Not actually ready for review, I marked as such to see if it would fail. Not sure why they don't.

@christiangnrd christiangnrd marked this pull request as draft March 4, 2025 00:03
@christiangnrd christiangnrd force-pushed the simd branch 3 times, most recently from 486ecc7 to 71db25a Compare March 15, 2025 17:53
Copy link

codecov bot commented Apr 2, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.47%. Comparing base (145dc33) to head (2d80405).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #555      +/-   ##
==========================================
- Coverage   80.66%   80.47%   -0.19%     
==========================================
  Files          61       61              
  Lines        2679     2679              
==========================================
- Hits         2161     2156       -5     
- Misses        518      523       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants