diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a5277a906a..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 134a03b983..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9ffaea3a67 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} + int32_t j = 0; + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..7b3b5b653a 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,24 +647,9 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, @@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} @@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 76eba64c99..f07ef5830e --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -652,6 +660,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -852,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -970,8 +987,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1042,7 +1062,22 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; + {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1088,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( @@ -1162,7 +1197,17 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1185,18 +1230,17 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -179,7 +166,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -221,22 +208,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -244,24 +225,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,24 +261,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,22 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -346,24 +313,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,24 +340,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,13 +373,10 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -435,13 +391,10 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,11 +405,11 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 37e774bb49..a3edb6b965 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -459,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + {% if is_rocm %} + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + {%- endif %} + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 0f4814fb41..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -743,6 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -857,7 +858,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -1059,7 +1060,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + int32_t max_segment_length_per_warp = 64; + int32_t total_L = indices.numel(); + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 0a2a7ab0a1..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -820,7 +820,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2517,6 +2517,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2535,6 +2536,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2544,6 +2546,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..5475f74ddd 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,9 +21,14 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -46,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -59,34 +64,70 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -96,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -134,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -145,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -154,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -173,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -187,31 +204,133 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; template -struct load_row_per_warp { +struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); } }; @@ -233,93 +352,156 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; @@ -471,7 +653,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index 0d65c4798a..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755