|
| 1 | +diff --git before/hptt/include/hptt_types.h after/hptt/include/hptt_types.h |
| 2 | +index 170288e..ebc5796 100644 |
| 3 | +--- before/hptt/include/hptt_types.h |
| 4 | ++++ after/hptt/include/hptt_types.h |
| 5 | +@@ -1,7 +1,6 @@ |
| 6 | + #pragma once |
| 7 | + |
| 8 | + #include <complex> |
| 9 | +-#include <complex.h> |
| 10 | + |
| 11 | + #define REGISTER_BITS 256 // AVX |
| 12 | + #ifdef HPTT_ARCH_ARM |
| 13 | +diff --git before/hptt/src/hptt.cpp after/hptt/src/hptt.cpp |
| 14 | +index 82d4e73..3018664 100644 |
| 15 | +--- before/hptt/src/hptt.cpp |
| 16 | ++++ after/hptt/src/hptt.cpp |
| 17 | +@@ -180,8 +180,10 @@ void cTensorTranspose( const int *perm, const int dim, |
| 18 | + const float _Complex beta, float _Complex *B, const int *outerSizeB, |
| 19 | + const int numThreads, const int useRowMajor) |
| 20 | + { |
| 21 | ++ const hptt::FloatComplex* calpha = reinterpret_cast<const hptt::FloatComplex*>(&alpha); |
| 22 | ++ const hptt::FloatComplex* cbeta = reinterpret_cast<const hptt::FloatComplex*>(&beta); |
| 23 | + auto plan(std::make_shared<hptt::Transpose<hptt::FloatComplex> >(sizeA, perm, outerSizeA, outerSizeB, dim, |
| 24 | +- (const hptt::FloatComplex*) A, (hptt::FloatComplex) alpha, (hptt::FloatComplex*) B, (hptt::FloatComplex) beta, hptt::ESTIMATE, numThreads, nullptr, useRowMajor)); |
| 25 | ++ (const hptt::FloatComplex*) A, *calpha, (hptt::FloatComplex*) B, *cbeta, hptt::ESTIMATE, numThreads, nullptr, useRowMajor)); |
| 26 | + plan->setConjA(conjA); |
| 27 | + plan->execute(); |
| 28 | + } |
| 29 | +@@ -191,8 +193,10 @@ void zTensorTranspose( const int *perm, const int dim, |
| 30 | + const double _Complex beta, double _Complex *B, const int *outerSizeB, |
| 31 | + const int numThreads, const int useRowMajor) |
| 32 | + { |
| 33 | ++ const hptt::DoubleComplex* calpha = reinterpret_cast<const hptt::DoubleComplex*>(&alpha); |
| 34 | ++ const hptt::DoubleComplex* cbeta = reinterpret_cast<const hptt::DoubleComplex*>(&beta); |
| 35 | + auto plan(std::make_shared<hptt::Transpose<hptt::DoubleComplex> >(sizeA, perm, outerSizeA, outerSizeB, dim, |
| 36 | +- (const hptt::DoubleComplex*) A, (hptt::DoubleComplex) alpha, (hptt::DoubleComplex*) B, (hptt::DoubleComplex) beta, hptt::ESTIMATE, numThreads, nullptr, useRowMajor)); |
| 37 | ++ (const hptt::DoubleComplex*) A, *calpha, (hptt::DoubleComplex*) B, *cbeta, hptt::ESTIMATE, numThreads, nullptr, useRowMajor)); |
| 38 | + plan->setConjA(conjA); |
| 39 | + plan->execute(); |
| 40 | + } |
0 commit comments