diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 31200b6190..dc3acace35 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index f04f254acc..e1325eb352 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1d48ad0cd2..e458ca26e5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - + auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} @@ -780,11 +762,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endmacro %} {%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} - {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} + {%- for grad_type in ['float', 'at::Half'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 186a9d529f..c050e0329b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -852,8 +860,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1187,18 +1194,17 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..97758db773 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -179,7 +166,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -221,22 +208,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -244,24 +225,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,24 +261,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,22 +295,17 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} + table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -346,24 +314,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,24 +341,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,14 +374,13 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + // load_row_per_warp::run( + // &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + load_row_per_warp_v2::run(&grad_data[0], bag_index * num_tables, + p_output_grad + table_index, lane_id, embedding_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -435,14 +394,13 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + + // load_row_per_warp::run( + // &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + load_row_per_warp_v2::run(&grad_data[0], bag_index * num_tables, + p_output_grad + table_index, lane_id, embedding_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,8 +410,9 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( - &emb_data[0], emb_idx, p_emb_table, lane_id); + // load_row_per_warp::run( + // &emb_data[0], emb_idx, p_emb_table, lane_id); + load_row_per_warp_v2::run(emb_data, emb_idx, p_emb_table, lane_id, embedding_dim); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..20c055e917 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,6 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..d69d685136 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -808,7 +808,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2262,6 +2262,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2280,6 +2281,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2289,6 +2291,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..384e619b22 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2016 - 2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2016 - 2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -68,11 +73,23 @@ __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32_t soffset, int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); +__device__ float2 llvm_amdgcn_raw_buffer_load_fp32x2( + int32x4_t srsrc, + int32_t voffset, + int32_t soffset, + int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + + __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, @@ -89,6 +106,236 @@ __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ +// template +// struct load_row_per_warp_v2; +// { +// __device__ __forceinline__ static void +// run(data_t *dst, index_t row_index, const data_t *src, int lane_id, int dim) { +// // static_assert(std::is_same_v || +// // std::is_same_v || +// // std::is_same_v, +// // "Template parameters are not supported"); +// } +// }; + + + +// template +// struct load_row_per_warp_v2 { +// __device__ __forceinline__ static void +// run(data_t *dst, index_t row_index, const data_t *src, int lane_id, int dim) { +// // static_assert(std::is_same_v || +// // std::is_same_v || +// // std::is_same_v, +// // "Template parameters are not supported"); +// } +// }; + +// template +// __device__ __forceinline__ void +// load_row_per_warp(data_t *dst, index_t row_index, const data_t *src, +// int lane_id, int dim) {} + +// template +// struct load_row_per_warp_v2< +// emb_t, k_elements_per_thread, index_t, +// typename std::enable_if_t || +// std::is_same_v, +// half>> { +// __device__ __forceinline__ static void run(emb_t *dst, index_t row_index, +// const emb_t *src, int lane_id, +// int emb_dim) { +// static_assert(k_elements_per_thread > 0); + +// const auto row_offset = row_index * emb_dim; +// int32x4_t emb_res = amdgcn_make_buffer_resource(src + row_index * emb_dim); + +// #pragma unroll +// for (int idx = 0; idx < (k_elements_per_thread - 1) / 2; ++idx) { +// *reinterpret_cast(dst + idx * 2) = +// llvm_amdgcn_raw_buffer_load_fp32x2( +// emb_res, (lane_id + idx * warpSize) * sizeof(float2), 0, 0); +// } + +// if constexpr (k_elements_per_thread % 2 == 0) { +// constexpr int k_tailing_even_idx = (k_elements_per_thread - 2); +// dst[k_tailing_even_idx] = llvm_amdgcn_raw_buffer_load_fp32( +// emb_res, (lane_id + k_tailing_even_idx * warpSize) * sizeof(float), 0, +// 0); +// } + +// constexpr int k_tailing_last_idx = k_elements_per_thread - 1; +// const auto tailing_idx = lane_id + k_tailing_last_idx * warpSize; +// dst[k_tailing_last_idx] = +// tailing_idx < emb_dim ? llvm_amdgcn_raw_buffer_load_fp32( +// emb_res, tailing_idx * sizeof(float), 0, 0) +// : 0.f; +// } +// }; + +// template +// struct load_row_per_warp_v2 { +// __device__ __forceinline__ static void +// run(float *dst, index_t row_index, const float *src, int lane_id, int emb_dim) { +// static_assert(k_elements_per_thread > 0); + +// const auto row_offset = row_index * emb_dim; +// int32x4_t emb_res = amdgcn_make_buffer_resource(src + row_index * emb_dim); + +// #pragma unroll +// for (int idx = 0; idx < (k_elements_per_thread - 1) / 2; ++idx) { +// *reinterpret_cast(dst + idx * 2) = +// llvm_amdgcn_raw_buffer_load_fp32x2( +// emb_res, (lane_id + idx * warpSize) * sizeof(float2), 0, 0); +// } + +// if constexpr (k_elements_per_thread % 2 == 0) { +// constexpr int k_tailing_even_idx = (k_elements_per_thread - 2); +// dst[k_tailing_even_idx] = llvm_amdgcn_raw_buffer_load_fp32( +// emb_res, (lane_id + k_tailing_even_idx * warpSize) * sizeof(float), 0, +// 0); +// } + +// constexpr int k_tailing_last_idx = k_elements_per_thread - 1; +// const auto tailing_idx = lane_id + k_tailing_last_idx * warpSize; +// dst[k_tailing_last_idx] = +// tailing_idx < emb_dim ? llvm_amdgcn_raw_buffer_load_fp32( +// emb_res, tailing_idx * sizeof(float), 0, 0) +// : 0.f; +// } +// }; + +template +struct load_row_per_warp_v2 { + __device__ __forceinline__ static void + run(float *dst, index_t row_index, const float *src, int lane_id, int emb_dim) { + static_assert(k_elements_per_thread > 0); + + const auto row_offset = row_index * emb_dim; + int32x4_t emb_res = amdgcn_make_buffer_resource(src + row_index * emb_dim); + +#pragma unroll + for (int idx = 0; idx < (k_elements_per_thread - 1) / 2; ++idx) { + *reinterpret_cast(dst + idx * 2) = + llvm_amdgcn_raw_buffer_load_fp32x2( + emb_res, (lane_id + idx * warpSize) * sizeof(float2), 0, 0); + } + + if constexpr (k_elements_per_thread % 2 == 0) { + constexpr int k_tailing_even_idx = (k_elements_per_thread - 2); + dst[k_tailing_even_idx] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + k_tailing_even_idx * warpSize) * sizeof(float), 0, + 0); + } + + constexpr int k_tailing_last_idx = k_elements_per_thread - 1; + const auto tailing_idx = lane_id + k_tailing_last_idx * warpSize; + dst[k_tailing_last_idx] = + tailing_idx < emb_dim ? llvm_amdgcn_raw_buffer_load_fp32( + emb_res, tailing_idx * sizeof(float), 0, 0) + : 0.f; + } + + __device__ __forceinline__ static void + run(half *dst, index_t row_index, const half *src, int lane_id, int emb_dim) { + static_assert(k_elements_per_thread > 0); + + const auto row_offset = row_index * emb_dim; + int32x4_t emb_res = amdgcn_make_buffer_resource(src + row_index * emb_dim); + +#pragma unroll + for (int idx = 0; idx < (k_elements_per_thread - 1) / 2; ++idx) { + *reinterpret_cast(dst + idx * 2) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + idx * warpSize) * sizeof(half2), 0, 0); + } + + if constexpr (k_elements_per_thread % 2 == 0) { + constexpr int k_tailing_even_idx = (k_elements_per_thread - 2); + dst[k_tailing_even_idx] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + k_tailing_even_idx * warpSize) * sizeof(half), 0, + 0); + } + + constexpr int k_tailing_last_idx = k_elements_per_thread - 1; + const auto tailing_idx = lane_id + k_tailing_last_idx * warpSize; + dst[k_tailing_last_idx] = + tailing_idx < emb_dim ? llvm_amdgcn_raw_buffer_load_fp16( + emb_res, tailing_idx * sizeof(half), 0, 0) + : half(0.f); + } + + __device__ __forceinline__ static void run(c10::Half *dst, index_t row_index, + const c10::Half *src, int lane_id, + int emb_dim) { + // load_row_per_warp_v2:: + run( + reinterpret_cast(dst), row_index, + reinterpret_cast(src), lane_id, emb_dim); + } + +}; + +// template +// struct load_row_per_warp_v2 { +// __device__ __forceinline__ static void +// run(half *dst, index_t row_index, const half *src, int lane_id, int emb_dim) { +// static_assert(k_elements_per_thread > 0); + +// const auto row_offset = row_index * emb_dim; +// int32x4_t emb_res = amdgcn_make_buffer_resource(src + row_index * emb_dim); + +// #pragma unroll +// for (int idx = 0; idx < (k_elements_per_thread - 1) / 2; ++idx) { +// *reinterpret_cast(dst + idx * 2) = +// llvm_amdgcn_raw_buffer_load_fp16x2( +// emb_res, (lane_id + idx * warpSize) * sizeof(half2), 0, 0); +// } + +// if constexpr (k_elements_per_thread % 2 == 0) { +// constexpr int k_tailing_even_idx = (k_elements_per_thread - 2); +// dst[k_tailing_even_idx] = llvm_amdgcn_raw_buffer_load_fp16( +// emb_res, (lane_id + k_tailing_even_idx * warpSize) * sizeof(half), 0, +// 0); +// } + +// constexpr int k_tailing_last_idx = k_elements_per_thread - 1; +// const auto tailing_idx = lane_id + k_tailing_last_idx * warpSize; +// dst[k_tailing_last_idx] = +// tailing_idx < emb_dim ? llvm_amdgcn_raw_buffer_load_fp16( +// emb_res, tailing_idx * sizeof(half), 0, 0) +// : half(0.f); +// } +// }; + +// template +// struct load_row_per_warp_v2 { +// __device__ __forceinline__ static void run(c10::Half *dst, index_t row_index, +// const c10::Half *src, int lane_id, +// int emb_dim) { +// load_row_per_warp_v2::run( +// reinterpret_cast(dst), row_index, +// reinterpret_cast(src), lane_id, emb_dim); +// } +// }; + +// template +// struct load_row_per_warp_v2 { +// __device__ __forceinline__ static void run(c10::Half *dst, index_t row_index, +// const c10::Half *src, int lane_id, +// int emb_dim) { +// load_row_per_warp_v2::run( +// reinterpret_cast(dst), row_index, +// reinterpret_cast(src), lane_id, emb_dim); +// } +// }; +// __device__ __forceinline__ void +// load_row_per_warp( +// c10::Half *dst, index_t row_index, const c10::Half *src, int lane_id, +// int emb_dim) { +// load_row_per_warp(reinterpret_cast(dst), row_index, +// reinterpret_cast(src), lane_id, emb_dim); +// } template struct load_row_per_warp { @@ -194,6 +441,22 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + } +}; + template struct load_row_per_warp { static __device__ void @@ -215,6 +478,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -233,7 +514,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } @@ -259,6 +547,26 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + p_output[lane_id + 256] = acc[4]; + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { @@ -471,7 +779,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 35d2d87fa5..c1812123ec 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs,