1
1
export naive_matmul_kernel, matmul
2
2
3
+ """
4
+ matmul_heuristics(x, y)
5
+ This function computes workgroup size and workgroup count heuristics for a given input.
6
+ This is used by `naive_matmul_kernel`.
7
+ """
3
8
function matmul_heuristics (x, y)
4
9
aSize = size (x)
5
10
bSize = size (y)
@@ -9,6 +14,12 @@ function matmul_heuristics(x, y)
9
14
return (outSize, outSize, (1 , 1 ))
10
15
end
11
16
17
+ """
18
+ naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
19
+ This is naive matrix multiplication implementation kernel. This is not supposed to be used as a regular
20
+ julia function. This needs to be passed to @wgpukernel to under transformations to `WGSL` compatible
21
+ shader code.
22
+ """
12
23
function naive_matmul_kernel (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} , out:: WgpuArray{T, N} ) where {T, N}
13
24
gIdx = globalId. x
14
25
gIdy = globalId. y
@@ -23,14 +34,24 @@ function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
23
34
out[gId] = sum
24
35
end
25
36
37
+ """
38
+ matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
39
+ This is wrapper function for end users which uses naive implementation of matrix multiplication
40
+ `naive_matmul_kernel` kernel for matrix computation.
41
+ """
26
42
function matmul (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
27
43
(outSize, wgSize, wgCount) = matmul_heuristics (x, y)
28
44
out = WgpuArray {eltype(x), ndims(x)} (undef, outSize)
29
45
@wgpukernel launch= true workgroupSizes= wgSize workgroupCount= wgCount shmem= () naive_matmul_kernel (x, y, out)
30
46
return out
31
47
end
32
48
33
-
49
+ """
50
+ tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
51
+ This is compute kernel which carries out tiled matrix multiplication of input `WgpuArrays`. This is
52
+ not supposed to be used as a regular julia function. This instead needs to be passed to `@wgpukernel` macro
53
+ inside a wrapper function.
54
+ """
34
55
function tiled_matmul_kernel (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} , out:: WgpuArray{T, N} ) where {T, N}
35
56
# set out matrix to zero
36
57
gId = xDims. x* globalId. y + globalId. x
@@ -61,18 +82,29 @@ function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
61
82
62
83
out[gId] = sum
63
84
end
64
- # For now valid only for square matrices of size powers of 2 and base size 16.
85
+
86
+ """
87
+ tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
88
+ This function computes workgroup size and workgroup count for a given input for
89
+ `tiled_matmul_heuristics` kernel function.
90
+ """
65
91
function tiled_matmul_heuristics (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
66
92
aSize = size (x)
67
93
bSize = size (y)
68
94
@assert last (aSize) == first (bSize)
69
95
outSize = (first (aSize), last (bSize))
70
96
@assert eltype (x) == eltype (y)
97
+ # For now valid only for square matrices of size powers of 2 and base size 16.
71
98
wgSize = (16 , 16 ) # This can be fixed for now
72
99
wgCount = div .((outSize[1 ], outSize[2 ]), 16 , RoundUp)
73
100
return (outSize, wgSize, wgCount)
74
101
end
75
102
103
+ """
104
+ tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
105
+ This is user end matrix multiplication function which carries out tiled matrix multiplication of
106
+ input `WgpuArray` arguments.
107
+ """
76
108
function tiled_matmul (x:: WgpuArray{T, N} , y:: WgpuArray{T, N} ) where {T, N}
77
109
(outSize, wgSize, wgCount) = tiled_matmul_heuristics (x, y)
78
110
out = WgpuArray {eltype(x), ndims(x)} (undef, outSize)
0 commit comments