Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions csrc/mla_preprocess/op_host/mla_preprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512;

constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
Expand Down Expand Up @@ -373,25 +372,25 @@ class MlaPreprocessTiling
this->platformInfo = platformInfo;
this->opParam = opParam;
}
void Init();
void Init(uint32_t hiddenStateDim);

void RmsNormQuantTiling();
void RmsNormQuantTiling(uint32_t hiddenStateDim);
void RopeConcatTiling();
void EinSumQuantTiling();

void SetTilingKey();
void SetMlapoWorkSpace();
void SetMlapoWorkSpace(uint32_t hiddenStateDim);

private:
MlaTilingData *tilingData;
struct PlatformInfo platformInfo;
struct OpParam opParam;
};

void MlaPreprocessTiling::RmsNormQuantTiling()
void MlaPreprocessTiling::RmsNormQuantTiling(uint32_t hiddenStateDim)
{
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE;
tilingData->rmsNumCol1 = hiddenStateDim;
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
Expand Down Expand Up @@ -504,12 +503,12 @@ void MlaPreprocessTiling::EinSumQuantTiling()
tilingData->esqColTail = esqColTail;
}

void MlaPreprocessTiling::SetMlapoWorkSpace()
void MlaPreprocessTiling::SetMlapoWorkSpace(uint32_t hiddenStateDim)
{
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(hiddenStateDim * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t));
: hiddenStateDim * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
Expand Down Expand Up @@ -548,11 +547,11 @@ void MlaPreprocessTiling::SetTilingKey()
tilingData->tilingKey = tilingKey;
}

void MlaPreprocessTiling::Init()
void MlaPreprocessTiling::Init(uint32_t hiddenStateDim)
{
tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N;

tilingData->hiddenStateDim = hiddenStateDim;
bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
deqOnTheFly = true;
Expand All @@ -561,7 +560,7 @@ void MlaPreprocessTiling::Init()
PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE, // k
hiddenStateDim, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
Expand Down Expand Up @@ -591,11 +590,11 @@ void MlaPreprocessTiling::Init()
deqOnTheFly); // in bf16.cce?
mm3TilingApi.GetTilingData(tilingData->mm3);

RmsNormQuantTiling();
RmsNormQuantTiling(hiddenStateDim);
RopeConcatTiling();
EinSumQuantTiling();

SetMlapoWorkSpace();
SetMlapoWorkSpace(hiddenStateDim);
SetTilingKey();

return;
Expand Down Expand Up @@ -656,6 +655,7 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces

int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0];
uint32_t hiddenStateDim = hiddenState.sizes().back();

OpParam opParam;
opParam.N = N;
Expand All @@ -667,7 +667,7 @@ std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preproces
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

mlaTiling.Init();
mlaTiling.Init(hiddenStateDim);
uint32_t blockDim = platformInfo.coreNumAic;

// workspace
Expand Down
3 changes: 3 additions & 0 deletions csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ struct MlaTilingData {
uint32_t esqHeadTail{0};
uint32_t esqColLoop{0};
uint32_t esqColTail{0};

// hidden state dimension
uint32_t hiddenStateDim{7168};
};

#endif // MLAPREPROCESS_TILING_H
1 change: 0 additions & 1 deletion csrc/mla_preprocess/op_kernel/mla_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
constexpr uint8_t CACHE_MODE_NZCACHE = 3;

// pp matmul
constexpr uint32_t HIDDTEN_STATE = 7168;
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
constexpr uint32_t HALF_BLOCK_SIZE = 64;
constexpr uint32_t HALF_VECTOR_SIZE = 64;
Expand Down
20 changes: 10 additions & 10 deletions csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,7 @@ class MLAOperation
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}

__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
Expand Down Expand Up @@ -2694,6 +2695,7 @@ class MLAOperation
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;

uint32_t num_core_;
Expand Down Expand Up @@ -2799,18 +2801,16 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
const uint32_t base_offset = hiddenStateDim * 6;

AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
AscendC::LocalTensor<float> res3_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64);
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);
Expand Down
28 changes: 15 additions & 13 deletions csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2035,6 +2035,7 @@ class MLAOperation
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
this->hiddenStateDim = mlaParams_.hiddenStateDim;
}

__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
Expand Down Expand Up @@ -2297,6 +2298,7 @@ class MLAOperation
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
uint32_t hiddenStateDim;
MlaTilingData mlaParams;

// rmsnormQuant
Expand Down Expand Up @@ -2394,21 +2396,21 @@ __aicore__ inline void MLAOperation<cacheMode, weightFormat1, weightFormat2, wei
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;

const uint32_t gamma_offset = hiddenStateDim * 2;
const uint32_t beta_offset = gamma_offset + hiddenStateDim * 2;
const uint32_t scale_offset = beta_offset + hiddenStateDim * 2;

AscendC::LocalTensor<half> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(0);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> beta_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<half> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, half>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<half> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(gamma_offset);
AscendC::LocalTensor<half> beta_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(beta_offset);
AscendC::LocalTensor<half> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, half>(scale_offset);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(scale_offset + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64);
AscendC::LocalTensor<float> res3_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(scale_offset + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 32);
scale_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 32);
Quant1.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor, res1_tensor,
res3_tensor);
}
Expand Down