Skip to content

Commit 7c29498

Browse files
committed
Revert "Adapt to GPUArrays.jl transition to KernelAbstractions.jl. (JuliaGPU#461)"
This reverts commit 711758d.
1 parent f733812 commit 7c29498

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ BFloat16s = "0.5"
4040
CEnum = "0.4, 0.5"
4141
CodecBzip2 = "0.8"
4242
ExprTools = "0.1"
43-
GPUArrays = "11"
43+
GPUArrays = "10.1"
4444
GPUCompiler = "0.26, 0.27, 1"
4545
KernelAbstractions = "0.9.1"
4646
LLVM = "7.2, 8, 9"

src/gpuarrays.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,59 @@
11
## GPUArrays interfaces
22

3+
## execution
4+
5+
struct mtlArrayBackend <: AbstractGPUBackend end
6+
7+
struct mtlKernelContext <: AbstractKernelContext end
8+
9+
@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N};
10+
elements::Int, elements_per_thread::Int) where {F,N}
11+
kernel = @metal launch=false f(mtlKernelContext(), args...)
12+
13+
# The pipeline state automatically computes occupancy stats
14+
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
15+
blocks = cld(elements, threads)
16+
17+
return (; threads=Int(threads), blocks=Int(blocks))
18+
end
19+
20+
function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int;
21+
name::Union{String,Nothing})
22+
@metal threads groups name f(mtlKernelContext(), args...)
23+
end
24+
25+
26+
## on-device
27+
28+
# indexing
29+
GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d()
30+
GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d()
31+
GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d()
32+
GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d()
33+
GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d()
34+
GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d()
35+
36+
# memory
37+
38+
@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
39+
) where {T, dims, id}
40+
ptr = emit_threadgroup_memory(T, Val(prod(dims)))
41+
MtlDeviceArray(dims, ptr)
42+
end
43+
44+
# synchronization
45+
46+
@inline GPUArrays.synchronize_threads(::mtlKernelContext) =
47+
threadgroup_barrier(MemoryFlagThreadGroup)
48+
49+
50+
51+
#
52+
# Host abstractions
53+
#
54+
55+
GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend()
56+
357
const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
458
function GPUArrays.default_rng(::Type{<:MtlArray})
559
dev = device()

test/random.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
246246
a = f(T, d)
247247
Metal.seed!(1)
248248
b = f(T, d)
249-
@test Array(a) == Array(b)
249+
# TODO: Remove broken parameter once https://github.com/JuliaGPU/GPUArrays.jl/issues/530 is fixed
250+
@test Array(a) == Array(b) broken = (T == Float16 && d == (1000,1000))
250251
end
251252
end
252253
end # testset

0 commit comments

Comments
 (0)