Skip to content

Commit 3892d1b

Browse files
committed
Test latest KernelAbstractions
1 parent 1499c12 commit 3892d1b

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/host/linalg.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,15 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
364364
@inbounds outval[1] = -zero(R)
365365

366366
# number of tiles depends on inner dimension
367-
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
367+
@uniform NUM_TILES = cld(Q, TILE_DIM)
368+
369+
I = (grow - 1) * TILE_DIM + tile_row
370+
J = (gcol - 1) * TILE_DIM + tile_col
368371

369372
# loop over all tiles needed for this calculation
370373
for t in 0:(NUM_TILES - 1)
371-
I = (grow - 1) * TILE_DIM + tile_row
372-
J = (gcol - 1) * TILE_DIM + tile_col
374+
# I = (grow - 1) * TILE_DIM + tile_row
375+
# J = (gcol - 1) * TILE_DIM + tile_col
373376

374377
# load inputs into tiles, with bounds checking for non-square matrices
375378
if I <= N && t * TILE_DIM + tile_col <= Q
@@ -386,8 +389,8 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
386389
# wait for all tiles to be loaded
387390
@synchronize
388391

389-
I = (grow - 1) * TILE_DIM + tile_row
390-
J = (gcol - 1) * TILE_DIM + tile_col
392+
# I = (grow - 1) * TILE_DIM + tile_row
393+
# J = (gcol - 1) * TILE_DIM + tile_col
391394

392395
# calculate value of spot in output, use temporary value to allow for vectorization
393396
out = zero(R)
@@ -399,16 +402,17 @@ function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B:
399402
@synchronize
400403
end
401404

402-
I = (grow - 1) * TILE_DIM + tile_row
403-
J = (gcol - 1) * TILE_DIM + tile_col
405+
# I = (grow - 1) * TILE_DIM + tile_row
406+
# J = (gcol - 1) * TILE_DIM + tile_col
404407

405408
# save if inbounds
406409
if I <= N && J <= M
407410
@inbounds output[I, J] = add(outval[1], output[I, J])
408411
end
409412
end
410413

411-
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
414+
# coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
415+
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=size(C))
412416
C
413417
end
414418
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using Dates
33
import REPL
44
using Printf: @sprintf
55

6+
using Pkg
7+
Pkg.add(url="https://github.com/JuliaGPU/KernelAbstractions.jl", rev="main")
8+
69
# parse some command-line arguments
710
function extract_flag!(args, flag, default=nothing)
811
for f in args

0 commit comments

Comments
 (0)