Skip to content

Conversation

@liligwu
Copy link
Contributor

@liligwu liligwu commented Sep 24, 2025

bwd performance optimization for ROCm.
Fix numerical issues

@meta-cla meta-cla bot added the cla signed label Sep 24, 2025
@netlify
Copy link

netlify bot commented Sep 24, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 570f148
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/690124a2371622000880421b
😎 Deploy Preview https://deploy-preview-4925--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@facebook-github-bot
Copy link
Contributor

@haoyuz has imported this pull request. If you are a Meta employee, you can view this in D83116315.

@q10
Copy link
Contributor

q10 commented Oct 2, 2025

@liligwu we're seeing

OSError: libtbb.so.12: cannot open shared object file: No such file or directory

We already install tbb here, so it might just be an issue of updating the build scripts to put libtbb in the LD_LIBRARY_PATH

@meta-codesync
Copy link
Contributor

meta-codesync bot commented Oct 13, 2025

@q10 has imported this pull request. If you are a Meta employee, you can view this in D83116315.

@liligwu
Copy link
Contributor Author

liligwu commented Oct 13, 2025

@liligwu we're seeing

OSError: libtbb.so.12: cannot open shared object file: No such file or directory

We already install tbb here, so it might just be an issue of updating the build scripts to put libtbb in the LD_LIBRARY_PATH

Hi @q10 , sorry I missed your message.
I actually have this commit 4d2bfdd that links tbb explicitly, it works in your container. Do you have any suggestions for fixing this issue in CI, please?

BTW, we discovered a numerical issue in 986cceb and reverted it in 85417b4. It unblocks merging the bwd optimization first.

Thank you.

@q10
Copy link
Contributor

q10 commented Oct 14, 2025

@liligwu we're seeing

OSError: libtbb.so.12: cannot open shared object file: No such file or directory

We already install tbb here, so it might just be an issue of updating the build scripts to put libtbb in the LD_LIBRARY_PATH

Hi @q10 , sorry I missed your message. I actually have this commit 4d2bfdd that links tbb explicitly, it works in your container. Do you have any suggestions for fixing this issue in CI, please?

BTW, we discovered a numerical issue in 986cceb and reverted it in 85417b4. It unblocks merging the bwd optimization first.

Thank you.

I think this commit only addresses the build step, where we need to link to tbb. However, for runtime, you might need to do a find in $CONDA_PREFIX from inside the container, and manually update LD_LIBRARY_PATH, or create a symlink (something like

(print_exec ln -s "${conda_prefix}/lib/librhash.so" "${conda_prefix}/lib/librhash.so.0") || return 1
).

@liligwu
Copy link
Contributor Author

liligwu commented Oct 15, 2025

@liligwu we're seeing

OSError: libtbb.so.12: cannot open shared object file: No such file or directory

We already install tbb here, so it might just be an issue of updating the build scripts to put libtbb in the LD_LIBRARY_PATH

Hi @q10 , sorry I missed your message. I actually have this commit 4d2bfdd that links tbb explicitly, it works in your container. Do you have any suggestions for fixing this issue in CI, please?
BTW, we discovered a numerical issue in 986cceb and reverted it in 85417b4. It unblocks merging the bwd optimization first.
Thank you.

I think this commit only addresses the build step, where we need to link to tbb. However, for runtime, you might need to do a find in $CONDA_PREFIX from inside the container, and manually update LD_LIBRARY_PATH, or create a symlink (something like

(print_exec ln -s "${conda_prefix}/lib/librhash.so" "${conda_prefix}/lib/librhash.so.0") || return 1

).

Hi @q10 , I can see a few line below of your example that links tbb

(print_exec ln -s "${conda_prefix}/lib/libtbb.so.12" "${conda_prefix}/lib/libtbb.so") || return 1

One more thing is that installing tbb may not be sufficient.
for example, in CentOS we dnf install -y tbb-devel tbb

@liligwu liligwu changed the title forward performance tuning for MI350 backward performance optimization for MI350 Oct 20, 2025
Copy link

@ionuthristodorescu ionuthristodorescu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also send some diffs on Slack.


// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line and the one below common between CUDA and ROCm ? If yes, we should the {% if rocm %} guards around them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{% if rocm %} guards are already applied. Do we need additional changes on here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a look and review in the final version so that we're on the same page.

// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
int32_t total_L = indices.numel();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment - total_L seems used only in the ROCm path, so move it under ifdef USE_ROCM ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved int32_t total_L = indices.numel(); doen under {% if is_rocm %}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find the fixes, not sure if it's another PR I missed or it hasn't been pushed. Would you mind sharing the latest version with all the fixes?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix regarding this is already on line 1065 on this PR. This display is outdated.


const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms();
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024;
const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to affect the CUDA regular path as well - please use a {% if rocm %} guard to select between 1024 and 4096.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced {% if is_rocm %} onto here

constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
int32_t total_L = indices.numel();
#ifdef USE_ROCM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

USE_ROCM is a ROCm-specific, could we guard it with a {% if rocm %} so it does not bleed into CUDA codegen ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced {% if is_rocm %} onto here

FBGEMM_LAUNCH_KERNEL(
backward_cta_per_row_kernel,
cta_per_row_grid_size,
// (64, 2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this comment ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

TORCH_CHECK_EQ(grad_outputs.size(), 1);

constexpr int32_t max_segment_length_per_warp = 32;
constexpr int32_t max_segment_length_per_warp = 16384;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_segment_length_per_warp passed by host will be modified on embedding_split_host_pt2_autograd_template.cpp later. Reverting max_segment_length_per_warp on here back to 32.

const auto permute_output_dim_0_1 =
ctx->saved_data["permute_output_dim_0_1"].toBool();

constexpr int32_t max_segment_length_per_warp = 32;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_segment_length_per_warp passed by host will be modified on embedding_split_host_pt2_autograd_template.cpp later. Reverting max_segment_length_per_warp on here back to 32.

TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value.");
const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value();
const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE];
const auto mixed_D = aux_bool[IDX_MIXED_D];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied {% if is_rocm %} onto here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think rocm needs to be guarded here. mixed_D has been passed to all. It was not used in the forward function, so it's just being saved for backward.

Please use static_cast<bool>(aux_bool[IDX_MIXED_D]);

then you can replace this line to ctx->saved_data["mixed_D"] = mixed_D

// Workaround. Should not be upstreamed in any way.
// Redistribute all cta_per_row work to warp_per_row.
int32_t total_L = indices.numel();
{%- if (not nobag) and

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a {% if rocm %} guard around this code, total_L is used only in the USE_ROCM path and USE_ROCM is ROCm specific (so if we do not guard it with {% if rocm %} it will be codegen'ed for regular CUDA paths as well).

Copy link

@kudomcho kudomcho Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By adding a jinja, do you mean only applying {% if rocm %} around the total_L, i.e. @ionuthristodorescu

{% if rocm %}
    int32_t total_L = indices.numel();
{%- endif %}?

Or changing from using USE_ROCM to {% if rocm %} entirely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source code will generate but it shouldn't be compiled. Are we trying to double check here or does it cause any issues, because CUDA path should not compile this. Besides, I think rocm is not passed in this file as a global variable, so the condition will always be false, and total_L will never show up in the source code.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. Either USE_ROCM or jinja {% if is_rocm %} is fine for us. Please let us know which one we stick with.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this file, we should stick with USE_ROCM. I think your original change is correct. rocm is not defined in this file, so jinja will see it as false, and int32_t total_L = indices.numel(); may not show up in the generated source code. Let me review in the final version.

)
{%- endif %}

for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should use a {% if rocm %} guard to select between rolled / unrolled versions of the loop on regular, non-ROCm paths.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied {% if is_rocm %} onto here

TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value.");
const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value();
const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE];
const auto mixed_D = aux_bool[IDX_MIXED_D];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think rocm needs to be guarded here. mixed_D has been passed to all. It was not used in the forward function, so it's just being saved for backward.

Please use static_cast<bool>(aux_bool[IDX_MIXED_D]);

then you can replace this line to ctx->saved_data["mixed_D"] = mixed_D

// Workaround. Should not be upstreamed in any way.
// Redistribute all cta_per_row work to warp_per_row.
int32_t total_L = indices.numel();
{%- if (not nobag) and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source code will generate but it shouldn't be compiled. Are we trying to double check here or does it cause any issues, because CUDA path should not compile this. Besides, I think rocm is not passed in this file as a global variable, so the condition will always be false, and total_L will never show up in the source code.

self.pooling_mode != PoolingMode.NONE
), "Mixed dimension tables only supported for pooling tables."

self.mixed_D = mixed_D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this change, I assume mixed_D needs to be accessible as the module parameter? Would it cause any issues if it's not self.mixed_D? Asking this to check if we need to split the PRs for backend (C++ source code and backend codegen) and frontend changes (split_table_batched_embeddings_ops_training.py).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimization on warp_per_row and cta_per_row kernels will not be activated if self.mixed_D is not present.


// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common.

// Compute shared memory size for cta_per_row
constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>);
int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize;
int32_t total_L = indices.numel();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find the fixes, not sure if it's another PR I missed or it hasn't been pushed. Would you mind sharing the latest version with all the fixes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants