diff --git a/csrc/mla_preprocess/op_host/mla_preprocess.cpp b/csrc/mla_preprocess/op_host/mla_preprocess.cpp index 06e5c532..47a64236 100644 --- a/csrc/mla_preprocess/op_host/mla_preprocess.cpp +++ b/csrc/mla_preprocess/op_host/mla_preprocess.cpp @@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048; constexpr uint32_t L0C_SIZE = 128 * 1024; constexpr uint32_t CONCAT_SIZE = 512; -constexpr uint32_t HIDDEN_STRATE = 7168; constexpr uint32_t HIDDEN_STRATE_ROPE = 192; constexpr uint32_t HIDDEN_STRATE_MM = 2112; constexpr uint32_t HIDDEN_STRATE_RMS = 1536; @@ -373,14 +372,14 @@ class MlaPreprocessTiling this->platformInfo = platformInfo; this->opParam = opParam; } - void Init(); + void Init(uint32_t hiddenStateDim); - void RmsNormQuantTiling(); + void RmsNormQuantTiling(uint32_t hiddenStateDim); void RopeConcatTiling(); void EinSumQuantTiling(); void SetTilingKey(); - void SetMlapoWorkSpace(); + void SetMlapoWorkSpace(uint32_t hiddenStateDim); private: MlaTilingData *tilingData; @@ -388,10 +387,10 @@ class MlaPreprocessTiling struct OpParam opParam; }; -void MlaPreprocessTiling::RmsNormQuantTiling() +void MlaPreprocessTiling::RmsNormQuantTiling(uint32_t hiddenStateDim) { tilingData->rmsNumCore1 = platformInfo.coreNumAiv; - tilingData->rmsNumCol1 = HIDDEN_STRATE; + tilingData->rmsNumCol1 = hiddenStateDim; tilingData->rmsNumRow1 = opParam.N; tilingData->rmsQuantMin1 = -CONST_128; tilingData->rmsNumCore2 = platformInfo.coreNumAiv; @@ -504,12 +503,12 @@ void MlaPreprocessTiling::EinSumQuantTiling() tilingData->esqColTail = esqColTail; } -void MlaPreprocessTiling::SetMlapoWorkSpace() +void MlaPreprocessTiling::SetMlapoWorkSpace(uint32_t hiddenStateDim) { uint64_t s1wsFactor = - static_cast(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t), + static_cast(opParam.cacheMode == 2 ? std::max(hiddenStateDim * sizeof(int8_t), opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t)) - : HIDDEN_STRATE * sizeof(int8_t)); + : hiddenStateDim * sizeof(int8_t)); uint64_t workSizeS1 = s1wsFactor; uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t); uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t); @@ -548,11 +547,11 @@ void MlaPreprocessTiling::SetTilingKey() tilingData->tilingKey = tilingKey; } -void MlaPreprocessTiling::Init() +void MlaPreprocessTiling::Init(uint32_t hiddenStateDim) { tilingData->numCore = platformInfo.coreNumAic; tilingData->n = opParam.N; - + tilingData->hiddenStateDim = hiddenStateDim; bool deqOnTheFly = false; if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { deqOnTheFly = true; @@ -561,7 +560,7 @@ void MlaPreprocessTiling::Init() PpMatmulTilingApi mm1TilingApi(platformInfo, 1, // numBatch opParam.N, // m - HIDDEN_STRATE, // k + hiddenStateDim, // k HIDDEN_STRATE_MM, // n false, // transA true, // transB @@ -591,11 +590,11 @@ void MlaPreprocessTiling::Init() deqOnTheFly); // in bf16.cce? mm3TilingApi.GetTilingData(tilingData->mm3); - RmsNormQuantTiling(); + RmsNormQuantTiling(hiddenStateDim); RopeConcatTiling(); EinSumQuantTiling(); - SetMlapoWorkSpace(); + SetMlapoWorkSpace(hiddenStateDim); SetTilingKey(); return; @@ -656,6 +655,7 @@ std::tuple mla_preproces int32_t N = hiddenState.sizes()[0]; int32_t headNum = wuk.sizes()[0]; + uint32_t hiddenStateDim = hiddenState.sizes().back(); OpParam opParam; opParam.N = N; @@ -667,7 +667,7 @@ std::tuple mla_preproces MlaTilingData tilingData; MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData); - mlaTiling.Init(); + mlaTiling.Init(hiddenStateDim); uint32_t blockDim = platformInfo.coreNumAic; // workspace diff --git a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h index aab1f3a7..418aaa3a 100644 --- a/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h +++ b/csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h @@ -90,6 +90,9 @@ struct MlaTilingData { uint32_t esqHeadTail{0}; uint32_t esqColLoop{0}; uint32_t esqColTail{0}; + + // hidden state dimension + uint32_t hiddenStateDim{7168}; }; #endif // MLAPREPROCESS_TILING_H diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess.h b/csrc/mla_preprocess/op_kernel/mla_preprocess.h index 35254112..d644f55f 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess.h +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess.h @@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format constexpr uint8_t CACHE_MODE_NZCACHE = 3; // pp matmul -constexpr uint32_t HIDDTEN_STATE = 7168; constexpr uint32_t FLOAT_BLOCK_SIZE = 64; constexpr uint32_t HALF_BLOCK_SIZE = 64; constexpr uint32_t HALF_VECTOR_SIZE = 64; diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp index f58f4aa7..2c894087 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp @@ -2386,6 +2386,7 @@ class MLAOperation this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, @@ -2694,6 +2695,7 @@ class MLAOperation uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; uint32_t num_core_; @@ -2799,18 +2801,16 @@ MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); - AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(base_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(base_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(base_offset + 64); + AscendC::LocalTensor res3_tensor = + buf.GetBuffer(base_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 64); + base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64); Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); } FftsCrossCoreSync(QUANT1); diff --git a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp index 097fbc2a..ed8d19c8 100644 --- a/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp +++ b/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp @@ -2035,6 +2035,7 @@ class MLAOperation this->num_row = mlaParams_.n; this->epsilon_ = 1e-6; this->mlaParams = mlaParams_; + this->hiddenStateDim = mlaParams_.hiddenStateDim; } __aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm, @@ -2297,6 +2298,7 @@ class MLAOperation uint32_t blockOffset; uint32_t perTaskNum; uint32_t resTaskNum; + uint32_t hiddenStateDim; MlaTilingData mlaParams; // rmsnormQuant @@ -2394,21 +2396,21 @@ __aicore__ inline void MLAOperation input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(HIDDTEN_STATE * 2); - AscendC::LocalTensor beta_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor scale_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2); - AscendC::LocalTensor offset_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32); - AscendC::LocalTensor res1_tensor = - buf.GetBuffer(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64); - AscendC::LocalTensor res3_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4); + AscendC::LocalTensor gamma_tensor = buf.GetBuffer(gamma_offset); + AscendC::LocalTensor beta_tensor = buf.GetBuffer(beta_offset); + AscendC::LocalTensor scale_tensor = buf.GetBuffer(scale_offset); + AscendC::LocalTensor offset_tensor = buf.GetBuffer(scale_offset + 32); + AscendC::LocalTensor res1_tensor = buf.GetBuffer(scale_offset + 64); + AscendC::LocalTensor res3_tensor = + buf.GetBuffer(scale_offset + 64 + num_col_align_f32 * 4); AscendC::LocalTensor output_tensor = buf.GetBuffer( - HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 + - BUF_FACTOR * num_col_align_f32 * 4 + 32); + scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32); Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor); }