Skip to content

Commit cfe7eac

Browse files
committed
Revert "removing heuristic"
This reverts commit 9a7a84a.
1 parent e8ffa59 commit cfe7eac

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

src/gpuarrays.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22

33
GPUArrays.device(x::MtlArray) = x.dev
44

5+
import KernelAbstractions
6+
import KernelAbstractions: Backend
7+
8+
@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N};
9+
elements::Int, elements_per_thread::Int) where {O,N}
10+
11+
ndrange = ceil(Int, elements / elements_per_thread)
12+
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange,
13+
nothing)
14+
15+
ctx = KA.mkcontext(obj, ndrange, iterspace)
16+
17+
kernel = @metal launch=false obj.f(ctx, args...)
18+
19+
# The pipeline state automatically computes occupancy stats
20+
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
21+
blocks = cld(elements, threads)
22+
23+
return (; threads=Int(threads), blocks=Int(blocks))
24+
end
25+
526
const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
627
function GPUArrays.default_rng(::Type{<:MtlArray})
728
dev = device()

0 commit comments

Comments
 (0)