Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/fts/src/function/query_fts_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
auto edgeCompute = std::make_unique<QFTSEdgeCompute>(scores, dfs);
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute),
std::move(auxiliaryState), nullptr /* outputNodeMask */);
std::move(auxiliaryState), {}, nullptr /* outputNodeMask */);
GDSUtils::runFrontiersUntilConvergence(executionContext, compState, graph, ExtendDirection::FWD,
1 /* maxIters */, QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME);

Expand Down
321 changes: 289 additions & 32 deletions src/binder/query/query_graph_label_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/query/query_graph_label_analyzer.h"

#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
Expand Down Expand Up @@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
if (queryRel->isRecursive()) {
continue;
}
common::table_id_set_t candidates;
table_id_set_t candidates;
std::unordered_set<std::string> candidateNamesSet;
auto isSrcConnect = *queryRel->getSrcNode() == node;
auto isDstConnect = *queryRel->getDstNode() == node;
Expand Down Expand Up @@ -94,49 +95,305 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
}
}

void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
if (rel.isRecursive()) {
return;
}
std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel(
const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet,
const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const {

auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) {
return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID);
};
auto backwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) {
return dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID);
};
std::vector<TableCatalogEntry*> prunedEntries;
if (rel.getDirectionType() == RelDirectionType::BOTH) {
table_id_set_t srcBoundTableIDSet;
table_id_set_t dstBoundTableIDSet;
for (auto entry : rel.getSrcNode()->getEntries()) {
srcBoundTableIDSet.insert(entry->getTableID());
for (auto& entry : relEntries) {
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
auto srcTableID = relEntry.getSrcTableID();
auto dstTableID = relEntry.getDstTableID();
auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID);
if (directionType == RelDirectionType::BOTH) {
if (satisfyForwardPruning || backwardPruningFunc(srcTableID, dstTableID)) {
prunedEntries.push_back(entry);
}
} else {
if (satisfyForwardPruning) {
prunedEntries.push_back(entry);
}
}
}
return prunedEntries;
}

table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction,
std::vector<TableCatalogEntry*> relEntries) const {
table_id_set_t nodeIDs;
for (const auto& entry : relEntries) {
const auto& relEntry = entry->constCast<RelTableCatalogEntry>();
if (direction == RelDataDirection::FWD) {
nodeIDs.insert(relEntry.getDstTableID());
} else if (direction == RelDataDirection::BWD) {
nodeIDs.insert(relEntry.getSrcTableID());
} else {
KU_UNREACHABLE;
}
for (auto entry : rel.getDstNode()->getEntries()) {
dstBoundTableIDSet.insert(entry->getTableID());
}
return nodeIDs;
}

std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>>
QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries,
const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound,
size_t upperBound, RelDirectionType relDirectionType) const {
// src-->[dst,[rels]]
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to use but I would use std::unrdered_map<table_id_t, STRUCT>

and STRUCT contains

table_id_t dstTableID
std::vector<tabel_id_t> relTableIDs

because we always access by srcTableID instead of src&dstTableID

stepFromLeftGraph;
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
stepFromRightGraph;
table_id_t maxTableID = 0;
for (auto entry : relEntries) {
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
auto srcTableID = relEntry.getSrcTableID();
auto dstTableID = relEntry.getDstTableID();
auto tableID = relEntry.getTableID();
stepFromLeftGraph[srcTableID][dstTableID].push_back(tableID);
stepFromRightGraph[dstTableID][srcTableID].push_back(tableID);
if (relDirectionType == RelDirectionType::BOTH) {
stepFromLeftGraph[dstTableID][srcTableID].push_back(tableID);
stepFromRightGraph[srcTableID][dstTableID].push_back(tableID);
}
for (auto& entry : rel.getEntries()) {
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
auto srcTableID = relEntry.getSrcTableID();
auto dstTableID = relEntry.getDstTableID();
if ((srcBoundTableIDSet.contains(srcTableID) &&
dstBoundTableIDSet.contains(dstTableID)) ||
(dstBoundTableIDSet.contains(srcTableID) &&
srcBoundTableIDSet.contains(dstTableID))) {
prunedEntries.push_back(entry);
maxTableID = std::max(maxTableID, tableID);
}

auto stepFromLeft = pruneRecursiveRel(stepFromLeftGraph, stepFromRightGraph, srcTableIDSet,
dstTableIDSet, lowerBound, upperBound, maxTableID);
auto stepFromRight = pruneRecursiveRel(stepFromRightGraph, stepFromLeftGraph, dstTableIDSet,
srcTableIDSet, lowerBound, upperBound, maxTableID);
return {stepFromLeft, stepFromRight};
}

std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel(
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph,
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
reseveGraph,
const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound,
size_t upperBound, table_id_t maxTableID) const {

// f[i][j] represent whether the edge numbered i can be reached by jumping j times through the
// set A.
std::unordered_map<table_id_t, std::vector<bool>> f, g;

auto initFunc = [upperBound](const std::unordered_map<table_id_t,
std::unordered_map<table_id_t, table_id_vector_t>>& _graph,
std::unordered_map<table_id_t, std::vector<bool>>& ans,
const table_id_set_t& beginTableIDSet) {
for (auto [_, map] : _graph) {
for (auto [_, rels] : map) {
for (auto rel : rels) {
ans.emplace(rel, std::vector<bool>(upperBound + 1, false));
}
}
}
} else {
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
for (auto& entry : rel.getEntries()) {
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
auto srcTableID = relEntry.getSrcTableID();
auto dstTableID = relEntry.getDstTableID();
if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) {

for (auto tableID : beginTableIDSet) {
if (!_graph.contains(tableID)) {
continue;
}
prunedEntries.push_back(entry);
for (auto [dst, rels] : _graph.at(tableID)) {
for (auto rel : rels) {
ans[rel][1] = true;
ans[rel][0] = true;
}
}
}
};

initFunc(graph, f, startTableIDSet);
initFunc(reseveGraph, g, endTableIDSet);

auto isOk = [&](const table_id_vector_t& rels, int j,
std::unordered_map<table_id_t, std::vector<bool>>& map) -> bool {
for (auto rel : rels) {
if (map[rel][j - 1]) {
return true;
}
}
return false;
};

auto bfsFunc =
[upperBound, maxTableID, isOk](
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
_graph,
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
_reseveGraph,
std::unordered_map<table_id_t, std::vector<bool>>& map) {
for (int j = 2; j <= upperBound; ++j) {
for (auto v = 0u; v < maxTableID; ++v) {
bool flag = false;
if (_reseveGraph.contains(v)) {
for (auto [_, rels] : _reseveGraph.at(v)) {
if (isOk(rels, j, map)) {
flag = true;
break;
}
}
}

if (flag && _graph.contains(v)) {
for (auto [dst, rels] : _graph.at(v)) {
for (auto rel : rels) {
map[rel][j] = true;
}
}
}
}
}
};

bfsFunc(graph, reseveGraph, f);
bfsFunc(reseveGraph, graph, g);

std::vector<table_id_set_t> stepActiveTableIDs(upperBound);
for (auto [rel, vector] : f) {
for (int j = 0; j <= upperBound; ++j) {
if (!vector[j]) {
continue;
}
for (int k = 0; k <= upperBound; ++k) {
if (!g[rel][k]) {
continue;
}
auto step = j + k;
if (step != upperBound) {
// rel repeat count
step--;
}
if (step < lowerBound) {
continue;
} else if (step > upperBound) {
break;
} else {
int index = j == 0 ? 0 : j - 1;
stepActiveTableIDs[index].emplace(rel);
break;
}
}
}
}
rel.setEntries(prunedEntries);
return stepActiveTableIDs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to summarize the logic here and pls correct me if I'm wrong.

Consider topology Person - knows -> Person and Person - livesIn -> City and query

MATCH (a:Person) - [*2..2] -> (b:City)

WLOG, we assume going from left to right.

The dfs part will first generate all size 2 path

<<knows>, <knows>>
<<knows>, <livesIn>>

and because <<knows>, <knows>> dst is not b:City it will be pruned.

So eventually we only compute <<knows>, <livesIn>>.

If my understanding is correct, can u also give a benchmark number in the PR description to make sure we actually get performance benefit.

}

std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries(
table_id_set_t tableIDs) const {
std::vector<TableCatalogEntry*> relEntries;
for (const auto& tableID : tableIDs) {
relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID));
}
return relEntries;
}

std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const {
std::vector<table_id_t> nodeTableIDs;
for (auto node_table_entry : catalog->getNodeTableEntries(tx)) {
nodeTableIDs.push_back(node_table_entry->getTableID());
}
return nodeTableIDs;
}

std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs(
const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const {
std::unordered_set<table_id_t> temp;
for (auto tableIDs : v1) {
temp.insert(tableIDs.begin(), tableIDs.end());
}
for (auto tableIDs : v2) {
temp.insert(tableIDs.begin(), tableIDs.end());
}
std::unordered_set<TableCatalogEntry*> ans;
for (table_id_t tableID : temp) {
ans.emplace(catalog->getTableCatalogEntry(tx, tableID));
}
return ans;
}

static std::vector<catalog::TableCatalogEntry*> intersectEntries(
std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) {
std::sort(v1.begin(), v1.end());
std::sort(v2.begin(), v2.end());
std::vector<catalog::TableCatalogEntry*> intersection;
std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
std::back_inserter(intersection));
return intersection;
}

static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1,
std::vector<TableCatalogEntry*> v2) {
auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) {
return a->getTableID() < b->getTableID();
};
std::sort(v1.begin(), v1.end(), compareFunc);
std::sort(v2.begin(), v2.end(), compareFunc);
return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end());
}

void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
if (rel.isRecursive()) {
auto nodeTableIDs = getNodeTableIDs();
// there is no label on both sides
if (rel.getUpperBound() == 0 || rel.getEntries().empty()) {
return;
}

auto [stepFromLeftTableIDs, stepFromRightTableIDs] =
pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(),
rel.getUpperBound(), rel.getDirectionType());
auto recursiveInfo = rel.getRecursiveInfoUnsafe();
recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs;
recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs;
// todo we need reset rel entries?
auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs);
std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()};
if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we do l284-l312 as a separate PR because

  • it can be viewed as another optimization after pruning recursive rel labels; and
  • we will also need to get a benchmark number to make sure there are performance improvements for node pruning.

rel.setEntries(newRelEntries);
recursiveInfo->rel->setEntries(newRelEntries);
// update src&dst entries
auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD,
getTableCatalogEntries(stepFromLeftTableIDs.front()));

std::unordered_set<table_id_t> backwardRelNodes;
for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) {
if (i == 0) {
continue;
}
const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD,
getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1)));
backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end());
}

if (rel.getDirectionType() == RelDirectionType::BOTH) {
forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end());
backwardRelNodes = forwardRelNodes;
}

auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(),
getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()}));
rel.getSrcNode()->setEntries(newSrcNodeEntries);

auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(),
getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()}));
rel.getDstNode()->setEntries(newDstNodeEntries);
}
} else {
auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet,
rel.getDirectionType());
rel.setEntries(prunedEntries);
}
// Note the pruning for node should guarantee the following exception won't be triggered.
// For safety (and consistency) reason, we still write the check but skip coverage check.
// LCOV_EXCL_START
if (prunedEntries.empty()) {
if (rel.getEntries().empty()) {
if (throwOnViolate) {
throw BinderException(stringFormat("Cannot find a label for relationship {} that "
"connects to all of its neighbour nodes.",
Expand Down
3 changes: 2 additions & 1 deletion src/function/gds/asp_destinations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ class AllSPDestinationsAlgorithm final : public RJAlgorithm {
std::make_unique<ASPDestinationsEdgeCompute>(frontierPair.get(), multiplicities);
auto auxiliaryState = std::make_unique<ASPDestinationsAuxiliaryState>(multiplicities);
auto gdsState = std::make_unique<GDSComputeState>(std::move(frontierPair),
std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap());
std::move(edgeCompute), std::move(auxiliaryState),
std::vector<common::table_id_set_t>(), sharedState->getOutputNodeMaskMap());
return RJCompState(std::move(gdsState), std::move(outputWriter));
}
};
Expand Down
3 changes: 2 additions & 1 deletion src/function/gds/asp_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class AllSPPathsAlgorithm final : public RJAlgorithm {
std::make_unique<ASPPathsEdgeCompute>(frontierPair.get(), bfsGraph.get());
auto auxiliaryState = std::make_unique<PathAuxiliaryState>(std::move(bfsGraph));
auto gdsState = std::make_unique<GDSComputeState>(std::move(frontierPair),
std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap());
std::move(edgeCompute), std::move(auxiliaryState),
std::vector<common::table_id_set_t>(), sharedState->getOutputNodeMaskMap());
return RJCompState(std::move(gdsState), std::move(outputWriter));
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/function/gds/degrees.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct DegreesUtils {
auto ec = std::make_unique<DegreeEdgeCompute>(degrees);
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
auto computeState = GDSComputeState(std::move(frontierPair), std::move(ec),
std::move(auxiliaryState), nullptr);
std::move(auxiliaryState), {}, nullptr);
GDSUtils::runFrontiersUntilConvergence(context, computeState, graph, direction,
1 /* maxIters */);
}
Expand Down
Loading