diff --git a/extension/fts/src/function/query_fts_index.cpp b/extension/fts/src/function/query_fts_index.cpp index b54c5ff123f..c17c283ba2c 100644 --- a/extension/fts/src/function/query_fts_index.cpp +++ b/extension/fts/src/function/query_fts_index.cpp @@ -265,7 +265,7 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) { auto edgeCompute = std::make_unique(scores, dfs); auto auxiliaryState = std::make_unique(); auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute), - std::move(auxiliaryState), nullptr /* outputNodeMask */); + std::move(auxiliaryState), {}, nullptr /* outputNodeMask */); GDSUtils::runFrontiersUntilConvergence(executionContext, compState, graph, ExtendDirection::FWD, 1 /* maxIters */, QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME); diff --git a/src/binder/query/query_graph_label_analyzer.cpp b/src/binder/query/query_graph_label_analyzer.cpp index 1b39af1e6da..ff0c7a7c874 100644 --- a/src/binder/query/query_graph_label_analyzer.cpp +++ b/src/binder/query/query_graph_label_analyzer.cpp @@ -1,6 +1,7 @@ #include "binder/query/query_graph_label_analyzer.h" #include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" #include "catalog/catalog_entry/rel_table_catalog_entry.h" #include "common/exception/binder.h" #include "common/string_format.h" @@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& if (queryRel->isRecursive()) { continue; } - common::table_id_set_t candidates; + table_id_set_t candidates; std::unordered_set candidateNamesSet; auto isSrcConnect = *queryRel->getSrcNode() == node; auto isDstConnect = *queryRel->getDstNode() == node; @@ -94,49 +95,305 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& } } -void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { - if (rel.isRecursive()) { - return; - } +std::vector QueryGraphLabelAnalyzer::pruneNonRecursiveRel( + const std::vector& relEntries, const table_id_set_t& srcTableIDSet, + const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const { + + auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { + return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID); + }; + auto backwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { + return dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID); + }; std::vector prunedEntries; - if (rel.getDirectionType() == RelDirectionType::BOTH) { - table_id_set_t srcBoundTableIDSet; - table_id_set_t dstBoundTableIDSet; - for (auto entry : rel.getSrcNode()->getEntries()) { - srcBoundTableIDSet.insert(entry->getTableID()); + for (auto& entry : relEntries) { + auto& relEntry = entry->constCast(); + auto srcTableID = relEntry.getSrcTableID(); + auto dstTableID = relEntry.getDstTableID(); + auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID); + if (directionType == RelDirectionType::BOTH) { + if (satisfyForwardPruning || backwardPruningFunc(srcTableID, dstTableID)) { + prunedEntries.push_back(entry); + } + } else { + if (satisfyForwardPruning) { + prunedEntries.push_back(entry); + } + } + } + return prunedEntries; +} + +table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction, + std::vector relEntries) const { + table_id_set_t nodeIDs; + for (const auto& entry : relEntries) { + const auto& relEntry = entry->constCast(); + if (direction == RelDataDirection::FWD) { + nodeIDs.insert(relEntry.getDstTableID()); + } else if (direction == RelDataDirection::BWD) { + nodeIDs.insert(relEntry.getSrcTableID()); + } else { + KU_UNREACHABLE; } - for (auto entry : rel.getDstNode()->getEntries()) { - dstBoundTableIDSet.insert(entry->getTableID()); + } + return nodeIDs; +} + +std::pair, std::vector> +QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector& relEntries, + const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound, + size_t upperBound, RelDirectionType relDirectionType) const { + // src-->[dst,[rels]] + std::unordered_map> + stepFromLeftGraph; + std::unordered_map> + stepFromRightGraph; + table_id_t maxTableID = 0; + for (auto entry : relEntries) { + auto& relEntry = entry->constCast(); + auto srcTableID = relEntry.getSrcTableID(); + auto dstTableID = relEntry.getDstTableID(); + auto tableID = relEntry.getTableID(); + stepFromLeftGraph[srcTableID][dstTableID].push_back(tableID); + stepFromRightGraph[dstTableID][srcTableID].push_back(tableID); + if (relDirectionType == RelDirectionType::BOTH) { + stepFromLeftGraph[dstTableID][srcTableID].push_back(tableID); + stepFromRightGraph[srcTableID][dstTableID].push_back(tableID); } - for (auto& entry : rel.getEntries()) { - auto& relEntry = entry->constCast(); - auto srcTableID = relEntry.getSrcTableID(); - auto dstTableID = relEntry.getDstTableID(); - if ((srcBoundTableIDSet.contains(srcTableID) && - dstBoundTableIDSet.contains(dstTableID)) || - (dstBoundTableIDSet.contains(srcTableID) && - srcBoundTableIDSet.contains(dstTableID))) { - prunedEntries.push_back(entry); + maxTableID = std::max(maxTableID, tableID); + } + + auto stepFromLeft = pruneRecursiveRel(stepFromLeftGraph, stepFromRightGraph, srcTableIDSet, + dstTableIDSet, lowerBound, upperBound, maxTableID); + auto stepFromRight = pruneRecursiveRel(stepFromRightGraph, stepFromLeftGraph, dstTableIDSet, + srcTableIDSet, lowerBound, upperBound, maxTableID); + return {stepFromLeft, stepFromRight}; +} + +std::vector QueryGraphLabelAnalyzer::pruneRecursiveRel( + const std::unordered_map>& graph, + const std::unordered_map>& + reseveGraph, + const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound, + size_t upperBound, table_id_t maxTableID) const { + + // f[i][j] represent whether the edge numbered i can be reached by jumping j times through the + // set A. + std::unordered_map> f, g; + + auto initFunc = [upperBound](const std::unordered_map>& _graph, + std::unordered_map>& ans, + const table_id_set_t& beginTableIDSet) { + for (auto [_, map] : _graph) { + for (auto [_, rels] : map) { + for (auto rel : rels) { + ans.emplace(rel, std::vector(upperBound + 1, false)); + } } } - } else { - auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); - auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); - for (auto& entry : rel.getEntries()) { - auto& relEntry = entry->constCast(); - auto srcTableID = relEntry.getSrcTableID(); - auto dstTableID = relEntry.getDstTableID(); - if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) { + + for (auto tableID : beginTableIDSet) { + if (!_graph.contains(tableID)) { continue; } - prunedEntries.push_back(entry); + for (auto [dst, rels] : _graph.at(tableID)) { + for (auto rel : rels) { + ans[rel][1] = true; + ans[rel][0] = true; + } + } + } + }; + + initFunc(graph, f, startTableIDSet); + initFunc(reseveGraph, g, endTableIDSet); + + auto isOk = [&](const table_id_vector_t& rels, int j, + std::unordered_map>& map) -> bool { + for (auto rel : rels) { + if (map[rel][j - 1]) { + return true; + } + } + return false; + }; + + auto bfsFunc = + [upperBound, maxTableID, isOk]( + const std::unordered_map>& + _graph, + const std::unordered_map>& + _reseveGraph, + std::unordered_map>& map) { + for (int j = 2; j <= upperBound; ++j) { + for (auto v = 0u; v < maxTableID; ++v) { + bool flag = false; + if (_reseveGraph.contains(v)) { + for (auto [_, rels] : _reseveGraph.at(v)) { + if (isOk(rels, j, map)) { + flag = true; + break; + } + } + } + + if (flag && _graph.contains(v)) { + for (auto [dst, rels] : _graph.at(v)) { + for (auto rel : rels) { + map[rel][j] = true; + } + } + } + } + } + }; + + bfsFunc(graph, reseveGraph, f); + bfsFunc(reseveGraph, graph, g); + + std::vector stepActiveTableIDs(upperBound); + for (auto [rel, vector] : f) { + for (int j = 0; j <= upperBound; ++j) { + if (!vector[j]) { + continue; + } + for (int k = 0; k <= upperBound; ++k) { + if (!g[rel][k]) { + continue; + } + auto step = j + k; + if (step != upperBound) { + // rel repeat count + step--; + } + if (step < lowerBound) { + continue; + } else if (step > upperBound) { + break; + } else { + int index = j == 0 ? 0 : j - 1; + stepActiveTableIDs[index].emplace(rel); + break; + } + } } } - rel.setEntries(prunedEntries); + return stepActiveTableIDs; +} + +std::vector QueryGraphLabelAnalyzer::getTableCatalogEntries( + table_id_set_t tableIDs) const { + std::vector relEntries; + for (const auto& tableID : tableIDs) { + relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID)); + } + return relEntries; +} + +std::vector QueryGraphLabelAnalyzer::getNodeTableIDs() const { + std::vector nodeTableIDs; + for (auto node_table_entry : catalog->getNodeTableEntries(tx)) { + nodeTableIDs.push_back(node_table_entry->getTableID()); + } + return nodeTableIDs; +} + +std::unordered_set QueryGraphLabelAnalyzer::mergeTableIDs( + const std::vector& v1, const std::vector& v2) const { + std::unordered_set temp; + for (auto tableIDs : v1) { + temp.insert(tableIDs.begin(), tableIDs.end()); + } + for (auto tableIDs : v2) { + temp.insert(tableIDs.begin(), tableIDs.end()); + } + std::unordered_set ans; + for (table_id_t tableID : temp) { + ans.emplace(catalog->getTableCatalogEntry(tx, tableID)); + } + return ans; +} + +static std::vector intersectEntries( + std::vector v1, std::vector v2) { + std::sort(v1.begin(), v1.end()); + std::sort(v2.begin(), v2.end()); + std::vector intersection; + std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), + std::back_inserter(intersection)); + return intersection; +} + +static bool isSameTableCatalogEntryVector(std::vector v1, + std::vector v2) { + auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) { + return a->getTableID() < b->getTableID(); + }; + std::sort(v1.begin(), v1.end(), compareFunc); + std::sort(v2.begin(), v2.end(), compareFunc); + return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end()); +} + +void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { + auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); + auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); + if (rel.isRecursive()) { + auto nodeTableIDs = getNodeTableIDs(); + // there is no label on both sides + if (rel.getUpperBound() == 0 || rel.getEntries().empty()) { + return; + } + + auto [stepFromLeftTableIDs, stepFromRightTableIDs] = + pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(), + rel.getUpperBound(), rel.getDirectionType()); + auto recursiveInfo = rel.getRecursiveInfoUnsafe(); + recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs; + recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs; + // todo we need reset rel entries? + auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs); + std::vector newRelEntries{temp.begin(), temp.end()}; + if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) { + rel.setEntries(newRelEntries); + recursiveInfo->rel->setEntries(newRelEntries); + // update src&dst entries + auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD, + getTableCatalogEntries(stepFromLeftTableIDs.front())); + + std::unordered_set backwardRelNodes; + for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) { + if (i == 0) { + continue; + } + const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD, + getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1))); + backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end()); + } + + if (rel.getDirectionType() == RelDirectionType::BOTH) { + forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end()); + backwardRelNodes = forwardRelNodes; + } + + auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(), + getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()})); + rel.getSrcNode()->setEntries(newSrcNodeEntries); + + auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(), + getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()})); + rel.getDstNode()->setEntries(newDstNodeEntries); + } + } else { + auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, + rel.getDirectionType()); + rel.setEntries(prunedEntries); + } // Note the pruning for node should guarantee the following exception won't be triggered. // For safety (and consistency) reason, we still write the check but skip coverage check. // LCOV_EXCL_START - if (prunedEntries.empty()) { + if (rel.getEntries().empty()) { if (throwOnViolate) { throw BinderException(stringFormat("Cannot find a label for relationship {} that " "connects to all of its neighbour nodes.", diff --git a/src/function/gds/asp_destinations.cpp b/src/function/gds/asp_destinations.cpp index de871f1f65e..5c35394ca68 100644 --- a/src/function/gds/asp_destinations.cpp +++ b/src/function/gds/asp_destinations.cpp @@ -196,7 +196,8 @@ class AllSPDestinationsAlgorithm final : public RJAlgorithm { std::make_unique(frontierPair.get(), multiplicities); auto auxiliaryState = std::make_unique(multiplicities); auto gdsState = std::make_unique(std::move(frontierPair), - std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); return RJCompState(std::move(gdsState), std::move(outputWriter)); } }; diff --git a/src/function/gds/asp_paths.cpp b/src/function/gds/asp_paths.cpp index 8760ad3a3eb..76c9bda242b 100644 --- a/src/function/gds/asp_paths.cpp +++ b/src/function/gds/asp_paths.cpp @@ -91,7 +91,8 @@ class AllSPPathsAlgorithm final : public RJAlgorithm { std::make_unique(frontierPair.get(), bfsGraph.get()); auto auxiliaryState = std::make_unique(std::move(bfsGraph)); auto gdsState = std::make_unique(std::move(frontierPair), - std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); return RJCompState(std::move(gdsState), std::move(outputWriter)); } }; diff --git a/src/function/gds/degrees.h b/src/function/gds/degrees.h index 24c3e43a25d..12c23bebedd 100644 --- a/src/function/gds/degrees.h +++ b/src/function/gds/degrees.h @@ -72,7 +72,7 @@ struct DegreesUtils { auto ec = std::make_unique(degrees); auto auxiliaryState = std::make_unique(); auto computeState = GDSComputeState(std::move(frontierPair), std::move(ec), - std::move(auxiliaryState), nullptr); + std::move(auxiliaryState), {}, nullptr); GDSUtils::runFrontiersUntilConvergence(context, computeState, graph, direction, 1 /* maxIters */); } diff --git a/src/function/gds/gds_state.cpp b/src/function/gds/gds_state.cpp index 21d7cc056f1..75f8f36424c 100644 --- a/src/function/gds/gds_state.cpp +++ b/src/function/gds/gds_state.cpp @@ -14,5 +14,18 @@ void GDSComputeState::beginFrontierCompute(common::table_id_t currTableID, auxiliaryState->beginFrontierCompute(currTableID, nextTableID); } +common::table_id_set_t GDSComputeState::getActiveRelTableIDs(size_t index, graph::Graph* graph) { + if (stepActiveRelTableIDs.empty()) { + auto nodeIDs = graph->getRelTableIDs(); + common::table_id_set_t set; + set.insert(nodeIDs.begin(), nodeIDs.end()); + stepActiveRelTableIDs.push_back(set); + } + if (index < stepActiveRelTableIDs.size()) { + return stepActiveRelTableIDs[index]; + } else { + return stepActiveRelTableIDs.back(); + } +} } // namespace function } // namespace kuzu diff --git a/src/function/gds/gds_utils.cpp b/src/function/gds/gds_utils.cpp index 7ea5ce77ecf..029b7ced11d 100644 --- a/src/function/gds/gds_utils.cpp +++ b/src/function/gds/gds_utils.cpp @@ -64,11 +64,16 @@ void GDSUtils::runFrontiersUntilConvergence(processor::ExecutionContext* context compState.edgeCompute->terminate(*compState.outputNodeMask)) { break; } + auto activeRelTableIDs = + compState.getActiveRelTableIDs(frontierPair->getCurrentIter() - 1, graph); for (auto info : graph->getGraphEntry()->nodeInfos) { auto fromEntry = info.entry; for (auto& nbrInfo : graph->getForwardNbrTableInfos(fromEntry->getTableID())) { auto toEntry = nbrInfo.nodeEntry; auto relEntry = nbrInfo.relEntry; + if (!activeRelTableIDs.contains(relEntry->getTableID())) { + continue; + } switch (extendDirection) { case ExtendDirection::FWD: { compState.beginFrontierCompute(fromEntry->getTableID(), toEntry->getTableID()); diff --git a/src/function/gds/k_core_decomposition.cpp b/src/function/gds/k_core_decomposition.cpp index 07ba1c57b4b..a0b0433472c 100644 --- a/src/function/gds/k_core_decomposition.cpp +++ b/src/function/gds/k_core_decomposition.cpp @@ -224,9 +224,9 @@ class KCoreDecomposition final : public GDSAlgorithm { std::make_unique(currentFrontier, nextFrontier); // Compute Core values auto removeVertexEdgeCompute = std::make_unique(degrees); - auto computeState = - GDSComputeState(std::move(frontierPair), std::move(removeVertexEdgeCompute), - std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + auto computeState = GDSComputeState(std::move(frontierPair), + std::move(removeVertexEdgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); auto coreValue = 0u; auto numNodes = graph->getNumNodes(clientContext->getTransaction()); auto numNodesComputed = 0u; diff --git a/src/function/gds/output_writer.cpp b/src/function/gds/output_writer.cpp index 9498d087c25..38cbc1e5d2f 100644 --- a/src/function/gds/output_writer.cpp +++ b/src/function/gds/output_writer.cpp @@ -339,7 +339,7 @@ void PathsOutputWriter::writePath(const std::vector& path) const { if (path.size() == 0) { return; } - if (!info.flipPath) { + if (!info.extendRightToLeft) { // By default, write path in reverse direction because we append ParentList from dst to src. writePathBwd(path); } else { diff --git a/src/function/gds/page_rank.cpp b/src/function/gds/page_rank.cpp index 134fa72ffb8..b70c69497ba 100644 --- a/src/function/gds/page_rank.cpp +++ b/src/function/gds/page_rank.cpp @@ -283,7 +283,7 @@ class PageRank final : public GDSAlgorithm { frontierPair->setActiveNodesForNextIter(); frontierPair->getNextSparseFrontier().disable(); auto computeState = GDSComputeState(std::move(frontierPair), nullptr, nullptr, - sharedState->getOutputNodeMaskMap()); + std::vector(), sharedState->getOutputNodeMaskMap()); auto pNextUpdateConstant = (1 - pageRankBindData->dampingFactor) * ((double)1 / numNodes); while (currentIter < pageRankBindData->maxIteration) { computeState.edgeCompute = diff --git a/src/function/gds/rec_joins.cpp b/src/function/gds/rec_joins.cpp index 9da4b5a83e0..7c191ac01ef 100644 --- a/src/function/gds/rec_joins.cpp +++ b/src/function/gds/rec_joins.cpp @@ -25,7 +25,7 @@ RJBindData::RJBindData(const RJBindData& other) : GDSBindData{other} { upperBound = other.upperBound; semantic = other.semantic; extendDirection = other.extendDirection; - flipPath = other.flipPath; + extendRightToLeft = other.extendRightToLeft; writePath = other.writePath; directionExpr = other.directionExpr; lengthExpr = other.lengthExpr; @@ -33,13 +33,15 @@ RJBindData::RJBindData(const RJBindData& other) : GDSBindData{other} { pathEdgeIDsExpr = other.pathEdgeIDsExpr; weightPropertyExpr = other.weightPropertyExpr; weightOutputExpr = other.weightOutputExpr; + stepFromLeftActivationRelInfos = other.stepFromLeftActivationRelInfos; + stepFromRightActivationRelInfos = other.stepFromRightActivationRelInfos; } PathsOutputWriterInfo RJBindData::getPathWriterInfo() const { auto info = PathsOutputWriterInfo(); info.semantic = semantic; info.lowerBound = lowerBound; - info.flipPath = flipPath; + info.extendRightToLeft = extendRightToLeft; info.writeEdgeDirection = writePath && extendDirection == ExtendDirection::BOTH; info.writePath = writePath; return info; @@ -50,6 +52,14 @@ void RJAlgorithm::bind(const kuzu::function::GDSBindInput&, main::ClientContext& "Try cypher patter ()-[*]->() instead."); } +std::vector RJBindData::getStepActiveRelTableIDs() const { + if (extendRightToLeft) { + return stepFromRightActivationRelInfos; + } else { + return stepFromLeftActivationRelInfos; + } +} + void RJAlgorithm::setToNoPath() { bindData->ptrCast()->writePath = false; } diff --git a/src/function/gds/ssp_destinations.cpp b/src/function/gds/ssp_destinations.cpp index 8dc91449c91..61922a353b9 100644 --- a/src/function/gds/ssp_destinations.cpp +++ b/src/function/gds/ssp_destinations.cpp @@ -105,7 +105,8 @@ class SingleSPDestinationsAlgorithm : public RJAlgorithm { auto edgeCompute = std::make_unique(frontierPair.get()); auto auxiliaryState = std::make_unique(); auto gdsState = std::make_unique(std::move(frontierPair), - std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); return RJCompState(std::move(gdsState), std::move(outputWriter)); } }; diff --git a/src/function/gds/ssp_paths.cpp b/src/function/gds/ssp_paths.cpp index 3fcc3ef2bac..313c6657fa0 100644 --- a/src/function/gds/ssp_paths.cpp +++ b/src/function/gds/ssp_paths.cpp @@ -82,7 +82,8 @@ class SingleSPPathsAlgorithm : public RJAlgorithm { std::make_unique(frontierPair.get(), bfsGraph.get()); auto auxiliaryState = std::make_unique(std::move(bfsGraph)); auto gdsState = std::make_unique(std::move(frontierPair), - std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); return RJCompState(std::move(gdsState), std::move(outputWriter)); } }; diff --git a/src/function/gds/variable_length_path.cpp b/src/function/gds/variable_length_path.cpp index d04d9994037..504ee544d80 100644 --- a/src/function/gds/variable_length_path.cpp +++ b/src/function/gds/variable_length_path.cpp @@ -121,7 +121,8 @@ class VarLenJoinsAlgorithm final : public RJAlgorithm { std::make_unique(frontierPair.get(), bfsGraph.get()); auto auxiliaryState = std::make_unique(std::move(bfsGraph)); auto gdsState = std::make_unique(std::move(frontierPair), - std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(edgeCompute), std::move(auxiliaryState), + rjBindData->getStepActiveRelTableIDs(), sharedState->getOutputNodeMaskMap()); return RJCompState(std::move(gdsState), std::move(outputWriter)); } }; diff --git a/src/function/gds/weakly_connected_components.cpp b/src/function/gds/weakly_connected_components.cpp index ac784e558ff..d65acf35fce 100644 --- a/src/function/gds/weakly_connected_components.cpp +++ b/src/function/gds/weakly_connected_components.cpp @@ -189,7 +189,8 @@ class WeaklyConnectedComponent final : public GDSAlgorithm { sharedState.get(), std::move(writer), componentIDs); auto auxiliaryState = std::make_unique(componentIDs); auto computeState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute), - std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + std::move(auxiliaryState), std::vector(), + sharedState->getOutputNodeMaskMap()); GDSUtils::runFrontiersUntilConvergence(context, computeState, graph, ExtendDirection::BOTH, MAX_ITERATION); GDSUtils::runVertexCompute(context, graph, *vertexCompute); diff --git a/src/function/gds/weighted_shortest_paths.cpp b/src/function/gds/weighted_shortest_paths.cpp index 489065f92aa..4204183f86d 100644 --- a/src/function/gds/weighted_shortest_paths.cpp +++ b/src/function/gds/weighted_shortest_paths.cpp @@ -292,9 +292,9 @@ class WeightedSPDestinationsAlgorithm : public RJAlgorithm { std::unique_ptr gdsState; visit(rjBindData->weightPropertyExpr->getDataType(), [&](T) { auto edgeCompute = std::make_unique>(costs); - gdsState = - std::make_unique(std::move(frontierPair), std::move(edgeCompute), - std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + gdsState = std::make_unique(std::move(frontierPair), + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); }); return RJCompState(std::move(gdsState), std::move(outputWriter)); } @@ -343,9 +343,9 @@ class WeightedSPPathsAlgorithm : public RJAlgorithm { visit(rjBindData->weightPropertyExpr->getDataType(), [&](T) { auto edgeCompute = std::make_unique>(*bfsGraph); auto auxiliaryState = std::make_unique(std::move(bfsGraph)); - gdsState = - std::make_unique(std::move(frontierPair), std::move(edgeCompute), - std::move(auxiliaryState), sharedState->getOutputNodeMaskMap()); + gdsState = std::make_unique(std::move(frontierPair), + std::move(edgeCompute), std::move(auxiliaryState), + std::vector(), sharedState->getOutputNodeMaskMap()); }); return RJCompState(std::move(gdsState), std::move(outputWriter)); } diff --git a/src/include/binder/expression/rel_expression.h b/src/include/binder/expression/rel_expression.h index d809487a568..55240bd64a7 100644 --- a/src/include/binder/expression/rel_expression.h +++ b/src/include/binder/expression/rel_expression.h @@ -3,6 +3,7 @@ #include "common/constants.h" #include "common/enums/extend_direction.h" #include "common/enums/query_rel_type.h" +#include "common/types/types.h" #include "node_expression.h" namespace kuzu { @@ -55,6 +56,9 @@ struct RecursiveInfo { // Edge property representing weight std::shared_ptr weightPropertyExpr = nullptr; std::shared_ptr weightOutputExpr = nullptr; + + std::vector stepFromLeftActivationRelInfos; + std::vector stepFromRightActivationRelInfos; }; class RelExpression final : public NodeOrRelExpression { @@ -102,6 +106,7 @@ class RelExpression final : public NodeOrRelExpression { recursiveInfo = std::move(recursiveInfo_); } const RecursiveInfo* getRecursiveInfo() const { return recursiveInfo.get(); } + RecursiveInfo* getRecursiveInfoUnsafe() { return recursiveInfo.get(); } size_t getLowerBound() const { return recursiveInfo->lowerBound; } size_t getUpperBound() const { return recursiveInfo->upperBound; } std::shared_ptr getLengthExpression() const { diff --git a/src/include/binder/query/query_graph_label_analyzer.h b/src/include/binder/query/query_graph_label_analyzer.h index b6f16cad846..d5294e40c7b 100644 --- a/src/include/binder/query/query_graph_label_analyzer.h +++ b/src/include/binder/query/query_graph_label_analyzer.h @@ -1,5 +1,6 @@ #pragma once +#include "common/enums/rel_direction.h" #include "main/client_context.h" #include "query_graph.h" @@ -9,7 +10,10 @@ namespace binder { class QueryGraphLabelAnalyzer { public: explicit QueryGraphLabelAnalyzer(const main::ClientContext& clientContext, bool throwOnViolate) - : throwOnViolate{throwOnViolate}, clientContext{clientContext} {} + : throwOnViolate{throwOnViolate}, clientContext{clientContext} { + tx = clientContext.getTransaction(); + catalog = clientContext.getCatalog(); + } void pruneLabel(QueryGraph& graph) const; @@ -17,9 +21,62 @@ class QueryGraphLabelAnalyzer { void pruneNode(const QueryGraph& graph, NodeExpression& node) const; void pruneRel(RelExpression& rel) const; + common::table_id_set_t collectRelNodes(const common::RelDataDirection direction, + std::vector relEntries) const; + + std::pair, std::vector> + pruneRecursiveRel(const std::vector& relEntries, + const common::table_id_set_t srcTableIDSet, const common::table_id_set_t dstTableIDSet, + size_t lowerBound, size_t upperBound, RelDirectionType relDirectionType) const; + + std::vector pruneNonRecursiveRel( + const std::vector& relEntries, + const common::table_id_set_t& srcTableIDSet, const common::table_id_set_t& dstTableIDSet, + const RelDirectionType directionType) const; + + std::vector getTableCatalogEntries( + common::table_id_set_t tableIDs) const; + + std::vector pruneRecursiveRel( + const std::unordered_map>& graph, + const std::unordered_map>& reserveGraph, + const common::table_id_set_t& startTableIDSet, const common::table_id_set_t& endTableIDSet, + size_t lowerBound, size_t upperBound,common::table_id_t maxTableID) const; + + std::vector getNodeTableIDs() const; + + std::unordered_set mergeTableIDs( + const std::vector& v1, + const std::vector& v2) const; + + class Path { + public: + void addNode(common::table_id_t node) { + nodeIndex.emplace(node, nodes.size()); + nodes.push_back(node); + } + + void addRel(const common::table_id_set_t& rel) { rels.push_back(rel); } + + void pop_back() { + + } + + const std::vector& getRels() { return rels; } + + private: + std::vector rels; + std::vector nodes; + std::unordered_map nodeIndex; + }; + private: bool throwOnViolate; const main::ClientContext& clientContext; + transaction::Transaction* tx; + catalog::Catalog* catalog; }; } // namespace binder diff --git a/src/include/function/gds/gds_state.h b/src/include/function/gds/gds_state.h index 1dfffde37be..7eb2161a881 100644 --- a/src/include/function/gds/gds_state.h +++ b/src/include/function/gds/gds_state.h @@ -10,15 +10,19 @@ struct KUZU_API GDSComputeState { std::unique_ptr frontierPair = nullptr; std::unique_ptr edgeCompute = nullptr; std::unique_ptr auxiliaryState = nullptr; + // While stepActiveRelTableIDs is empty, using all relTableIDs in graph + std::vector stepActiveRelTableIDs; processor::NodeOffsetMaskMap* outputNodeMask = nullptr; GDSComputeState(std::unique_ptr frontierPair, std::unique_ptr edgeCompute, std::unique_ptr auxiliaryState, + std::vector stepActiveRelTableIDs, processor::NodeOffsetMaskMap* outputNodeMask) : frontierPair{std::move(frontierPair)}, edgeCompute{std::move(edgeCompute)}, - auxiliaryState{std::move(auxiliaryState)}, outputNodeMask{outputNodeMask} {} + auxiliaryState{std::move(auxiliaryState)}, + stepActiveRelTableIDs{std::move(stepActiveRelTableIDs)}, outputNodeMask{outputNodeMask} {} void initSource(common::nodeID_t sourceNodeID) const; // When performing computations on multi-label graphs, it is beneficial to fix a single @@ -32,6 +36,8 @@ struct KUZU_API GDSComputeState { // RJOutputs, to possibly avoid them doing lookups of S and T-related data structures, // e.g., maps, internally. void beginFrontierCompute(common::table_id_t currTableID, common::table_id_t nextTableID) const; + + common::table_id_set_t getActiveRelTableIDs(size_t index, graph::Graph* graph); }; } // namespace function diff --git a/src/include/function/gds/output_writer.h b/src/include/function/gds/output_writer.h index d16c0aeda2c..9e95e0712ee 100644 --- a/src/include/function/gds/output_writer.h +++ b/src/include/function/gds/output_writer.h @@ -57,7 +57,7 @@ struct PathsOutputWriterInfo { // Range uint16_t lowerBound = 0; // Direction - bool flipPath = false; + bool extendRightToLeft = false; bool writeEdgeDirection = false; bool writePath = false; // Node predicate mask diff --git a/src/include/function/gds/rec_joins.h b/src/include/function/gds/rec_joins.h index f62c004e67e..7f8f4425745 100644 --- a/src/include/function/gds/rec_joins.h +++ b/src/include/function/gds/rec_joins.h @@ -22,7 +22,10 @@ struct RJBindData : public GDSBindData { common::ExtendDirection extendDirection = common::ExtendDirection::FWD; - bool flipPath = false; // See PathsOutputWriterInfo::flipPath for comments. + std::vector stepFromLeftActivationRelInfos; + std::vector stepFromRightActivationRelInfos; + + bool extendRightToLeft = false; // See PathsOutputWriterInfo::extendLeftToRight for comments. bool writePath = true; std::shared_ptr directionExpr = nullptr; @@ -43,6 +46,8 @@ struct RJBindData : public GDSBindData { PathsOutputWriterInfo getPathWriterInfo() const; + std::vector getStepActiveRelTableIDs() const; + std::unique_ptr copy() const override { return std::make_unique(*this); } diff --git a/src/include/processor/operator/recursive_extend/bfs_state.h b/src/include/processor/operator/recursive_extend/bfs_state.h index 368ba009543..d9dff750c4c 100644 --- a/src/include/processor/operator/recursive_extend/bfs_state.h +++ b/src/include/processor/operator/recursive_extend/bfs_state.h @@ -71,6 +71,7 @@ class BaseBFSState { inline void finalizeCurrentLevel() { moveNextLevelAsCurrentLevel(); } inline size_t getNumFrontiers() const { return frontiers.size(); } inline Frontier* getFrontier(common::idx_t idx) const { return frontiers[idx].get(); } + inline uint8_t getCurrentLevel() { return currentLevel; } protected: inline bool isCurrentFrontierEmpty() const { return currentFrontier->nodeIDs.empty(); } diff --git a/src/planner/plan/append_extend.cpp b/src/planner/plan/append_extend.cpp index 59b1317215e..cc71e3cc348 100644 --- a/src/planner/plan/append_extend.cpp +++ b/src/planner/plan/append_extend.cpp @@ -143,8 +143,10 @@ void Planner::appendRecursiveExtendAsGDS(const std::shared_ptr& bindData->upperBound = recursiveInfo->upperBound; bindData->semantic = semantic; bindData->extendDirection = direction; + bindData->stepFromLeftActivationRelInfos = recursiveInfo->stepFromLeftActivationRelInfos; + bindData->stepFromRightActivationRelInfos = recursiveInfo->stepFromRightActivationRelInfos; // If we extend from right to left, we need to print path in reverse direction. - bindData->flipPath = *boundNode == *rel->getRightNode(); + bindData->extendRightToLeft = *boundNode == *rel->getRightNode(); if (direction == common::ExtendDirection::BOTH) { bindData->directionExpr = recursiveInfo->pathEdgeDirectionsExpr; }