Skip to content

Commit 48d3161

Browse files
committed
applied jinja is_rocm onto optimizations for backward and forward parameters
1 parent 04d590e commit 48d3161

File tree

5 files changed

+24
-15
lines changed

5 files changed

+24
-15
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
213213
2, offset_idx + D_emb <= weights_numel, offset_idx
214214
)
215215
{%- endif %}
216-
216+
{%- if is_rocm %}
217217
int32_t j = 0;
218218
{%- if not ssd and not dense and not use_vec_blocking and not vbe %}
219219
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
@@ -335,6 +335,9 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
335335
}
336336
{%- endif %}
337337
for (; j < kWarpSize && l_start + j < L; ++j) {
338+
{%- else %} // if is_rocm
339+
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
340+
{%- endif %} // if is_rocm
338341
const auto offset_idx_j = shfl_sync(offset_idx, j);
339342
{%- if not dense %}
340343
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -987,8 +987,11 @@ Tensor {{ embedding_cuda_op }}(
987987
auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt));
988988

989989
const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms();
990-
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096;
991-
990+
{% if is_rocm %}
991+
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096;
992+
{% else %}
993+
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024;
994+
{%- endif %}
992995
Tensor long_run_id_to_really_long_run_ids;
993996
if (use_deterministic_algorithms) {
994997
long_run_id_to_really_long_run_ids =
@@ -1059,8 +1062,8 @@ Tensor {{ embedding_cuda_op }}(
10591062

10601063
// Compute shared memory size for cta_per_row
10611064
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
1062-
int32_t total_L = indices.numel();
1063-
#ifdef USE_ROCM
1065+
{% if is_rocm %}
1066+
int32_t total_L = indices.numel();
10641067
int32_t num_cta_per_row_groups;
10651068
int32_t work_group_size;
10661069
if (total_L/total_B > 1){
@@ -1071,10 +1074,10 @@ Tensor {{ embedding_cuda_op }}(
10711074
num_cta_per_row_groups = kMaxThreads / kWarpSize;
10721075
work_group_size = kMaxThreads;
10731076
}
1074-
#else
1077+
{%- else %}
10751078
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
10761079
int32_t work_group_size = kMaxThreads;
1077-
#endif
1080+
{%- endif %}
10781081
const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes(
10791082
&num_cta_per_row_groups,
10801083
[&] (int num_groups) {
@@ -1091,7 +1094,6 @@ Tensor {{ embedding_cuda_op }}(
10911094
FBGEMM_LAUNCH_KERNEL(
10921095
backward_cta_per_row_kernel,
10931096
cta_per_row_grid_size,
1094-
// (64, 2)
10951097
dim3(kThreadGroupSize, num_cta_per_row_groups),
10961098
cta_per_row_smem_bytes,
10971099
at::cuda::getCurrentCUDAStream(),
@@ -1195,7 +1197,7 @@ Tensor {{ embedding_cuda_op }}(
11951197
kUseVecBlocking>;
11961198

11971199
// Compute shared memory size for warp_per_row
1198-
#ifdef USE_ROCM
1200+
{%- if is_rocm %}
11991201
int32_t num_warp_per_row_groups;
12001202

12011203
if (total_L/total_B > 1){
@@ -1204,9 +1206,9 @@ Tensor {{ embedding_cuda_op }}(
12041206
else{
12051207
num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize;
12061208
}
1207-
#else
1209+
{%- else %}
12081210
int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize;
1209-
#endif
1211+
{%- endif %}
12101212
int32_t warp_per_row_smem_bytes = 0;
12111213

12121214
if constexpr (kUseVecBlocking) {

fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,15 +458,15 @@ batch_index_select_dim0_codegen_forward_cuda(
458458

459459
CUDA_DEVICE_GUARD(dev_weights);
460460

461-
#ifdef USE_ROCM
461+
{% if is_rocm %}
462462
if (!rocm::is_supported_cdna()) {
463463
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
464464
}
465465
else {
466466
// Ensure we're running on a supported CDNA architecture (including MI350)
467467
TORCH_WARN_ONCE("Running on CDNA architecture");
468468
}
469-
#endif
469+
{%- endif %}
470470

471471
{%- if not nobag %}
472472
int32_t T = D_offsets.numel() - 1;

fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ class BatchIndexSelectDim0GPUOp
341341
Tensor grad_dev_weights;
342342
TORCH_CHECK_EQ(grad_outputs.size(), 1);
343343

344-
constexpr int32_t max_segment_length_per_warp = 16384;
344+
constexpr int32_t max_segment_length_per_warp = 32;
345345

346346
auto grad_output = grad_outputs[0];
347347

@@ -656,7 +656,7 @@ class BatchIndexSelectDim0TensorGPUOp
656656
const auto permute_output_dim_0_1 =
657657
ctx->saved_data["permute_output_dim_0_1"].toBool();
658658

659-
constexpr int32_t max_segment_length_per_warp = 16384;
659+
constexpr int32_t max_segment_length_per_warp = 32;
660660

661661
auto grad_output = grad_outputs[0];
662662

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,10 @@ class {{ autograd_func }} :
698698
TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value.");
699699
const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value();
700700
const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE];
701+
{% if is_rocm %}
701702
const auto mixed_D = aux_bool[IDX_MIXED_D];
702703
{%- endif %}
704+
{%- endif %}
703705

704706
// Default values for Dynamo tracing
705707
// SymInt does not support bitshifts operator
@@ -1009,7 +1011,9 @@ static torch::autograd::variable_list backward(
10091011
int32_t max_segment_length_per_warp = 64;
10101012
// Workaround. Should not be upstreamed in any way.
10111013
// Redistribute all cta_per_row work to warp_per_row.
1014+
{% if is_rocm %}
10121015
int32_t total_L = indices.numel();
1016+
{%- endif %}
10131017
{%- if (not nobag) and
10141018
(optimizer == "rowwise_adagrad") and
10151019
(not vbe) and

0 commit comments

Comments
 (0)