Skip to content

Commit 467fc60

Browse files
authored
Merge pull request #38 from arhik/main
[docs] update ops/matmul.jl
2 parents 5621e39 + 090f0c0 commit 467fc60

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

src/ops/matmul.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
export naive_matmul_kernel, matmul
22

3+
"""
4+
matmul_heuristics(x, y)
5+
This function computes workgroup size and workgroup count heuristics for a given input.
6+
This is used by `naive_matmul_kernel`.
7+
"""
38
function matmul_heuristics(x, y)
49
aSize = size(x)
510
bSize = size(y)
@@ -9,6 +14,12 @@ function matmul_heuristics(x, y)
914
return (outSize, outSize, (1, 1))
1015
end
1116

17+
"""
18+
naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
19+
This is naive matrix multiplication implementation kernel. This is not supposed to be used as a regular
20+
julia function. This needs to be passed to @wgpukernel to under transformations to `WGSL` compatible
21+
shader code.
22+
"""
1223
function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
1324
gIdx = globalId.x
1425
gIdy = globalId.y
@@ -23,14 +34,24 @@ function naive_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
2334
out[gId] = sum
2435
end
2536

37+
"""
38+
matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
39+
This is wrapper function for end users which uses naive implementation of matrix multiplication
40+
`naive_matmul_kernel` kernel for matrix computation.
41+
"""
2642
function matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
2743
(outSize, wgSize, wgCount) = matmul_heuristics(x, y)
2844
out = WgpuArray{eltype(x), ndims(x)}(undef, outSize)
2945
@wgpukernel launch=true workgroupSizes=wgSize workgroupCount=wgCount shmem=() naive_matmul_kernel(x, y, out)
3046
return out
3147
end
3248

33-
49+
"""
50+
tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
51+
This is compute kernel which carries out tiled matrix multiplication of input `WgpuArrays`. This is
52+
not supposed to be used as a regular julia function. This instead needs to be passed to `@wgpukernel` macro
53+
inside a wrapper function.
54+
"""
3455
function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuArray{T, N}) where {T, N}
3556
#set out matrix to zero
3657
gId = xDims.x*globalId.y + globalId.x
@@ -61,18 +82,29 @@ function tiled_matmul_kernel(x::WgpuArray{T, N}, y::WgpuArray{T, N}, out::WgpuAr
6182

6283
out[gId] = sum
6384
end
64-
# For now valid only for square matrices of size powers of 2 and base size 16.
85+
86+
"""
87+
tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
88+
This function computes workgroup size and workgroup count for a given input for
89+
`tiled_matmul_heuristics` kernel function.
90+
"""
6591
function tiled_matmul_heuristics(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
6692
aSize = size(x)
6793
bSize = size(y)
6894
@assert last(aSize) == first(bSize)
6995
outSize = (first(aSize), last(bSize))
7096
@assert eltype(x) == eltype(y)
97+
# For now valid only for square matrices of size powers of 2 and base size 16.
7198
wgSize = (16, 16) # This can be fixed for now
7299
wgCount = div.((outSize[1], outSize[2]), 16, RoundUp)
73100
return (outSize, wgSize, wgCount)
74101
end
75102

103+
"""
104+
tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
105+
This is user end matrix multiplication function which carries out tiled matrix multiplication of
106+
input `WgpuArray` arguments.
107+
"""
76108
function tiled_matmul(x::WgpuArray{T, N}, y::WgpuArray{T, N}) where {T, N}
77109
(outSize, wgSize, wgCount) = tiled_matmul_heuristics(x, y)
78110
out = WgpuArray{eltype(x), ndims(x)}(undef, outSize)

0 commit comments

Comments
 (0)