Skip to content

Commit 8306a04

Browse files
committed
Merge branch 'abokovoi/mi350-fix-optimized-segfault' into mi350_dev
2 parents 48d3161 + 0129050 commit 8306a04

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
650650
opt_karg.weight_decay = weight_decay;
651651

652652
rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
653-
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
653+
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, index_t, embedding_dim, weight_decay_mode_v>,
654654
rocm::{{optimizer}}_kernel_arg_t,
655655
emb_t,
656656
cache_t,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,9 @@ Tensor {{ embedding_cuda_op }}(
12381238
const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half
12391239
|| dev_weights.scalar_type() == at::ScalarType::Float;
12401240

1241-
if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna())
1241+
constexpr bool supported_grad_type = std::is_same_v<grad_t, float> || std::is_same_v<grad_t, at::Half>;
1242+
1243+
if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna())
12421244
{
12431245
constexpr int segments_per_workgroup = 4;
12441246
{%- for kDimSize in [64, 128, 160, 192, 256, 320] %}

fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
2828

2929
namespace fbgemm_gpu::rocm {
30-
template <typename cache_t, typename emb_t, int32_t embedding_dim, int32_t weight_decay_mode>
30+
template <typename cache_t, typename emb_t, typename index_t, int32_t embedding_dim, int32_t weight_decay_mode>
3131
struct rowwise_adagrad_optimizer_t
3232
{
3333
__device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_)
@@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t
3636
}
3737

3838
template <int32_t thread_length, int32_t segment_split>
39-
__device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index)
39+
__device__ void update(cache_t* acc, emb_t* weight, index_t row_index)
4040
{
4141
if constexpr(segment_split == 0)
4242
{

fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <c10/util/Half.h>
2525
#include <hip/hip_fp16.h>
2626
#include <hip/hip_runtime.h>
27+
#include <rocm-core/rocm_version.h>
2728

2829
/******************************************************************************/
2930
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
@@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16(
6162
int32_t voffset,
6263
int32_t soffset,
6364
int32_t glc_slc)
64-
#if defined(__gfx950__)
65+
#if ROCM_VERSION_MAJOR >= 7
6566
__asm("llvm.amdgcn.raw.buffer.load.i16");
6667
#else
6768
__asm("llvm.amdgcn.raw.buffer.load.f16");
@@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2(
7879
int32_t voffset,
7980
int32_t soffset,
8081
int32_t glc_slc)
81-
#if defined(__gfx950__)
82+
#if ROCM_VERSION_MAJOR >= 7
8283
__asm("llvm.amdgcn.raw.buffer.load.i32");
8384
#else
8485
__asm("llvm.amdgcn.raw.buffer.load.v2f16");
@@ -164,7 +165,7 @@ struct load_row_per_warp<half, 160, index_t> {
164165
static __device__ void
165166
run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) {
166167
int32x4_t emb_res =
167-
amdgcn_make_buffer_resource(p_emb_table + row_index * 192);
168+
amdgcn_make_buffer_resource(p_emb_table + row_index * 160);
168169
*reinterpret_cast<half2*>(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2(
169170
emb_res, lane_id * sizeof(half2), 0, 0);
170171
if ((lane_id + 128) % 192 < 160) {

0 commit comments

Comments
 (0)