@@ -325,11 +325,92 @@ function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
325
325
B
326
326
end
327
327
328
+ # XXX : figure out how to do dynamically
329
+ MAX_TILE_DIM = 16
328
330
329
331
# # matrix multiplication
330
332
# legacy method
331
333
generic_matmatmul! (C:: AbstractArray , A:: AbstractArray , B:: AbstractArray , a:: Number , b:: Number ) =
332
334
generic_matmatmul! (C, A, B, MulAddMul (a, b))
335
+ function generic_matmatmul! (C:: AbstractGPUMatrix{R} , A:: AbstractGPUMatrix{T} , B:: AbstractGPUMatrix{S} , add:: MulAddMul ) where {T<: Number ,S<: Number ,R<: Number }
336
+ N = size (A,1 )
337
+ Q = size (A,2 )
338
+ M = size (B,2 )
339
+ if Q != size (B,1 )
340
+ throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
341
+ end
342
+ if size (C,1 ) != N || size (C,2 ) != M
343
+ throw (DimensionMismatch (" result C has dimensions $(size (C)) , needs $((N,M)) " ))
344
+ end
345
+ if isempty (A) || isempty (B)
346
+ return fill! (C, zero (R))
347
+ end
348
+
349
+ @kernel unsafe_indices= true function coalesced_matmul_kernel! (
350
+ output, @Const (input1), @Const (input2), N, Q, M,
351
+ :: Val{BANK} = Val (1 ),
352
+ ) where {BANK}
353
+ grow, gcol = @index (Group, NTuple)
354
+ tile_row, tile_col = @index (Local, NTuple)
355
+
356
+ TILE_DIM = @uniform @groupsize ()[1 ]
357
+
358
+ # +1 to avoid bank conflicts on shared memory
359
+ tile1 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
360
+ tile2 = @localmem (R, (TILE_DIM + BANK, TILE_DIM))
361
+
362
+ # private variable for tile output
363
+ outval = @private R 1
364
+ @inbounds outval[1 ] = - zero (R)
365
+
366
+ # number of tiles depends on inner dimension
367
+ @uniform NUM_TILES = div (Q + TILE_DIM - 1 , TILE_DIM)
368
+
369
+ # loop over all tiles needed for this calculation
370
+ for t in 0 : (NUM_TILES - 1 )
371
+ I = (grow - 1 ) * TILE_DIM + tile_row
372
+ J = (gcol - 1 ) * TILE_DIM + tile_col
373
+
374
+ # load inputs into tiles, with bounds checking for non-square matrices
375
+ if I <= N && t * TILE_DIM + tile_col <= Q
376
+ @inbounds tile1[tile_row, tile_col] = input1[I, t * TILE_DIM + tile_col]
377
+ else
378
+ @inbounds tile1[tile_row, tile_col] = zero (R)
379
+ end
380
+ if J <= M && t * TILE_DIM + tile_row <= Q
381
+ @inbounds tile2[tile_row, tile_col] = input2[t * TILE_DIM + tile_row, J]
382
+ else
383
+ @inbounds tile2[tile_row, tile_col] = zero (R)
384
+ end
385
+
386
+ # wait for all tiles to be loaded
387
+ @synchronize
388
+
389
+ I = (grow - 1 ) * TILE_DIM + tile_row
390
+ J = (gcol - 1 ) * TILE_DIM + tile_col
391
+
392
+ # calculate value of spot in output, use temporary value to allow for vectorization
393
+ out = zero (R)
394
+ @simd for k in 1 : TILE_DIM
395
+ @inbounds out += tile1[tile_row, k] * tile2[k, tile_col]
396
+ end
397
+ outval[1 ] += out
398
+
399
+ @synchronize
400
+ end
401
+
402
+ I = (grow - 1 ) * TILE_DIM + tile_row
403
+ J = (gcol - 1 ) * TILE_DIM + tile_col
404
+
405
+ # save if inbounds
406
+ if I <= N && J <= M
407
+ @inbounds output[I, J] = add (outval[1 ], output[I, J])
408
+ end
409
+ end
410
+
411
+ coalesced_matmul_kernel! (get_backend (C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange= map (x -> ceil (Int,x/ MAX_TILE_DIM)* MAX_TILE_DIM, size (C)))
412
+ C
413
+ end
333
414
function generic_matmatmul! (C:: AbstractArray{R} , A:: AbstractArray{T} , B:: AbstractArray{S} , add:: MulAddMul ) where {T,S,R}
334
415
if size (A,2 ) != size (B,1 )
335
416
throw (DimensionMismatch (" matrix A has dimensions $(size (A)) , matrix B has dimensions $(size (B)) " ))
@@ -744,7 +825,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
744
825
745
826
@kernel function kron_kernel! (z, @Const (x), @Const (y))
746
827
i, j = @index (Global, NTuple)
747
-
828
+
748
829
@inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749
830
end
750
831
@@ -777,13 +858,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
777
858
778
859
ta = $ transa (T1)
779
860
tb = $ transb (T2)
780
-
861
+
781
862
@kernel function kron_kernel! (C, @Const (A), @Const (B))
782
863
ai, aj = @index (Global, NTuple) # Indices in the result matrix
783
-
864
+
784
865
# lb1, lb2 = size(B) # Dimensions of B
785
866
lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786
-
867
+
787
868
# Map global indices (ai, aj) to submatrices of the Kronecker product
788
869
i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789
870
i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
@@ -797,12 +878,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
797
878
C[ai, aj] = a_ij * b_ij
798
879
end
799
880
end
800
-
881
+
801
882
backend = KernelAbstractions. get_backend (C)
802
883
kernel = kron_kernel! (backend)
803
-
884
+
804
885
kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
805
-
886
+
806
887
return C
807
888
end
808
889
0 commit comments