Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
523a317
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
aee3078
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
5a1ac2e
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
7856903
Fix c-style half to float cast
avbokovoy Aug 11, 2025
e1e246a
Patch 256 half stores
avbokovoy Aug 11, 2025
6a99fe0
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
349a7b5
Added mi350 guards
shbiswas834 Aug 11, 2025
1178cd1
Fix index overflow in row load
shbiswas834 Aug 12, 2025
606ad34
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
a22ddeb
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
6775452
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
90e6ba7
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
a9073ac
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
9a16e12
Merge pull request #121
shbiswas834 Aug 28, 2025
68630da
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
bac0610
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
a12112f
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
3ef64f7
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
f601e55
changed the max length per warp and cta per row WG size
Sep 8, 2025
04916da
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
1e09555
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
b41192b
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
2b08f96
opt for grad_indice_weights kernel
Sep 18, 2025
0c26470
added store row per warp on emb 192 and added accuracy test functiona…
kudomcho Sep 23, 2025
d6b491b
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
70ed5e2
specialize
Hardcode84 Sep 19, 2025
cf6a2b1
explicitly link to tbb
liligwu Sep 24, 2025
1be9bd8
added warpReduceAllSum with rocm guards
shbiswas834 Sep 25, 2025
9d3ee64
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
a5a3b1e
Minor update embedding_forward_split_kernel_template.cu
liligwu Oct 13, 2025
28e93c0
add tbb-devel to the install_build_tools ()
liligwu Oct 17, 2025
842846c
fix lint issues
liligwu Oct 21, 2025
97aeb83
solve lint issues
liligwu Oct 21, 2025
00976c7
applied jinja is_rocm onto optimizations for backward and forward par…
kudomcho Oct 22, 2025
4c19030
Guard supported grad_t for optimized warp_per_row dispatch
avbokovoy Oct 23, 2025
9991cf1
Forward index_t to the optimizer
avbokovoy Oct 23, 2025
b61bd19
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
c38ff6f
Fix buffer offset for emb_dim == 160
avbokovoy Oct 23, 2025
e201e8b
Remove sanity check
avbokovoy Oct 27, 2025
aaaf80c
address the potential lint issues and revert the change in indices_ge…
liligwu Oct 27, 2025
b8aea67
addresss code style issue
liligwu Oct 27, 2025
b9a7759
removed guard rocm on mixed_D and refactored mixed_D var assignment
kudomcho Oct 28, 2025
a4b4431
Remove general load/store methods
avbokovoy Oct 24, 2025
5d4f2cd
Move weight type check to compile-time
avbokovoy Oct 24, 2025
d3b7d7a
Switch to 256B stores for float type
avbokovoy Oct 27, 2025
878d00f
removed jinj is_rocm on total_L as USE_ROCM is already applied
kudomcho Nov 3, 2025
d2596c7
Change mixed_D default value to false
avbokovoy Nov 6, 2025
e076556
Make const work_group_size for CUDA
avbokovoy Nov 6, 2025
585300d
Add jinja comments to grad_indice_weights kernel
avbokovoy Nov 6, 2025
e0db2f1
Remove redundand comment
avbokovoy Nov 6, 2025
bf143c7
Unify cuda and rocm loops
avbokovoy Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/utils_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ install_build_tools () {
patchelf \
rhash \
scikit-build \
tbb-devel \
tbb \
wheel \
xz \
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ function(cpp_library)
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
target_link_libraries(${lib_name} PUBLIC TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
target_link_libraries(${lib_name} PUBLIC ${TBB_LIB})
endif()
endif()

# Add sanitizer options if needed
if(args_SANITIZER_OPTIONS)
target_link_options(${lib_name} PUBLIC
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/GpuCppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ function(gpu_cpp_library)
list(APPEND library_dependencies ${NVML_LIB_PATH})
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
list(APPEND library_dependencies TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
list(APPEND library_dependencies ${TBB_LIB})
endif()
endif()

# Link against the external libraries as needed
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})

Expand Down
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no nobag for gen_embedding_backward_split_{}{}_device_kernel_hip.hip?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not now. The nobag support is planned in near future

False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = true */) {
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down Expand Up @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- else %}
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = true
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down
134 changes: 133 additions & 1 deletion fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand Down Expand Up @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
2, offset_idx + D_emb <= weights_numel, offset_idx
)
{%- endif %}
int32_t j = 0;
{%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %}
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
if (placement != PlacementType::MANAGED_CACHING) {
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);

at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;

const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;

Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
weight0 = weight_row0.load(d);
weight1 = weight_row1.load(d);
weight2 = weight_row2.load(d);
weight3 = weight_row3.load(d);

grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
}

grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);

if (threadIdx.x == 0) {
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
}
}
} else {
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);

const auto cache_idx_j0 = shfl_sync(cache_idx, j);
const auto cache_idx_j1 = shfl_sync(cache_idx, j+1);
const auto cache_idx_j2 = shfl_sync(cache_idx, j+2);
const auto cache_idx_j3 = shfl_sync(cache_idx, j+3);

at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;

const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;

Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
weight0 = (cache_idx_j0 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) :
weight_row0.load(d);

weight1 = (cache_idx_j1 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) :
weight_row1.load(d);

weight2 = (cache_idx_j2 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) :
weight_row2.load(d);

weight3 = (cache_idx_j3 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) :
weight_row3.load(d);


grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
}

grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);

for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should use a {% if rocm %} guard to select between rolled / unrolled versions of the loop on regular, non-ROCm paths.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied {% if is_rocm %} onto here

if (threadIdx.x == 0) {
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
}
}
}
{%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#}
for (; j < kWarpSize && l_start + j < L; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 337 and 339, does it really make a difference?

Otherwise, line 337-340 could just be {%- endif %}{#-/* if is_rocm */#}, since line 339 (i.e., for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {` will be common for both rocm and cuda.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in bf143c7. Now it's unified, so CUDA uses it as a main loop while ROCm uses this loop to handle the tailing iterations

const auto offset_idx_j = shfl_sync(offset_idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
Expand Down Expand Up @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);

CUDA_DEVICE_GUARD(dev_weights);
#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@

{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
Expand Down Expand Up @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row

{%- endif %}

{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No nobag support?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not now. The nobag support is planned in near future

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
Expand Down Expand Up @@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
{%- endif %}
) {
{%- if not nobag %}
int32_t T = D_offsets.size(0) - 1;
{%- else %}
int32_t T = weights_offsets.size(0);
{%- endif %}

auto p_output_grad = grad_output.data();
auto p_emb_table = dev_weights.data();
auto p_hash_size_cumsum = hash_size_cumsum.data();
Expand All @@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
auto batch = grad_output.size(0);
auto num_rows = dev_weights.size(0) / T / max_D;
{%- if weighted %}
constexpr bool is_weighted = true;
{%- else %}
Expand All @@ -646,24 +647,9 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
opt_karg.weight_decay_mode = weight_decay_mode_v;
opt_karg.weight_decay = weight_decay;
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
assert(d >= 1 && d <= INT32_MAX);
uint8_t shift;
for(shift = 0; shift < 32; shift++)
if((1U << shift) >= d)
break;

uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
assert(magic <= 0xffffffffUL);

rocm::magic_div_u32_t result;
result.magic = magic;
result.shift = shift;
return result;
}(batch);

rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, index_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_kernel_arg_t,
emb_t,
cache_t,
Expand All @@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
p_sorted_linear_indices_run,
p_sorted_linear_indices_cumulative_run_lengths,
p_sorted_linear_indices_num_runs,
{%- if not nobag %}
info_B_num_bits,
info_B_mask,
{%- endif %}
p_sorted_infos,
batch_mdiv,
max_segment_length_per_warp,
emb_dim,
batch,
num_rows,
T,
opt_karg
{%- if weighted %}
Expand Down Expand Up @@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
Expand Down
Loading
Loading