Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions fbgemm_gpu/cmake/Hip.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ if(HIP_FOUND)
list(APPEND HIP_CXX_FLAGS -mf16c)
list(APPEND HIP_CXX_FLAGS -mfma)
list(APPEND HIP_CXX_FLAGS -std=c++20)
list(APPEND HIP_CXX_FLAGS -g)
list(APPEND HIP_CXX_FLAGS -ggdb)

# list(APPEND HIP_CXX_FLAGS -Wa,-adhln)
#list(APPEND HIP_CXX_FLAGS -adhln)
list(APPEND HIP_CXX_FLAGS -save-temps)
list(APPEND HIP_CXX_FLAGS -fverbose-asm)


set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS})
# Ask hcc to generate device code during compilation so we can use
Expand Down
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@

{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
Expand Down Expand Up @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row

{%- endif %}

{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
Expand Down Expand Up @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
{%- endif %}
) {
{%- if not nobag %}
int32_t T = D_offsets.size(0) - 1;
{%- else %}
int32_t T = weights_offsets.size(0);
{%- endif %}


auto p_output_grad = grad_output.data();
auto p_emb_table = dev_weights.data();
auto p_hash_size_cumsum = hash_size_cumsum.data();
Expand All @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
auto batch = grad_output.size(0);
auto num_rows = dev_weights.size(0) / T / max_D;
{%- if weighted %}
constexpr bool is_weighted = true;
{%- else %}
Expand All @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
opt_karg.weight_decay_mode = weight_decay_mode_v;
opt_karg.weight_decay = weight_decay;
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
assert(d >= 1 && d <= INT32_MAX);
uint8_t shift;
for(shift = 0; shift < 32; shift++)
if((1U << shift) >= d)
break;

uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
assert(magic <= 0xffffffffUL);

rocm::magic_div_u32_t result;
result.magic = magic;
result.shift = shift;
return result;
}(batch);

rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_kernel_arg_t,
Expand All @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
p_sorted_linear_indices_run,
p_sorted_linear_indices_cumulative_run_lengths,
p_sorted_linear_indices_num_runs,
{%- if not nobag %}
info_B_num_bits,
info_B_mask,
{%- endif %}
p_sorted_infos,
batch_mdiv,
max_segment_length_per_warp,
emb_dim,
batch,
num_rows,
T,
opt_karg
{%- if weighted %}
Expand Down Expand Up @@ -780,11 +762,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- endmacro %}

{%- macro hip_bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %}
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for grad_type in ['float', 'at::Half'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ using namespace fbgemm_gpu;
has_global_weight_decay_support,
ssd) %}
{%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %}
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}

template <
typename emb_t,
typename grad_t,
Expand Down Expand Up @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
{%- endif %}
);

{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select
and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
template <
typename emb_t,
Expand Down Expand Up @@ -852,8 +860,7 @@ Tensor {{ embedding_cuda_op }}(
}
{%- endif %}

{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select
and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
{%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format(
ndesc,
optimizer,
Expand Down Expand Up @@ -1187,18 +1194,17 @@ Tensor {{ embedding_cuda_op }}(
get_max_thread_blocks_());

#ifdef USE_ROCM
{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and
not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %}
{%- if is_optimized_hip_kernel_supported_mode %}

const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL);

const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half
|| dev_weights.scalar_type() == at::ScalarType::Float;

if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna())
if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna())
{
constexpr int segments_per_workgroup = 4;
{%- for kDimSize in [64, 128, 160, 192, 256] %}
{%- for kDimSize in [64, 128, 160, 192, 256, 320] %}
{%- for kWeightDecayMode in [0, 1, 2] %}
if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
{
Expand Down
Loading