diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index c7f1fcdb7a..9622ba922a 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -53,7 +53,7 @@ struct int32 { struct int8 { using dtype = int8_t; static constexpr uint32_t id = 9; - static constexpr uint32_t bits = 16; + static constexpr uint32_t bits = 8; static constexpr const char* name = "i8"; }; diff --git a/tests/regression/sgemm_tcu/common.h b/tests/regression/sgemm_tcu/common.h deleted file mode 100644 index a762a4fb2e..0000000000 --- a/tests/regression/sgemm_tcu/common.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _COMMON_H_ -#define _COMMON_H_ - -#include - -#ifndef NUM_THREADS -#define NUM_THREADS 4 -#endif - -#ifndef ITYPE -#define ITYPE fp16 -#endif - -#ifndef OTYPE -#define OTYPE fp32 -#endif - -typedef struct { - uint32_t grid_dim[2]; - uint32_t block_dim[2]; - uint32_t M, N, K; - uint64_t A_addr; - uint64_t B_addr; - uint64_t C_addr; -} kernel_arg_t; - -#endif diff --git a/tests/regression/sgemm_tcu/kernel.cpp b/tests/regression/sgemm_tcu/kernel.cpp deleted file mode 100644 index 1d722cbac7..0000000000 --- a/tests/regression/sgemm_tcu/kernel.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "common.h" -#include -#include - -namespace vt = vortex::tensor; -using ctx = vt::wmma_context; - -void kernel_body(kernel_arg_t *__UNIFORM__ arg) { - auto pA = reinterpret_cast(arg->A_addr); - auto pB = reinterpret_cast(arg->B_addr); - auto pC = reinterpret_cast(arg->C_addr); - - uint32_t M = arg->M; - uint32_t N = arg->N; - uint32_t K = arg->K; - - ctx::fragment_a fragA; - ctx::fragment_b fragB; - ctx::fragment_acc fragC; - - // calculate tile row & column based on block index - uint32_t tile_row = blockIdx.y * ctx::tileM; - uint32_t tile_col = blockIdx.x * ctx::tileN; - - // Initialize accumulator tile to zero - ctx::fill_fragment(fragC, 0); - - for (int i = 0; i < K; i += ctx::tileK) { - auto pTileA = pA + tile_row * K + i; - - // Load A tile - ctx::load_matrix_sync(fragA, pTileA, K); - - // Load B tile - if constexpr (vt::ITYPE::bits < 8) { - // For sub-byte matrix B must be in col-major format - auto pTileB = pB + tile_col * K + i; - ctx::load_matrix_sync(fragB, pTileB, K); - } else { - auto pTileB = pB + i * N + tile_col; - ctx::load_matrix_sync(fragB, pTileB, N); - } - - // Matrix multiply-accumulate: c += a * b - ctx::mma_sync(fragC, fragA, fragB, fragC); - } - - // Store the computed C tile - auto pTileC = pC + tile_row * N + tile_col; - ctx::store_matrix_sync(pTileC, fragC, N); -} - -int main() { - auto arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH); - return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); -} diff --git a/tests/regression/sgemm_tcu/main.cpp b/tests/regression/sgemm_tcu/main.cpp deleted file mode 100644 index ac1c77fc17..0000000000 --- a/tests/regression/sgemm_tcu/main.cpp +++ /dev/null @@ -1,710 +0,0 @@ -#include "common.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define FLOAT_ULP 6 -#define MAX_ERRORS 100 - -#define RT_CHECK(_expr) \ - do { \ - int _ret = _expr; \ - if (0 == _ret) \ - break; \ - printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ - cleanup(); \ - exit(-1); \ - } while (false) - -using namespace vortex; -namespace vt = tensor; - -static bool g_enable_sparse = false; -/////////////////////////////////////////////////////////////////////////////// - -static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) { - // Calculate output size and stride - uint32_t out_bytes = (width * height + 1) / 2; - memset(dst, 0, out_bytes); - uint32_t dst_stride = (height + 1) / 2; // Bytes per column in output - - // For each column in source (which becomes row in destination) - for (uint32_t c = 0; c < width; ++c) { - uint32_t base = c * dst_stride; - - // For each row in source (which becomes column in destination) - for (uint32_t r = 0; r < height; r += 2) { - // Calculate source indices (row-major) - uint32_t idx_even = r * width + c; - uint32_t idx_odd = (r + 1) * width + c; - - // Extract nibbles - consistent with data_accessor_t - uint8_t b_even = src[idx_even / 2]; - uint8_t b_odd = (r + 1 < height) ? src[idx_odd / 2] : 0; - - uint8_t nib_even = (idx_even & 1) ? (b_even >> 4) : (b_even & 0x0F); - uint8_t nib_odd = (r + 1 < height) - ? ((idx_odd & 1) ? (b_odd >> 4) : (b_odd & 0x0F)) - : 0; - - // Pack into destination: even row in low nibble, odd row in high nibble - dst[base + r / 2] = (nib_odd << 4) | nib_even; - } - } -} - -/////////////////////////////////////////////////////////////////////////////// - -template -struct data_accessor_t { - using Type = typename T::dtype; - static Type read(const Type *ptr, uint32_t offset) { - return ptr[offset]; - } - static void write(Type *ptr, uint32_t offset, Type value) { - ptr[offset] = value; - } -}; - -template <> -struct data_accessor_t { - static uint8_t read(const uint8_t *ptr, uint32_t offset) { - uint32_t row_off = offset / 2; - bool odd = offset & 0x1; - uint8_t value8 = ptr[row_off]; - return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble - } - static void write(uint8_t *ptr, uint32_t offset, int32_t value) { - uint32_t row_off = offset / 2; - bool odd = offset & 0x1; - uint8_t old_value = ptr[row_off]; - uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) - : ((old_value & 0xf0) | (value & 0x0f)); - ptr[offset / 2] = new_value; - } -}; - -template <> -struct data_accessor_t { - static uint8_t read(const uint8_t *ptr, uint32_t offset) { - uint32_t row_off = offset / 2; - bool odd = offset & 0x1; - uint8_t value8 = ptr[row_off]; - return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble - } - static void write(uint8_t *ptr, uint32_t offset, int32_t value) { - uint32_t row_off = offset / 2; - bool odd = offset & 0x1; - uint8_t old_value = ptr[row_off]; - uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) - : ((old_value & 0xf0) | (value & 0x0f)); - ptr[offset / 2] = new_value; - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -template -class Comparator {}; - -template <> -class Comparator { -public: - static int8_t generate() { - return (int8_t)rand(); - } - static bool compare(int8_t a, int8_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static uint8_t generate() { - return (uint8_t)rand(); - } - static bool compare(uint8_t a, uint8_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static uint8_t generate() { - return (uint8_t)rand(); // store 2 nibbles in a byte - } - static bool compare(uint8_t a, uint8_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static uint8_t generate() { - return (uint8_t)rand(); // store 2 nibbles in a byte - } - static bool compare(uint8_t a, uint8_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static int32_t generate() { - return (int32_t)rand(); - } - static bool compare(int32_t a, int32_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static uint16_t generate() { - auto fvalue = float(rand()) / RAND_MAX; - return rv_ftoh_s(bit_cast(fvalue), 0, nullptr); - } - static bool compare(uint16_t a, uint16_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static uint16_t generate() { - auto fvalue = float(rand()) / RAND_MAX; - return rv_ftob_s(bit_cast(fvalue), 0, nullptr); - } - static bool compare(uint16_t a, uint16_t b, int index, int errors) { - if (a != b) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); - } - return false; - } - return true; - } -}; - -template <> -class Comparator { -public: - static float generate() { - return static_cast(rand()) / RAND_MAX; - } - static bool compare(float a, float b, int index, int errors) { - union fi_t { - float f; - int32_t i; - }; - fi_t fa, fb; - fa.f = a; - fb.f = b; - auto d = std::abs(fa.i - fb.i); - if (d > FLOAT_ULP) { - if (errors < MAX_ERRORS) { - printf("*** error: [%d] expected=%f, actual=%f\n", index, fb.f, fa.f); - } - return false; - } - return true; - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -template -struct muladd_t { - using stype = typename S::dtype; - using dtype = typename D::dtype; - static dtype eval(stype a, stype b, dtype c) { - return static_cast(a) * static_cast(b) + c; - } -}; - -template <> -struct muladd_t { - static float eval(uint16_t a, uint16_t b, float c) { - auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); - auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); - return fa * fb + c; - } -}; - -template <> -struct muladd_t { - static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { - auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); - auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); - auto fc = bit_cast(rv_htof_s(c, 0, nullptr)); - auto fd = fa * fb + fc; - return rv_ftoh_s(bit_cast(fd), 0, nullptr); - } -}; - -template <> -struct muladd_t { - static float eval(uint16_t a, uint16_t b, float c) { - auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); - auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); - return fa * fb + c; - } -}; - -template <> -struct muladd_t { - static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { - auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); - auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); - auto fc = bit_cast(rv_btof_s(c, 0, nullptr)); - auto fd = fa * fb + fc; - return rv_ftob_s(bit_cast(fd), 0, nullptr); - } -}; - -template <> -struct muladd_t { - static int32_t eval(uint8_t a, uint8_t b, int32_t c) { - int32_t a_val = a & 0xF; - if (a & 0x8) { - a_val |= 0xFFFFFFF0; // sign extend - } - int32_t b_val = b & 0xF; - if (b & 0x8) { - b_val |= 0xFFFFFFF0; // sign extend - } - return a_val * b_val + c; - } -}; - -template <> -struct muladd_t { - static int32_t eval(uint8_t a, uint8_t b, int32_t c) { - int32_t a_val = a & 0xF; - int32_t b_val = b & 0xF; - return a_val * b_val + c; - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -using cfg = vt::wmma_config_t; - -using itype_t = typename vt::ITYPE::dtype; -using otype_t = typename vt::OTYPE::dtype; - -struct SparseMat { - std::vector values; // non-zeros - std::vector meta; // Array of row-masks: 1 byte marks the columns - // of the 4 elements in the block that are non-zero. - // e.g. 0b0101 means 2nd and 4th elements are non-zero. - - uint32_t rows, cols; // original A dims (M × K) -}; - -static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { - uint32_t subbytes = 8 / vt::ITYPE::bits; - uint32_t KS = subbytes ? (K * subbytes) : K; - for (uint32_t m = 0; m < M; ++m) { - for (uint32_t n = 0; n < N; ++n) { - otype_t sum(0); - for (uint32_t k = 0; k < KS; ++k) { - auto a = data_accessor_t::read(A, m * KS + k); - auto b = data_accessor_t::read(B, k * N + n); - sum = muladd_t::eval(a, b, sum); - } - data_accessor_t::write(C, m * N + n, sum); - } - } -} - -/* -static void matmul_cpu_sparseA( - otype_t* C, // [M × N] output - const SparseMat& A, // sparse-A - const itype_t* B, // [K × N] dense-B - uint32_t N) // number of columns of B/C -{ - const uint32_t M = A.rows; - const uint32_t K = A.cols; - - const uint32_t subbytes = 8 / vt::ITYPE::bits; - - // --- helper lambdas to index sparse arrays by row --- - auto row_values = [&](uint32_t m) { - return A.values.data() + m * (K / 2); // two kept per block - }; - auto row_meta = [&](uint32_t m) { - return A.meta .data() + m * (K / 4); - }; - - for (uint32_t m = 0; m < M; ++m) { - - const itype_t* Avals = row_values(m); - const uint8_t* Ameta = row_meta (m); - size_t v_idx = 0; // cursor inside values[] - - for (uint32_t n = 0; n < N; ++n) { - otype_t sum(0); - for (uint32_t blk = 0; blk < K; blk += 4) { - uint8_t mask = *(Ameta++); - assert(mask); - for (uint32_t i = 0; i < 4; ++i) { - if (mask & (1u << i)) { - auto a_val = Avals[v_idx++]; - uint32_t k = blk + i; // logical K index - uint32_t kk = subbytes ? k * subbytes // packed-layout idx - : k; - auto b_val = data_accessor_t::read( - B, kk * N + n); - sum = muladd_t::eval(a_val, b_val, sum); - } - } - } - data_accessor_t::write(C, m * N + n, sum); - } - } -}*/ - -/////////////////////////////////////////////////////////////////////////////// - -const char *kernel_file = "kernel.vxbin"; - -uint32_t xm = 4; -uint32_t xn = 8; -uint32_t xk = 2; - -vx_device_h device = nullptr; -vx_buffer_h A_buffer = nullptr; -vx_buffer_h B_buffer = nullptr; -vx_buffer_h C_buffer = nullptr; -vx_buffer_h krnl_buffer = nullptr; -vx_buffer_h args_buffer = nullptr; -kernel_arg_t kernel_arg = {}; - -std::string last_build_options; - -static void show_usage() { - std::cout << "Vortex Sgemm TCU Test." << std::endl; - std::cout << "Usage: [-m: m] [-n N] [-k: K] [-s] [-h: help]" << std::endl; - std::cout << " -s Enable 2:4 structured sparsity " << std::endl; -} - -static void parse_args(int argc, char **argv) { - int c; - while ((c = getopt(argc, argv, "m:n:k:i:o:hs")) != -1) { - switch (c) { - case 'm': - xm = atoi(optarg); - break; - case 'n': - xn = atoi(optarg); - break; - case 'k': - xk = atoi(optarg); - break; - case 's': - g_enable_sparse = true; - std::cout << "Sparse mode enabled (-s)" << std::endl; - break; - case 'h': - show_usage(); - exit(0); - break; - default: - show_usage(); - exit(-1); - } - } -} - -void cleanup() { - if (device) { - vx_mem_free(A_buffer); - vx_mem_free(B_buffer); - vx_mem_free(C_buffer); - vx_mem_free(krnl_buffer); - vx_mem_free(args_buffer); - vx_dev_close(device); - } -} - - -static SparseMat pruneAndCompressMatrixA(const std::vector& denseA, - uint32_t M, uint32_t K) { - SparseMat out; - out.rows = M; - out.cols = K; - out.values.reserve(M * K / 2); // Select 2 values every 4 values - out.meta.reserve(M * K / 4); // 1 byte for every 4 values - - const itype_t* src = denseA.data(); - - for (uint32_t r = 0; r < M; ++r) { - for (uint32_t c = 0; c < K; c += 4) { - itype_t blk[4] = {src[r * K + c], - src[r * K + c + 1], - src[r * K + c + 2], - src[r * K + c + 3]}; - - uint32_t idx[4] = {0, 1, 2, 3}; - std::sort(idx, idx + 4, - [&](uint32_t a, uint32_t b) { - return std::abs((int)blk[a]) < std::abs((int)blk[b]); - }); //Sort the 4 elements by absolute value, ascending order - - uint8_t keep0 = idx[3]; - uint8_t keep1 = idx[2]; //idx of largest 2 elements - - out.values.push_back(blk[keep0]); - out.values.push_back(blk[keep1]); - - uint8_t m = (1u << keep0) | (1u << keep1); // e.g. 0b0101 - out.meta.push_back(m); - } - } - return out; -} - -void test_pruneA() { - const uint32_t M = 4, K = 8; - std::vector denseA(M * K); - for (auto& v : denseA) v = Comparator::generate(); - - auto spA = pruneAndCompressMatrixA(denseA, M, K); - - std::vector recovered(M * K, 0); - size_t v_idx = 0, m_idx = 0; - for (uint32_t r = 0; r < M; ++r) - for (uint32_t c = 0; c < K; c += 4) { - uint8_t m = spA.meta[m_idx++]; - for (uint32_t i = 0; i < 4; ++i) - if (m & (1u << i)) - recovered[r * K + c + i] = spA.values[v_idx++]; - } - - for (uint32_t i = 0; i < M * K; ++i) - assert(recovered[i] == denseA[i] || recovered[i] == 0); //Either the value is preserved or pruned - std::cout << "pruneAndCompressMatrixA passed\n"; -} - - -int main(int argc, char *argv[]) { - // parse command arguments - parse_args(argc, argv); - - if(g_enable_sparse) { - test_pruneA(); // Test the pruning function - } - - std::srand(50); - - // open device connection - std::cout << "open device connection" << std::endl; - RT_CHECK(vx_dev_open(&device)); - - uint64_t isa_flags; - RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags)); - bool has_ext = (isa_flags & VX_ISA_EXT_TCU) != 0; - if (!has_ext) { - std::cout << "TCU extension not supported!" << std::endl; - cleanup(); - return -1; - } - - uint64_t NT; - RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT)); - if (NT != NUM_THREADS) { - std::cout << "Error: device warp size (" << NT << ") must match NUM_THREADS=" << NUM_THREADS << "!" << std::endl; - return -1; - } - - uint32_t M = xm * cfg::tileM; - uint32_t N = xn * cfg::tileN; - uint32_t K = xk = cfg::tileK; - - if ((M % cfg::tileM) != 0) { - std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl; - return -1; - } - - if ((N % cfg::tileN) != 0) { - std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl; - return -1; - } - - if ((K % cfg::tileK) != 0) { - std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl; - return -1; - } - - size_t sizeA = M * K; - size_t sizeB = K * N; - size_t sizeC = M * N; - - std::cout << "input data type: " << vt::ITYPE::name << " (id=" << vt::ITYPE::id << ")" << std::endl; - std::cout << "output data type: " << vt::OTYPE::name << " (id=" << vt::OTYPE::id << ")" << std::endl; - std::cout << "WMMA Core Dimension: M=" << cfg::tcM << ", N=" << cfg::tcN << ", K=" << cfg::tcK << std::endl; - std::cout << "WMMA Tile Dimension: M=" << cfg::tileM << ", N=" << cfg::tileN << ", K=" << cfg::tileK << std::endl; - std::cout << "matrix A: " << M << "x" << K << std::endl; - std::cout << "matrix B: " << K << "x" << N << std::endl; - std::cout << "matrix C: " << M << "x" << N << std::endl; - - // set block size to warp size - kernel_arg.grid_dim[0] = N / cfg::tileN; - kernel_arg.grid_dim[1] = M / cfg::tileM; - kernel_arg.block_dim[0] = NT; // warp sizeb - kernel_arg.block_dim[1] = 1; - - // set matrix dimensions - kernel_arg.M = M; - kernel_arg.N = N; - kernel_arg.K = K; - - // allocate device memory - std::cout << "allocate device memory" << std::endl; - RT_CHECK(vx_mem_alloc(device, sizeA * sizeof(itype_t), VX_MEM_READ, &A_buffer)); - RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); - RT_CHECK(vx_mem_alloc(device, sizeB * sizeof(itype_t), VX_MEM_READ, &B_buffer)); - RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); - RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); - RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); - - std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; - std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; - std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; - - // generate source data - std::vector h_A(sizeA); - std::vector h_B(sizeB); - for (uint32_t i = 0; i < sizeA; ++i) { - h_A[i] = Comparator::generate(); - } - for (uint32_t i = 0; i < sizeB; ++i) { - h_B[i] = Comparator::generate(); - } - - // upload matrix A buffer - { - std::cout << "upload matrix A buffer" << std::endl; - RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, sizeA * sizeof(itype_t))); - } - - // upload matrix B buffer - { - std::cout << "upload matrix B buffer" << std::endl; - if constexpr (std::is_same::value || std::is_same::value) { - // sub-byte matrix B must be in col-major format - // we convert the 4-bit row-major to col-major here - std::vector h_B_col(sizeB); - convert_row_to_col_major_4bit(h_B_col.data(), N, 2 * K, (uint8_t*)h_B.data()); - RT_CHECK(vx_copy_to_dev(B_buffer, h_B_col.data(), 0, sizeB)); - } else { - RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, sizeB * sizeof(itype_t))); - } - } - - // upload program - std::cout << "upload program" << std::endl; - RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); - - // upload kernel argument - std::cout << "upload kernel argument" << std::endl; - RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); - - auto time_start = std::chrono::high_resolution_clock::now(); - - // start device - std::cout << "start device" << std::endl; - RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); - - // wait for completion - std::cout << "wait for completion" << std::endl; - RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); - - auto time_end = std::chrono::high_resolution_clock::now(); - double elapsed = std::chrono::duration_cast(time_end - time_start).count(); - printf("Elapsed time: %lg ms\n", elapsed); - - // download destination buffer - std::vector h_C(sizeC); - std::cout << "download destination buffer" << std::endl; - RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); - - // verify result - std::cout << "verify result" << std::endl; - int errors = 0; - { - std::vector h_ref(sizeC); - matmul_cpu(h_ref.data(), h_A.data(), h_B.data(), M, N, K); - - for (uint32_t i = 0; i < h_ref.size(); ++i) { - if (!Comparator::compare(h_C[i], h_ref[i], i, errors)) { - ++errors; - } - } - } - - // cleanup - std::cout << "cleanup" << std::endl; - cleanup(); - - if (errors != 0) { - std::cout << "Found " << std::dec << errors << " / " << sizeC << " errors!" << std::endl; - std::cout << "FAILED!" << std::endl; - return errors; - } - - std::cout << "PASSED!" << std::endl; - - return 0; -} \ No newline at end of file