Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit c3ae3ae

Browse files
antiagainsttensorflower-gardener
authored andcommitted
NFC: Wire up DRR settings for SPIR-V canonicalization patterns
This CL added necessary files and settings for using DRR to write SPIR-V canonicalization patterns and also converted the patterns for spv.Bitcast and spv.LogicalNot. PiperOrigin-RevId: 282132786
1 parent c82c4e1 commit c3ae3ae

File tree

3 files changed

+55
-58
lines changed

3 files changed

+55
-58
lines changed

lib/Dialect/SPIRV/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
set(LLVM_TARGET_DEFINITIONS SPIRVCanonicalization.td)
2+
mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
3+
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
4+
15
add_llvm_library(MLIRSPIRV
26
DialectRegistration.cpp
37
LayoutUtils.cpp
@@ -11,8 +15,9 @@ add_llvm_library(MLIRSPIRV
1115
)
1216

1317
add_dependencies(MLIRSPIRV
14-
MLIRSPIRVOpsIncGen
18+
MLIRSPIRVCanonicalizationIncGen
1519
MLIRSPIRVEnumsIncGen
20+
MLIRSPIRVOpsIncGen
1621
MLIRSPIRVOpUtilsGen)
1722

1823
target_link_libraries(MLIRSPIRV
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//==- SPIRVCanonicalization.td - Canonicalization Patterns ---*- tablegen -*==//
2+
3+
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines SPIR-V canonicalization patterns.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
include "mlir/Dialect/SPIRV/SPIRVOps.td"
14+
15+
//===----------------------------------------------------------------------===//
16+
// spv.Bitcast
17+
//===----------------------------------------------------------------------===//
18+
19+
def ConvertChainedBitcast : Pat<(SPV_BitcastOp (SPV_BitcastOp $operand)),
20+
(SPV_BitcastOp $operand)>;
21+
22+
//===----------------------------------------------------------------------===//
23+
// spv.LogicalNot
24+
//===----------------------------------------------------------------------===//
25+
26+
def ConvertLogicalNotOfIEqual : Pat<
27+
(SPV_LogicalNotOp (SPV_IEqualOp $lhs, $rhs)),
28+
(SPV_INotEqualOp $lhs, $rhs)>;
29+
30+
def ConvertLogicalNotOfINotEqual : Pat<
31+
(SPV_LogicalNotOp (SPV_INotEqualOp $lhs, $rhs)),
32+
(SPV_IEqualOp $lhs, $rhs)>;
33+
34+
def ConvertLogicalNotOfLogicalEqual : Pat<
35+
(SPV_LogicalNotOp (SPV_LogicalEqualOp $lhs, $rhs)),
36+
(SPV_LogicalNotEqualOp $lhs, $rhs)>;
37+
38+
def ConvertLogicalNotOfLogicalNotEqual : Pat<
39+
(SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)),
40+
(SPV_LogicalEqualOp $lhs, $rhs)>;

lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ static inline bool isMergeBlock(Block &block) {
377377
isa<spirv::MergeOp>(block.front());
378378
}
379379

380+
//===----------------------------------------------------------------------===//
381+
// TableGen'erated canonicalizers
382+
//===----------------------------------------------------------------------===//
383+
384+
#include "SPIRVCanonicalization.inc"
385+
380386
//===----------------------------------------------------------------------===//
381387
// Common parsers and printers
382388
//===----------------------------------------------------------------------===//
@@ -771,30 +777,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
771777
return success();
772778
}
773779

774-
namespace {
775-
776-
/// Converts chained `spirv::BitcastOp` operations into one
777-
/// `spirv::BitcastOp` operation.
778-
struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
779-
using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
780-
781-
PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
782-
PatternRewriter &rewriter) const override {
783-
auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
784-
bitcastOp.operand()->getDefiningOp());
785-
786-
if (!parentBitcastOp) {
787-
return matchFailure();
788-
}
789-
790-
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
791-
/*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
792-
bitcastOp.result()->getType(), parentBitcastOp.operand());
793-
return matchSuccess();
794-
}
795-
};
796-
} // end anonymous namespace
797-
798780
void spirv::BitcastOp::getCanonicalizationPatterns(
799781
OwningRewritePatternList &results, MLIRContext *context) {
800782
results.insert<ConvertChainedBitcast>(context);
@@ -1587,41 +1569,11 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
15871569
// spv.LogicalNot
15881570
//===----------------------------------------------------------------------===//
15891571

1590-
namespace {
1591-
1592-
/// Converts `spirv::LogicalNotOp` to the given `NewOp` using the first and the
1593-
/// second operands from the given `ParentOp`.
1594-
template <typename NewOp, typename ParentOp>
1595-
struct ConvertLogicalNotOp : public OpRewritePattern<spirv::LogicalNotOp> {
1596-
using OpRewritePattern<spirv::LogicalNotOp>::OpRewritePattern;
1597-
1598-
PatternMatchResult matchAndRewrite(spirv::LogicalNotOp logicalNotOp,
1599-
PatternRewriter &rewriter) const override {
1600-
auto parentOp =
1601-
dyn_cast_or_null<ParentOp>(logicalNotOp.operand()->getDefiningOp());
1602-
1603-
if (!parentOp) {
1604-
return this->matchFailure();
1605-
}
1606-
1607-
rewriter.replaceOpWithNewOp<NewOp>(
1608-
/*valuesToRemoveIfDead=*/{parentOp.result()}, logicalNotOp,
1609-
logicalNotOp.result()->getType(), parentOp.operand1(),
1610-
parentOp.operand2());
1611-
1612-
return this->matchSuccess();
1613-
}
1614-
};
1615-
} // end anonymous namespace
1616-
16171572
void spirv::LogicalNotOp::getCanonicalizationPatterns(
16181573
OwningRewritePatternList &results, MLIRContext *context) {
1619-
results.insert<
1620-
ConvertLogicalNotOp<spirv::INotEqualOp, spirv::IEqualOp>,
1621-
ConvertLogicalNotOp<spirv::IEqualOp, spirv::INotEqualOp>,
1622-
ConvertLogicalNotOp<spirv::LogicalNotEqualOp, spirv::LogicalEqualOp>,
1623-
ConvertLogicalNotOp<spirv::LogicalEqualOp, spirv::LogicalNotEqualOp>>(
1624-
context);
1574+
results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
1575+
ConvertLogicalNotOfLogicalEqual,
1576+
ConvertLogicalNotOfLogicalNotEqual>(context);
16251577
}
16261578

16271579
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)