Skip to content

Commit 416b28f

Browse files
committed
Faster matmul
1 parent 55a943e commit 416b28f

File tree

1 file changed

+88
-7
lines changed

1 file changed

+88
-7
lines changed

src/host/linalg.jl

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325325
B
326326
end
327327

328+
# XXX: figure out how to do dynamically
329+
MAX_TILE_DIM = 16
328330

329331
## matrix multiplication
330332
# legacy method
331333
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) =
332334
generic_matmatmul!(C, A, B, MulAddMul(a, b))
335+
function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number}
336+
N = size(A,1)
337+
Q = size(A,2)
338+
M = size(B,2)
339+
if Q != size(B,1)
340+
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
341+
end
342+
if size(C,1) != N || size(C,2) != M
343+
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))"))
344+
end
345+
if isempty(A) || isempty(B)
346+
return fill!(C, zero(R))
347+
end
348+
349+
@kernel unsafe_indices=true function coalesced_matmul_kernel!(
350+
output, @Const(input1), @Const(input2), N, Q, M,
351+
::Val{BANK} = Val(1),
352+
) where {BANK}
353+
grow, gcol = @index(Group, NTuple)
354+
tile_row, tile_col = @index(Local, NTuple)
355+
356+
TILE_DIM = @uniform @groupsize()[1]
357+
358+
# +1 to avoid bank conflicts on shared memory
359+
tile1 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
360+
tile2 = @localmem(R, (TILE_DIM + BANK, TILE_DIM))
361+
362+
# private variable for tile output
363+
outval = @private R 1
364+
@inbounds outval[1] = -zero(R)
365+
366+
# number of tiles depends on inner dimension
367+
@uniform NUM_TILES = div(Q + TILE_DIM - 1, TILE_DIM)
368+
369+
# loop over all tiles needed for this calculation
370+
for t in 0:(NUM_TILES - 1)
371+
I = (grow - 1) * TILE_DIM + tile_row
372+
J = (gcol - 1) * TILE_DIM + tile_col
373+
374+
# load inputs into tiles, with bounds checking for non-square matrices
375+
if I <= N && t * TILE_DIM + tile_col <= Q
376+
@inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377+
else
378+
@inbounds tile1[tile_row, tile_col] = zero(R)
379+
end
380+
if J <= M && t * TILE_DIM + tile_row <= Q
381+
@inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382+
else
383+
@inbounds tile2[tile_row, tile_col] = zero(R)
384+
end
385+
386+
# wait for all tiles to be loaded
387+
@synchronize
388+
389+
I = (grow - 1) * TILE_DIM + tile_row
390+
J = (gcol - 1) * TILE_DIM + tile_col
391+
392+
# calculate value of spot in output, use temporary value to allow for vectorization
393+
out = zero(R)
394+
@simd for k in 1:TILE_DIM
395+
@inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
396+
end
397+
outval[1] += out
398+
399+
@synchronize
400+
end
401+
402+
I = (grow - 1) * TILE_DIM + tile_row
403+
J = (gcol - 1) * TILE_DIM + tile_col
404+
405+
# save if inbounds
406+
if I <= N && J <= M
407+
@inbounds output[I, J] = add(outval[1], output[I, J])
408+
end
409+
end
410+
411+
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
412+
C
413+
end
333414
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
334415
if size(A,2) != size(B,1)
335416
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
@@ -744,7 +825,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
744825

745826
@kernel function kron_kernel!(z, @Const(x), @Const(y))
746827
i, j = @index(Global, NTuple)
747-
828+
748829
@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
749830
end
750831

@@ -777,13 +858,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
777858

778859
ta = $transa(T1)
779860
tb = $transb(T2)
780-
861+
781862
@kernel function kron_kernel!(C, @Const(A), @Const(B))
782863
ai, aj = @index(Global, NTuple) # Indices in the result matrix
783-
864+
784865
# lb1, lb2 = size(B) # Dimensions of B
785866
lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B))
786-
867+
787868
# Map global indices (ai, aj) to submatrices of the Kronecker product
788869
i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A
789870
i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B
@@ -797,12 +878,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
797878
C[ai, aj] = a_ij * b_ij
798879
end
799880
end
800-
881+
801882
backend = KernelAbstractions.get_backend(C)
802883
kernel = kron_kernel!(backend)
803-
884+
804885
kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2)))
805-
886+
806887
return C
807888
end
808889

0 commit comments

Comments
 (0)