-
Notifications
You must be signed in to change notification settings - Fork 295
Support gds var length label pruning #4868
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| #include "binder/query/query_graph_label_analyzer.h" | ||
|
|
||
| #include "catalog/catalog.h" | ||
| #include "catalog/catalog_entry/node_table_catalog_entry.h" | ||
| #include "catalog/catalog_entry/rel_table_catalog_entry.h" | ||
| #include "common/exception/binder.h" | ||
| #include "common/string_format.h" | ||
|
|
@@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& | |
| if (queryRel->isRecursive()) { | ||
| continue; | ||
| } | ||
| common::table_id_set_t candidates; | ||
| table_id_set_t candidates; | ||
| std::unordered_set<std::string> candidateNamesSet; | ||
| auto isSrcConnect = *queryRel->getSrcNode() == node; | ||
| auto isDstConnect = *queryRel->getDstNode() == node; | ||
|
|
@@ -94,49 +95,305 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& | |
| } | ||
| } | ||
|
|
||
| void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { | ||
| if (rel.isRecursive()) { | ||
| return; | ||
| } | ||
| std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel( | ||
| const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet, | ||
| const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const { | ||
|
|
||
| auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { | ||
| return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID); | ||
| }; | ||
| auto backwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { | ||
| return dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID); | ||
| }; | ||
| std::vector<TableCatalogEntry*> prunedEntries; | ||
| if (rel.getDirectionType() == RelDirectionType::BOTH) { | ||
| table_id_set_t srcBoundTableIDSet; | ||
| table_id_set_t dstBoundTableIDSet; | ||
| for (auto entry : rel.getSrcNode()->getEntries()) { | ||
| srcBoundTableIDSet.insert(entry->getTableID()); | ||
| for (auto& entry : relEntries) { | ||
| auto& relEntry = entry->constCast<RelTableCatalogEntry>(); | ||
| auto srcTableID = relEntry.getSrcTableID(); | ||
| auto dstTableID = relEntry.getDstTableID(); | ||
| auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID); | ||
| if (directionType == RelDirectionType::BOTH) { | ||
| if (satisfyForwardPruning || backwardPruningFunc(srcTableID, dstTableID)) { | ||
| prunedEntries.push_back(entry); | ||
| } | ||
| } else { | ||
| if (satisfyForwardPruning) { | ||
| prunedEntries.push_back(entry); | ||
| } | ||
| } | ||
| } | ||
| return prunedEntries; | ||
| } | ||
|
|
||
| table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction, | ||
| std::vector<TableCatalogEntry*> relEntries) const { | ||
| table_id_set_t nodeIDs; | ||
| for (const auto& entry : relEntries) { | ||
| const auto& relEntry = entry->constCast<RelTableCatalogEntry>(); | ||
| if (direction == RelDataDirection::FWD) { | ||
| nodeIDs.insert(relEntry.getDstTableID()); | ||
| } else if (direction == RelDataDirection::BWD) { | ||
| nodeIDs.insert(relEntry.getSrcTableID()); | ||
| } else { | ||
| KU_UNREACHABLE; | ||
| } | ||
| for (auto entry : rel.getDstNode()->getEntries()) { | ||
| dstBoundTableIDSet.insert(entry->getTableID()); | ||
| } | ||
| return nodeIDs; | ||
| } | ||
|
|
||
| std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>> | ||
| QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries, | ||
| const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound, | ||
| size_t upperBound, RelDirectionType relDirectionType) const { | ||
| // src-->[dst,[rels]] | ||
| std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> | ||
| stepFromLeftGraph; | ||
| std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> | ||
| stepFromRightGraph; | ||
| table_id_t maxTableID = 0; | ||
| for (auto entry : relEntries) { | ||
| auto& relEntry = entry->constCast<RelTableCatalogEntry>(); | ||
| auto srcTableID = relEntry.getSrcTableID(); | ||
| auto dstTableID = relEntry.getDstTableID(); | ||
| auto tableID = relEntry.getTableID(); | ||
| stepFromLeftGraph[srcTableID][dstTableID].push_back(tableID); | ||
| stepFromRightGraph[dstTableID][srcTableID].push_back(tableID); | ||
| if (relDirectionType == RelDirectionType::BOTH) { | ||
| stepFromLeftGraph[dstTableID][srcTableID].push_back(tableID); | ||
| stepFromRightGraph[srcTableID][dstTableID].push_back(tableID); | ||
| } | ||
| for (auto& entry : rel.getEntries()) { | ||
| auto& relEntry = entry->constCast<RelTableCatalogEntry>(); | ||
| auto srcTableID = relEntry.getSrcTableID(); | ||
| auto dstTableID = relEntry.getDstTableID(); | ||
| if ((srcBoundTableIDSet.contains(srcTableID) && | ||
| dstBoundTableIDSet.contains(dstTableID)) || | ||
| (dstBoundTableIDSet.contains(srcTableID) && | ||
| srcBoundTableIDSet.contains(dstTableID))) { | ||
| prunedEntries.push_back(entry); | ||
| maxTableID = std::max(maxTableID, tableID); | ||
| } | ||
|
|
||
| auto stepFromLeft = pruneRecursiveRel(stepFromLeftGraph, stepFromRightGraph, srcTableIDSet, | ||
| dstTableIDSet, lowerBound, upperBound, maxTableID); | ||
| auto stepFromRight = pruneRecursiveRel(stepFromRightGraph, stepFromLeftGraph, dstTableIDSet, | ||
| srcTableIDSet, lowerBound, upperBound, maxTableID); | ||
| return {stepFromLeft, stepFromRight}; | ||
| } | ||
|
|
||
| std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel( | ||
| const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph, | ||
| const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& | ||
| reseveGraph, | ||
| const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound, | ||
| size_t upperBound, table_id_t maxTableID) const { | ||
|
|
||
| // f[i][j] represent whether the edge numbered i can be reached by jumping j times through the | ||
| // set A. | ||
| std::unordered_map<table_id_t, std::vector<bool>> f, g; | ||
|
|
||
| auto initFunc = [upperBound](const std::unordered_map<table_id_t, | ||
| std::unordered_map<table_id_t, table_id_vector_t>>& _graph, | ||
| std::unordered_map<table_id_t, std::vector<bool>>& ans, | ||
| const table_id_set_t& beginTableIDSet) { | ||
| for (auto [_, map] : _graph) { | ||
| for (auto [_, rels] : map) { | ||
| for (auto rel : rels) { | ||
| ans.emplace(rel, std::vector<bool>(upperBound + 1, false)); | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); | ||
| auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); | ||
| for (auto& entry : rel.getEntries()) { | ||
| auto& relEntry = entry->constCast<RelTableCatalogEntry>(); | ||
| auto srcTableID = relEntry.getSrcTableID(); | ||
| auto dstTableID = relEntry.getDstTableID(); | ||
| if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) { | ||
|
|
||
| for (auto tableID : beginTableIDSet) { | ||
| if (!_graph.contains(tableID)) { | ||
| continue; | ||
| } | ||
| prunedEntries.push_back(entry); | ||
| for (auto [dst, rels] : _graph.at(tableID)) { | ||
| for (auto rel : rels) { | ||
| ans[rel][1] = true; | ||
| ans[rel][0] = true; | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| initFunc(graph, f, startTableIDSet); | ||
| initFunc(reseveGraph, g, endTableIDSet); | ||
|
|
||
| auto isOk = [&](const table_id_vector_t& rels, int j, | ||
| std::unordered_map<table_id_t, std::vector<bool>>& map) -> bool { | ||
| for (auto rel : rels) { | ||
| if (map[rel][j - 1]) { | ||
| return true; | ||
| } | ||
| } | ||
| return false; | ||
| }; | ||
|
|
||
| auto bfsFunc = | ||
| [upperBound, maxTableID, isOk]( | ||
| const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& | ||
| _graph, | ||
| const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& | ||
| _reseveGraph, | ||
| std::unordered_map<table_id_t, std::vector<bool>>& map) { | ||
| for (int j = 2; j <= upperBound; ++j) { | ||
| for (auto v = 0u; v < maxTableID; ++v) { | ||
| bool flag = false; | ||
| if (_reseveGraph.contains(v)) { | ||
| for (auto [_, rels] : _reseveGraph.at(v)) { | ||
| if (isOk(rels, j, map)) { | ||
| flag = true; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (flag && _graph.contains(v)) { | ||
| for (auto [dst, rels] : _graph.at(v)) { | ||
| for (auto rel : rels) { | ||
| map[rel][j] = true; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| bfsFunc(graph, reseveGraph, f); | ||
| bfsFunc(reseveGraph, graph, g); | ||
|
|
||
| std::vector<table_id_set_t> stepActiveTableIDs(upperBound); | ||
| for (auto [rel, vector] : f) { | ||
| for (int j = 0; j <= upperBound; ++j) { | ||
| if (!vector[j]) { | ||
| continue; | ||
| } | ||
| for (int k = 0; k <= upperBound; ++k) { | ||
| if (!g[rel][k]) { | ||
| continue; | ||
| } | ||
| auto step = j + k; | ||
| if (step != upperBound) { | ||
| // rel repeat count | ||
| step--; | ||
| } | ||
| if (step < lowerBound) { | ||
| continue; | ||
| } else if (step > upperBound) { | ||
| break; | ||
| } else { | ||
| int index = j == 0 ? 0 : j - 1; | ||
| stepActiveTableIDs[index].emplace(rel); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| rel.setEntries(prunedEntries); | ||
| return stepActiveTableIDs; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me try to summarize the logic here and pls correct me if I'm wrong. Consider topology WLOG, we assume going from left to right. The and because So eventually we only compute If my understanding is correct, can u also give a benchmark number in the PR description to make sure we actually get performance benefit. |
||
| } | ||
|
|
||
| std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries( | ||
| table_id_set_t tableIDs) const { | ||
| std::vector<TableCatalogEntry*> relEntries; | ||
| for (const auto& tableID : tableIDs) { | ||
| relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID)); | ||
| } | ||
| return relEntries; | ||
| } | ||
|
|
||
| std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const { | ||
| std::vector<table_id_t> nodeTableIDs; | ||
| for (auto node_table_entry : catalog->getNodeTableEntries(tx)) { | ||
| nodeTableIDs.push_back(node_table_entry->getTableID()); | ||
| } | ||
| return nodeTableIDs; | ||
| } | ||
|
|
||
| std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs( | ||
| const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const { | ||
| std::unordered_set<table_id_t> temp; | ||
| for (auto tableIDs : v1) { | ||
| temp.insert(tableIDs.begin(), tableIDs.end()); | ||
| } | ||
| for (auto tableIDs : v2) { | ||
| temp.insert(tableIDs.begin(), tableIDs.end()); | ||
| } | ||
| std::unordered_set<TableCatalogEntry*> ans; | ||
| for (table_id_t tableID : temp) { | ||
| ans.emplace(catalog->getTableCatalogEntry(tx, tableID)); | ||
| } | ||
| return ans; | ||
| } | ||
|
|
||
| static std::vector<catalog::TableCatalogEntry*> intersectEntries( | ||
| std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) { | ||
| std::sort(v1.begin(), v1.end()); | ||
| std::sort(v2.begin(), v2.end()); | ||
| std::vector<catalog::TableCatalogEntry*> intersection; | ||
| std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), | ||
| std::back_inserter(intersection)); | ||
| return intersection; | ||
| } | ||
|
|
||
| static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1, | ||
| std::vector<TableCatalogEntry*> v2) { | ||
| auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) { | ||
| return a->getTableID() < b->getTableID(); | ||
| }; | ||
| std::sort(v1.begin(), v1.end(), compareFunc); | ||
| std::sort(v2.begin(), v2.end(), compareFunc); | ||
| return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end()); | ||
| } | ||
|
|
||
| void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { | ||
| auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); | ||
| auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); | ||
| if (rel.isRecursive()) { | ||
| auto nodeTableIDs = getNodeTableIDs(); | ||
| // there is no label on both sides | ||
| if (rel.getUpperBound() == 0 || rel.getEntries().empty()) { | ||
| return; | ||
| } | ||
|
|
||
| auto [stepFromLeftTableIDs, stepFromRightTableIDs] = | ||
| pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(), | ||
| rel.getUpperBound(), rel.getDirectionType()); | ||
| auto recursiveInfo = rel.getRecursiveInfoUnsafe(); | ||
| recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs; | ||
| recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs; | ||
| // todo we need reset rel entries? | ||
| auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs); | ||
| std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()}; | ||
| if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we do l284-l312 as a separate PR because
|
||
| rel.setEntries(newRelEntries); | ||
| recursiveInfo->rel->setEntries(newRelEntries); | ||
| // update src&dst entries | ||
| auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD, | ||
| getTableCatalogEntries(stepFromLeftTableIDs.front())); | ||
|
|
||
| std::unordered_set<table_id_t> backwardRelNodes; | ||
| for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) { | ||
| if (i == 0) { | ||
| continue; | ||
| } | ||
| const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD, | ||
| getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1))); | ||
| backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end()); | ||
| } | ||
|
|
||
| if (rel.getDirectionType() == RelDirectionType::BOTH) { | ||
| forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end()); | ||
| backwardRelNodes = forwardRelNodes; | ||
| } | ||
|
|
||
| auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(), | ||
| getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()})); | ||
| rel.getSrcNode()->setEntries(newSrcNodeEntries); | ||
|
|
||
| auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(), | ||
| getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()})); | ||
| rel.getDstNode()->setEntries(newDstNodeEntries); | ||
| } | ||
| } else { | ||
| auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, | ||
| rel.getDirectionType()); | ||
| rel.setEntries(prunedEntries); | ||
| } | ||
| // Note the pruning for node should guarantee the following exception won't be triggered. | ||
| // For safety (and consistency) reason, we still write the check but skip coverage check. | ||
| // LCOV_EXCL_START | ||
| if (prunedEntries.empty()) { | ||
| if (rel.getEntries().empty()) { | ||
| if (throwOnViolate) { | ||
| throw BinderException(stringFormat("Cannot find a label for relationship {} that " | ||
| "connects to all of its neighbour nodes.", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to use but I would use
std::unrdered_map<table_id_t, STRUCT>and STRUCT contains
because we always access by srcTableID instead of
src&dstTableID