Skip to content

Commit 5f89e14

Browse files
committed
wip
1 parent c507644 commit 5f89e14

File tree

1 file changed

+46
-27
lines changed

1 file changed

+46
-27
lines changed

csrc/id_model/id_model.cpp

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -974,34 +974,53 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) {
974974

975975
// loop_ids only contains at IDs between logical->loop
976976
VectorOfUniqueEntries<IterDomain*> loop_ids;
977-
for (TensorView* tv : tvs_) {
978-
loop_ids.pushBack(tv->getLogicalDomain());
979-
loop_ids.pushBack(tv->getLoopDomain());
980-
// TODO: put this into TensorDomain instead like TensorDomain::allIDs()
981-
auto path =
982-
getExprsBetween<IRBFS>(
983-
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()},
984-
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()},
985-
false)
986-
.first;
987-
for (auto [expr, _] : path) {
988-
loop_ids.pushBack(ir_utils::filterByType<IterDomain>(expr->outputs()));
989-
loop_ids.pushBack(ir_utils::filterByType<IterDomain>(expr->inputs()));
990-
}
991977

992-
if (tv->hasRoot()) {
993-
loop_ids.pushBack(tv->getRootDomain());
994-
auto path =
995-
getExprsBetween<IRBFS>(
996-
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
997-
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()},
998-
false)
999-
.first;
1000-
for (auto [expr, _] : path) {
1001-
loop_ids.pushBack(ir_utils::filterByType<IterDomain>(expr->outputs()));
1002-
loop_ids.pushBack(ir_utils::filterByType<IterDomain>(expr->inputs()));
1003-
}
1004-
}
978+
auto all_ids_except_allocation = [&loop_ids](TensorView* tv) {
979+
std::vector<const std::vector<IterDomain*>*> all_domains = {
980+
&tv->getLoopDomain(),
981+
&tv->getLogicalDomain(),
982+
&tv->getInitialLoopDomain(),
983+
&tv->domain()->additionalIDs()};
984+
if (tv->hasRoot()) {
985+
all_domains.push_back(&tv->getRootDomain());
986+
}
987+
if (tv->getAlternateLoopDomain().has_value()) {
988+
all_domains.push_back(&tv->getAlternateLoopDomain().value());
989+
}
990+
991+
for (auto domain : all_domains) {
992+
loop_ids.pushBack(*domain);
993+
}
994+
995+
// We only care about IDs on the shortest path between domains
996+
std::unordered_multimap<IterDomain*, IterDomain*> out2in;
997+
for (auto i : arange(all_domains.size() - 1)) {
998+
if (all_domains[i]->empty()) {
999+
continue;
1000+
}
1001+
for (auto j : arange(i + 1, all_domains.size())) {
1002+
if (all_domains[j]->empty()) {
1003+
continue;
1004+
}
1005+
auto path = getExprsBetween<IRBFS>(
1006+
{all_domains[i]->begin(), all_domains[i]->end()},
1007+
{all_domains[j]->begin(), all_domains[j]->end()},
1008+
false)
1009+
.first;
1010+
for (auto [expr, _] : path) {
1011+
loop_ids.pushBack(
1012+
ir_utils::filterByType<IterDomain>(expr->outputs()));
1013+
loop_ids.pushBack(
1014+
ir_utils::filterByType<IterDomain>(expr->inputs()));
1015+
}
1016+
}
1017+
}
1018+
return loop_ids.vector();
1019+
};
1020+
1021+
1022+
for (TensorView* tv : tvs_) {
1023+
all_ids_except_allocation(tv);
10051024
}
10061025
std::vector<IterDomain*> all_ids = loop_ids.vector();
10071026

0 commit comments

Comments
 (0)