Skip to content

Commit 9578938

Browse files
committed
test: don't divide by 0 if alpha = beta = 0
1 parent 9b03ad1 commit 9578938

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

test/check_gemm.hh

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void check_gemm(
3636

3737
using std::sqrt;
3838
using std::abs;
39+
using blas::max;
3940
using real_t = blas::real_type<T>;
4041

4142
assert( m >= 0 );
@@ -51,22 +52,25 @@ void check_gemm(
5152
}
5253
}
5354

55+
real_t alpha_ = max( abs( alpha ), 1.0 );
56+
real_t beta_ = max( abs( beta ), 1.0 );
57+
5458
real_t work[1], Cout_norm;
5559
Cout_norm = lapack_lange( "f", m, n, C, ldc, work );
5660
error[0] = Cout_norm;
57-
real_t denom = sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
58-
+ 2 * abs( beta ) * Cnorm;
61+
real_t denom = sqrt( real_t( k ) + 2 ) * alpha_ * Anorm * Bnorm
62+
+ 2 * beta_ * Cnorm;
5963
if (denom != 0) {
6064
error[0] /= denom;
6165
}
6266

6367
if (verbose) {
6468
printf( "error: ||Cout||=%.2e, denom = (sqrt(k=%lld + 2)"
65-
" * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
66-
" + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e, error = %.2e\n",
69+
" * max(|alpha|, 1)=%.2e * ||A||=%.2e * ||B||=%.2e"
70+
" + 2 * max(|beta|, 1)=%.2e * ||C||=%.2e) = %.2e, error = %.2e\n",
6771
Cout_norm, llong( k ),
68-
abs( alpha ), Anorm, Bnorm,
69-
abs( beta ), Cnorm, denom, error[0] );
72+
alpha_, Anorm, Bnorm,
73+
beta_, Cnorm, denom, error[0] );
7074
}
7175

7276
// complex needs extra factor; see Higham, 2002, sec. 3.6.
@@ -111,6 +115,7 @@ void check_herk(
111115

112116
using std::sqrt;
113117
using std::abs;
118+
using blas::max;
114119
typedef blas::real_type<T> real_t;
115120

116121
assert( n >= 0 );
@@ -134,26 +139,29 @@ void check_herk(
134139
}
135140
}
136141

142+
real_t alpha_ = max( abs( alpha ), 1.0 );
143+
real_t beta_ = max( abs( beta ), 1.0 );
144+
137145
// For a rank-2k update, this should be
138146
// sqrt(k+3) |alpha| (norm(A)*norm(B^T) + norm(B)*norm(A^T))
139147
// + 3 |beta| norm(C)
140148
// However, so far using the same bound as rank-k works fine.
141149
real_t work[1], Cout_norm;
142150
Cout_norm = lapack_lanhe( "f", to_c_string( uplo ), n, C, ldc, work );
143151
error[0] = Cout_norm;
144-
real_t denom = sqrt( real_t( k ) + 2 ) * abs( alpha ) * Anorm * Bnorm
145-
+ 2 * abs( beta ) * Cnorm;
152+
real_t denom = sqrt( real_t( k ) + 2 ) * alpha_ * Anorm * Bnorm
153+
+ 2 * beta_ * Cnorm;
146154
if (denom != 0) {
147155
error[0] /= denom;
148156
}
149157

150158
if (verbose) {
151159
printf( "error: ||Cout||=%.2e, denom = (sqrt(k=%lld + 2)"
152-
" * |alpha|=%.2e * ||A||=%.2e * ||B||=%.2e"
153-
" + 2 * |beta|=%.2e * ||C||=%.2e) = %.2e, error = %.2e\n",
160+
" * max(|alpha|,1)=%.2e * ||A||=%.2e * ||B||=%.2e"
161+
" + 2 * max(|beta|,1)=%.2e * ||C||=%.2e) = %.2e, error = %.2e\n",
154162
Cout_norm, llong( k ),
155-
abs( alpha ), Anorm, Bnorm,
156-
abs( beta ), Cnorm, denom, error[0] );
163+
alpha_, Anorm, Bnorm,
164+
beta_, Cnorm, denom, error[0] );
157165
}
158166

159167
// complex needs extra factor; see Higham, 2002, sec. 3.6.

0 commit comments

Comments
 (0)