|
1 | 1 | ## GPUArrays interfaces
|
2 | 2 |
|
3 |
| -## execution |
4 |
| - |
5 |
| -struct mtlArrayBackend <: AbstractGPUBackend end |
6 |
| - |
7 |
| -struct mtlKernelContext <: AbstractKernelContext end |
8 |
| - |
9 |
| -@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N}; |
10 |
| - elements::Int, elements_per_thread::Int) where {F,N} |
11 |
| - kernel = @metal launch=false f(mtlKernelContext(), args...) |
12 |
| - |
13 |
| - # The pipeline state automatically computes occupancy stats |
14 |
| - threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) |
15 |
| - blocks = cld(elements, threads) |
16 |
| - |
17 |
| - return (; threads=Int(threads), blocks=Int(blocks)) |
18 |
| -end |
19 |
| - |
20 |
| -function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int; |
21 |
| - name::Union{String,Nothing}) |
22 |
| - @metal threads groups name f(mtlKernelContext(), args...) |
23 |
| -end |
24 |
| - |
25 |
| - |
26 |
| -## on-device |
27 |
| - |
28 |
| -# indexing |
29 |
| -GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d() |
30 |
| -GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d() |
31 |
| -GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d() |
32 |
| -GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d() |
33 |
| -GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d() |
34 |
| -GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d() |
35 |
| - |
36 |
| -# memory |
37 |
| - |
38 |
| -@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id} |
39 |
| - ) where {T, dims, id} |
40 |
| - ptr = emit_threadgroup_memory(T, Val(prod(dims))) |
41 |
| - MtlDeviceArray(dims, ptr) |
42 |
| -end |
43 |
| - |
44 |
| -# synchronization |
45 |
| - |
46 |
| -@inline GPUArrays.synchronize_threads(::mtlKernelContext) = |
47 |
| - threadgroup_barrier(MemoryFlagThreadGroup) |
48 |
| - |
49 |
| - |
50 |
| - |
51 |
| -# |
52 |
| -# Host abstractions |
53 |
| -# |
54 |
| - |
55 |
| -GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend() |
56 |
| - |
57 | 3 | const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
|
58 | 4 | function GPUArrays.default_rng(::Type{<:MtlArray})
|
59 | 5 | dev = device()
|
|
0 commit comments