|
| 1 | +module BenchmarkMatMul |
| 2 | + |
| 3 | +using StaticArrays |
| 4 | +using BenchmarkTools |
| 5 | +using LinearAlgebra |
| 6 | +using Printf |
| 7 | + |
| 8 | +suite = BenchmarkGroup() |
| 9 | + |
| 10 | +mul_wrappers = [ |
| 11 | + (m -> m, "ident "), |
| 12 | + (m -> Symmetric(m, :U), "sym-u "), |
| 13 | + (m -> Hermitian(m, :U), "herm-u "), |
| 14 | + (m -> UpperTriangular(m), "up-tri "), |
| 15 | + (m -> LowerTriangular(m), "lo-tri "), |
| 16 | + (m -> UnitUpperTriangular(m), "uup-tri"), |
| 17 | + (m -> UnitLowerTriangular(m), "ulo-tri"), |
| 18 | + (m -> Adjoint(m), "adjoint"), |
| 19 | + (m -> Transpose(m), "transpo"), |
| 20 | + (m -> Diagonal(m), "diag ")] |
| 21 | + |
| 22 | +mul_wrappers_reduced = [ |
| 23 | + (m -> m, "ident "), |
| 24 | + (m -> Symmetric(m, :U), "sym-u "), |
| 25 | + (m -> UpperTriangular(m), "up-tri "), |
| 26 | + (m -> Transpose(m), "transpo"), |
| 27 | + (m -> Diagonal(m), "diag ")] |
| 28 | + |
| 29 | +for N in [2, 4, 8, 10, 16] |
| 30 | + |
| 31 | + matvecstr = @sprintf("mat-vec %2d", N) |
| 32 | + matmatstr = @sprintf("mat-mat %2d", N) |
| 33 | + matvec_mut_str = @sprintf("mat-vec! %2d", N) |
| 34 | + matmat_mut_str = @sprintf("mat-mat! %2d", N) |
| 35 | + |
| 36 | + suite[matvecstr] = BenchmarkGroup() |
| 37 | + suite[matmatstr] = BenchmarkGroup() |
| 38 | + suite[matvec_mut_str] = BenchmarkGroup() |
| 39 | + suite[matmat_mut_str] = BenchmarkGroup() |
| 40 | + |
| 41 | + |
| 42 | + A = randn(SMatrix{N,N,Float64}) |
| 43 | + B = randn(SMatrix{N,N,Float64}) |
| 44 | + bv = randn(SVector{N,Float64}) |
| 45 | + for (wrapper_a, wrapper_name) in mul_wrappers_reduced |
| 46 | + thrown = false |
| 47 | + try |
| 48 | + wrapper_a(A) * bv |
| 49 | + catch e |
| 50 | + thrown = true |
| 51 | + end |
| 52 | + if !thrown |
| 53 | + suite[matvecstr][wrapper_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(bv))[] |
| 54 | + end |
| 55 | + end |
| 56 | + |
| 57 | + for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers |
| 58 | + thrown = false |
| 59 | + try |
| 60 | + wrapper_a(A) * wrapper_b(B) |
| 61 | + catch e |
| 62 | + thrown = true |
| 63 | + end |
| 64 | + if !thrown |
| 65 | + suite[matmatstr][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable $(Ref(wrapper_a(A)))[] * $(Ref(wrapper_b(B)))[] |
| 66 | + end |
| 67 | + end |
| 68 | + |
| 69 | + C = randn(MMatrix{N,N,Float64}) |
| 70 | + cv = randn(MVector{N,Float64}) |
| 71 | + |
| 72 | + for (wrapper_a, wrapper_name) in mul_wrappers |
| 73 | + thrown = false |
| 74 | + try |
| 75 | + mul!(cv, wrapper_a(A), bv) |
| 76 | + catch e |
| 77 | + thrown = true |
| 78 | + end |
| 79 | + if !thrown |
| 80 | + suite[matvec_mut_str][wrapper_name] = @benchmarkable mul!($cv, $(Ref(wrapper_a(A)))[], $(Ref(bv))[]) |
| 81 | + end |
| 82 | + end |
| 83 | + |
| 84 | + for (wrapper_a, wrapper_a_name) in mul_wrappers, (wrapper_b, wrapper_b_name) in mul_wrappers |
| 85 | + thrown = false |
| 86 | + try |
| 87 | + mul!(C, wrapper_a(A), wrapper_b(B)) |
| 88 | + catch e |
| 89 | + thrown = true |
| 90 | + end |
| 91 | + if !thrown |
| 92 | + suite[matmat_mut_str][wrapper_a_name * " * " * wrapper_b_name] = @benchmarkable mul!($C, $(Ref(wrapper_a(A)))[], $(Ref(wrapper_b(B)))[]) |
| 93 | + end |
| 94 | + end |
| 95 | +end |
| 96 | + |
| 97 | +function run_and_save(fname, make_params = true) |
| 98 | + if make_params |
| 99 | + tune!(suite) |
| 100 | + BenchmarkTools.save("params.json", params(suite)) |
| 101 | + else |
| 102 | + loadparams!(suite, BenchmarkTools.load("params.json")[1], :evals, :samples) |
| 103 | + end |
| 104 | + results = run(suite, verbose = true) |
| 105 | + BenchmarkTools.save(fname, results) |
| 106 | +end |
| 107 | + |
| 108 | +function judge_results(m1, m2) |
| 109 | + results = Any[] |
| 110 | + for key1 in keys(m1) |
| 111 | + if !haskey(m2, key1) |
| 112 | + continue |
| 113 | + end |
| 114 | + for key2 in keys(m1[key1]) |
| 115 | + if !haskey(m2[key1], key2) |
| 116 | + continue |
| 117 | + end |
| 118 | + push!(results, (key1, key2, judge(median(m1[key1][key2]), median(m2[key1][key2])))) |
| 119 | + end |
| 120 | + end |
| 121 | + return results |
| 122 | +end |
| 123 | + |
| 124 | +function full_benchmark(mul_wrappers, size_iter = 1:4, T = Float64) |
| 125 | + suite_full = BenchmarkGroup() |
| 126 | + for N in size_iter |
| 127 | + for M in size_iter |
| 128 | + a = randn(SMatrix{N,M,T}) |
| 129 | + wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]] |
| 130 | + sa = Size(a) |
| 131 | + for K in size_iter |
| 132 | + b = randn(SMatrix{M,K,T}) |
| 133 | + wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]] |
| 134 | + sb = Size(b) |
| 135 | + for (w_a, w_a_name) in wrappers_a |
| 136 | + for (w_b, w_b_name) in wrappers_b |
| 137 | + cur_str = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 138 | + suite_full[cur_str] = @benchmarkable StaticArrays.mul_generic($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) |
| 139 | + cur_str = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 140 | + suite_full[cur_str] = @benchmarkable StaticArrays._mul($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) |
| 141 | + cur_str = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 142 | + suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) |
| 143 | + if w_a_name != "diag " && w_b_name != "diag " |
| 144 | + cur_str = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 145 | + suite_full[cur_str] = @benchmarkable StaticArrays.mul_unrolled_chunks($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) |
| 146 | + end |
| 147 | + if w_a_name == "ident " && w_b_name == "ident " |
| 148 | + cur_str = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 149 | + suite_full[cur_str] = @benchmarkable StaticArrays.mul_loop($sa, $sb, $(Ref(w_a(a)))[], $(Ref(w_b(b)))[]) |
| 150 | + end |
| 151 | + end |
| 152 | + end |
| 153 | + end |
| 154 | + end |
| 155 | + end |
| 156 | + results = run(suite_full, verbose = true) |
| 157 | + results_median = map(collect(results)) do res |
| 158 | + return (res[1], median(res[2]).time) |
| 159 | + end |
| 160 | + return results_median |
| 161 | +end |
| 162 | + |
| 163 | +function judge_this(new_time, old_time, tol, w_a_name, w_b_name, N, M, K, which) |
| 164 | + if new_time*tol < old_time |
| 165 | + msg = @sprintf("better for %s %s (%2d, %2d) x (%2d, %2d): %s", w_a_name, w_b_name, N, M, M, K, which) |
| 166 | + println(msg) |
| 167 | + println(">> ", new_time, " | ", old_time) |
| 168 | + end |
| 169 | +end |
| 170 | + |
| 171 | +function pick_best(results, mul_wrappers, size_iter; tol = 1.2) |
| 172 | + for N in size_iter |
| 173 | + for M in size_iter |
| 174 | + wrappers_a = N == M ? mul_wrappers : [mul_wrappers[1]] |
| 175 | + for K in size_iter |
| 176 | + wrappers_b = M == K ? mul_wrappers : [mul_wrappers[1]] |
| 177 | + for (w_a, w_a_name) in wrappers_a |
| 178 | + for (w_b, w_b_name) in wrappers_b |
| 179 | + cur_default = @sprintf("mat-mat %s %s default (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 180 | + default_time = results[cur_default] |
| 181 | + |
| 182 | + cur_generic = @sprintf("mat-mat %s %s generic (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 183 | + generic_time = results[cur_generic] |
| 184 | + judge_this(generic_time, default_time, tol, w_a_name, w_b_name, N, M, K, "generic") |
| 185 | + |
| 186 | + cur_unrolled = @sprintf("mat-mat %s %s unrolled (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 187 | + unrolled_time = results[cur_unrolled] |
| 188 | + judge_this(unrolled_time, default_time, tol, w_a_name, w_b_name, N, M, K, "unrolled") |
| 189 | + |
| 190 | + if w_a_name != "diag " && w_b_name != "diag " |
| 191 | + cur_chunks = @sprintf("mat-mat %s %s chunks (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 192 | + chunk_time = results[cur_chunks] |
| 193 | + judge_this(chunk_time, default_time, tol, w_a_name, w_b_name, N, M, K, "chunks") |
| 194 | + end |
| 195 | + if w_a_name == "ident " && w_b_name == "ident " |
| 196 | + cur_loop = @sprintf("mat-mat %s %s loop (%2d, %2d) x (%2d, %2d)", w_a_name, w_b_name, N, M, M, K) |
| 197 | + loop_time = results[cur_loop] |
| 198 | + judge_this(loop_time, default_time, tol, w_a_name, w_b_name, N, M, K, "loop") |
| 199 | + end |
| 200 | + end |
| 201 | + end |
| 202 | + end |
| 203 | + end |
| 204 | + end |
| 205 | +end |
| 206 | + |
| 207 | +function run_1() |
| 208 | + return full_benchmark(mul_wrappers_reduced, [2, 3, 4, 5, 8, 9, 14, 16]) |
| 209 | +end |
| 210 | + |
| 211 | +end #module |
| 212 | +BenchmarkMatMul.suite |
0 commit comments