Skip to content

Commit 094ea90

Browse files
committed
Merge branch 'muladdmul' into mbaran/matmul-symmetric
2 parents 59b4a9b + 86eab40 commit 094ea90

File tree

1 file changed

+83
-10
lines changed

1 file changed

+83
-10
lines changed

src/matrix_multiply.jl

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -454,18 +454,44 @@ end
454454

455455
S = Size(sa[1], sb[2])
456456

457-
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:sa[1], k2 = 1:sb[2]]
458-
exprs_init = [:($(tmps[k1,k2]) = a[$k1] * b[1 + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]]
459-
exprs_loop = [:($(tmps[k1,k2]) += a[$(k1-sa[1]) + $(sa[1])*j] * b[j + $((k2-1) * sb[1])]) for k1 = 1:sa[1], k2 = 1:sb[2]]
460-
457+
# optimal for AVX2 with `Float64
458+
# AVX512 would want something more like 16x14 or 24x9 with `Float64`
459+
M_r, N_r = 8, 6
460+
n = 0
461+
M, K = sa
462+
N = sb[2]
463+
q = Expr(:block)
464+
atemps = [Symbol(:a_, k1) for k1 = 1:M]
465+
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N]
466+
while n < N
467+
nu = min(N, n + N_r)
468+
nrange = n+1:nu
469+
m = 0
470+
while m < M
471+
mu = min(M, m + M_r)
472+
mrange = m+1:mu
473+
474+
atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
475+
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
476+
atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange]
477+
exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[j + $((k2-1) * sb[1])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange]
478+
qblock = quote
479+
@inbounds $(Expr(:block, atemps_init...))
480+
@inbounds $(Expr(:block, exprs_init...))
481+
for j = 2:$(sa[2])
482+
@inbounds $(Expr(:block, atemps_loop_init...))
483+
@inbounds $(Expr(:block, exprs_loop...))
484+
end
485+
end
486+
push!(q.args, qblock)
487+
m = mu
488+
end
489+
n = nu
490+
end
461491
return quote
462492
@_inline_meta
463493
T = promote_op(matprod,Ta,Tb)
464-
465-
@inbounds $(Expr(:block, exprs_init...))
466-
for j = 2:$(sa[2])
467-
@inbounds $(Expr(:block, exprs_loop...))
468-
end
494+
$q
469495
@inbounds return similar_type(a, T, $S)(tuple($(tmps...)))
470496
end
471497
end
@@ -512,7 +538,6 @@ end
512538
))
513539
end
514540
end
515-
516541
return quote
517542
@_inline_meta
518543
T = promote_op(matprod, Ta, Tb)
@@ -522,4 +547,52 @@ end
522547
end
523548
end
524549

550+
# a special version for plain matrices
551+
@generated function mul_unrolled_chunks(::Size{sa}, ::Size{sb}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
552+
if sb[1] != sa[2]
553+
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
554+
end
555+
556+
S = Size(sa[1], sb[2])
557+
558+
# optimal for AVX2 with `Float64
559+
# AVX512 would want something more like 16x14 or 24x9 with `Float64`
560+
M_r, N_r = 8, 6
561+
n = 0
562+
M, K = sa
563+
N = sb[2]
564+
q = Expr(:block)
565+
atemps = [Symbol(:a_, k1) for k1 = 1:M]
566+
tmps = [Symbol("tmp_$(k1)_$(k2)") for k1 = 1:M, k2 = 1:N]
567+
while n < N
568+
nu = min(N, n + N_r)
569+
nrange = n+1:nu
570+
m = 0
571+
while m < M
572+
mu = min(M, m + M_r)
573+
mrange = m+1:mu
574+
575+
atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
576+
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
577+
push!(q.args, :(@inbounds $(Expr(:block, atemps_init...))))
578+
push!(q.args, :(@inbounds $(Expr(:block, exprs_init...))))
579+
580+
for j in 2:K
581+
atemps_loop_init = [:($(atemps[k1]) = a[$(LinearIndices(sa)[k1,j])]) for k1 = mrange]
582+
exprs_loop = [:($(tmps[k1,k2]) = muladd($(atemps[k1]), b[$(LinearIndices(sb)[j,k2])], $(tmps[k1,k2]))) for k1 = mrange, k2 = nrange]
583+
push!(q.args, :(@inbounds $(Expr(:block, atemps_loop_init...))))
584+
push!(q.args, :(@inbounds $(Expr(:block, exprs_loop...))))
585+
end
586+
m = mu
587+
end
588+
n = nu
589+
end
590+
return quote
591+
@_inline_meta
592+
T = promote_op(matprod,Ta,Tb)
593+
$q
594+
@inbounds return similar_type(a, T, $S)(tuple($(tmps...)))
595+
end
596+
end
597+
525598
#

0 commit comments

Comments
 (0)