Skip to content

Commit 1b94778

Browse files
Structured matrix multiplication (#814)
* structured matrix multiplication pt 1 * mul! doesn't always take parent; allocations on Julia 1.5 * updating triangular matrix multiplication to the new scheme * more tests for multiplication * adjoint and transpose wrappers for multiplication; more documentation * partical unification of in-placed and out-of-place matrix multiplication * more matrix types for multiplication * fixed code for symmetric and hermitian multiplication and small cleanup * some work on in-place structured multiplication * minor fixes * optimized multiplication by triangular matrices * blas mul! fix and matrix multiplication benchmarks * small matmul benchmark fix * slightly relaxing allocation tests * adding Diagonal to the new matrix multiplication scheme * slight adjustments to matrix multiplication * modified matrix multiplication heuristics * matmul muladd and improved order. * 14 -> 12 for loopmul decisions. * BLAS decision should be for larger than 14x14, if anything. * fixing matmul benchmark * by default use a reduced set of matmul benchmarks * muladd in combine_products * formatting fix * small cleanup in matmul benchmarks Co-authored-by: Chris Elrod <[email protected]>
1 parent 84e0f54 commit 1b94778

7 files changed

+967
-636
lines changed

benchmark/bench_mat_mul.jl

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
module BenchmarkMatMul
2+
3+
using StaticArrays
4+
using BenchmarkTools
5+
using LinearAlgebra
6+
using Printf
7+
8+
suite = BenchmarkGroup()
9+
10+
mul_wrappers = [
11+
(m -> m, "ident "),
12+
(m -> Symmetric(m, :U), "sym-u "),
13+
(m -> Hermitian(m, :U), "herm-u "),
14+
(m -> UpperTriangular(m), "up-tri "),
15+
(m -> LowerTriangular(m), "lo-tri "),
16+
(m -> UnitUpperTriangular(m), "uup-tri"),
17+
(m -> UnitLowerTriangular(m), "ulo-tri"),
18+
(m -> Adjoint(m), "adjoint"),
19+
(m -> Transpose(m), "transpo"),
20+
(m -> Diagonal(m), "diag ")]
21+
22+
mul_wrappers_reduced = [
23+
(m -> m, "ident "),
24+
(m -> Symmetric(m, :U), "sym-u "),
25+
(m -> UpperTriangular(m), "up-tri "),
26+
(m -> Transpose(m), "transpo"),
27+
(m -> Diagonal(m), "diag ")]
28+
29+
for N in [2, 4, 8, 10, 16]
30+
31+
matvecstr = @sprintf("mat-vec %2d", N)
32+
matmatstr = @sprintf("mat-mat %2d", N)
33+
matvec_mut_str = @sprintf("mat-vec! %2d", N)
34+
matmat_mut_str = @sprintf("mat-mat! %2d", N)
35+
36+
suite[matvecstr] = BenchmarkGroup()
37+
suite[matmatstr] = BenchmarkGroup()
38+
suite[matvec_mut_str] = BenchmarkGroup()
39+
suite[matmat_mut_str] = BenchmarkGroup()
40+
41+
42+
A = randn(SMatrix{N,N,Float64})
43+
B = randn(SMatrix{N,N,Float64})
44+
bv = randn(SVector{N,Float64})
45+
for (wrapper_a, wrapper_name) in mul_wrappers_reduced
46+
thrown = false
47+
try
48+
wrapper_a(A) * bv
49+
catch e
50+
thrown = true
51+
end
52+
if !thrown
53+
suite[matvecstr][wrapper_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(bv))[]
54+
end
55+
end
56+
57+
for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers
58+
thrown = false
59+
try
60+
wrapper_a(A) * wrapper_b(B)
61+
catch e
62+
thrown = true
63+
end
64+
if !thrown
65+
suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(wrapper_b(B)))[]
66+
end
67+
end
68+
69+
C = randn(MMatrix{N,N,Float64})
70+
cv = randn(MVector{N,Float64})
71+
72+
for (wrapper_a, wrapper_name) in mul_wrappers
73+
thrown = false
74+
try
75+
mul!(cv, wrapper_a(A), bv)
76+
catch e
77+
thrown = true
78+
end
79+
if !thrown
80+
suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(Ref(wrapper_a(A)))[], $(Ref(bv))[])
81+
end
82+
end
83+
84+
for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers
85+
thrown = false
86+
try
87+
mul!(C, wrapper_a(A), wrapper_b(B))
88+
catch e
89+
thrown = true
90+
end
91+
if !thrown
92+
suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(Ref(wrapper_a(A)))[], $(Ref(wrapper_b(B)))[])
93+
end
94+
end
95+
end
96+
97+
function run_and_save(fname, make_params = true)
98+
if make_params
99+
tune!(suite)
100+
BenchmarkTools.save("params.json", params(suite))
101+
else
102+
loadparams!(suite, BenchmarkTools.load("params.json")[1], :evals, :samples)
103+
end
104+
results = run(suite, verbose = true)
105+
BenchmarkTools.save(fname, results)
106+
end
107+
108+
function judge_results(m1, m2)
109+
results = Any[]
110+
for key1 in keys(m1)
111+
if !haskey(m2, key1)
112+
continue
113+
end
114+
for key2 in keys(m1[key1])
115+
if !haskey(m2[key1], key2)
116+
continue
117+
end
118+
push!(results, (key1, key2, judge(median(m1[key1][key2]), median(m2[key1][key2]))))
119+
end
120+
end
121+
return results
122+
end
123+
124+
function full_benchmark(mul_wrappers, size_iter = 1:4, T = Float64)
125+
suite_full = BenchmarkGroup()
126+
for N in size_iter
127+
for M in size_iter
128+
a = randn(SMatrix{N,M,T})
129+
wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]]
130+
sa = Size(a)
131+
for K in size_iter
132+
b = randn(SMatrix{M,K,T})
133+
wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]]
134+
sb = Size(b)
135+
for (w_a, w_a_name) in wrappers_a
136+
for (w_b, w_b_name) in wrappers_b
137+
cur_str = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
138+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_generic($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
139+
cur_str = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
140+
suite_full[cur_str] = @benchmarkable StaticArrays._mul($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
141+
cur_str = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
142+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
143+
if w_a_name != "diag " && w_b_name != "diag "
144+
cur_str = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
145+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled_chunks($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
146+
end
147+
if w_a_name == "ident " && w_b_name == "ident "
148+
cur_str = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
149+
suite_full[cur_str] = @benchmarkable StaticArrays.mul_loop($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[])
150+
end
151+
end
152+
end
153+
end
154+
end
155+
end
156+
results = run(suite_full, verbose = true)
157+
results_median = map(collect(results)) do res
158+
return (res[1], median(res[2]).time)
159+
end
160+
return results_median
161+
end
162+
163+
function judge_this(new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which)
164+
if new_time*tol < old_time
165+
msg = @sprintf("better for %s %s (%2d, %2d) x (%2d, %2d): %s", w_a_name, w_b_name, N, M, M, K, which)
166+
println(msg)
167+
println(">> ", new_time, " | ", old_time)
168+
end
169+
end
170+
171+
function pick_best(results, mul_wrappers, size_iter; tol = 1.2)
172+
for N in size_iter
173+
for M in size_iter
174+
wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]]
175+
for K in size_iter
176+
wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]]
177+
for (w_a, w_a_name) in wrappers_a
178+
for (w_b, w_b_name) in wrappers_b
179+
cur_default = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
180+
default_time = results[cur_default]
181+
182+
cur_generic = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
183+
generic_time = results[cur_generic]
184+
judge_this(generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, "generic")
185+
186+
cur_unrolled = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
187+
unrolled_time = results[cur_unrolled]
188+
judge_this(unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, "unrolled")
189+
190+
if w_a_name != "diag " && w_b_name != "diag "
191+
cur_chunks = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
192+
chunk_time = results[cur_chunks]
193+
judge_this(chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, "chunks")
194+
end
195+
if w_a_name == "ident " && w_b_name == "ident "
196+
cur_loop = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K)
197+
loop_time = results[cur_loop]
198+
judge_this(loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, "loop")
199+
end
200+
end
201+
end
202+
end
203+
end
204+
end
205+
end
206+
207+
function run_1()
208+
return full_benchmark(mul_wrappers_reduced, [2, 3, 4, 5, 8, 9, 14, 16])
209+
end
210+
211+
end #module
212+
BenchmarkMatMul.suite

src/SDiagonal.jl

-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ size(::Type{<:SDiagonal{N}}) where {N} = (N,N)
1818
size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N
1919

2020
# define specific methods to avoid allocating mutable arrays
21-
*(A::StaticMatrix, D::SDiagonal) = A .* transpose(D.diag)
22-
*(D::SDiagonal, A::StaticMatrix) = D.diag .* A
2321
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
2422
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity
2523

0 commit comments

Comments
 (0)