|
1 | 1 | #include "binder/query/query_graph_label_analyzer.h" |
2 | 2 |
|
3 | 3 | #include "catalog/catalog.h" |
| 4 | +#include "catalog/catalog_entry/node_table_catalog_entry.h" |
4 | 5 | #include "catalog/catalog_entry/rel_table_catalog_entry.h" |
5 | 6 | #include "common/exception/binder.h" |
6 | 7 | #include "common/string_format.h" |
@@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& |
29 | 30 | if (queryRel->isRecursive()) { |
30 | 31 | continue; |
31 | 32 | } |
32 | | - common::table_id_set_t candidates; |
| 33 | + table_id_set_t candidates; |
33 | 34 | std::unordered_set<std::string> candidateNamesSet; |
34 | 35 | auto isSrcConnect = *queryRel->getSrcNode() == node; |
35 | 36 | auto isDstConnect = *queryRel->getDstNode() == node; |
@@ -94,49 +95,305 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& |
94 | 95 | } |
95 | 96 | } |
96 | 97 |
|
97 | | -void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { |
98 | | - if (rel.isRecursive()) { |
99 | | - return; |
100 | | - } |
| 98 | +std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel( |
| 99 | + const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet, |
| 100 | + const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const { |
| 101 | + |
| 102 | + auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { |
| 103 | + return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID); |
| 104 | + }; |
| 105 | + auto backwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { |
| 106 | + return dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID); |
| 107 | + }; |
101 | 108 | std::vector<TableCatalogEntry*> prunedEntries; |
102 | | - if (rel.getDirectionType() == RelDirectionType::BOTH) { |
103 | | - table_id_set_t srcBoundTableIDSet; |
104 | | - table_id_set_t dstBoundTableIDSet; |
105 | | - for (auto entry : rel.getSrcNode()->getEntries()) { |
106 | | - srcBoundTableIDSet.insert(entry->getTableID()); |
| 109 | + for (auto& entry : relEntries) { |
| 110 | + auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 111 | + auto srcTableID = relEntry.getSrcTableID(); |
| 112 | + auto dstTableID = relEntry.getDstTableID(); |
| 113 | + auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID); |
| 114 | + if (directionType == RelDirectionType::BOTH) { |
| 115 | + if (satisfyForwardPruning || backwardPruningFunc(srcTableID, dstTableID)) { |
| 116 | + prunedEntries.push_back(entry); |
| 117 | + } |
| 118 | + } else { |
| 119 | + if (satisfyForwardPruning) { |
| 120 | + prunedEntries.push_back(entry); |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + return prunedEntries; |
| 125 | +} |
| 126 | + |
| 127 | +table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction, |
| 128 | + std::vector<TableCatalogEntry*> relEntries) const { |
| 129 | + table_id_set_t nodeIDs; |
| 130 | + for (const auto& entry : relEntries) { |
| 131 | + const auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 132 | + if (direction == RelDataDirection::FWD) { |
| 133 | + nodeIDs.insert(relEntry.getDstTableID()); |
| 134 | + } else if (direction == RelDataDirection::BWD) { |
| 135 | + nodeIDs.insert(relEntry.getSrcTableID()); |
| 136 | + } else { |
| 137 | + KU_UNREACHABLE; |
107 | 138 | } |
108 | | - for (auto entry : rel.getDstNode()->getEntries()) { |
109 | | - dstBoundTableIDSet.insert(entry->getTableID()); |
| 139 | + } |
| 140 | + return nodeIDs; |
| 141 | +} |
| 142 | + |
| 143 | +std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>> |
| 144 | +QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries, |
| 145 | + const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound, |
| 146 | + size_t upperBound, RelDirectionType relDirectionType) const { |
| 147 | + // src-->[dst,[rels]] |
| 148 | + std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> |
| 149 | + stepFromLeftGraph; |
| 150 | + std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> |
| 151 | + stepFromRightGraph; |
| 152 | + table_id_t maxTableID = 0; |
| 153 | + for (auto entry : relEntries) { |
| 154 | + auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 155 | + auto srcTableID = relEntry.getSrcTableID(); |
| 156 | + auto dstTableID = relEntry.getDstTableID(); |
| 157 | + auto tableID = relEntry.getTableID(); |
| 158 | + stepFromLeftGraph[srcTableID][dstTableID].push_back(tableID); |
| 159 | + stepFromRightGraph[dstTableID][srcTableID].push_back(tableID); |
| 160 | + if (relDirectionType == RelDirectionType::BOTH) { |
| 161 | + stepFromLeftGraph[dstTableID][srcTableID].push_back(tableID); |
| 162 | + stepFromRightGraph[srcTableID][dstTableID].push_back(tableID); |
110 | 163 | } |
111 | | - for (auto& entry : rel.getEntries()) { |
112 | | - auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
113 | | - auto srcTableID = relEntry.getSrcTableID(); |
114 | | - auto dstTableID = relEntry.getDstTableID(); |
115 | | - if ((srcBoundTableIDSet.contains(srcTableID) && |
116 | | - dstBoundTableIDSet.contains(dstTableID)) || |
117 | | - (dstBoundTableIDSet.contains(srcTableID) && |
118 | | - srcBoundTableIDSet.contains(dstTableID))) { |
119 | | - prunedEntries.push_back(entry); |
| 164 | + maxTableID = std::max(maxTableID, tableID); |
| 165 | + } |
| 166 | + |
| 167 | + auto stepFromLeft = pruneRecursiveRel(stepFromLeftGraph, stepFromRightGraph, srcTableIDSet, |
| 168 | + dstTableIDSet, lowerBound, upperBound, maxTableID); |
| 169 | + auto stepFromRight = pruneRecursiveRel(stepFromRightGraph, stepFromLeftGraph, dstTableIDSet, |
| 170 | + srcTableIDSet, lowerBound, upperBound, maxTableID); |
| 171 | + return {stepFromLeft, stepFromRight}; |
| 172 | +} |
| 173 | + |
| 174 | +std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel( |
| 175 | + const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph, |
| 176 | + const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& |
| 177 | + reseveGraph, |
| 178 | + const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound, |
| 179 | + size_t upperBound, table_id_t maxTableID) const { |
| 180 | + |
| 181 | + // f[i][j] represent whether the edge numbered i can be reached by jumping j times through the |
| 182 | + // set A. |
| 183 | + std::unordered_map<table_id_t, std::vector<bool>> f, g; |
| 184 | + |
| 185 | + auto initFunc = [upperBound](const std::unordered_map<table_id_t, |
| 186 | + std::unordered_map<table_id_t, table_id_vector_t>>& _graph, |
| 187 | + std::unordered_map<table_id_t, std::vector<bool>>& ans, |
| 188 | + const table_id_set_t& beginTableIDSet) { |
| 189 | + for (auto [_, map] : _graph) { |
| 190 | + for (auto [_, rels] : map) { |
| 191 | + for (auto rel : rels) { |
| 192 | + ans.emplace(rel, std::vector<bool>(upperBound + 1, false)); |
| 193 | + } |
120 | 194 | } |
121 | 195 | } |
122 | | - } else { |
123 | | - auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); |
124 | | - auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); |
125 | | - for (auto& entry : rel.getEntries()) { |
126 | | - auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
127 | | - auto srcTableID = relEntry.getSrcTableID(); |
128 | | - auto dstTableID = relEntry.getDstTableID(); |
129 | | - if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) { |
| 196 | + |
| 197 | + for (auto tableID : beginTableIDSet) { |
| 198 | + if (!_graph.contains(tableID)) { |
130 | 199 | continue; |
131 | 200 | } |
132 | | - prunedEntries.push_back(entry); |
| 201 | + for (auto [dst, rels] : _graph.at(tableID)) { |
| 202 | + for (auto rel : rels) { |
| 203 | + ans[rel][1] = true; |
| 204 | + ans[rel][0] = true; |
| 205 | + } |
| 206 | + } |
| 207 | + } |
| 208 | + }; |
| 209 | + |
| 210 | + initFunc(graph, f, startTableIDSet); |
| 211 | + initFunc(reseveGraph, g, endTableIDSet); |
| 212 | + |
| 213 | + auto isOk = [&](const table_id_vector_t& rels, int j, |
| 214 | + std::unordered_map<table_id_t, std::vector<bool>>& map) -> bool { |
| 215 | + for (auto rel : rels) { |
| 216 | + if (map[rel][j - 1]) { |
| 217 | + return true; |
| 218 | + } |
| 219 | + } |
| 220 | + return false; |
| 221 | + }; |
| 222 | + |
| 223 | + auto bfsFunc = |
| 224 | + [upperBound, maxTableID, isOk]( |
| 225 | + const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& |
| 226 | + _graph, |
| 227 | + const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& |
| 228 | + _reseveGraph, |
| 229 | + std::unordered_map<table_id_t, std::vector<bool>>& map) { |
| 230 | + for (int j = 2; j <= upperBound; ++j) { |
| 231 | + for (auto v = 0u; v < maxTableID; ++v) { |
| 232 | + bool flag = false; |
| 233 | + if (_reseveGraph.contains(v)) { |
| 234 | + for (auto [_, rels] : _reseveGraph.at(v)) { |
| 235 | + if (isOk(rels, j, map)) { |
| 236 | + flag = true; |
| 237 | + break; |
| 238 | + } |
| 239 | + } |
| 240 | + } |
| 241 | + |
| 242 | + if (flag && _graph.contains(v)) { |
| 243 | + for (auto [dst, rels] : _graph.at(v)) { |
| 244 | + for (auto rel : rels) { |
| 245 | + map[rel][j] = true; |
| 246 | + } |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + }; |
| 252 | + |
| 253 | + bfsFunc(graph, reseveGraph, f); |
| 254 | + bfsFunc(reseveGraph, graph, g); |
| 255 | + |
| 256 | + std::vector<table_id_set_t> stepActiveTableIDs(upperBound); |
| 257 | + for (auto [rel, vector] : f) { |
| 258 | + for (int j = 0; j <= upperBound; ++j) { |
| 259 | + if (!vector[j]) { |
| 260 | + continue; |
| 261 | + } |
| 262 | + for (int k = 0; k <= upperBound; ++k) { |
| 263 | + if (!g[rel][k]) { |
| 264 | + continue; |
| 265 | + } |
| 266 | + auto step = j + k; |
| 267 | + if (step != upperBound) { |
| 268 | + // rel repeat count |
| 269 | + step--; |
| 270 | + } |
| 271 | + if (step < lowerBound) { |
| 272 | + continue; |
| 273 | + } else if (step > upperBound) { |
| 274 | + break; |
| 275 | + } else { |
| 276 | + int index = j == 0 ? 0 : j - 1; |
| 277 | + stepActiveTableIDs[index].emplace(rel); |
| 278 | + break; |
| 279 | + } |
| 280 | + } |
133 | 281 | } |
134 | 282 | } |
135 | | - rel.setEntries(prunedEntries); |
| 283 | + return stepActiveTableIDs; |
| 284 | +} |
| 285 | + |
| 286 | +std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries( |
| 287 | + table_id_set_t tableIDs) const { |
| 288 | + std::vector<TableCatalogEntry*> relEntries; |
| 289 | + for (const auto& tableID : tableIDs) { |
| 290 | + relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID)); |
| 291 | + } |
| 292 | + return relEntries; |
| 293 | +} |
| 294 | + |
| 295 | +std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const { |
| 296 | + std::vector<table_id_t> nodeTableIDs; |
| 297 | + for (auto node_table_entry : catalog->getNodeTableEntries(tx)) { |
| 298 | + nodeTableIDs.push_back(node_table_entry->getTableID()); |
| 299 | + } |
| 300 | + return nodeTableIDs; |
| 301 | +} |
| 302 | + |
| 303 | +std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs( |
| 304 | + const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const { |
| 305 | + std::unordered_set<table_id_t> temp; |
| 306 | + for (auto tableIDs : v1) { |
| 307 | + temp.insert(tableIDs.begin(), tableIDs.end()); |
| 308 | + } |
| 309 | + for (auto tableIDs : v2) { |
| 310 | + temp.insert(tableIDs.begin(), tableIDs.end()); |
| 311 | + } |
| 312 | + std::unordered_set<TableCatalogEntry*> ans; |
| 313 | + for (table_id_t tableID : temp) { |
| 314 | + ans.emplace(catalog->getTableCatalogEntry(tx, tableID)); |
| 315 | + } |
| 316 | + return ans; |
| 317 | +} |
| 318 | + |
| 319 | +static std::vector<catalog::TableCatalogEntry*> intersectEntries( |
| 320 | + std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) { |
| 321 | + std::sort(v1.begin(), v1.end()); |
| 322 | + std::sort(v2.begin(), v2.end()); |
| 323 | + std::vector<catalog::TableCatalogEntry*> intersection; |
| 324 | + std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), |
| 325 | + std::back_inserter(intersection)); |
| 326 | + return intersection; |
| 327 | +} |
| 328 | + |
| 329 | +static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1, |
| 330 | + std::vector<TableCatalogEntry*> v2) { |
| 331 | + auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) { |
| 332 | + return a->getTableID() < b->getTableID(); |
| 333 | + }; |
| 334 | + std::sort(v1.begin(), v1.end(), compareFunc); |
| 335 | + std::sort(v2.begin(), v2.end(), compareFunc); |
| 336 | + return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end()); |
| 337 | +} |
| 338 | + |
| 339 | +void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { |
| 340 | + auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); |
| 341 | + auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); |
| 342 | + if (rel.isRecursive()) { |
| 343 | + auto nodeTableIDs = getNodeTableIDs(); |
| 344 | + // there is no label on both sides |
| 345 | + if (rel.getUpperBound() == 0 || rel.getEntries().empty()) { |
| 346 | + return; |
| 347 | + } |
| 348 | + |
| 349 | + auto [stepFromLeftTableIDs, stepFromRightTableIDs] = |
| 350 | + pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(), |
| 351 | + rel.getUpperBound(), rel.getDirectionType()); |
| 352 | + auto recursiveInfo = rel.getRecursiveInfoUnsafe(); |
| 353 | + recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs; |
| 354 | + recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs; |
| 355 | + // todo we need reset rel entries? |
| 356 | + auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs); |
| 357 | + std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()}; |
| 358 | + if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) { |
| 359 | + rel.setEntries(newRelEntries); |
| 360 | + recursiveInfo->rel->setEntries(newRelEntries); |
| 361 | + // update src&dst entries |
| 362 | + auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD, |
| 363 | + getTableCatalogEntries(stepFromLeftTableIDs.front())); |
| 364 | + |
| 365 | + std::unordered_set<table_id_t> backwardRelNodes; |
| 366 | + for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) { |
| 367 | + if (i == 0) { |
| 368 | + continue; |
| 369 | + } |
| 370 | + const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD, |
| 371 | + getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1))); |
| 372 | + backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end()); |
| 373 | + } |
| 374 | + |
| 375 | + if (rel.getDirectionType() == RelDirectionType::BOTH) { |
| 376 | + forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end()); |
| 377 | + backwardRelNodes = forwardRelNodes; |
| 378 | + } |
| 379 | + |
| 380 | + auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(), |
| 381 | + getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()})); |
| 382 | + rel.getSrcNode()->setEntries(newSrcNodeEntries); |
| 383 | + |
| 384 | + auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(), |
| 385 | + getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()})); |
| 386 | + rel.getDstNode()->setEntries(newDstNodeEntries); |
| 387 | + } |
| 388 | + } else { |
| 389 | + auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, |
| 390 | + rel.getDirectionType()); |
| 391 | + rel.setEntries(prunedEntries); |
| 392 | + } |
136 | 393 | // Note the pruning for node should guarantee the following exception won't be triggered. |
137 | 394 | // For safety (and consistency) reason, we still write the check but skip coverage check. |
138 | 395 | // LCOV_EXCL_START |
139 | | - if (prunedEntries.empty()) { |
| 396 | + if (rel.getEntries().empty()) { |
140 | 397 | if (throwOnViolate) { |
141 | 398 | throw BinderException(stringFormat("Cannot find a label for relationship {} that " |
142 | 399 | "connects to all of its neighbour nodes.", |
|
0 commit comments