Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ struct MfmaEmitter : public AccelEmitter {

int64_t getRowGroupSize() const;

// Return the MFMA K dimension
int64_t getMfmaK() const;

// Return the MFMA instruction's non-K dimension
int64_t getMfmaNonKDim() const;

static bool classof(const AccelEmitter *AE) {
return AE->getKind() == AccelEmitterKind::AEK_MFMAEmitter;
}
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,29 @@ defvar SameShapeVectorOfI1 = [{
::mlir::VectorType::get(::llvm::cast<::mlir::ShapedType>($_self).getShape(), ::mlir::IntegerType::get($_ctxt, 1))
}];

// lds_transpose_load
def Rock_LDSTransposeLoadOp :
Rock_Op<"lds_transpose_load",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]>,
Arguments<(ins
Arg<MemRefOf<[F16, BF16]>, "LDS source buffer">:$source,
Variadic<Index>:$indices)>,
Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$fragment)> {
let summary = "Hardware-assisted LDS transpose panel load (v4 fragment)";
let description = [{
LDS transpose panel load (vector<4 x F16/BF16>) for MFMA-friendly ordering.
Loads a single panel (4 elements) from LDS memory using AMD ds.read.tr16.* instructions.
The returned fragment ($fragment) is always a vector<4 x elementType>.
Helps to reduced bank conflict. Only supported for gfx950.
}];

let assemblyFormat = [{
$source (`[` $indices^ `]`)? attr-dict `:` type($source) `->` type($fragment)
}];

let hasVerifier = 1;
}

// threadwise_read_into
def Rock_ThreadwiseReadIntoOp
: Rock_Op<
Expand Down
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,47 @@ LogicalResult InBoundsStoreOp::verify() {
return success();
}

//===-----------------------------------------------------===//
// LDSTransposeLoadOp
//===-----------------------------------------------------===//
void LDSTransposeLoadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// This op only reads from LDS (workgroup) memory
auto *read = MemoryEffects::Read::get();
effects.emplace_back(read, &getSourceMutable());
}

LogicalResult LDSTransposeLoadOp::verify() {
// Source must be memref in workgroup (LDS) address space
MemRefType srcType = getSource().getType();
Attribute memSpaceAttr = srcType.getMemorySpace();
if (!memSpaceAttr)
return emitOpError(
"source memref must have an address space (workgroup/LDS)");
auto gpuMemSpaceAttr = dyn_cast<gpu::AddressSpaceAttr>(memSpaceAttr);
bool isWorkgroup = false;
if (gpuMemSpaceAttr &&
gpuMemSpaceAttr.getValue() == gpu::AddressSpace::Workgroup)
isWorkgroup = true;
else if (auto intAttr = dyn_cast<IntegerAttr>(memSpaceAttr)) {
// Accept raw integer 3 as LDS (common textual form memref<... , 3>)
if (intAttr.getInt() == 3)
isWorkgroup = true;
}
if (!isWorkgroup)
return emitOpError("source must reside in workgroup (LDS) memory");

// Indices size must match rank
if (getIndices().size() != srcType.getRank())
return emitOpError("expected " + Twine(srcType.getRank()) + " indices");
for (Value idx : getIndices()) {
if (!idx.getType().isIndex())
return emitOpError("indices must be of index type");
}

return success();
}

//===-----------------------------------------------------===//
// ThreadwiseReadIntoOp
//===-----------------------------------------------------===//
Expand Down
13 changes: 10 additions & 3 deletions mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"

#include "LdsTransposeLoad.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"

#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -441,6 +441,8 @@ struct BlockwiseGemmAccelRewritePattern
int64_t kBasePerThread = params.kBasePerThread;

auto tid = WorkitemIdOp::create(b, loc, b.getIndexType());
// Retrieve the stored lds transpose load decision.
auto globalDecisionOpt = hwtranspose::getDecisionLdsTranspose();

LLVM_DEBUG(llvm::dbgs()
<< "argVectorType A: " << argTypeA << "\n"
Expand Down Expand Up @@ -524,10 +526,15 @@ struct BlockwiseGemmAccelRewritePattern
b, inputBuffer, getElementTypeOrSelf(argType), shapeForLoad);
}
// regs = read from LDS
ThreadwiseReadIntoOp::create(
auto twr = ThreadwiseReadIntoOp::create(
b, loc, wrappedLDSBufferForLoad, viewForReadInto,
b.getArrayAttr({}), ValueRange{tid, loopVar}, /*forceUnroll=*/true,
/*useIndexDiffs=*/true);
// Apply stored transpose attributes if a valid decision exists.
if (globalDecisionOpt &&
hwtranspose::isApplicable(*globalDecisionOpt)) {
hwtranspose::attachAttributes(twr, *globalDecisionOpt, b, isA);
}
} else {
if (cast<ShapedType>(buffer.getType()).getRank() == 1) {
StringRef dk = isA ? "mk" : "nk";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//===----------------------------------------------------------------------===//

#include "GridLayoutEmitter.h"
#include "LdsTransposeLoad.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -76,7 +77,7 @@ class LoweringBlockwiseLoadTileOp final
const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
Value tid, StringRef dName, Value ldsView, Value regs, int64_t blockSize,
int64_t inDPerThread, bool rotateDWithK, bool forceUnroll,
bool directToLDS, bool ldsLayoutDxK) const {
bool directToLDS, bool ldsLayoutDxK, bool isA) const {

// wrapLDSBufferForLoad is reading a single set of Ks into private memory
// A/B[m/n, 0:kBasePerThread]
Expand Down Expand Up @@ -113,10 +114,15 @@ class LoweringBlockwiseLoadTileOp final
regs = rock::transform(b, regs, b.getArrayAttr({mkRegBuilder.get()}));
}

ThreadwiseReadIntoOp::create(b, loc, ldsViewForLoad, regs,
b.getArrayAttr({}), ValueRange{tid},
/*forceUnroll=*/forceUnroll,
/*useIndexDiffs=*/true);
auto globalDecisionOpt = hwtranspose::getDecisionLdsTranspose();
auto twr = ThreadwiseReadIntoOp::create(b, loc, ldsViewForLoad, regs,
b.getArrayAttr({}), ValueRange{tid},
/*forceUnroll=*/forceUnroll,
/*useIndexDiffs=*/true);
// Apply the global decision if it exists and is marked as usable.
if (globalDecisionOpt && hwtranspose::isApplicable(*globalDecisionOpt)) {
hwtranspose::attachAttributes(twr, *globalDecisionOpt, b, isA);
}
}

std::pair<StageOp, bool> createOrGetStage(PatternRewriter &b, Location loc,
Expand Down Expand Up @@ -406,7 +412,7 @@ class LoweringBlockwiseLoadTileOp final
generateReadLoop(loc, b, accelEmitterPtr, tid, dName, ldsViewForGemm,
destRegisters, blockSize, copyDPerThread,
ldsLayoutConfig.doRotateWithK, forceUnroll,
directToLDS, ldsLayoutConfig.ldsLayoutDxK);
directToLDS, ldsLayoutConfig.ldsLayoutDxK, isA);
if (stageLDSReadNew)
rock::YieldOp::create(b, loc);
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_rocmlir_dialect_library(MLIRRockTransforms
FindFirstGemmIndex.cpp
RemoveOutputAlloc.cpp
BlockwiseLoadTileToThreadwise.cpp
LdsTransposeLoad.cpp
AnnotateLiveness.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
50 changes: 33 additions & 17 deletions mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"

#include "GridLayoutEmitter.h"
#include "LdsTransposeLoad.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand All @@ -54,9 +57,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"

#include "GridLayoutEmitter.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -1283,17 +1283,17 @@ struct GridwiseAttentionAccelRewritePattern
switch (outOfScopeType) {
case OutOfScopeType::KVCache:
assert(currentSeqLen != nullptr);
isInvalid = arith::CmpIOp::create(thenb,
loc, arith::CmpIPredicate::ugt, mIndex, currentSeqLen);
isInvalid = arith::CmpIOp::create(
thenb, loc, arith::CmpIPredicate::ugt, mIndex, currentSeqLen);
break;
case OutOfScopeType::Causal:
Value nIndex = lowerCoords[1];
if (constNumRepeatsGQA)
nIndex = thenb.createOrFold<arith::DivUIOp>(loc, nIndex,
constNumRepeatsGQA);

isInvalid = arith::CmpIOp::create(thenb,
loc, arith::CmpIPredicate::ugt, mIndex, nIndex);
isInvalid = arith::CmpIOp::create(
thenb, loc, arith::CmpIPredicate::ugt, mIndex, nIndex);
break;
}

Expand Down Expand Up @@ -1704,15 +1704,15 @@ struct GridwiseAttentionAccelRewritePattern
// so we need to take the minimum of currentSeqLen and maxRowOfBlock
if (effectiveSeqLen)
maxRowOfBlock = arith::MinUIOp::create(rewriter, loc, currentSeqLen,
maxRowOfBlock);
maxRowOfBlock);
effectiveSeqLen = maxRowOfBlock;
}

// compute end index
Value constGemm0MPerBlock =
rewriter.createOrFold<arith::ConstantIndexOp>(loc, gemm0MPerBlock);
Value numerator = arith::AddIOp::create(rewriter, loc, effectiveSeqLen,
constGemm0MPerBlock);
constGemm0MPerBlock);
end = rewriter.createOrFold<arith::DivUIOp>(loc, numerator,
constGemm0MPerBlock);
Value one = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
Expand All @@ -1733,12 +1733,12 @@ struct GridwiseAttentionAccelRewritePattern
rewriter.createOrFold<arith::DivUIOp>(loc, numerator, constSplitKV);

// if split-kv is enabled, we need to compute the start and end indices.
start = arith::MulIOp::create(rewriter, loc, gridCoordsGemm0.split_block,
gemm0MIterations);
Value splitPlusOne = arith::AddIOp::create(rewriter, loc,
gridCoordsGemm0.split_block, one);
start = arith::MulIOp::create(
rewriter, loc, gridCoordsGemm0.split_block, gemm0MIterations);
Value splitPlusOne = arith::AddIOp::create(
rewriter, loc, gridCoordsGemm0.split_block, one);
Value endSplitKV = arith::MulIOp::create(rewriter, loc, splitPlusOne,
gemm0MIterations);
gemm0MIterations);
end = arith::MinUIOp::create(rewriter, loc, end, endSplitKV);
}
// compute last iteration of the block, this will be used later in
Expand All @@ -1755,9 +1755,10 @@ struct GridwiseAttentionAccelRewritePattern
Value one = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
start = arith::MulIOp::create(rewriter, loc, gridCoordsGemm0.split_block,
gemm0MIterations);
Value splitPlusOne =
arith::AddIOp::create(rewriter, loc, gridCoordsGemm0.split_block, one);
end = arith::MulIOp::create(rewriter, loc, splitPlusOne, gemm0MIterations);
Value splitPlusOne = arith::AddIOp::create(
rewriter, loc, gridCoordsGemm0.split_block, one);
end =
arith::MulIOp::create(rewriter, loc, splitPlusOne, gemm0MIterations);
}
return std::make_tuple(start, end, gemm0MBlocksLastIter, currentSeqLen);
}
Expand Down Expand Up @@ -3158,6 +3159,21 @@ struct GridwiseGemmAccelRewritePattern

zeroAccBuffer(b, loc, regCAllocOp);

hwtranspose::Decision decision;
auto *mfma = dyn_cast<rock::accel::MfmaEmitter>(accelEmitterPtr.get());
// Only compute a transpose decision if both layouts are DxK disabled.
if (mfma && !ldsLayoutConfigA.ldsLayoutDxK &&
!ldsLayoutConfigB.ldsLayoutDxK) {
hwtranspose::MfmaInstrShape shape{mfma->getMfmaNonKDim(),
mfma->getMfmaK()};
decision = hwtranspose::makeDecision(
arch, elementTypeA, elementTypeB, directToLDS, shape,
hwtranspose::OperandKind::A, hwtranspose::OperandKind::B, mPerBlock,
nPerBlock, kPerBlock);
// Store the computed decision globally for later use.
hwtranspose::setDecisionLdsTranspose(decision);
}

// Emit loop.
Value nIterations = ConstantIndexOp::create(b, loc, K / kPerBlock);
Value step = ConstantIndexOp::create(b, loc, 1);
Expand Down
Loading