Skip to content

Commit 0592834

Browse files
authored
[Feat] Add A Pass to Handle Negative Index (#1192)
1 parent 777881e commit 0592834

File tree

4 files changed

+233
-0
lines changed

4 files changed

+233
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*!
2+
* \file legalize_negative_index.cc
3+
* \brief Legalize negative indices in buffer load expressions.
4+
*/
5+
6+
#include <tvm/ffi/reflection/registry.h>
7+
#include <tvm/runtime/logging.h>
8+
#include <tvm/tir/stmt_functor.h>
9+
#include <tvm/tir/transform.h>
10+
11+
#include <unordered_map>
12+
#include <vector>
13+
14+
#include "arith/ir_mutator_with_analyzer.h"
15+
#include "arith/ir_visitor_with_analyzer.h"
16+
17+
namespace tvm {
18+
namespace tl {
19+
20+
using namespace tir;
21+
using arith::IRVisitorWithAnalyzer;
22+
23+
enum class IndexSignState { kNonNegative, kNegative, kUnknown };
24+
25+
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
26+
public:
27+
explicit NegativeIndexAnalyzer(
28+
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
29+
*result)
30+
: result_(result) {}
31+
32+
void VisitExpr_(const BufferLoadNode *op) final {
33+
auto load = tvm::ffi::GetRef<BufferLoad>(op);
34+
std::vector<IndexSignState> states;
35+
states.reserve(op->indices.size());
36+
bool needs_record = false;
37+
38+
for (size_t i = 0; i < op->indices.size(); ++i) {
39+
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
40+
if (analyzer_.CanProve(simplified >= 0)) {
41+
states.push_back(IndexSignState::kNonNegative);
42+
continue;
43+
}
44+
45+
if (analyzer_.CanProve(simplified < 0)) {
46+
states.push_back(IndexSignState::kNegative);
47+
needs_record = true;
48+
continue;
49+
}
50+
51+
states.push_back(IndexSignState::kUnknown);
52+
needs_record = true;
53+
LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
54+
<< simplified << " for buffer " << load->buffer->name
55+
<< " (axis " << i << ").";
56+
}
57+
58+
if (needs_record) {
59+
(*result_)[op] = std::move(states);
60+
}
61+
62+
IRVisitorWithAnalyzer::VisitExpr_(op);
63+
}
64+
65+
private:
66+
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
67+
*result_;
68+
};
69+
70+
class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
71+
public:
72+
static PrimFunc
73+
Apply(PrimFunc func,
74+
const std::unordered_map<const BufferLoadNode *,
75+
std::vector<IndexSignState>> &states) {
76+
arith::Analyzer analyzer;
77+
NegativeIndexRewriter rewriter(&analyzer, states);
78+
if (!func->body.defined()) {
79+
return func;
80+
}
81+
PrimFuncNode *func_node = func.CopyOnWrite();
82+
func_node->body = rewriter.VisitStmt(func_node->body);
83+
return func;
84+
}
85+
86+
private:
87+
NegativeIndexRewriter(
88+
arith::Analyzer *analyzer,
89+
const std::unordered_map<const BufferLoadNode *,
90+
std::vector<IndexSignState>> &states)
91+
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}
92+
93+
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
94+
BufferLoad load =
95+
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
96+
97+
auto it = states_.find(op);
98+
if (it == states_.end()) {
99+
return load;
100+
}
101+
102+
auto indices = load->indices;
103+
bool changed = false;
104+
105+
const auto &state_vector = it->second;
106+
ICHECK_EQ(state_vector.size(), indices.size())
107+
<< "State vector size mismatch for buffer load " << load->buffer->name;
108+
109+
for (size_t i = 0; i < indices.size(); ++i) {
110+
if (state_vector[i] != IndexSignState::kNegative) {
111+
continue;
112+
}
113+
PrimExpr extent = load->buffer->shape[i];
114+
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
115+
changed = true;
116+
}
117+
118+
if (!changed) {
119+
return load;
120+
}
121+
122+
return BufferLoad(load->buffer, indices);
123+
}
124+
125+
const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
126+
&states_;
127+
};
128+
129+
PrimFunc LegalizeNegativeIndex(PrimFunc func) {
130+
if (!func->body.defined()) {
131+
return func;
132+
}
133+
134+
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
135+
states;
136+
NegativeIndexAnalyzer analyzer(&states);
137+
analyzer(func->body);
138+
if (states.empty()) {
139+
return func;
140+
}
141+
142+
return NegativeIndexRewriter::Apply(std::move(func), states);
143+
}
144+
145+
tvm::transform::Pass LegalizeNegativeIndexPass() {
146+
using namespace tir::transform;
147+
auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
148+
return LegalizeNegativeIndex(std::move(f));
149+
};
150+
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
151+
}
152+
153+
TVM_FFI_STATIC_INIT_BLOCK() {
154+
namespace refl = tvm::ffi::reflection;
155+
refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
156+
LegalizeNegativeIndexPass);
157+
}
158+
159+
} // namespace tl
160+
} // namespace tvm
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from tilelang import tvm
2+
import tilelang as tl
3+
import tilelang.testing
4+
from tvm.script import tir as T
5+
6+
7+
@T.prim_func
8+
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
9+
T.func_attr({"tir.noalias": True})
10+
B[0] = A[T.int32(-1)]
11+
12+
13+
@T.prim_func
14+
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
15+
T.func_attr({"tir.noalias": True})
16+
B[0] = A[T.int32(15)]
17+
18+
19+
@T.prim_func
20+
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
21+
T.func_attr({"tir.noalias": True})
22+
for i in T.serial(4):
23+
B[i] = A[-i - 1]
24+
25+
26+
@T.prim_func
27+
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
28+
T.func_attr({"tir.noalias": True})
29+
for i in T.serial(4):
30+
B[i] = A[15 - i]
31+
32+
33+
@T.prim_func
34+
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
35+
B: T.Buffer((16,), "float32")):
36+
T.func_attr({"tir.noalias": True})
37+
for i in T.serial(16):
38+
B[i] = A[shift + i]
39+
40+
41+
def test_legalize_negative_index_scalar():
42+
mod = tvm.IRModule({"main": negative_index_before})
43+
transformed = tl.transform.LegalizeNegativeIndex()(mod)
44+
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)
45+
46+
47+
def test_legalize_negative_index_affine_expr():
48+
mod = tvm.IRModule({"main": negative_index_loop_before})
49+
transformed = tl.transform.LegalizeNegativeIndex()(mod)
50+
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)
51+
52+
53+
def test_legalize_negative_index_symbolic_passthrough():
54+
mod = tvm.IRModule({"main": negative_index_symbolic_before})
55+
transformed = tl.transform.LegalizeNegativeIndex()(mod)
56+
tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)
57+
58+
59+
if __name__ == "__main__":
60+
tilelang.testing.main()

tilelang/engine/phase.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
9696
mod = tilelang.transform.LetInline()(mod)
9797
# Add wrapper for single buf store
9898
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
99+
# Normalize negative indices to canonical non-negative form
100+
mod = tilelang.transform.LegalizeNegativeIndex()(mod)
99101
# Inject assumes to speedup tvm prover
100102
mod = tilelang.transform.InjectAssumes()(mod)
101103
# Simplify the IR expressions

tilelang/transform/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ def FrontendLegalize():
8080
return _ffi_api.FrontendLegalize() # type: ignore
8181

8282

83+
def LegalizeNegativeIndex():
84+
"""Legalize negative indices in buffer loads.
85+
86+
Returns
87+
-------
88+
fpass : tvm.transform.Pass
89+
The result pass
90+
"""
91+
return _ffi_api.LegalizeNegativeIndex() # type: ignore
92+
93+
8394
def InjectAssumes():
8495
"""Inject Assumes
8596

0 commit comments

Comments
 (0)