-
Notifications
You must be signed in to change notification settings - Fork 674
backward performance optimization for MI350 #4925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
Hi @q10 , sorry I missed your message. BTW, we discovered a numerical issue in 986cceb and reverted it in 85417b4. It unblocks merging the bwd optimization first. Thank you. |
I think this commit only addresses the build step, where we need to link to tbb. However, for runtime, you might need to do a find in $CONDA_PREFIX from inside the container, and manually update LD_LIBRARY_PATH, or create a symlink (something like FBGEMM/.github/scripts/utils_build.bash Line 383 in 0d49628
|
Hi @q10 , I can see a few line below of your example that links tbb FBGEMM/.github/scripts/utils_build.bash Line 388 in 0d49628
One more thing is that installing tbb may not be sufficient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will also send some diffs on Slack.
|
|
||
| // Compute shared memory size for cta_per_row | ||
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this line and the one below common between CUDA and ROCm ? If yes, we should the {% if rocm %} guards around them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is common.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{% if rocm %} guards are already applied. Do we need additional changes on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me take a look and review in the final version so that we're on the same page.
| // Compute shared memory size for cta_per_row | ||
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; | ||
| int32_t total_L = indices.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment - total_L seems used only in the ROCm path, so move it under ifdef USE_ROCM ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved int32_t total_L = indices.numel(); doen under {% if is_rocm %}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find the fixes, not sure if it's another PR I missed or it hasn't been pushed. Would you mind sharing the latest version with all the fixes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fix regarding this is already on line 1065 on this PR. This display is outdated.
|
|
||
| const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); | ||
| const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; | ||
| const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to affect the CUDA regular path as well - please use a {% if rocm %} guard to select between 1024 and 4096.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced {% if is_rocm %} onto here
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; | ||
| int32_t total_L = indices.numel(); | ||
| #ifdef USE_ROCM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
USE_ROCM is a ROCm-specific, could we guard it with a {% if rocm %} so it does not bleed into CUDA codegen ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced {% if is_rocm %} onto here
| FBGEMM_LAUNCH_KERNEL( | ||
| backward_cta_per_row_kernel, | ||
| cta_per_row_grid_size, | ||
| // (64, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this comment ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
| TORCH_CHECK_EQ(grad_outputs.size(), 1); | ||
|
|
||
| constexpr int32_t max_segment_length_per_warp = 32; | ||
| constexpr int32_t max_segment_length_per_warp = 16384; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_segment_length_per_warp passed by host will be modified on embedding_split_host_pt2_autograd_template.cpp later. Reverting max_segment_length_per_warp on here back to 32.
| const auto permute_output_dim_0_1 = | ||
| ctx->saved_data["permute_output_dim_0_1"].toBool(); | ||
|
|
||
| constexpr int32_t max_segment_length_per_warp = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_segment_length_per_warp passed by host will be modified on embedding_split_host_pt2_autograd_template.cpp later. Reverting max_segment_length_per_warp on here back to 32.
| TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); | ||
| const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); | ||
| const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; | ||
| const auto mixed_D = aux_bool[IDX_MIXED_D]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This path seems common with regular CUDA, could we guard it with a {% if rocm %} guard to select between 32 and 16384?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied {% if is_rocm %} onto here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think rocm needs to be guarded here. mixed_D has been passed to all. It was not used in the forward function, so it's just being saved for backward.
Please use static_cast<bool>(aux_bool[IDX_MIXED_D]);
then you can replace this line to ctx->saved_data["mixed_D"] = mixed_D
| // Workaround. Should not be upstreamed in any way. | ||
| // Redistribute all cta_per_row work to warp_per_row. | ||
| int32_t total_L = indices.numel(); | ||
| {%- if (not nobag) and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a {% if rocm %} guard around this code, total_L is used only in the USE_ROCM path and USE_ROCM is ROCm specific (so if we do not guard it with {% if rocm %} it will be codegen'ed for regular CUDA paths as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By adding a jinja, do you mean only applying {% if rocm %} around the total_L, i.e. @ionuthristodorescu
{% if rocm %}
int32_t total_L = indices.numel();
{%- endif %}?
Or changing from using USE_ROCM to {% if rocm %} entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The source code will generate but it shouldn't be compiled. Are we trying to double check here or does it cause any issues, because CUDA path should not compile this. Besides, I think rocm is not passed in this file as a global variable, so the condition will always be false, and total_L will never show up in the source code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. Either USE_ROCM or jinja {% if is_rocm %} is fine for us. Please let us know which one we stick with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this file, we should stick with USE_ROCM. I think your original change is correct. rocm is not defined in this file, so jinja will see it as false, and int32_t total_L = indices.numel(); may not show up in the generated source code. Let me review in the final version.
| ) | ||
| {%- endif %} | ||
|
|
||
| for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should use a {% if rocm %} guard to select between rolled / unrolled versions of the loop on regular, non-ROCm paths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied {% if is_rocm %} onto here
warp per row wg change
| TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); | ||
| const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); | ||
| const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; | ||
| const auto mixed_D = aux_bool[IDX_MIXED_D]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think rocm needs to be guarded here. mixed_D has been passed to all. It was not used in the forward function, so it's just being saved for backward.
Please use static_cast<bool>(aux_bool[IDX_MIXED_D]);
then you can replace this line to ctx->saved_data["mixed_D"] = mixed_D
| // Workaround. Should not be upstreamed in any way. | ||
| // Redistribute all cta_per_row work to warp_per_row. | ||
| int32_t total_L = indices.numel(); | ||
| {%- if (not nobag) and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The source code will generate but it shouldn't be compiled. Are we trying to double check here or does it cause any issues, because CUDA path should not compile this. Besides, I think rocm is not passed in this file as a global variable, so the condition will always be false, and total_L will never show up in the source code.
| self.pooling_mode != PoolingMode.NONE | ||
| ), "Mixed dimension tables only supported for pooling tables." | ||
|
|
||
| self.mixed_D = mixed_D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On this change, I assume mixed_D needs to be accessible as the module parameter? Would it cause any issues if it's not self.mixed_D? Asking this to check if we need to split the PRs for backend (C++ source code and backend codegen) and frontend changes (split_table_batched_embeddings_ops_training.py).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimization on warp_per_row and cta_per_row kernels will not be activated if self.mixed_D is not present.
|
|
||
| // Compute shared memory size for cta_per_row | ||
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is common.
| // Compute shared memory size for cta_per_row | ||
| constexpr auto kCacheAccBytes = sizeof(at::acc_type<cache_t, true>); | ||
| int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; | ||
| int32_t total_L = indices.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find the fixes, not sure if it's another PR I missed or it hasn't been pushed. Would you mind sharing the latest version with all the fixes?
…ad-store' into mi350_dev
bwd performance optimization for ROCm.
Fix numerical issues