Skip to content

Commit 960beb6

Browse files
committed
matmul muladd and improved order.
1 parent ad583c9 commit 960beb6

File tree

1 file changed

+104
-23
lines changed

1 file changed

+104
-23
lines changed

src/matrix_multiply.jl

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,24 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1616

1717
# Implementations
1818

19+
function matrix_vector_quote(sa)
20+
q = Expr(:block)
21+
exprs = [Symbol(:x_, k) for k 1:sa[1]]
22+
for j 1:sa[2]
23+
for k 1:sa[1]
24+
call = isone(j) ? :(a[$(LinearIndices(sa)[k, j])]*b[$j]) : :(muladd(a[$(LinearIndices(sa)[k, j])], b[$j], $(exprs[k])))
25+
push!(q.args, :($(exprs[k]) = $call))
26+
end
27+
end
28+
q, exprs
29+
end
30+
1931
@generated function _mul(::Size{sa}, a::StaticMatrix{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
2032
if sa[2] != 0
21-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
33+
# [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
34+
q, exprs = matrix_vector_quote(sa)
2235
else
36+
q = nothing
2337
exprs = [:(zero(T)) for k = 1:sa[1]]
2438
end
2539

@@ -29,6 +43,7 @@ import LinearAlgebra: BlasFloat, matprod, mul!
2943
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
3044
end
3145
T = promote_op(matprod,Ta,Tb)
46+
$q
3247
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
3348
end
3449
end
@@ -39,14 +54,17 @@ end
3954
end
4055

4156
if sa[2] != 0
42-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
57+
q, exprs = matrix_vector_quote(sa)
58+
# exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k, j])]*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
4359
else
60+
q = nothing
4461
exprs = [:(zero(T)) for k = 1:sa[1]]
4562
end
4663

4764
return quote
4865
@_inline_meta
4966
T = promote_op(matprod,Ta,Tb)
67+
$q
5068
@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))
5169
end
5270
end
@@ -125,14 +143,30 @@ end
125143
S = Size(sa[1], sb[2])
126144

127145
if sa[2] != 0
128-
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]) for k1 = 1:sa[1], k2 = 1:sb[2]]
146+
# exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:(a[$(LinearIndices(sa)[k1, j])]*b[$(LinearIndices(sb)[j, k2])]) for j = 1:sa[2]]) for k1 = 1:sa[1], k2 = 1:sb[2]]
147+
exprs = [Symbol(:C_,k1,:_,k2) for k1 = 1:sa[1], k2 = 1:sb[2]]
148+
q = Expr(:block)
149+
for k2 in 1:sb[2]
150+
for k1 in 1:sa[1]
151+
push!(q.args, :($(exprs[k1,k2]) = a[$(LinearIndices(sa)[k1, 1])]*b[$(LinearIndices(sb)[1, k2])]))
152+
end
153+
end
154+
for j in 2:sb[1]
155+
for k2 in 1:sb[2]
156+
for k1 in 1:sa[1]
157+
push!(q.args, :($(exprs[k1,k2]) = muladd(a[$(LinearIndices(sa)[k1, j])], b[$(LinearIndices(sb)[j, k2])], $(exprs[k1,k2]))))
158+
end
159+
end
160+
end
129161
else
162+
q = nothing
130163
exprs = [:(zero(T)) for k1 = 1:sa[1], k2 = 1:sb[2]]
131164
end
132165

133166
return quote
134167
@_inline_meta
135168
T = promote_op(matprod,Ta,Tb)
169+
$q
136170
@inbounds return similar_type(a, T, $S)(tuple($(exprs...)))
137171
end
138172
end
@@ -145,18 +179,44 @@ end
145179

146180
S = Size(sa[1], sb[2])
147181

148-
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:sa[1], k2 = 1:sb[2]]
149-
exprs_init = [:($(tmps[k1,k2]) = a[$k1] * b[1 + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]]
150-
exprs_loop = [:($(tmps[k1,k2]) += a[$(k1-sa[1]) + $(sa[1])*j] * b[j + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]]
151-
182+
# optimal for AVX2 with `Float64
183+
# AVX512 would want something more like 16x14 or 24x9 with `Float64`
184+
M_r, N_r = 8, 6
185+
n = 0
186+
M, K = sa
187+
N = sb[2]
188+
q = Expr(:block)
189+
atemps = [Symbol(:a_, k1) for k1 = 1:M]
190+
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N]
191+
while n < N
192+
nu = min(N, n + N_r)
193+
nrange = n+1:nu
194+
m = 0
195+
while m < M
196+
mu = min(M, m + M_r)
197+
mrange = m+1:mu
198+
199+
atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
200+
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
201+
atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange]
202+
exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[j + $((k2-1) * sb[1])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange]
203+
qblock = quote
204+
@inbounds $(Expr(:block, atemps_init...))
205+
@inbounds $(Expr(:block, exprs_init...))
206+
for j = 2:$(sa[2])
207+
@inbounds $(Expr(:block, atemps_loop_init...))
208+
@inbounds $(Expr(:block, exprs_loop...))
209+
end
210+
end
211+
push!(q.args, qblock)
212+
m = mu
213+
end
214+
n = nu
215+
end
152216
return quote
153217
@_inline_meta
154218
T = promote_op(matprod,Ta,Tb)
155-
156-
@inbounds $(Expr(:block, exprs_init...))
157-
for j = 2:$(sa[2])
158-
@inbounds $(Expr(:block, exprs_loop...))
159-
end
219+
$q
160220
@inbounds return similar_type(a, T, $S)(tuple($(tmps...)))
161221
end
162222
end
@@ -170,22 +230,43 @@ end
170230

171231
S = Size(sa[1], sb[2])
172232

173-
# Do a custom b[:, k2] to return a SVector (an isbitstype type) rather than (possibly) a mutable type. Avoids allocation == faster
174-
tmp_type_in = :(SVector{$(sb[1]), T})
175-
tmp_type_out = :(SVector{$(sa[1]), T})
176-
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply(TSize(a), TSize($(sb[1])), a,
177-
$(Expr(:call, tmp_type_in, [Expr(:ref, :b, LinearIndices(sb)[i, k2]) for i = 1:sb[1]]...)))::$tmp_type_out)
178-
for k2 = 1:sb[2]]
233+
# optimal for AVX2 with `Float64
234+
# AVX512 would want something more like 16x14 or 24x9 with `Float64`
235+
M_r, N_r = 8, 6
236+
n = 0
237+
M, K = sa
238+
N = sb[2]
239+
q = Expr(:block)
240+
atemps = [Symbol(:a_, k1) for k1 = 1:M]
241+
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N]
242+
while n < N
243+
nu = min(N, n + N_r)
244+
nrange = n+1:nu
245+
m = 0
246+
while m < M
247+
mu = min(M, m + M_r)
248+
mrange = m+1:mu
179249

180-
exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]
250+
atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
251+
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
252+
push!(q.args, :(@inbounds $(Expr(:block, atemps_init...))))
253+
push!(q.args, :(@inbounds $(Expr(:block, exprs_init...))))
181254

255+
for j in 2:K
256+
atemps_loop_init = [:($(atemps[k1]) = a[$(LinearIndices(sa)[k1,j])]) for k1 = mrange]
257+
exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[$(LinearIndices(sb)[j,k2])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange]
258+
push!(q.args, :(@inbounds $(Expr(:block, atemps_loop_init...))))
259+
push!(q.args, :(@inbounds $(Expr(:block, exprs_loop...))))
260+
end
261+
m = mu
262+
end
263+
n = nu
264+
end
182265
return quote
183266
@_inline_meta
184267
T = promote_op(matprod,Ta,Tb)
185-
$(Expr(:block,
186-
vect_exprs...,
187-
:(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
188-
))
268+
$q
269+
@inbounds return similar_type(a, T, $S)(tuple($(tmps...)))
189270
end
190271
end
191272

0 commit comments

Comments
 (0)