Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 288ac1b

Browse files
River707tensorflower-gardener
authored andcommitted
Add Ch.5 of the toy tutorial.
This chapter adds a partial lowering of toy operations, all but PrintOp, to a combination of the Affine and Std dialects. This chapter focuses on introducing the conversion framework, the benefits of partial lowering, and how easily dialects may co-exist in the IR. PiperOrigin-RevId: 275150649
1 parent 48f7ec8 commit 288ac1b

36 files changed

+1956
-2681
lines changed

examples/toy/Ch3/mlir/Dialect.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,27 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
5555
ConstantOp::build(builder, state, dataType, dataAttribute);
5656
}
5757

58-
/// Verifier for constant operation.
58+
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
59+
/// in the op definition.
5960
static mlir::LogicalResult verify(ConstantOp op) {
6061
// If the return type of the constant is not an unranked tensor, the shape
6162
// must match the shape of the attribute holding the data.
6263
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
6364
if (!resultType)
6465
return success();
6566

67+
// Check that the rank of the attribute type matches the rank of the constant
68+
// result type.
6669
auto attrType = op.value().getType().cast<mlir::TensorType>();
6770
if (attrType.getRank() != resultType.getRank()) {
6871
return op.emitOpError(
6972
"return type must match the one of the attached value "
7073
"attribute: ")
7174
<< attrType.getRank() << " != " << resultType.getRank();
7275
}
73-
for (int dim = 0; dim < attrType.getRank(); ++dim) {
76+
77+
// Check that each of the dimensions match between the two types.
78+
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
7479
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
7580
return op.emitOpError(
7681
"return type shape mismatches its attribute at dimension ")

examples/toy/Ch4/mlir/Dialect.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,18 @@ static mlir::LogicalResult verify(ConstantOp op) {
118118
if (!resultType)
119119
return success();
120120

121+
// Check that the rank of the attribute type matches the rank of the constant
122+
// result type.
121123
auto attrType = op.value().getType().cast<mlir::TensorType>();
122124
if (attrType.getRank() != resultType.getRank()) {
123125
return op.emitOpError(
124126
"return type must match the one of the attached value "
125127
"attribute: ")
126128
<< attrType.getRank() << " != " << resultType.getRank();
127129
}
128-
for (int dim = 0; dim < attrType.getRank(); ++dim) {
130+
131+
// Check that each of the dimensions match between the two types.
132+
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
129133
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
130134
return op.emitOpError(
131135
"return type shape mismatches its attribute at dimension ")

examples/toy/Ch4/mlir/MLIRGen.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class MLIRGenImpl {
7979
// the structural properties of the IR and invoke any specific verifiers we
8080
// have on the Toy operations.
8181
if (failed(mlir::verify(theModule))) {
82-
theModule.emitError("Module verification error");
82+
theModule.emitError("module verification error");
8383
return nullptr;
8484
}
8585

@@ -229,7 +229,7 @@ class MLIRGenImpl {
229229
if (auto *variable = symbolTable.lookup(expr.getName()))
230230
return variable;
231231

232-
emitError(loc(expr.loc()), "Error: unknown variable '")
232+
emitError(loc(expr.loc()), "error: unknown variable '")
233233
<< expr.getName() << "'";
234234
return nullptr;
235235
}
@@ -289,7 +289,8 @@ class MLIRGenImpl {
289289
auto dataAttribute =
290290
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
291291

292-
// Build the MLIR op `toy.constant`.
292+
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
293+
// method.
293294
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
294295
}
295296

@@ -389,7 +390,7 @@ class MLIRGenImpl {
389390
auto init = vardecl.getInitVal();
390391
if (!init) {
391392
emitError(loc(vardecl.loc()),
392-
"Missing initializer in variable declaration");
393+
"missing initializer in variable declaration");
393394
return nullptr;
394395
}
395396

examples/toy/Ch5/CMakeLists.txt

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
1+
add_subdirectory(include)
2+
13
set(LLVM_LINK_COMPONENTS
24
Support
35
)
46

7+
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
8+
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
9+
add_public_tablegen_target(ToyCh5CombineIncGen)
10+
511
add_toy_chapter(toyc-ch5
612
toyc.cpp
713
parser/AST.cpp
8-
mlir/EarlyLowering.cpp
9-
mlir/LateLowering.cpp
1014
mlir/MLIRGen.cpp
15+
mlir/Dialect.cpp
16+
mlir/DeadFunctionEliminationPass.cpp
17+
mlir/LowerToAffineLoops.cpp
1118
mlir/ShapeInferencePass.cpp
12-
mlir/ToyDialect.cpp
1319
mlir/ToyCombine.cpp
1420
)
21+
22+
add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen)
23+
add_dependencies(toyc-ch5 ToyCh5OpsIncGen)
24+
add_dependencies(toyc-ch5 ToyCh5CombineIncGen)
25+
add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen)
1526
include_directories(include/)
16-
include_directories(../../Linalg/Linalg1/include/)
17-
include_directories(../../Linalg/Linalg2/include/)
18-
include_directories(../../Linalg/Linalg3/include/)
27+
include_directories(${CMAKE_CURRENT_BINARY_DIR})
28+
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
1929
target_link_libraries(toyc-ch5
2030
PRIVATE
21-
Linalg3DialectConstruction
22-
Linalg3
23-
Linalg2
24-
Linalg1
31+
MLIRAffineOps
2532
MLIRAnalysis
26-
MLIREDSC
27-
MLIRExecutionEngine
2833
MLIRIR
29-
MLIRLLVMIR
3034
MLIRParser
3135
MLIRPass
32-
MLIRTargetLLVMIR
33-
MLIRTransforms
34-
MLIRSupport
35-
)
36+
MLIRStandardOps
37+
MLIRTransforms)
38+
3639
whole_archive_link(toyc-ch5
3740
MLIRAffineOps
3841
MLIRStandardOps
39-
)
40-
42+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(toy)

examples/toy/Ch5/include/toy/AST.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333

3434
namespace toy {
3535

36-
/// A variable
36+
/// A variable type with shape information.
3737
struct VarType {
38-
enum { TY_FLOAT, TY_INT } elt_ty;
39-
std::vector<int> shape;
38+
std::vector<int64_t> shape;
4039
};
4140

4241
/// Base class for all expression nodes.
@@ -50,9 +49,7 @@ class ExprAST {
5049
Expr_Var,
5150
Expr_BinOp,
5251
Expr_Call,
53-
Expr_Print, // builtin
54-
Expr_If,
55-
Expr_For,
52+
Expr_Print,
5653
};
5754

5855
ExprAST(ExprASTKind kind, Location location)
@@ -85,7 +82,7 @@ class NumberExprAST : public ExprAST {
8582
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
8683
};
8784

88-
///
85+
/// Expression class for a literal value.
8986
class LiteralExprAST : public ExprAST {
9087
std::vector<std::unique_ptr<ExprAST>> values;
9188
std::vector<int64_t> dims;
@@ -116,7 +113,7 @@ class VariableExprAST : public ExprAST {
116113
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
117114
};
118115

119-
///
116+
/// Expression class for defining a variable.
120117
class VarDeclExprAST : public ExprAST {
121118
std::string name;
122119
VarType type;
@@ -136,7 +133,7 @@ class VarDeclExprAST : public ExprAST {
136133
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
137134
};
138135

139-
///
136+
/// Expression class for a return operator.
140137
class ReturnExprAST : public ExprAST {
141138
llvm::Optional<std::unique_ptr<ExprAST>> expr;
142139

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
set(LLVM_TARGET_DEFINITIONS Ops.td)
2+
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
3+
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
4+
add_public_tablegen_target(ToyCh5OpsIncGen)
5+
6+
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
7+
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
8+
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
9+
add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen)

0 commit comments

Comments
 (0)