-
Notifications
You must be signed in to change notification settings - Fork 29
Closed
Description
The fully templated GEMM code (below) references the input value of C(i,j) even when beta == 0.0. As a result, if C isn't initialized and contains a NaN, then that NaN can still be propagated through to the output when beta == 0.0.
Lines 166 to 265 in e954a9b
| // alpha != zero | |
| if (transA == Op::NoTrans) { | |
| if (transB == Op::NoTrans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) *= beta; | |
| for (int64_t l = 0; l < k; ++l) { | |
| scalar_t alpha_Blj = alpha*B(l, j); | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) += A(i, l)*alpha_Blj; | |
| } | |
| } | |
| } | |
| else if (transB == Op::Trans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) *= beta; | |
| for (int64_t l = 0; l < k; ++l) { | |
| scalar_t alpha_Bjl = alpha*B(j, l); | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) += A(i, l)*alpha_Bjl; | |
| } | |
| } | |
| } | |
| else { // transB == Op::ConjTrans | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) *= beta; | |
| for (int64_t l = 0; l < k; ++l) { | |
| scalar_t alpha_Bjl = alpha*conj(B(j, l)); | |
| for (int64_t i = 0; i < m; ++i) | |
| C(i, j) += A(i, l)*alpha_Bjl; | |
| } | |
| } | |
| } | |
| } | |
| else if (transA == Op::Trans) { | |
| if (transB == Op::NoTrans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += A(l, i)*B(l, j); | |
| C(i, j) = alpha*sum + beta*C(i, j); | |
| } | |
| } | |
| } | |
| else if (transB == Op::Trans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += A(l, i)*B(j, l); | |
| C(i, j) = alpha*sum + beta*C(i, j); | |
| } | |
| } | |
| } | |
| else { // transB == Op::ConjTrans | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += A(l, i)*conj(B(j, l)); | |
| C(i, j) = alpha*sum + beta*C(i, j); | |
| } | |
| } | |
| } | |
| } | |
| else { // transA == Op::ConjTrans | |
| if (transB == Op::NoTrans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += conj(A(l, i))*B(l, j); | |
| C(i, j) = alpha*sum + beta*C(i, j); | |
| } | |
| } | |
| } | |
| else if (transB == Op::Trans) { | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += conj(A(l, i))*B(j, l); | |
| C(i, j) = alpha*sum + beta*C(i, j); | |
| } | |
| } | |
| } | |
| else { // transB == Op::ConjTrans | |
| for (int64_t j = 0; j < n; ++j) { | |
| for (int64_t i = 0; i < m; ++i) { | |
| scalar_t sum = zero; | |
| for (int64_t l = 0; l < k; ++l) | |
| sum += A(l, i)*B(j, l); // little improvement here | |
| C(i, j) = alpha*conj(sum) + beta*C(i, j); | |
| } | |
| } | |
| } | |
| } |
@weslleyspereira, recently I suggested that tests for this situation be added to your BLAS stress tests. Once you add such tests, it would be good to run against BLAS++. It might be that other fully-templated functions in BLAS++ make similar mistakes.
weslleyspereira
Metadata
Metadata
Assignees
Labels
No labels