Skip to content

Commit 1273856

Browse files
leiosmaleadt
authored andcommitted
Adapt to GPUArrays.jl transition to KernelAbstractions.jl.
1 parent 100f831 commit 1273856

File tree

3 files changed

+2
-57
lines changed

3 files changed

+2
-57
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ BFloat16s = "0.5"
4040
CEnum = "0.4, 0.5"
4141
CodecBzip2 = "0.8"
4242
ExprTools = "0.1"
43-
GPUArrays = "10.1"
43+
GPUArrays = "11"
4444
GPUCompiler = "0.26, 0.27, 1"
4545
KernelAbstractions = "0.9.1"
4646
LLVM = "7.2, 8, 9"

src/gpuarrays.jl

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,5 @@
11
## GPUArrays interfaces
22

3-
## execution
4-
5-
struct mtlArrayBackend <: AbstractGPUBackend end
6-
7-
struct mtlKernelContext <: AbstractKernelContext end
8-
9-
@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N};
10-
elements::Int, elements_per_thread::Int) where {F,N}
11-
kernel = @metal launch=false f(mtlKernelContext(), args...)
12-
13-
# The pipeline state automatically computes occupancy stats
14-
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
15-
blocks = cld(elements, threads)
16-
17-
return (; threads=Int(threads), blocks=Int(blocks))
18-
end
19-
20-
function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int;
21-
name::Union{String,Nothing})
22-
@metal threads groups name f(mtlKernelContext(), args...)
23-
end
24-
25-
26-
## on-device
27-
28-
# indexing
29-
GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d()
30-
GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d()
31-
GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d()
32-
GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d()
33-
GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d()
34-
GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d()
35-
36-
# memory
37-
38-
@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
39-
) where {T, dims, id}
40-
ptr = emit_threadgroup_memory(T, Val(prod(dims)))
41-
MtlDeviceArray(dims, ptr)
42-
end
43-
44-
# synchronization
45-
46-
@inline GPUArrays.synchronize_threads(::mtlKernelContext) =
47-
threadgroup_barrier(MemoryFlagThreadGroup)
48-
49-
50-
51-
#
52-
# Host abstractions
53-
#
54-
55-
GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend()
56-
573
const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
584
function GPUArrays.default_rng(::Type{<:MtlArray})
595
dev = device()

test/random.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,7 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
246246
a = f(T, d)
247247
Metal.seed!(1)
248248
b = f(T, d)
249-
# TODO: Remove broken parameter once https://github.com/JuliaGPU/GPUArrays.jl/issues/530 is fixed
250-
@test Array(a) == Array(b) broken = (T == Float16 && d == (1000,1000))
249+
@test Array(a) == Array(b)
251250
end
252251
end
253252
end # testset

0 commit comments

Comments
 (0)