diff --git a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h index f98917aad4fd..8753488e43eb 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h +++ b/mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h @@ -198,6 +198,12 @@ struct MfmaEmitter : public AccelEmitter { int64_t getRowGroupSize() const; + // Return the MFMA K dimension + int64_t getMfmaK() const; + + // Return the MFMA instruction's non-K dimension + int64_t getMfmaNonKDim() const; + static bool classof(const AccelEmitter *AE) { return AE->getKind() == AccelEmitterKind::AEK_MFMAEmitter; } diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index 08924b5f4757..48326417777e 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1131,6 +1131,29 @@ defvar SameShapeVectorOfI1 = [{ ::mlir::VectorType::get(::llvm::cast<::mlir::ShapedType>($_self).getShape(), ::mlir::IntegerType::get($_ctxt, 1)) }]; +// lds_transpose_load +def Rock_LDSTransposeLoadOp : + Rock_Op<"lds_transpose_load", + [DeclareOpInterfaceMethods]>, + Arguments<(ins + Arg, "LDS source buffer">:$source, + Variadic:$indices)>, + Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$fragment)> { + let summary = "Hardware-assisted LDS transpose panel load (v4 fragment)"; + let description = [{ + LDS transpose panel load (vector<4 x F16/BF16>) for MFMA-friendly ordering. + Loads a single panel (4 elements) from LDS memory using AMD ds.read.tr16.* instructions. + The returned fragment ($fragment) is always a vector<4 x elementType>. + Helps to reduced bank conflict. Only supported for gfx950. + }]; + + let assemblyFormat = [{ + $source (`[` $indices^ `]`)? attr-dict `:` type($source) `->` type($fragment) + }]; + + let hasVerifier = 1; +} + // threadwise_read_into def Rock_ThreadwiseReadIntoOp : Rock_Op< diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 7ad59a311f1e..dfee10779c84 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -1956,6 +1956,47 @@ LogicalResult InBoundsStoreOp::verify() { return success(); } +//===-----------------------------------------------------===// +// LDSTransposeLoadOp +//===-----------------------------------------------------===// +void LDSTransposeLoadOp::getEffects( + SmallVectorImpl &effects) { + // This op only reads from LDS (workgroup) memory + auto *read = MemoryEffects::Read::get(); + effects.emplace_back(read, &getSourceMutable()); +} + +LogicalResult LDSTransposeLoadOp::verify() { + // Source must be memref in workgroup (LDS) address space + MemRefType srcType = getSource().getType(); + Attribute memSpaceAttr = srcType.getMemorySpace(); + if (!memSpaceAttr) + return emitOpError( + "source memref must have an address space (workgroup/LDS)"); + auto gpuMemSpaceAttr = dyn_cast(memSpaceAttr); + bool isWorkgroup = false; + if (gpuMemSpaceAttr && + gpuMemSpaceAttr.getValue() == gpu::AddressSpace::Workgroup) + isWorkgroup = true; + else if (auto intAttr = dyn_cast(memSpaceAttr)) { + // Accept raw integer 3 as LDS (common textual form memref<... , 3>) + if (intAttr.getInt() == 3) + isWorkgroup = true; + } + if (!isWorkgroup) + return emitOpError("source must reside in workgroup (LDS) memory"); + + // Indices size must match rank + if (getIndices().size() != srcType.getRank()) + return emitOpError("expected " + Twine(srcType.getRank()) + " indices"); + for (Value idx : getIndices()) { + if (!idx.getType().isIndex()) + return emitOpError("indices must be of index type"); + } + + return success(); +} + //===-----------------------------------------------------===// // ThreadwiseReadIntoOp //===-----------------------------------------------------===// diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp index 3ee76d9c2a19..d9ff450c7f45 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp @@ -32,15 +32,15 @@ #include "mlir/Dialect/Rock/utility/math.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" +#include "LdsTransposeLoad.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" - -#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" @@ -436,6 +436,8 @@ struct BlockwiseGemmAccelRewritePattern int64_t kBasePerThread = params.kBasePerThread; auto tid = WorkitemIdOp::create(b, loc, b.getIndexType()); + // Retrieve the stored lds transpose load decision. + auto globalDecisionOpt = hwtranspose::getDecisionLdsTranspose(); LLVM_DEBUG(llvm::dbgs() << "argVectorType A: " << argTypeA << "\n" @@ -516,10 +518,15 @@ struct BlockwiseGemmAccelRewritePattern b, inputBuffer, getElementTypeOrSelf(argType), shapeForLoad); } // regs = read from LDS - ThreadwiseReadIntoOp::create( + auto twr = ThreadwiseReadIntoOp::create( b, loc, wrappedLDSBufferForLoad, viewForReadInto, b.getArrayAttr({}), ValueRange{tid, loopVar}, /*forceUnroll=*/true, /*useIndexDiffs=*/true); + // Apply stored transpose attributes if a valid decision exists. + if (globalDecisionOpt && + hwtranspose::isApplicable(*globalDecisionOpt)) { + hwtranspose::attachAttributes(twr, *globalDecisionOpt, b, isA); + } } else { if (cast(buffer.getType()).getRank() == 1) { StringRef dk = isA ? "mk" : "nk"; diff --git a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp index 80b3da613e30..2df993f62b5a 100644 --- a/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp @@ -17,6 +17,7 @@ //===----------------------------------------------------------------------===// #include "GridLayoutEmitter.h" +#include "LdsTransposeLoad.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -75,7 +76,8 @@ class LoweringBlockwiseLoadTileOp final Location loc, PatternRewriter &b, const std::unique_ptr &accelEmitterPtr, Value tid, StringRef dName, Value ldsView, Value regs, int64_t blockSize, - bool forceUnroll, const BlockwiseMatrixParamsAttr &matrixParams) const { + bool forceUnroll, const BlockwiseMatrixParamsAttr &matrixParams, + bool isA) const { // wrapLDSBufferForLoad is reading a single set of Ks into private memory // A/B[m/n, 0:kBasePerThread] @@ -111,10 +113,15 @@ class LoweringBlockwiseLoadTileOp final regs = rock::transform(b, regs, b.getArrayAttr({mkRegBuilder.get()})); } - ThreadwiseReadIntoOp::create(b, loc, ldsViewForLoad, regs, - b.getArrayAttr({}), ValueRange{tid}, - /*forceUnroll=*/forceUnroll, - /*useIndexDiffs=*/true); + auto globalDecisionOpt = hwtranspose::getDecisionLdsTranspose(); + auto twr = ThreadwiseReadIntoOp::create(b, loc, ldsViewForLoad, regs, + b.getArrayAttr({}), ValueRange{tid}, + /*forceUnroll=*/forceUnroll, + /*useIndexDiffs=*/true); + // Apply the global decision if it exists and is marked as usable. + if (globalDecisionOpt && hwtranspose::isApplicable(*globalDecisionOpt)) { + hwtranspose::attachAttributes(twr, *globalDecisionOpt, b, isA); + } } std::pair createOrGetStage(PatternRewriter &b, Location loc, @@ -404,7 +411,8 @@ class LoweringBlockwiseLoadTileOp final } generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm, - destRegisters, blockSize, forceUnroll, matrixParams); + destRegisters, blockSize, forceUnroll, matrixParams, + isA); if (stageLDSReadNew) rock::YieldOp::create(b, loc); } diff --git a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt index bfa41f0e3986..062188f4ef6b 100644 --- a/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt @@ -33,6 +33,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms FindFirstGemmIndex.cpp RemoveOutputAlloc.cpp BlockwiseLoadTileToThreadwise.cpp + LdsTransposeLoad.cpp AnnotateLiveness.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp index 4ba6147ef2d9..17257bf2e64c 100644 --- a/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp @@ -33,6 +33,8 @@ #include "mlir/Dialect/Rock/utility/math.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" +#include "GridLayoutEmitter.h" +#include "LdsTransposeLoad.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,6 +42,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" @@ -54,9 +57,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" - -#include "GridLayoutEmitter.h" -#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -1301,8 +1301,8 @@ struct GridwiseAttentionAccelRewritePattern switch (outOfScopeType) { case OutOfScopeType::KVCache: assert(currentSeqLen != nullptr); - isInvalid = arith::CmpIOp::create(thenb, - loc, arith::CmpIPredicate::ugt, mIndex, currentSeqLen); + isInvalid = arith::CmpIOp::create( + thenb, loc, arith::CmpIPredicate::ugt, mIndex, currentSeqLen); break; case OutOfScopeType::Causal: Value nIndex = lowerCoords[1]; @@ -1310,8 +1310,8 @@ struct GridwiseAttentionAccelRewritePattern nIndex = thenb.createOrFold(loc, nIndex, constNumRepeatsGQA); - isInvalid = arith::CmpIOp::create(thenb, - loc, arith::CmpIPredicate::ugt, mIndex, nIndex); + isInvalid = arith::CmpIOp::create( + thenb, loc, arith::CmpIPredicate::ugt, mIndex, nIndex); break; } @@ -1722,7 +1722,7 @@ struct GridwiseAttentionAccelRewritePattern // so we need to take the minimum of currentSeqLen and maxRowOfBlock if (effectiveSeqLen) maxRowOfBlock = arith::MinUIOp::create(rewriter, loc, currentSeqLen, - maxRowOfBlock); + maxRowOfBlock); effectiveSeqLen = maxRowOfBlock; } @@ -1730,7 +1730,7 @@ struct GridwiseAttentionAccelRewritePattern Value constGemm0MPerBlock = rewriter.createOrFold(loc, gemm0MPerBlock); Value numerator = arith::AddIOp::create(rewriter, loc, effectiveSeqLen, - constGemm0MPerBlock); + constGemm0MPerBlock); end = rewriter.createOrFold(loc, numerator, constGemm0MPerBlock); Value one = rewriter.createOrFold(loc, 1); @@ -1751,12 +1751,12 @@ struct GridwiseAttentionAccelRewritePattern rewriter.createOrFold(loc, numerator, constSplitKV); // if split-kv is enabled, we need to compute the start and end indices. - start = arith::MulIOp::create(rewriter, loc, gridCoordsGemm0.split_block, - gemm0MIterations); - Value splitPlusOne = arith::AddIOp::create(rewriter, loc, - gridCoordsGemm0.split_block, one); + start = arith::MulIOp::create( + rewriter, loc, gridCoordsGemm0.split_block, gemm0MIterations); + Value splitPlusOne = arith::AddIOp::create( + rewriter, loc, gridCoordsGemm0.split_block, one); Value endSplitKV = arith::MulIOp::create(rewriter, loc, splitPlusOne, - gemm0MIterations); + gemm0MIterations); end = arith::MinUIOp::create(rewriter, loc, end, endSplitKV); } // compute last iteration of the block, this will be used later in @@ -1773,9 +1773,10 @@ struct GridwiseAttentionAccelRewritePattern Value one = rewriter.createOrFold(loc, 1); start = arith::MulIOp::create(rewriter, loc, gridCoordsGemm0.split_block, gemm0MIterations); - Value splitPlusOne = - arith::AddIOp::create(rewriter, loc, gridCoordsGemm0.split_block, one); - end = arith::MulIOp::create(rewriter, loc, splitPlusOne, gemm0MIterations); + Value splitPlusOne = arith::AddIOp::create( + rewriter, loc, gridCoordsGemm0.split_block, one); + end = + arith::MulIOp::create(rewriter, loc, splitPlusOne, gemm0MIterations); } return std::make_tuple(start, end, gemm0MBlocksLastIter, currentSeqLen); } @@ -2975,6 +2976,21 @@ struct GridwiseGemmAccelRewritePattern zeroAccBuffer(b, loc, regCAllocOp); + hwtranspose::Decision decision; + auto *mfma = dyn_cast(accelEmitterPtr.get()); + // Only compute a transpose decision if both layouts are DxK disabled. + if (mfma && !ldsLayoutConfigA.ldsLayoutDxK && + !ldsLayoutConfigB.ldsLayoutDxK) { + hwtranspose::MfmaInstrShape shape{mfma->getMfmaNonKDim(), + mfma->getMfmaK()}; + decision = hwtranspose::makeDecision( + arch, elementTypeA, elementTypeB, directToLDS, shape, + hwtranspose::OperandKind::A, hwtranspose::OperandKind::B, mPerBlock, + nPerBlock, kPerBlock); + // Store the computed decision globally for later use. + hwtranspose::setDecisionLdsTranspose(decision); + } + // Emit loop. Value nIterations = ConstantIndexOp::create(b, loc, K / kPerBlock); Value step = ConstantIndexOp::create(b, loc, 1); diff --git a/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.cpp new file mode 100644 index 000000000000..f9122a6aaad6 --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.cpp @@ -0,0 +1,565 @@ +//===- LdsTransposeLoad.cpp - MLIR helper for rock.lds_transpose_load +// +// Copyright 2025 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines helper functions for MLIR code generation related to +// rock.lds_transpose_load operations. It provides utilities for computing +// panel offsets, generating indices, and emitting calls to the LDS +// transpose load operation in a MFMA-friendly layout. +// +// It is intended to simplify the IR generation logic and ensure +// consistent handling of f16/bf16 panel loads from LDS memory. +// +//===----------------------------------------------------------------------===// + +#include "LdsTransposeLoad.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/Dialect/Rock/utility/builderUtils.h" +#include "mlir/Dialect/Rock/utility/transformMapUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "rock-hw-transpose-support" + +using namespace mlir; +using namespace mlir::rock; + +namespace mlir::rock::hwtranspose { +namespace { + +bool archSupported(StringRef arch) { return arch.contains("gfx950"); } + +// Describes available hardware layouts and their MFMA geometry. +struct LayoutConfig { + LayoutKind kind; + int64_t mnDim; + int64_t kDim; + StringRef name; +}; + +static constexpr LayoutConfig kLayoutConfigs[] = { + {LayoutKind::L16x32, 16, 32, "16x32"}, + {LayoutKind::L32x16, 32, 16, "32x16"}, + {LayoutKind::L16x16, 16, 16, "16x16"}, + {LayoutKind::L32x8, 32, 8, "32x8"}}; +} // namespace + +// Calculates the number of M/N/K panels per block based on the MFMA instruction +// shape. Returns `true` if the dimensions divide evenly, otherwise `false`. +bool calculatePanels(const MfmaInstrShape &shape, OperandKind operandA, + OperandKind operandB, int64_t &mPerBlock, + int64_t &nPerBlock, int64_t kPerBlock, int64_t &mPanels, + int64_t &nPanels, int64_t &kPanels) { + + if (kPerBlock % shape.kMfma != 0) { + return false; + } + kPanels = kPerBlock / shape.kMfma; + + if (operandA == OperandKind::A && operandB == OperandKind::B) { + if (mPerBlock % shape.mnMfma != 0) + return false; + mPanels = mPerBlock / shape.mnMfma; + if (nPerBlock % shape.mnMfma != 0) + return false; + nPanels = nPerBlock / shape.mnMfma; + return true; + } + return true; +} + +LayoutKind selectLayout(int64_t mnDim, int64_t kDim) { + for (const auto &config : kLayoutConfigs) { + if (config.mnDim == mnDim && config.kDim == kDim) { + return config.kind; + } + } + return LayoutKind::None; +} + +static DecisionLdsTransposeContext LdsTransposeDecison; + +DecisionLdsTransposeContext &getDecisionLdsTransposeContext() { + return LdsTransposeDecison; +} + +// Analyzes GEMM tiling and MFMA instruction parameters to determine +// if the hardware LDS transpose optimization can be applied. +// Returns a `Decision` struct indicating applicability and layout details. +Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, + bool DirectToLds, const MfmaInstrShape &shape, + OperandKind operandA, OperandKind operandB, + int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock) { + Decision dec; + dec.operandA = operandA; + dec.operandB = operandB; + dec.mPerBlock = mPerBlock; + dec.nPerBlock = nPerBlock; + + // Basic applicability checks + if (!archSupported(arch) || !DirectToLds) { + return dec; + } + + if (elemTypeA != elemTypeB) { + return dec; + } + if (!(elemTypeA.isF16() || elemTypeA.isBF16()) || + !(elemTypeB.isF16() || elemTypeB.isBF16())) { + return dec; + } + + // Check MFMA instruction shape and select a layout + bool geomOk = ((shape.mnMfma == 16 || shape.mnMfma == 32) && + (shape.kMfma == 8 || shape.kMfma == 16 || shape.kMfma == 32)); + if (!geomOk) { + return dec; + } + + dec.layout = selectLayout(shape.mnMfma, shape.kMfma); + if (dec.layout == LayoutKind::None) { + return dec; + } + + // Calculate and validate paneling + if (!calculatePanels(shape, dec.operandA, dec.operandB, dec.mPerBlock, + dec.nPerBlock, kPerBlock, dec.mPanels, dec.nPanels, + dec.kPanels)) { + return dec; + } + + // If all checks pass, the decision is usable + dec.usable = true; + return dec; +} + +StringRef layoutName(LayoutKind kind) { + for (const auto &config : kLayoutConfigs) + if (config.kind == kind) + return config.name; + return "none"; +} + +// Attaches attributes to a `ThreadwiseReadIntoOp` to encode the chosen +// LDS transpose configuration for later lowering. +void attachAttributes(Operation *readIntoOp, const Decision &dec, + PatternRewriter &rewriter, bool isA) { + if (!dec.usable) + return; + readIntoOp->setAttr("rock.hw_lds_transpose_enabled", rewriter.getUnitAttr()); + readIntoOp->setAttr("rock.hw_lds_transpose_layout", + rewriter.getStringAttr(layoutName(dec.layout))); + + if (isA) { + readIntoOp->setAttr("rock.hw_lds_transpose_operand", + rewriter.getStringAttr("A")); + if (dec.mPerBlock) + readIntoOp->setAttr("rock.hw_lds_transpose_mperblock", + rewriter.getI64IntegerAttr(dec.mPerBlock)); + if (dec.mPanels > 1) + readIntoOp->setAttr("rock.hw_lds_transpose_mpanels", + rewriter.getI64IntegerAttr(dec.mPanels)); + } else { + readIntoOp->setAttr("rock.hw_lds_transpose_operand", + rewriter.getStringAttr("B")); + if (dec.nPerBlock) + readIntoOp->setAttr("rock.hw_lds_transpose_nperblock", + rewriter.getI64IntegerAttr(dec.nPerBlock)); + if (dec.nPanels > 1) + readIntoOp->setAttr("rock.hw_lds_transpose_npanels", + rewriter.getI64IntegerAttr(dec.nPanels)); + } + if (dec.kPanels > 1) + readIntoOp->setAttr("rock.hw_lds_transpose_kpanels", + rewriter.getI64IntegerAttr(dec.kPanels)); + + LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] attachAttributes: enabled layout=" + << layoutName(dec.layout) << "\n"); +} + +static LayoutKind layoutFromString(StringRef s) { + for (const auto &config : kLayoutConfigs) { + if (config.name == s) { + return config.kind; + } + } + return LayoutKind::None; +} + +// Derived lowering-time configuration extracted from operation attributes. +// Used to drive emission of LDS transpose load instructions. +LoweringInfo deriveLoweringInfo(ThreadwiseReadIntoOp op, PatternRewriter &b) { + LoweringInfo info; + auto layoutAttr = + op->getAttrOfType("rock.hw_lds_transpose_layout"); + if (!layoutAttr) + return info; + + info.layout = layoutFromString(layoutAttr.getValue()); + if (info.layout == LayoutKind::None) + return info; + + // Destination buffer type + auto dest = op.getDest(); + auto destType = cast(dest.getType()); + Type elemType = destType.getElementType(); + info.elemType = elemType; + + // Operand kind + if (auto operandAttr = + op->getAttrOfType("rock.hw_lds_transpose_operand")) { + StringRef val = operandAttr.getValue(); + if (val == "A") { + info.operand = OperandKind::A; + if (auto mPerBlockAttr = + op->getAttrOfType("rock.hw_lds_transpose_mperblock")) + info.mPerBlock = mPerBlockAttr.getInt(); + if (auto mPanelsAttr = + op->getAttrOfType("rock.hw_lds_transpose_mpanels")) + info.mPanels = mPanelsAttr.getInt(); + if (auto kPanelsAttr = + op->getAttrOfType("rock.hw_lds_transpose_kpanels")) + info.kPanels = kPanelsAttr.getInt(); + + } else if (val == "B") { + info.operand = OperandKind::B; + if (auto nPerBlockAttr = + op->getAttrOfType("rock.hw_lds_transpose_nperblock")) + info.nPerBlock = nPerBlockAttr.getInt(); + if (auto nPanelsAttr = + op->getAttrOfType("rock.hw_lds_transpose_npanels")) + info.nPanels = nPanelsAttr.getInt(); + if (auto kPanelsAttr = + op->getAttrOfType("rock.hw_lds_transpose_kpanels")) + info.kPanels = kPanelsAttr.getInt(); + } + } + + info.usable = true; + return info; +} + +// Helper to get layout dimensions consistently +static std::pair getLayoutDims(LayoutKind kind) { + for (const auto &config : kLayoutConfigs) { + if (config.kind == kind) + return {config.mnDim, config.kDim}; + } + return {0, 0}; +} + +//===----------------------------------------------------------------------===// +// getBasePanelOffsets - Compute per-panel LDS offsets for a given lane ID +// +// Given a wavefront lane ID and a specific MFMA layout (L16x32, L16x16, etc.), +// this function computes the base byte offsets into LDS memory where each +// lane should read its operands from. +// +// These offsets are derived from AMD’s LDS tiling and MFMA operand layout +// conventions (e.g., 16x16, 16x32 panels). The goal is to map each lane’s +// register to the correct element position in LDS. +//===----------------------------------------------------------------------===// +static SmallVector getBasePanelOffsets(LayoutKind layout, Value lane, + PatternRewriter &b, + Location loc) { + auto cst = [&](int64_t v) { + return b.create(loc, v); + }; + + auto add = [&](Value a, Value m) { + return b.create(loc, a, m); + }; + auto mul = [&](Value a, Value m) { + return b.create(loc, a, m); + }; + auto div = [&](Value a, Value m) { + return b.create(loc, a, m); + }; + auto rem = [&](Value a, Value m) { + return b.create(loc, a, m); + }; + SmallVector panelOffsets; + Value c16 = cst(16), c4 = cst(4), c2 = cst(2); + Value blockId = div(lane, c16); + Value laneInBlock = rem(lane, c16); + // Base offset calculations + Value mOffsetBase = mul(rem(laneInBlock, c4), c4); + Value kOffsetBase = div(laneInBlock, c4); + + switch (layout) { + case LayoutKind::L16x32: { + panelOffsets = {kOffsetBase, mOffsetBase}; + break; + } + case LayoutKind::L16x16: { + // kbase = kOffsetBase + (blockId * 4) + Value kBase = add(mul(blockId, c4), kOffsetBase); + panelOffsets = {kBase, mOffsetBase}; + break; + } + case LayoutKind::L32x16: { + // mbase = mOffsetBase + (blockId % 2) * 16 + Value mBase = add(mul(rem(blockId, c2), cst(16)), mOffsetBase); + panelOffsets = {kOffsetBase, mBase}; + break; + } + case LayoutKind::L32x8: { + // k_base_local = kOffsetBase + (blockId / 2) * 4 + Value kBase = add(mul(div(blockId, c2), c4), kOffsetBase); + + // m_offset_base = mOffsetBase + (blockId % 2) * 16 + Value mBase = add(mul(rem(blockId, c2), cst(16)), mOffsetBase); + panelOffsets = {kBase, mBase}; + break; + } + default: + llvm_unreachable("Unsupported layout in getBasePanelOffsets"); + } + return panelOffsets; +} + +LogicalResult emitThreadwiseHWTranspose(ThreadwiseReadIntoOp op, + const LoweringInfo &info, + PatternRewriter &b) { + if (!info.usable) + return failure(); + + Location loc = op.getLoc(); + auto dest = op.getDest(); + auto destType = cast(dest.getType()); + Type elemType = info.elemType; + Value sourceView = op.getSource(); + auto [rawSrc, _, __] = untransform(b, sourceView); + + Value tid = b.createOrFold(loc, b.getIndexType()); + auto cst = [&](int64_t v) { + return b.create(loc, v); + }; + // Compute lane ID within the wavefront (0–63). + Value lane = b.create(loc, tid, cst(64)); + + // Use mPerBlock as stride for operand A, nPerBlock for operand B + int64_t ldsStride = + (info.operand == OperandKind::A) ? info.mPerBlock : info.nPerBlock; + + // Compute base LDS panel offsets according to the layout and lane mapping. + SmallVector panelOffsets = + getBasePanelOffsets(info.layout, lane, b, loc); + + // Determine if this is a double-rate instruction + // Double-rate ONLY for L32x16 (32x32x16 MFMA) and L16x32 (16x16x32 MFMA) + // L16x16 (16x16x16 MFMA) and L32x8 (32x32x8 MFMA) are SINGLE-RATE + // instruction. + auto [nonKDim, instrK] = getLayoutDims(info.layout); + bool isDoubleRate = + (info.layout == LayoutKind::L32x16 || info.layout == LayoutKind::L16x32); + + // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> + // For double-rate, we make 2 calls and store all 8 elements separately + VectorType panelVecType = VectorType::get({4}, elemType); + + // panelVectors will contain: + // - Single-rate: 1 vector<4xf16> per K tile + // - Double-rate: 2 vector<4xf16> per K tile (low + high) + SmallVector panelVectors; + + // Get base offsets from getBasePanelOffsets + // For panelOffsets[0] = k_base_local, panelOffsets[1] = m_offset_base + Value k_base_local = panelOffsets[0]; + Value m_offset_base = panelOffsets[1]; + + // M/N stride: MNMfma (e.g., 32) + // K stride per tile: KMfma (e.g., 8) + int64_t mnStride = nonKDim; + int64_t kTileStride = instrK; + + Value mnStrideVal = cst(mnStride); + Value kTileStrideVal = cst(kTileStride); + Value ldsStrideVal = cst(ldsStride); + + // The extra indices tell us WHICH M/N tile we're loading in this iteration. + // Check if there's an extra index for M/N tile selection + ValueRange extraIndices = op.getExtraIndices(); + Value mnTileIndex = nullptr; + + // Extra indices format: [tid, m_tile_idx] for A or [tid, n_tile_idx] for B + if (extraIndices.size() >= 2) { + mnTileIndex = extraIndices[1]; // Second index is the M/N tile iterator + } + + // If we have an M/N tile index from the outer loop, use it + // Otherwise, generate all M/N tiles (fallback for single-tile case) + int64_t startMnIdx = 0; + int64_t endMnIdx = 1; + bool useDynamicMnIndex = false; + + if (mnTileIndex) { + // Outer loop handles M/N iteration, we load only ONE M/N tile per call + useDynamicMnIndex = true; + endMnIdx = 1; + } else { + // No outer loop, generate all M/N tiles statically + if (info.operand == OperandKind::A) { + endMnIdx = info.mPanels; + } else if (info.operand == OperandKind::B) { + endMnIdx = info.nPanels; + } + } + + int64_t kPanels = info.kPanels; + + // For double-rate layouts ONLY (L32x16, L16x32), compute k_offset_base + // L32x16 (32x32x16): k_offset_base = (block_id / 2) * 8 + // L16x32 (16x16x32): k_offset_base = block_id * 8 + Value blockId = nullptr; + Value kOffsetBase = nullptr; + + if (isDoubleRate) { + Value c16 = cst(16), c2 = cst(2), c8 = cst(8); + blockId = b.create(loc, lane, c16); + + if (info.layout == LayoutKind::L32x16) { + // k_offset_base = (block_id / 2) * 8 + kOffsetBase = b.create( + loc, b.create(loc, blockId, c2), c8); + } else if (info.layout == LayoutKind::L16x32) { + // k_offset_base = block_id * 8 + kOffsetBase = b.create(loc, blockId, c8); + } + } + + // Generate loads: If outer loop exists, load one M/N tile with all K tiles + // Otherwise, load all M/N tiles with all K tiles + for (int64_t mnIdxLocal = startMnIdx; mnIdxLocal < endMnIdx; ++mnIdxLocal) { + for (int64_t kIdx = 0; kIdx < kPanels; ++kIdx) { + // Calculate m_base for this M/N tile + Value m_base = m_offset_base; + if (useDynamicMnIndex) { + // Use dynamic index from outer loop: m_base += mnTileIndex * mnStride + Value mnOffsetAdd = + b.create(loc, mnTileIndex, mnStrideVal); + m_base = b.create(loc, m_base, mnOffsetAdd); + } else if (mnIdxLocal > 0) { + // Use static index: m_base += mnIdxLocal * mnStride + Value mnOffsetAdd = + b.create(loc, mnStrideVal, cst(mnIdxLocal)); + m_base = b.create(loc, m_base, mnOffsetAdd); + } + + if (!isDoubleRate) { + // SINGLE-RATE (L32x8, L16x16): One load per K tile + // k_base = k_base_local + kIdx * kTileStride + Value k_base = k_base_local; + if (kIdx > 0) { + Value kOffsetAdd = + b.create(loc, kTileStrideVal, cst(kIdx)); + k_base = b.create(loc, k_base, kOffsetAdd); + } + + // final_offset = k_base * ldsStride + m_base + Value final_offset = b.create( + loc, m_base, b.create(loc, k_base, ldsStrideVal)); + + // Perform LDS transpose load (ds_read_tr16_b64) -> returns + // vector<4xf16> + auto l = b.create(loc, panelVecType, rawSrc, + ValueRange{final_offset}); + panelVectors.push_back(l.getFragment()); + + } else { + // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile + // Each load returns vector<4xf16>, total 8 elements per K tile + // k_offset_low = k_offset_base + k_tile * KMfma + // k_offset_high = k_offset_base + 4 + k_tile * KMfma + + Value kTileOffset = + b.create(loc, kTileStrideVal, cst(kIdx)); + Value k_offset_low = + b.create(loc, kOffsetBase, kTileOffset); + Value k_offset_high = + b.create(loc, k_offset_low, cst(4)); + + Value k_base_low = + b.create(loc, k_base_local, k_offset_low); + Value k_base_high = + b.create(loc, k_base_local, k_offset_high); + + // offset_low = k_base_low * ldsStride + m_base + Value offset_low = b.create( + loc, m_base, + b.create(loc, k_base_low, ldsStrideVal)); + + // offset_high = k_base_high * ldsStride + m_base + Value offset_high = b.create( + loc, m_base, + b.create(loc, k_base_high, ldsStrideVal)); + + // Load low half: returns vector<4xf16> + auto load_low = b.create( + loc, panelVecType, rawSrc, ValueRange{offset_low}); + panelVectors.push_back(load_low.getFragment()); + + // Load high half: returns vector<4xf16> + auto load_high = b.create( + loc, panelVecType, rawSrc, ValueRange{offset_high}); + panelVectors.push_back(load_high.getFragment()); + } + } + } + + // Calculate expected number of loads + // - Single-rate: 1 load per K tile → actualMnTiles × kPanels loads + // - Double-rate: 2 loads per K tile → actualMnTiles × kPanels × 2 loads + int64_t actualMnTiles = endMnIdx - startMnIdx; + int64_t loadsPerKTile = isDoubleRate ? 2 : 1; + int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; + + // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) + int64_t sliceElems = expectedLoads * 4; + + // Verify we generated the expected number of loads + if (panelVectors.size() != (size_t)expectedLoads) { + return op.emitOpError("Mismatch in number of generated loads: expected ") + << expectedLoads << ", got " << panelVectors.size(); + } + + // Scalar buffer path rank-1. + int64_t destCap = destType.getShape()[0]; + int64_t targetElems = std::min(sliceElems, destCap); + int64_t produced = 0; + + // Write each extracted element from the loaded panel vectors into `dest`. + // The destination is rank-1, meaning scalar sequential layout. + for (Value pv : panelVectors) { + for (int lane = 0; lane < 4 && produced < targetElems; ++lane) { + Value ciLane = cst(lane); + Value elem = b.create(loc, pv, ciLane); + Value idx = cst(produced++); + b.create(loc, elem, dest, ValueRange{idx}); + } + if (produced >= targetElems) + break; // Stop once we have written all target elements + } + + b.replaceOp(op, ValueRange{}); + return success(); +} + +} // namespace mlir::rock::hwtranspose diff --git a/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.h b/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.h new file mode 100644 index 000000000000..561f2896ee83 --- /dev/null +++ b/mlir/lib/Dialect/Rock/Transforms/LdsTransposeLoad.h @@ -0,0 +1,123 @@ +//===- LdsTransposeLoad.h - MLIR helper for rock.lds_transpose_load -------===// +// +// Copyright 2025 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines helper functions for MLIR code generation related to +// rock.lds_transpose_load operations. It provides utilities for computing +// panel offsets, generating indices, and emitting calls to the LDS +// transpose load operation in a MFMA-friendly layout. +// +// It is intended to simplify the IR generation logic and ensure +// consistent handling of f16/bf16 panel loads from LDS memory. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_DIALECT_ROCK_TRANSFORMS_LDS_TRANSPOSE_LOAD_H +#define MLIR_LIB_DIALECT_ROCK_TRANSFORMS_LDS_TRANSPOSE_LOAD_H + +#include "mlir/Dialect/Rock/IR/AmdArchDb.h" +#include "mlir/Dialect/Rock/IR/Rock.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::rock::hwtranspose { + +// Operand selector (A or B matrix) +enum class OperandKind { A, B }; + +// Simplified layout kinds +enum class LayoutKind { None, L16x32, L32x16, L16x16, L32x8 }; + +// Shape of a single MFMA instruction this load cooperates with. +struct MfmaInstrShape { + int64_t mnMfma; + int64_t kMfma; +}; + +// Structure to hold the outcome of the hardware transpose analysis. +struct Decision { + bool usable{false}; + LayoutKind layout{LayoutKind::None}; + OperandKind operandA{OperandKind::A}; + OperandKind operandB{OperandKind::B}; + int64_t mPanels{1}; + int64_t nPanels{1}; + int64_t kPanels{1}; + int64_t mPerBlock{1}; + int64_t nPerBlock{1}; +}; + +struct DecisionLdsTransposeContext { + std::optional currentDecision; +}; + +// Global accessor for shared decision state during codegen. +DecisionLdsTransposeContext &getDecisionLdsTransposeContext(); + +inline void setDecisionLdsTranspose(const Decision &dec) { + getDecisionLdsTransposeContext().currentDecision = dec; +} + +inline std::optional getDecisionLdsTranspose() { + return getDecisionLdsTransposeContext().currentDecision; +} + +// The main decision-making function. It analyzes the GEMM parameters and +// returns a Decision struct indicating if the optimization is applicable and +// with which paneling configuration. +Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, + bool DirectToLds, const MfmaInstrShape &shape, + OperandKind operandA, OperandKind operandB, + int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock); + +// A convenience wrapper around makeDecision to quickly check for applicability. +inline bool isApplicable(const Decision &dec) { return dec.usable; } + +// Select a layout kind based on the MFMA instruction shape. +LayoutKind selectLayout(int64_t nonKDim, int64_t instrK); + +// Attach attributes to the ThreadwiseReadIntoOp based on the decision. +void attachAttributes(Operation *readIntoOp, const Decision &dec, + PatternRewriter &rewriter, bool isA); + +// Lowering-time description. +struct LoweringInfo { + bool usable{false}; + LayoutKind layout{LayoutKind::None}; + OperandKind operand{OperandKind::A}; + Type elemType; + bool destIsVector{false}; + VectorType destVecType; + int64_t mPanels{1}; + int64_t nPanels{1}; + int64_t kPanels{1}; + int64_t mPerBlock{1}; + int64_t nPerBlock{1}; +}; + +// Derives lowering information from the attributes of a ThreadwiseReadIntoOp. +LoweringInfo deriveLoweringInfo(ThreadwiseReadIntoOp op, PatternRewriter &b); + +// Emits the actual hardware transpose load sequence. +LogicalResult emitThreadwiseHWTranspose(ThreadwiseReadIntoOp op, + const LoweringInfo &info, + PatternRewriter &b); + +// Utility to get the string name of a layout. +StringRef layoutName(LayoutKind kind); + +} // namespace mlir::rock::hwtranspose + +#endif // MLIR_LIB_DIALECT_ROCK_TRANSFORMS_LDS_TRANSPOSE_LOAD_H diff --git a/mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp b/mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp index 730f156745fb..ca8ad544401f 100644 --- a/mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp @@ -1565,6 +1565,23 @@ struct GlobalStoreRewritePattern : public OpRewritePattern { } }; +//===----------------------------------------------------------------------===// +// LDSTransposeLoadOp lowering. +//===----------------------------------------------------------------------===// +struct LDSTransposeLoadRewritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(LDSTransposeLoadOp op, + PatternRewriter &b) const override { + + // Replace with amdgpu.transpose_load having identical semantics. + auto newOp = b.create( + op.getLoc(), op.getResult().getType(), op.getSource(), op.getIndices()); + b.replaceOp(op, newOp.getResult()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // InBoundsLoad lowering. //===----------------------------------------------------------------------===// @@ -1620,8 +1637,8 @@ void RockSugarToLoopsPass::runOnOperation() { RewritePatternSet patterns(ctx); patterns.add(ctx); + GlobalStoreRewritePattern, LDSTransposeLoadRewritePattern, + InBoundsLoadRewritePattern, InBoundsStoreRewritePattern>(ctx); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp index da08c64f7c8d..11df6df70fb5 100644 --- a/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp +++ b/mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp @@ -34,15 +34,15 @@ #include "mlir/Dialect/Rock/utility/math.h" #include "mlir/Dialect/Rock/utility/transformMapUtils.h" +#include "LdsTransposeLoad.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" - -#include "mlir/Dialect/Rock/IR/AccelEmitter.h" #include "llvm/Support/Debug.h" #include @@ -604,6 +604,21 @@ LogicalResult ThreadwiseReadIntoRewritePattern::matchAndRewrite( } } + // Check if the operation has the attribute for LDS Transpose Load + if (op->hasAttr("rock.hw_lds_transpose_enabled")) { + // Derive lowering info from attributes (layout, panel counts, operand). + auto info = mlir::rock::hwtranspose::deriveLoweringInfo(op, b); + if (info.usable) { + if (failed(rock::hwtranspose::emitThreadwiseHWTranspose(op, info, b))) { + return failure(); + } + } else { + return op.emitOpError("LDS transpose load emission is not usable with " + "the derived attributes"); + } + return success(); + } + size_t extraIdxCount = op.getExtraIndices().size(); // We are vectorizing in the iter dimension, not block ID or thread ID auto elementType = sourceViewType.getElementType(); diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index c520c8dde13c..31e0bd998719 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -568,6 +568,16 @@ int64_t MfmaEmitter::getRowGroupSize() const { return mfmaAttr.rowGroupSize; } +int64_t MfmaEmitter::getMfmaK() const { + MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); + return mfmaAttr.k; +} + +int64_t MfmaEmitter::getMfmaNonKDim() const { + MfmaInsnAttr mfmaAttr = mfmaGroup.getInsnAttr(); + return mfmaAttr.mfmaNonKDim; +} + llvm::FailureOr MfmaEmitter::createAccelGemmOperandTransforms( OpBuilder &b, Location loc, int64_t kIters, diff --git a/mlir/test/Dialect/Rock/load_transpose_lds.mlir b/mlir/test/Dialect/Rock/load_transpose_lds.mlir new file mode 100644 index 000000000000..7a8afa1341b6 --- /dev/null +++ b/mlir/test/Dialect/Rock/load_transpose_lds.mlir @@ -0,0 +1,17 @@ +// RUN: rocmlir-opt --rock-sugar-to-loops %s | FileCheck %s + +// CHECK-LABEL: func @test_load_transpose_fp16 +module { + func.func @test_load_transpose_fp16(%src: memref<128x256xf16, 3>, %i: index, %j: index) -> vector<4xf16> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf16, 3> -> vector<4xf16> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf16, 3> -> vector<4xf16> + return %v : vector<4xf16> + } + +// CHECK-LABEL: func @test_load_transpose_bf16 + func.func @test_load_transpose_bf16(%src: memref<64x128xbf16, 3>, %i: index, %j: index) -> vector<4xbf16> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xbf16, 3> -> vector<4xbf16> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, 3> -> vector<4xbf16> + return %v : vector<4xbf16> + } +} diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index 3cc7b4709b7e..42cacac32d74 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -46,6 +46,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmF16SplitK PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS + PrLdsTransposeLoad PrConvDirectToLDS ) set(GEN_MODE "") diff --git a/mlir/test/e2e/PRLdsTransposeLoad.cfg b/mlir/test/e2e/PRLdsTransposeLoad.cfg new file mode 100644 index 000000000000..4d2974376114 --- /dev/null +++ b/mlir/test/e2e/PRLdsTransposeLoad.cfg @@ -0,0 +1,2 @@ +if config.arch != "gfx950": + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoad.toml b/mlir/test/e2e/PrLdsTransposeLoad.toml new file mode 100644 index 000000000000..e98a2b6d52df --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoad.toml @@ -0,0 +1,42 @@ +directory = "PrLdsTransposeLoad" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "transA" +values = ["true"] +prefix = "--transA=" + +[[axis]] +name = "transB" +values = ["false"] +prefix = "--transB=" + +[[axis]] +name = "data type" +values = ["f16", "bf16"] +prefix = "-t " + +[[suite]] +name = "lds_transpose_load" + +[[suite.test]] +config = "-g 1 -m 16 -k 16 -n 16 --perf_config v3:16,16,16,16,16,1,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 32 -k 8 -n 32 --perf_config v3:32,32,8,32,32,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 16 -k 64 -n 16 --perf_config v3:16,16,64,16,16,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 1 -m 32 -k 128 -n 32 --perf_config v3:32,32,32,32,32,4,1,3,2,1,1" + +[[suite.test]] +config = "-g 1 -m 32 -k 256 -n 32 --perf_config v3:32,32,256,32,32,1,1,4,2,1,1" + +[[suite.test]] +config = "-g 256 -m 32 -k 32 -n 32 --perf_config v3:32,32,32,32,32,1,1,3,2,1,1" + +[[suite.test]] +config = "-g 256 -m 16 -k 64 -n 16 --perf_config v3:16,16,64,16,16,1,1,4,2,1,1"