diff --git a/CMakeLists.txt b/CMakeLists.txt index ce6b348c3ea3..9479b8f34d03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,11 +185,17 @@ else(MSVC) endif(NOT APPLE) endif(MSVC) -if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") +if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)|(aarch64)|(AARCH64)") message(STATUS "Disabling LIBXSMM on ${CMAKE_SYSTEM_PROCESSOR}.") set(USE_LIBXSMM OFF) endif() +# Flag for arm specific optimization +if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)") + message(STATUS "setting flag for arm specific optimizations = ${CMAKE_SYSTEM_PROCESSOR} to ON.") + add_definitions(-DAARCH64) +endif() + # Source file lists file(GLOB DGL_SRC src/*.cc diff --git a/src/array/cpu/spmm.h b/src/array/cpu/spmm.h index aea4a0a5895a..e2760da28d55 100644 --- a/src/array/cpu/spmm.h +++ b/src/array/cpu/spmm.h @@ -141,10 +141,21 @@ void SpMMSumCsr( } #if !defined(_WIN32) #ifdef USE_LIBXSMM - int cpu_id = libxsmm_cpuid_x86(); + int cpu_id, limit; +#ifdef AARCH64 + static int arm_cpu_id = -1; + if (arm_cpu_id == -1){ + arm_cpu_id = libxsmm_cpuid_arm(); + } + cpu_id = arm_cpu_id; + limit = LIBXSMM_AARCH64_A64FX; +#else //x86 + cpu_id = libxsmm_cpuid_x86(); + limit = LIBXSMM_X86_AVX512; +#endif//AARCH64 const bool no_libxsmm = bcast.use_bcast || std::is_same::value || - (std::is_same::value && cpu_id < LIBXSMM_X86_AVX512) || + (std::is_same::value && cpu_id < limit) || !dgl::runtime::Config::Global()->IsLibxsmmAvailable(); if (!no_libxsmm) { SpMMSumCsrLibxsmm(bcast, csr, ufeat, efeat, out); @@ -266,10 +277,23 @@ void SpMMCmpCsr( } #if !defined(_WIN32) #ifdef USE_LIBXSMM - int cpu_id = libxsmm_cpuid_x86(); +#ifdef AARCH64 + static int arm_cpu_id = -1; + if (arm_cpu_id == -1){ + arm_cpu_id = libxsmm_cpuid_arm(); + } +#endif//AARCH64 + int cpu_id, limit; +#ifdef AARCH64 + cpu_id = arm_cpu_id; + limit = LIBXSMM_AARCH64_A64FX; +#else //x86 + cpu_id = libxsmm_cpuid_x86(); + limit = LIBXSMM_AARCH64_A64FX; +#endif//AARCH64 const bool no_libxsmm = bcast.use_bcast || std::is_same::value || - cpu_id < LIBXSMM_X86_AVX512 || + cpu_id < limit || !dgl::runtime::Config::Global()->IsLibxsmmAvailable(); if (!no_libxsmm) { SpMMCmpCsrLibxsmm( diff --git a/src/array/cpu/spmm_blocking_libxsmm.h b/src/array/cpu/spmm_blocking_libxsmm.h index de3579fbf304..29344dfb6164 100644 --- a/src/array/cpu/spmm_blocking_libxsmm.h +++ b/src/array/cpu/spmm_blocking_libxsmm.h @@ -44,7 +44,7 @@ struct CSRMatrixInternal { int32_t GetLLCSize() { #ifdef _SC_LEVEL3_CACHE_SIZE int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE); - if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE; + if (cache_size <= 0) cache_size = DGL_CPU_LLC_SIZE; #else int32_t cache_size = DGL_CPU_LLC_SIZE; #endif diff --git a/src/runtime/config.cc b/src/runtime/config.cc index 70d0d6d56659..25b7c0819cd9 100644 --- a/src/runtime/config.cc +++ b/src/runtime/config.cc @@ -17,9 +17,21 @@ namespace runtime { Config::Config() { #if !defined(_WIN32) && defined(USE_LIBXSMM) - int cpu_id = libxsmm_cpuid_x86(); + int cpu_id; + #if defined(AARCH64) + static int arm_cpu_id = -1; + if (arm_cpu_id == -1){ + arm_cpu_id = libxsmm_cpuid_arm(); + } + cpu_id = arm_cpu_id; + // Enable libxsmm on ARM machines by default + libxsmm_ = LIBXSMM_AARCH64_SVE128 <= cpu_id && cpu_id <= LIBXSMM_AARCH64_ALLFEAT; + #else + cpu_id = libxsmm_cpuid_x86(); + // Enable libxsmm on AVX machines by default libxsmm_ = LIBXSMM_X86_AVX2 <= cpu_id && cpu_id <= LIBXSMM_X86_ALLFEAT; + #endif //AARCH64 #else libxsmm_ = false; #endif