-
Notifications
You must be signed in to change notification settings - Fork 674
backward performance optimization for MI350 #4925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
523a317
aee3078
5a1ac2e
7856903
e1e246a
6a99fe0
349a7b5
1178cd1
606ad34
a22ddeb
6775452
90e6ba7
a9073ac
9a16e12
68630da
bac0610
a12112f
3ef64f7
f601e55
04916da
1e09555
b41192b
2b08f96
0c26470
d6b491b
70ed5e2
cf6a2b1
1be9bd8
9d3ee64
a5a3b1e
28e93c0
842846c
97aeb83
00976c7
4c19030
9991cf1
b61bd19
c38ff6f
e201e8b
aaaf80c
b8aea67
b9a7759
a4b4431
5d4f2cd
d3b7d7a
878d00f
d2596c7
e076556
585300d
e0db2f1
bf143c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -370,6 +370,7 @@ install_build_tools () { | |
| patchelf \ | ||
| rhash \ | ||
| scikit-build \ | ||
| tbb-devel \ | ||
| tbb \ | ||
| wheel \ | ||
| xz \ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,10 @@ | |
| #include "fbgemm_gpu/utils/assert_macros.h" | ||
| #include "fbgemm_gpu/utils/kernel_launcher.cuh" | ||
|
|
||
| {%- if is_rocm %} | ||
| #include "fbgemm_gpu/rocm/cdna_guard.h" | ||
| {%- endif %} | ||
|
|
||
| using Tensor = at::Tensor; | ||
| using namespace fbgemm_gpu; | ||
|
|
||
|
|
@@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void | |
| 2, offset_idx + D_emb <= weights_numel, offset_idx | ||
| ) | ||
| {%- endif %} | ||
| int32_t j = 0; | ||
| {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} | ||
| // Currently for split_embedding_codegen_grad_indice_weights_kernel only | ||
| if (placement != PlacementType::MANAGED_CACHING) { | ||
| for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { | ||
| const auto offset_idx_j0 = shfl_sync(offset_idx, j); | ||
| const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); | ||
| const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); | ||
| const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); | ||
|
|
||
| at::acc_type<cache_t, true> grad_indice_weight0 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight1 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight2 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight3 = 0.0; | ||
|
|
||
| const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D); | ||
| const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D); | ||
| const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D); | ||
| const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D); | ||
|
|
||
| #pragma unroll kFixedMaxVecsPerThread | ||
| for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { | ||
| const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; | ||
|
|
||
| Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3; | ||
| weight0 = weight_row0.load(d); | ||
| weight1 = weight_row1.load(d); | ||
| weight2 = weight_row2.load(d); | ||
| weight3 = weight_row3.load(d); | ||
|
|
||
| grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + | ||
| weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + | ||
| weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + | ||
| weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + | ||
| weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; | ||
| } | ||
|
|
||
| grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0); | ||
| grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1); | ||
| grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2); | ||
| grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3); | ||
|
|
||
| if (threadIdx.x == 0) { | ||
| grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; | ||
| grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; | ||
| grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; | ||
| grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; | ||
| } | ||
| } | ||
| } else { | ||
| for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { | ||
| const auto offset_idx_j0 = shfl_sync(offset_idx, j); | ||
| const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); | ||
| const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); | ||
| const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); | ||
|
|
||
| const auto cache_idx_j0 = shfl_sync(cache_idx, j); | ||
| const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); | ||
| const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); | ||
| const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); | ||
|
|
||
| at::acc_type<cache_t, true> grad_indice_weight0 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight1 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight2 = 0.0; | ||
| at::acc_type<cache_t, true> grad_indice_weight3 = 0.0; | ||
|
|
||
| const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D); | ||
| const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D); | ||
| const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D); | ||
| const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D); | ||
|
|
||
| #pragma unroll kFixedMaxVecsPerThread | ||
| for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { | ||
| const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; | ||
|
|
||
| Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3; | ||
| weight0 = (cache_idx_j0 != kCacheLocationMissing) ? | ||
| Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) : | ||
| weight_row0.load(d); | ||
|
|
||
| weight1 = (cache_idx_j1 != kCacheLocationMissing) ? | ||
| Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) : | ||
| weight_row1.load(d); | ||
|
|
||
| weight2 = (cache_idx_j2 != kCacheLocationMissing) ? | ||
| Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) : | ||
| weight_row2.load(d); | ||
|
|
||
| weight3 = (cache_idx_j3 != kCacheLocationMissing) ? | ||
| Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) : | ||
| weight_row3.load(d); | ||
|
|
||
|
|
||
| grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + | ||
| weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + | ||
| weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + | ||
| weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; | ||
| grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + | ||
| weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; | ||
| } | ||
|
|
||
| grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0); | ||
| grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1); | ||
| grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2); | ||
| grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3); | ||
|
|
||
| for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we should use a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Applied {% if is_rocm %} onto here |
||
| if (threadIdx.x == 0) { | ||
| grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; | ||
| grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; | ||
| grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; | ||
| grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; | ||
| } | ||
| } | ||
| } | ||
| {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} | ||
| for (; j < kWarpSize && l_start + j < L; ++j) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line 337 and 339, does it really make a difference? Otherwise, line 337-340 could just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in bf143c7. Now it's unified, so CUDA uses it as a main loop while ROCm uses this loop to handle the tailing iterations |
||
| const auto offset_idx_j = shfl_sync(offset_idx, j); | ||
| {%- if not dense %} | ||
| const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); | ||
|
|
@@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( | |
| auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); | ||
|
|
||
| CUDA_DEVICE_GUARD(dev_weights); | ||
| #ifdef USE_ROCM | ||
| if (!rocm::is_supported_cdna()) { | ||
| TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); | ||
| } | ||
| else { | ||
| // Ensure we're running on a supported CDNA architecture (including MI350) | ||
| TORCH_WARN_ONCE("Running on CDNA architecture"); | ||
| } | ||
| #endif | ||
|
|
||
| const auto T = D_offsets.size(0) - 1; | ||
| TORCH_CHECK_GT(T, 0); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,14 @@ | |
|
|
||
| {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} | ||
| {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} | ||
| {%- set is_optimized_hip_kernel_supported_mode = is_rocm and | ||
| optimizer == "rowwise_adagrad" and | ||
| not dense and | ||
| not nobag and | ||
| not is_index_select and | ||
| not is_gwd_kernel and | ||
| not vbe and | ||
| not ssd %} | ||
|
|
||
| #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" | ||
| #include "fbgemm_gpu/utils/tensor_accessor_builder.h" | ||
|
|
@@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row | |
|
|
||
| {%- endif %} | ||
|
|
||
| {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} | ||
| {%- if is_optimized_hip_kernel_supported_mode %} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No nobag support? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not now. The nobag support is planned in near future |
||
| #include <hip/hip_runtime.h> | ||
| #include <hip/hip_fp16.h> | ||
| #include "fbgemm_gpu/rocm/split_embeddings_common.h" | ||
|
|
@@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd | |
| {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} | ||
| {%- endif %} | ||
| ) { | ||
| {%- if not nobag %} | ||
| int32_t T = D_offsets.size(0) - 1; | ||
| {%- else %} | ||
| int32_t T = weights_offsets.size(0); | ||
| {%- endif %} | ||
|
|
||
| auto p_output_grad = grad_output.data(); | ||
| auto p_emb_table = dev_weights.data(); | ||
| auto p_hash_size_cumsum = hash_size_cumsum.data(); | ||
|
|
@@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd | |
| constexpr int32_t segment_prefetch = 2; | ||
| constexpr int32_t segment_unroll = 8; | ||
| constexpr int32_t segment_split = 0; | ||
| auto batch = grad_output.size(0); | ||
| auto num_rows = dev_weights.size(0) / T / max_D; | ||
| {%- if weighted %} | ||
| constexpr bool is_weighted = true; | ||
| {%- else %} | ||
|
|
@@ -646,24 +647,9 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd | |
| // weight_decay(_mode) is supplied as args.split_function_args_no_defaults | ||
| opt_karg.weight_decay_mode = weight_decay_mode_v; | ||
| opt_karg.weight_decay = weight_decay; | ||
| auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { | ||
| assert(d >= 1 && d <= INT32_MAX); | ||
| uint8_t shift; | ||
| for(shift = 0; shift < 32; shift++) | ||
| if((1U << shift) >= d) | ||
| break; | ||
|
|
||
| uint64_t one = 1; | ||
| uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; | ||
| assert(magic <= 0xffffffffUL); | ||
|
|
||
| rocm::magic_div_u32_t result; | ||
| result.magic = magic; | ||
| result.shift = shift; | ||
| return result; | ||
| }(batch); | ||
|
|
||
| rocm::split_tbe_backward_hip_kernel_{{kdesc}}< | ||
| rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>, | ||
| rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, index_t, embedding_dim, weight_decay_mode_v>, | ||
| rocm::{{optimizer}}_kernel_arg_t, | ||
| emb_t, | ||
| cache_t, | ||
|
|
@@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd | |
| p_sorted_linear_indices_run, | ||
| p_sorted_linear_indices_cumulative_run_lengths, | ||
| p_sorted_linear_indices_num_runs, | ||
| {%- if not nobag %} | ||
| info_B_num_bits, | ||
| info_B_mask, | ||
| {%- endif %} | ||
| p_sorted_infos, | ||
| batch_mdiv, | ||
| max_segment_length_per_warp, | ||
| emb_dim, | ||
| batch, | ||
| num_rows, | ||
| T, | ||
| opt_karg | ||
| {%- if weighted %} | ||
|
|
@@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd | |
| {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} | ||
| {%- for cache_type in ['float', 'at::Half'] %} | ||
| {%- for index_type in ['int32_t', 'int64_t'] %} | ||
| {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} | ||
| {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} | ||
| {%- for kWeighDecayMode in [0, 1, 2] %} | ||
| {{ hip_template_instantiation( | ||
| emb_type, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no nobag for gen_embedding_backward_split_{}{}_device_kernel_hip.hip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not now. The nobag support is planned in near future