@@ -974,34 +974,53 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) {
974974
975975 // loop_ids only contains at IDs between logical->loop
976976 VectorOfUniqueEntries<IterDomain*> loop_ids;
977- for (TensorView* tv : tvs_) {
978- loop_ids.pushBack (tv->getLogicalDomain ());
979- loop_ids.pushBack (tv->getLoopDomain ());
980- // TODO: put this into TensorDomain instead like TensorDomain::allIDs()
981- auto path =
982- getExprsBetween<IRBFS>(
983- {tv->getLogicalDomain ().begin (), tv->getLogicalDomain ().end ()},
984- {tv->getLoopDomain ().begin (), tv->getLoopDomain ().end ()},
985- false )
986- .first ;
987- for (auto [expr, _] : path) {
988- loop_ids.pushBack (ir_utils::filterByType<IterDomain>(expr->outputs ()));
989- loop_ids.pushBack (ir_utils::filterByType<IterDomain>(expr->inputs ()));
990- }
991977
992- if (tv->hasRoot ()) {
993- loop_ids.pushBack (tv->getRootDomain ());
994- auto path =
995- getExprsBetween<IRBFS>(
996- {tv->getRootDomain ().begin (), tv->getRootDomain ().end ()},
997- {tv->getLogicalDomain ().begin (), tv->getLogicalDomain ().end ()},
998- false )
999- .first ;
1000- for (auto [expr, _] : path) {
1001- loop_ids.pushBack (ir_utils::filterByType<IterDomain>(expr->outputs ()));
1002- loop_ids.pushBack (ir_utils::filterByType<IterDomain>(expr->inputs ()));
1003- }
1004- }
978+ auto all_ids_except_allocation = [&loop_ids](TensorView* tv) {
979+ std::vector<const std::vector<IterDomain*>*> all_domains = {
980+ &tv->getLoopDomain (),
981+ &tv->getLogicalDomain (),
982+ &tv->getInitialLoopDomain (),
983+ &tv->domain ()->additionalIDs ()};
984+ if (tv->hasRoot ()) {
985+ all_domains.push_back (&tv->getRootDomain ());
986+ }
987+ if (tv->getAlternateLoopDomain ().has_value ()) {
988+ all_domains.push_back (&tv->getAlternateLoopDomain ().value ());
989+ }
990+
991+ for (auto domain : all_domains) {
992+ loop_ids.pushBack (*domain);
993+ }
994+
995+ // We only care about IDs on the shortest path between domains
996+ std::unordered_multimap<IterDomain*, IterDomain*> out2in;
997+ for (auto i : arange (all_domains.size () - 1 )) {
998+ if (all_domains[i]->empty ()) {
999+ continue ;
1000+ }
1001+ for (auto j : arange (i + 1 , all_domains.size ())) {
1002+ if (all_domains[j]->empty ()) {
1003+ continue ;
1004+ }
1005+ auto path = getExprsBetween<IRBFS>(
1006+ {all_domains[i]->begin (), all_domains[i]->end ()},
1007+ {all_domains[j]->begin (), all_domains[j]->end ()},
1008+ false )
1009+ .first ;
1010+ for (auto [expr, _] : path) {
1011+ loop_ids.pushBack (
1012+ ir_utils::filterByType<IterDomain>(expr->outputs ()));
1013+ loop_ids.pushBack (
1014+ ir_utils::filterByType<IterDomain>(expr->inputs ()));
1015+ }
1016+ }
1017+ }
1018+ return loop_ids.vector ();
1019+ };
1020+
1021+
1022+ for (TensorView* tv : tvs_) {
1023+ all_ids_except_allocation (tv);
10051024 }
10061025 std::vector<IterDomain*> all_ids = loop_ids.vector ();
10071026
0 commit comments