From 144ca86c025cfeb5e8366648a0f8c4486773c305 Mon Sep 17 00:00:00 2001 From: Todobe Date: Wed, 10 Sep 2025 19:26:35 +0800 Subject: [PATCH 1/4] mlapo fit different hidden state dim --- .../mla_preprocess/op_host/mla_preprocess.cpp | 30 +++++++++---------- .../op_host/tiling/mla_preprocess_tiling.h | 3 ++ .../mla_preprocess/op_kernel/mla_preprocess.h | 1 - .../op_kernel/mla_preprocess_mix_bf16.hpp | 12 ++++---- .../op_kernel/mla_preprocess_mix_fp16.hpp | 16 +++++----- 5 files changed, 34 insertions(+), 28 deletions(-) diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.cpp b/csrc/mla_preprocess/op_host/mla_preprocess.cpp index 06e5c532..e37533a9 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.cpp +++ b/csrc/mla_preprocess/op_host/mla_preprocess.cpp @@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048; constexpr uint32_t L0C_SIZE = 128 * 1024; constexpr uint32_t CONCAT_SIZE = 512; -constexpr uint32_t HIDDEN_STRATE = 7168; constexpr uint32_t HIDDEN_STRATE_ROPE = 192; constexpr uint32_t HIDDEN_STRATE_MM = 2112; constexpr uint32_t HIDDEN_STRATE_RMS = 1536; @@ -373,14 +372,14 @@ class MlaPreprocessTiling this->platformInfo = platformInfo; this->opParam = opParam; } - void Init(); + void Init(uint32_t hiddenStateDim); - void RmsNormQuantTiling(); + void RmsNormQuantTiling(uint32_t hiddenStateDim); void RopeConcatTiling(); void EinSumQuantTiling(); void SetTilingKey(); - void SetMlapoWorkSpace(); + void SetMlapoWorkSpace(uint32_t hiddenStateDim); private: MlaTilingData *tilingData; @@ -388,10 +387,10 @@ class MlaPreprocessTiling struct OpParam opParam; }; -void MlaPreprocessTiling::RmsNormQuantTiling() +void MlaPreprocessTiling::RmsNormQuantTiling(uint32_t hiddenStateDim) { tilingData->rmsNumCore1 = platformInfo.coreNumAiv; - tilingData->rmsNumCol1 = HIDDEN_STRATE; + tilingData->rmsNumCol1 = hiddenStateDim; tilingData->rmsNumRow1 = opParam.N; tilingData->rmsQuantMin1 = -CONST_128; tilingData->rmsNumCore2 = platformInfo.coreNumAiv; @@ -504,12 +503,12 @@ void MlaPreprocessTiling::EinSumQuantTiling() tilingData->esqColTail = esqColTail; } -void MlaPreprocessTiling::SetMlapoWorkSpace() +void MlaPreprocessTiling::SetMlapoWorkSpace(uint32_t hiddenStateDim) { uint64_t s1wsFactor = - static_cast(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t), + static_cast(opParam.cacheMode == 2 ? std::max(hiddenStateDim * sizeof(int8_t), opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) - : HIDDEN_STRATE * sizeof(int8_t)); + : hiddenStateDim * sizeof(int8_t)); uint64_t workSizeS1 = s1wsFactor; uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); @@ -548,11 +547,11 @@ void MlaPreprocessTiling::SetTilingKey() tilingData->tilingKey = tilingKey; } -void MlaPreprocessTiling::Init() +void MlaPreprocessTiling::Init(uint32_t hiddenStateDim) { tilingData->numCore = platformInfo.coreNumAic; tilingData->n = opParam.N; - + tilingData->hiddenStateDim = hiddenStateDim; bool deqOnTheFly = false; if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { deqOnTheFly = true; @@ -561,7 +560,7 @@ void MlaPreprocessTiling::Init() PpMatmulTilingApi mm1TilingApi(platformInfo, 1, // numBatch opParam.N, // m - HIDDEN_STRATE, // k + hiddenStateDim, // k HIDDEN_STRATE_MM, // n false, // transA true, // transB @@ -591,11 +590,11 @@ void MlaPreprocessTiling::Init() deqOnTheFly); // in bf16.cce? mm3TilingApi.GetTilingData(tilingData->mm3); - RmsNormQuantTiling(); + RmsNormQuantTiling(hiddenStateDim); RopeConcatTiling(); EinSumQuantTiling(); - SetMlapoWorkSpace(); + SetMlapoWorkSpace(hiddenStateDim); SetTilingKey(); return; @@ -656,6 +655,7 @@ std::tuple mla_preproces int32_t N = hiddenState.sizes()[0]; int32_t headNum = wuk.sizes()[0]; + uint32_t hiddenStateDim = hiddenState.sizes().back(); OpParam opParam; opParam.N = N; @@ -667,7 +667,7 @@ std::tuple mla_preproces MlaTilingData tilingData; MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); - mlaTiling.Init(); + mlaTiling.Init(hiddenStateDim); uint32_t blockDim = platformInfo.coreNumAic; // workspace diff --git a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h index aab1f3a7..418aaa3a 100644 --- a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h +++ b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h @@ -90,6 +90,9 @@ struct MlaTilingData { uint32_t esqHeadTail{0}; uint32_t esqColLoop{0}; uint32_t esqColTail{0}; + + // hidden state dimension + uint32_t hiddenStateDim{7168}; }; #endif // MLAPREPROCESS_TILING_H diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h index 35254112..d644f55f 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess.h +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format constexpr uint8_t CACHE_MODE_NZCACHE = 3; // pp matmul -constexpr uint32_t HIDDTEN_STATE = 7168; constexpr uint32_t FLOAT_BLOCK_SIZE = 64; constexpr uint32_t HALF_BLOCK_SIZE = 64; constexpr uint32_t HALF_VECTOR_SIZE = 64; diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index f58f4aa7..3de2266d 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2386,6 +2386,7 @@ class MLAOperation this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, @@ -2694,6 +2695,7 @@ class MLAOperation uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; uint32_t num_core_; @@ -2801,15 +2803,15 @@ MLAOperation input_tensor = buf.GetBuffer(0); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 32); AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 097fbc2a..81c12940 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2035,6 +2035,7 @@ class MLAOperation this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, @@ -2297,6 +2298,7 @@ class MLAOperation uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; // rmsnormQuant @@ -2395,19 +2397,19 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(HIDDTEN_STATE * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(hiddenStateDim * 2); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); + buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 32); AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); + buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + + hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); From f8d416b2dad40c3ab86411a29888d5d53196f228 Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 11 Sep 2025 09:12:52 +0800 Subject: [PATCH 2/4] fix code assisit --- .../op_kernel/mla_preprocess_mix_bf16.hpp | 13 ++++++------ .../op_kernel/mla_preprocess_mix_fp16.hpp | 20 +++++++++++-------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index 3de2266d..1b22f82d 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2801,18 +2801,19 @@ MLAOperation input_tensor = buf.GetBuffer(0); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2); + buf.GetBuffer(base_offset); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 32); + base_offset + 32); AscendC::LocalTensor res1_tensor = - buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64); + buf.GetBuffer(base_offset + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4); + base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64); + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(QUANT1); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 81c12940..6651d7b9 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2396,21 +2396,25 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(hiddenStateDim * 2); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(gamma_offset); AscendC::LocalTensor beta_tensor = - buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2); + buf.GetBuffer(beta_offset); AscendC::LocalTensor scale_tensor = - buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2); + buf.GetBuffer(scale_offset); AscendC::LocalTensor offset_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 32); + scale_offset + 32); AscendC::LocalTensor res1_tensor = - buf.GetBuffer(hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64); + buf.GetBuffer(scale_offset + 64); AscendC::LocalTensor res3_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4); + scale_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - hiddenStateDim * 2 + hiddenStateDim * 2 + hiddenStateDim * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 32); + scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } From c2bf76e5bca1f55e3f0176c882cf162c21658fb5 Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 18 Sep 2025 17:36:15 +0800 Subject: [PATCH 3/4] clean code --- csrc/mla_preprocess/op_host/mla_preprocess.cpp | 2 +- .../op_kernel/mla_preprocess_mix_bf16.hpp | 15 ++++++--------- .../op_kernel/mla_preprocess_mix_fp16.hpp | 16 ++++++---------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.cpp b/csrc/mla_preprocess/op_host/mla_preprocess.cpp index e37533a9..47a64236 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.cpp +++ b/csrc/mla_preprocess/op_host/mla_preprocess.cpp @@ -560,7 +560,7 @@ void MlaPreprocessTiling::Init(uint32_t hiddenStateDim) PpMatmulTilingApi mm1TilingApi(platformInfo, 1, // numBatch opParam.N, // m - hiddenStateDim, // k + hiddenStateDim, // k HIDDEN_STRATE_MM, // n false, // transA true, // transB diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index 1b22f82d..4c821553 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2802,16 +2802,13 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(base_offset); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - base_offset + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(base_offset + 64); - AscendC::LocalTensor res3_tensor = buf.GetBuffer( - base_offset + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(base_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(base_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(base_offset + 64); + AscendC::LocalTensor res3_tensor = + buf.GetBuffer(base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 6651d7b9..f0e85e9a 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2403,16 +2403,12 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); AscendC::LocalTensor gamma_tensor = buf.GetBuffer(gamma_offset); - AscendC::LocalTensor beta_tensor = - buf.GetBuffer(beta_offset); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(scale_offset); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - scale_offset + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(scale_offset + 64); - AscendC::LocalTensor res3_tensor = buf.GetBuffer( - scale_offset + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor beta_tensor = buf.GetBuffer(beta_offset); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(scale_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(scale_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(scale_offset + 64); + AscendC::LocalTensor res3_tensor = + buf.GetBuffer(scale_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, From 8e6cc8ba676cc76b1c19f5fe29176194fad58a95 Mon Sep 17 00:00:00 2001 From: Todobe Date: Thu, 18 Sep 2025 17:37:44 +0800 Subject: [PATCH 4/4] clean code --- csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp | 2 +- csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index 4c821553..2c894087 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2807,7 +2807,7 @@ MLAOperation scale_tensor = buf.GetBuffer(base_offset); AscendC::LocalTensor offset_tensor = buf.GetBuffer(base_offset + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer(base_offset + 64); - AscendC::LocalTensor res3_tensor = + AscendC::LocalTensor res3_tensor = buf.GetBuffer(base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index f0e85e9a..ed8d19c8 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2407,7 +2407,7 @@ __aicore__ inline void MLAOperation scale_tensor = buf.GetBuffer(scale_offset); AscendC::LocalTensor offset_tensor = buf.GetBuffer(scale_offset + 32); AscendC::LocalTensor res1_tensor = buf.GetBuffer(scale_offset + 64); - AscendC::LocalTensor res3_tensor = + AscendC::LocalTensor res3_tensor = buf.GetBuffer(scale_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);