Skip to content

Commit f6f8905

Browse files
authored
[torch][quant] Quantized torch.mm for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing dequantization into dense operations. As a validation we compare to the dequant+mm torch implementation.
1 parent 60bf6c2 commit f6f8905

File tree

13 files changed

+577
-8
lines changed

13 files changed

+577
-8
lines changed

externals/llvm-project

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
106106

107107
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
108108

109+
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
110+
std::unique_ptr<OperationPass<func::FuncOp>>
111+
createMatchQuantizedCustomOpsPass();
112+
109113
std::unique_ptr<OperationPass<ModuleOp>>
110114
createReifyShapeCalculationsPass(StringRef extraLibrary);
111115

include/torch-mlir/Dialect/Torch/Transforms/Passes.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
258258
}];
259259
}
260260

261+
def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> {
262+
let summary = "QDQ: Fuse recognized QDQ op sequences.";
263+
let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()";
264+
let description = [{
265+
Torch models often represents quantized operations as the sequence:
266+
Dequantize
267+
DenseOp
268+
Quantize
269+
This allows the existing dense operations to be used without specifically
270+
representing quantized types. It is more computationally efficient to
271+
perform the dense operation in the quantized domain, so we fuse the
272+
quantization / dequantization behavior together and represent as purely
273+
quantized operations.
274+
}];
275+
}
276+
277+
def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> {
278+
let summary = "Match quantized operations that occur in different namespace.";
279+
let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()";
280+
let description = [{
281+
Torch quantization utilities generated custom op versions of known aten
282+
quantziation operations. We can match these specially named operations and
283+
rewrite to the corresponding aten quantized operations.
284+
285+
We handle this post import to maintain a simplified import process.
286+
}];
287+
}
288+
261289
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
262290
let summary = "Reify shape calculations.";
263291
let constructor = [{

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ using namespace mlir::torch;
2929
using namespace mlir::torch::Torch;
3030

3131
namespace {
32+
33+
static void getZeroPoint(Value value, Value &zeropoint) {
34+
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
35+
zeropoint = make.getZeroPoint();
36+
}
37+
}
38+
3239
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
3340
public:
3441
using OpConversionPattern::OpConversionPattern;
@@ -64,11 +71,27 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
6471
op.getSelf().getType().cast<ValueTensorType>();
6572
ValueTensorType rhsTorchType =
6673
op.getMat2().getType().cast<ValueTensorType>();
74+
75+
Value lhsZeroPoint, rhsZeroPoint;
76+
getZeroPoint(op.getSelf(), lhsZeroPoint);
77+
getZeroPoint(op.getMat2(), rhsZeroPoint);
78+
79+
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
80+
return rewriter.notifyMatchFailure(
81+
op, "unsupported: aten.mm with mixed quantization");
82+
}
83+
6784
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
6885
return rewriter.notifyMatchFailure(
6986
op, "unsupported: aten.mm with different input element types");
7087
}
7188

89+
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
90+
if (lhsZeroPoint && isUnsigned) {
91+
return rewriter.notifyMatchFailure(
92+
op, "unsupported: unsigned quantized matmul not supported");
93+
}
94+
7295
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
7396
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
7497

@@ -89,8 +112,26 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
89112
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
90113

91114
Value matmul;
92-
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
93-
if (intType && intType.isUnsigned()) {
115+
if (lhsZeroPoint && !isUnsigned) {
116+
lhsZeroPoint = typeConverter->materializeTargetConversion(
117+
rewriter, loc,
118+
getTypeConverter()->convertType(lhsZeroPoint.getType()),
119+
lhsZeroPoint);
120+
rhsZeroPoint = typeConverter->materializeTargetConversion(
121+
rewriter, loc,
122+
getTypeConverter()->convertType(rhsZeroPoint.getType()),
123+
rhsZeroPoint);
124+
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
125+
loc, rewriter.getI32Type(), lhsZeroPoint);
126+
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
127+
loc, rewriter.getI32Type(), rhsZeroPoint);
128+
matmul =
129+
rewriter
130+
.create<linalg::QuantizedMatmulOp>(
131+
loc, zeroFill.getType(),
132+
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
133+
.getResult(0);
134+
} else if (isUnsigned) {
94135
matmul = rewriter
95136
.create<linalg::MatmulUnsignedOp>(
96137
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)

lib/Dialect/Torch/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses
33
DecomposeComplexOps.cpp
44
DropAbstractInterpCalculations.cpp
55
EraseModuleInitializer.cpp
6+
FuseQuantizedOps.cpp
67
Passes.cpp
78
GlobalizeObjectGraph.cpp
89
InlineGlobalSlots.cpp
910
LowerToBackendContract.cpp
11+
MatchQuantizedOps.cpp
1012
MaximizeValueSemantics.cpp
1113
PrepareForGlobalizeObjectGraph.cpp
1214
RecomposeComplexOps.cpp
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "PassDetail.h"
11+
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
14+
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
15+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::torch;
19+
using namespace mlir::torch::Torch;
20+
21+
namespace {
22+
23+
template <typename SrcOp>
24+
class QuantizeOperands : public OpRewritePattern<SrcOp> {
25+
public:
26+
using OpRewritePattern<SrcOp>::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(SrcOp op,
29+
PatternRewriter &rewriter) const override {
30+
llvm::SmallVector<Value> operands(op->getOperands());
31+
32+
bool dequanted = false;
33+
for (auto &operand : operands) {
34+
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
35+
operand = dequant.getOperand();
36+
dequanted = true;
37+
}
38+
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
39+
operand = dequant.getOperand();
40+
dequanted = true;
41+
}
42+
}
43+
44+
if (!dequanted) {
45+
return rewriter.notifyMatchFailure(op, "no dequantizations found");
46+
}
47+
48+
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
49+
return success();
50+
}
51+
};
52+
53+
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
54+
public:
55+
using OpRewritePattern<SrcOp>::OpRewritePattern;
56+
57+
LogicalResult matchAndRewrite(SrcOp op,
58+
PatternRewriter &rewriter) const override {
59+
llvm::SmallVector<Value> operands(op->getOperands());
60+
if (operands.size() < 3)
61+
return failure();
62+
63+
Value bias = operands[2];
64+
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
65+
return failure();
66+
67+
Value lhsScale;
68+
if (auto qLhs =
69+
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
70+
lhsScale = qLhs.getScale();
71+
72+
Value rhsScale;
73+
if (auto qRhs =
74+
operands[1].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
75+
rhsScale = qRhs.getScale();
76+
77+
if (!rhsScale || !lhsScale)
78+
return failure();
79+
80+
auto biasTy = bias.getType().cast<ValueTensorType>();
81+
auto biasETy = biasTy.getOptionalDtype();
82+
if (!biasETy || !isa<mlir::FloatType>(biasETy))
83+
return failure();
84+
85+
Value biasScale = rewriter.create<AtenMulFloatOp>(
86+
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
87+
88+
Value zero = rewriter.create<Torch::ConstantIntOp>(
89+
op.getLoc(), rewriter.getType<Torch::IntType>(),
90+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
91+
92+
auto qi32Ty = rewriter.getType<QInt32Type>();
93+
auto newBiasTy =
94+
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
95+
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
96+
bias = rewriter.create<AtenQuantizePerTensorOp>(
97+
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
98+
99+
operands[2] = bias;
100+
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
101+
return success();
102+
}
103+
};
104+
105+
template <typename SrcOp>
106+
class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
107+
public:
108+
using OpRewritePattern<SrcOp>::OpRewritePattern;
109+
110+
LogicalResult matchAndRewrite(SrcOp op,
111+
PatternRewriter &rewriter) const override {
112+
auto lhs = op.getOperand(0);
113+
auto rhs = op.getOperand(1);
114+
115+
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
116+
if (!resultTy || !resultTy.hasDtype())
117+
return failure();
118+
119+
Type resultETy = resultTy.getDtype();
120+
if (!resultETy.isa<mlir::FloatType>())
121+
return failure();
122+
123+
Value lhsScale;
124+
if (auto defining =
125+
lhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
126+
lhsScale = defining.getScale();
127+
}
128+
129+
Value rhsScale;
130+
if (auto defining =
131+
rhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
132+
rhsScale = defining.getScale();
133+
}
134+
135+
if (!lhsScale || !rhsScale)
136+
return failure();
137+
138+
// Quantize the bias input to the expected result:
139+
Value zero = rewriter.create<Torch::ConstantIntOp>(
140+
op.getLoc(), rewriter.getType<Torch::IntType>(),
141+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
142+
143+
auto qi32Ty = rewriter.getType<QInt32Type>();
144+
Value biasScale = rewriter.create<AtenMulFloatOp>(
145+
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
146+
147+
// Update the quantied type:
148+
llvm::SmallVector<Value> operands(op.getOperands());
149+
150+
auto newResultTy =
151+
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
152+
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
153+
154+
// Attach the quantize information to the resulting quint32:
155+
auto intReprTy = rewriter.getType<ValueTensorType>(
156+
resultTy.getOptionalSizes(),
157+
rewriter.getIntegerType(32, IntegerType::Signed));
158+
auto intRepr = rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, conv);
159+
160+
auto quantTy =
161+
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
162+
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
163+
op.getLoc(), quantTy, intRepr, biasScale, zero);
164+
auto dequant =
165+
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
166+
rewriter.replaceOp(op, dequant);
167+
168+
return success();
169+
}
170+
};
171+
172+
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
173+
public:
174+
using OpRewritePattern<SrcOp>::OpRewritePattern;
175+
176+
LogicalResult matchAndRewrite(SrcOp op,
177+
PatternRewriter &rewriter) const override {
178+
auto result = op.getResult();
179+
if (result.use_empty()) {
180+
op.erase();
181+
return success();
182+
}
183+
return failure();
184+
}
185+
};
186+
187+
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
188+
public:
189+
void runOnOperation() override {
190+
MLIRContext *context = &getContext();
191+
RewritePatternSet patterns(context);
192+
patterns
193+
.insert<RemoveUnused<AtenDequantizeSelfOp>,
194+
RemoveUnused<AtenDequantizeTensorOp>,
195+
RemoveUnused<AtenQuantizePerTensorOp>,
196+
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
197+
QuantizeAccumulator<AtenConvolutionOp>,
198+
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
199+
context);
200+
201+
GreedyRewriteConfig config;
202+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
203+
config))) {
204+
return signalPassFailure();
205+
}
206+
}
207+
};
208+
209+
} // namespace
210+
211+
std::unique_ptr<OperationPass<func::FuncOp>>
212+
mlir::torch::Torch::createFuseQuantizedOpsPass() {
213+
return std::make_unique<FuseQuantizedOpsPass>();
214+
}

0 commit comments

Comments
 (0)