Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.

Commit 78f3840

Browse files
author
wangqiang
committed
Support gds var length label pruning
1 parent ab5e13a commit 78f3840

25 files changed

+430
-59
lines changed

extension/fts/src/function/query_fts_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
265265
auto edgeCompute = std::make_unique<QFTSEdgeCompute>(scores, dfs);
266266
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
267267
auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute),
268-
std::move(auxiliaryState), nullptr /* outputNodeMask */);
268+
std::move(auxiliaryState), {}, nullptr /* outputNodeMask */);
269269
GDSUtils::runFrontiersUntilConvergence(executionContext, compState, graph, ExtendDirection::FWD,
270270
1 /* maxIters */, QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME);
271271

src/binder/query/query_graph_label_analyzer.cpp

Lines changed: 289 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "binder/query/query_graph_label_analyzer.h"
22

33
#include "catalog/catalog.h"
4+
#include "catalog/catalog_entry/node_table_catalog_entry.h"
45
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
56
#include "common/exception/binder.h"
67
#include "common/string_format.h"
@@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
2930
if (queryRel->isRecursive()) {
3031
continue;
3132
}
32-
common::table_id_set_t candidates;
33+
table_id_set_t candidates;
3334
std::unordered_set<std::string> candidateNamesSet;
3435
auto isSrcConnect = *queryRel->getSrcNode() == node;
3536
auto isDstConnect = *queryRel->getDstNode() == node;
@@ -94,49 +95,305 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
9495
}
9596
}
9697

97-
void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
98-
if (rel.isRecursive()) {
99-
return;
100-
}
98+
std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel(
99+
const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet,
100+
const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const {
101+
102+
auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) {
103+
return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID);
104+
};
105+
auto backwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) {
106+
return dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID);
107+
};
101108
std::vector<TableCatalogEntry*> prunedEntries;
102-
if (rel.getDirectionType() == RelDirectionType::BOTH) {
103-
table_id_set_t srcBoundTableIDSet;
104-
table_id_set_t dstBoundTableIDSet;
105-
for (auto entry : rel.getSrcNode()->getEntries()) {
106-
srcBoundTableIDSet.insert(entry->getTableID());
109+
for (auto& entry : relEntries) {
110+
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
111+
auto srcTableID = relEntry.getSrcTableID();
112+
auto dstTableID = relEntry.getDstTableID();
113+
auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID);
114+
if (directionType == RelDirectionType::BOTH) {
115+
if (satisfyForwardPruning || backwardPruningFunc(srcTableID, dstTableID)) {
116+
prunedEntries.push_back(entry);
117+
}
118+
} else {
119+
if (satisfyForwardPruning) {
120+
prunedEntries.push_back(entry);
121+
}
122+
}
123+
}
124+
return prunedEntries;
125+
}
126+
127+
table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction,
128+
std::vector<TableCatalogEntry*> relEntries) const {
129+
table_id_set_t nodeIDs;
130+
for (const auto& entry : relEntries) {
131+
const auto& relEntry = entry->constCast<RelTableCatalogEntry>();
132+
if (direction == RelDataDirection::FWD) {
133+
nodeIDs.insert(relEntry.getDstTableID());
134+
} else if (direction == RelDataDirection::BWD) {
135+
nodeIDs.insert(relEntry.getSrcTableID());
136+
} else {
137+
KU_UNREACHABLE;
107138
}
108-
for (auto entry : rel.getDstNode()->getEntries()) {
109-
dstBoundTableIDSet.insert(entry->getTableID());
139+
}
140+
return nodeIDs;
141+
}
142+
143+
std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>>
144+
QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries,
145+
const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound,
146+
size_t upperBound, RelDirectionType relDirectionType) const {
147+
// src-->[dst,[rels]]
148+
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
149+
stepFromLeftGraph;
150+
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
151+
stepFromRightGraph;
152+
table_id_t maxTableID = 0;
153+
for (auto entry : relEntries) {
154+
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
155+
auto srcTableID = relEntry.getSrcTableID();
156+
auto dstTableID = relEntry.getDstTableID();
157+
auto tableID = relEntry.getTableID();
158+
stepFromLeftGraph[srcTableID][dstTableID].push_back(tableID);
159+
stepFromRightGraph[dstTableID][srcTableID].push_back(tableID);
160+
if (relDirectionType == RelDirectionType::BOTH) {
161+
stepFromLeftGraph[dstTableID][srcTableID].push_back(tableID);
162+
stepFromRightGraph[srcTableID][dstTableID].push_back(tableID);
110163
}
111-
for (auto& entry : rel.getEntries()) {
112-
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
113-
auto srcTableID = relEntry.getSrcTableID();
114-
auto dstTableID = relEntry.getDstTableID();
115-
if ((srcBoundTableIDSet.contains(srcTableID) &&
116-
dstBoundTableIDSet.contains(dstTableID)) ||
117-
(dstBoundTableIDSet.contains(srcTableID) &&
118-
srcBoundTableIDSet.contains(dstTableID))) {
119-
prunedEntries.push_back(entry);
164+
maxTableID = std::max(maxTableID, tableID);
165+
}
166+
167+
auto stepFromLeft = pruneRecursiveRel(stepFromLeftGraph, stepFromRightGraph, srcTableIDSet,
168+
dstTableIDSet, lowerBound, upperBound, maxTableID);
169+
auto stepFromRight = pruneRecursiveRel(stepFromRightGraph, stepFromLeftGraph, dstTableIDSet,
170+
srcTableIDSet, lowerBound, upperBound, maxTableID);
171+
return {stepFromLeft, stepFromRight};
172+
}
173+
174+
std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel(
175+
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph,
176+
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
177+
reseveGraph,
178+
const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound,
179+
size_t upperBound, table_id_t maxTableID) const {
180+
181+
// f[i][j] represent whether the edge numbered i can be reached by jumping j times through the
182+
// set A.
183+
std::unordered_map<table_id_t, std::vector<bool>> f, g;
184+
185+
auto initFunc = [upperBound](const std::unordered_map<table_id_t,
186+
std::unordered_map<table_id_t, table_id_vector_t>>& _graph,
187+
std::unordered_map<table_id_t, std::vector<bool>>& ans,
188+
const table_id_set_t& beginTableIDSet) {
189+
for (auto [_, map] : _graph) {
190+
for (auto [_, rels] : map) {
191+
for (auto rel : rels) {
192+
ans.emplace(rel, std::vector<bool>(upperBound + 1, false));
193+
}
120194
}
121195
}
122-
} else {
123-
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
124-
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
125-
for (auto& entry : rel.getEntries()) {
126-
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
127-
auto srcTableID = relEntry.getSrcTableID();
128-
auto dstTableID = relEntry.getDstTableID();
129-
if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) {
196+
197+
for (auto tableID : beginTableIDSet) {
198+
if (!_graph.contains(tableID)) {
130199
continue;
131200
}
132-
prunedEntries.push_back(entry);
201+
for (auto [dst, rels] : _graph.at(tableID)) {
202+
for (auto rel : rels) {
203+
ans[rel][1] = true;
204+
ans[rel][0] = true;
205+
}
206+
}
207+
}
208+
};
209+
210+
initFunc(graph, f, startTableIDSet);
211+
initFunc(reseveGraph, g, endTableIDSet);
212+
213+
auto isOk = [&](const table_id_vector_t& rels, int j,
214+
std::unordered_map<table_id_t, std::vector<bool>>& map) -> bool {
215+
for (auto rel : rels) {
216+
if (map[rel][j - 1]) {
217+
return true;
218+
}
219+
}
220+
return false;
221+
};
222+
223+
auto bfsFunc =
224+
[upperBound, maxTableID, isOk](
225+
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
226+
_graph,
227+
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>&
228+
_reseveGraph,
229+
std::unordered_map<table_id_t, std::vector<bool>>& map) {
230+
for (int j = 2; j <= upperBound; ++j) {
231+
for (auto v = 0u; v < maxTableID; ++v) {
232+
bool flag = false;
233+
if (_reseveGraph.contains(v)) {
234+
for (auto [_, rels] : _reseveGraph.at(v)) {
235+
if (isOk(rels, j, map)) {
236+
flag = true;
237+
break;
238+
}
239+
}
240+
}
241+
242+
if (flag && _graph.contains(v)) {
243+
for (auto [dst, rels] : _graph.at(v)) {
244+
for (auto rel : rels) {
245+
map[rel][j] = true;
246+
}
247+
}
248+
}
249+
}
250+
}
251+
};
252+
253+
bfsFunc(graph, reseveGraph, f);
254+
bfsFunc(reseveGraph, graph, g);
255+
256+
std::vector<table_id_set_t> stepActiveTableIDs(upperBound);
257+
for (auto [rel, vector] : f) {
258+
for (int j = 0; j <= upperBound; ++j) {
259+
if (!vector[j]) {
260+
continue;
261+
}
262+
for (int k = 0; k <= upperBound; ++k) {
263+
if (!g[rel][k]) {
264+
continue;
265+
}
266+
auto step = j + k;
267+
if (step != upperBound) {
268+
// rel repeat count
269+
step--;
270+
}
271+
if (step < lowerBound) {
272+
continue;
273+
} else if (step > upperBound) {
274+
break;
275+
} else {
276+
int index = j == 0 ? 0 : j - 1;
277+
stepActiveTableIDs[index].emplace(rel);
278+
break;
279+
}
280+
}
133281
}
134282
}
135-
rel.setEntries(prunedEntries);
283+
return stepActiveTableIDs;
284+
}
285+
286+
std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries(
287+
table_id_set_t tableIDs) const {
288+
std::vector<TableCatalogEntry*> relEntries;
289+
for (const auto& tableID : tableIDs) {
290+
relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID));
291+
}
292+
return relEntries;
293+
}
294+
295+
std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const {
296+
std::vector<table_id_t> nodeTableIDs;
297+
for (auto node_table_entry : catalog->getNodeTableEntries(tx)) {
298+
nodeTableIDs.push_back(node_table_entry->getTableID());
299+
}
300+
return nodeTableIDs;
301+
}
302+
303+
std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs(
304+
const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const {
305+
std::unordered_set<table_id_t> temp;
306+
for (auto tableIDs : v1) {
307+
temp.insert(tableIDs.begin(), tableIDs.end());
308+
}
309+
for (auto tableIDs : v2) {
310+
temp.insert(tableIDs.begin(), tableIDs.end());
311+
}
312+
std::unordered_set<TableCatalogEntry*> ans;
313+
for (table_id_t tableID : temp) {
314+
ans.emplace(catalog->getTableCatalogEntry(tx, tableID));
315+
}
316+
return ans;
317+
}
318+
319+
static std::vector<catalog::TableCatalogEntry*> intersectEntries(
320+
std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) {
321+
std::sort(v1.begin(), v1.end());
322+
std::sort(v2.begin(), v2.end());
323+
std::vector<catalog::TableCatalogEntry*> intersection;
324+
std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
325+
std::back_inserter(intersection));
326+
return intersection;
327+
}
328+
329+
static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1,
330+
std::vector<TableCatalogEntry*> v2) {
331+
auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) {
332+
return a->getTableID() < b->getTableID();
333+
};
334+
std::sort(v1.begin(), v1.end(), compareFunc);
335+
std::sort(v2.begin(), v2.end(), compareFunc);
336+
return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end());
337+
}
338+
339+
void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
340+
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
341+
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
342+
if (rel.isRecursive()) {
343+
auto nodeTableIDs = getNodeTableIDs();
344+
// there is no label on both sides
345+
if (rel.getUpperBound() == 0 || rel.getEntries().empty()) {
346+
return;
347+
}
348+
349+
auto [stepFromLeftTableIDs, stepFromRightTableIDs] =
350+
pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(),
351+
rel.getUpperBound(), rel.getDirectionType());
352+
auto recursiveInfo = rel.getRecursiveInfoUnsafe();
353+
recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs;
354+
recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs;
355+
// todo we need reset rel entries?
356+
auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs);
357+
std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()};
358+
if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) {
359+
rel.setEntries(newRelEntries);
360+
recursiveInfo->rel->setEntries(newRelEntries);
361+
// update src&dst entries
362+
auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD,
363+
getTableCatalogEntries(stepFromLeftTableIDs.front()));
364+
365+
std::unordered_set<table_id_t> backwardRelNodes;
366+
for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) {
367+
if (i == 0) {
368+
continue;
369+
}
370+
const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD,
371+
getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1)));
372+
backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end());
373+
}
374+
375+
if (rel.getDirectionType() == RelDirectionType::BOTH) {
376+
forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end());
377+
backwardRelNodes = forwardRelNodes;
378+
}
379+
380+
auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(),
381+
getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()}));
382+
rel.getSrcNode()->setEntries(newSrcNodeEntries);
383+
384+
auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(),
385+
getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()}));
386+
rel.getDstNode()->setEntries(newDstNodeEntries);
387+
}
388+
} else {
389+
auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet,
390+
rel.getDirectionType());
391+
rel.setEntries(prunedEntries);
392+
}
136393
// Note the pruning for node should guarantee the following exception won't be triggered.
137394
// For safety (and consistency) reason, we still write the check but skip coverage check.
138395
// LCOV_EXCL_START
139-
if (prunedEntries.empty()) {
396+
if (rel.getEntries().empty()) {
140397
if (throwOnViolate) {
141398
throw BinderException(stringFormat("Cannot find a label for relationship {} that "
142399
"connects to all of its neighbour nodes.",

src/function/gds/asp_destinations.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ class AllSPDestinationsAlgorithm final : public RJAlgorithm {
196196
std::make_unique<ASPDestinationsEdgeCompute>(frontierPair.get(), multiplicities);
197197
auto auxiliaryState = std::make_unique<ASPDestinationsAuxiliaryState>(multiplicities);
198198
auto gdsState = std::make_unique<GDSComputeState>(std::move(frontierPair),
199-
std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap());
199+
std::move(edgeCompute), std::move(auxiliaryState),
200+
std::vector<common::table_id_set_t>(), sharedState->getOutputNodeMaskMap());
200201
return RJCompState(std::move(gdsState), std::move(outputWriter));
201202
}
202203
};

src/function/gds/asp_paths.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ class AllSPPathsAlgorithm final : public RJAlgorithm {
9191
std::make_unique<ASPPathsEdgeCompute>(frontierPair.get(), bfsGraph.get());
9292
auto auxiliaryState = std::make_unique<PathAuxiliaryState>(std::move(bfsGraph));
9393
auto gdsState = std::make_unique<GDSComputeState>(std::move(frontierPair),
94-
std::move(edgeCompute), std::move(auxiliaryState), sharedState->getOutputNodeMaskMap());
94+
std::move(edgeCompute), std::move(auxiliaryState),
95+
std::vector<common::table_id_set_t>(), sharedState->getOutputNodeMaskMap());
9596
return RJCompState(std::move(gdsState), std::move(outputWriter));
9697
}
9798
};

src/function/gds/degrees.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct DegreesUtils {
7272
auto ec = std::make_unique<DegreeEdgeCompute>(degrees);
7373
auto auxiliaryState = std::make_unique<EmptyGDSAuxiliaryState>();
7474
auto computeState = GDSComputeState(std::move(frontierPair), std::move(ec),
75-
std::move(auxiliaryState), nullptr);
75+
std::move(auxiliaryState), {}, nullptr);
7676
GDSUtils::runFrontiersUntilConvergence(context, computeState, graph, direction,
7777
1 /* maxIters */);
7878
}

0 commit comments

Comments
 (0)