Skip to content

Commit c67397e

Browse files
JkSelfGlutenPerfBot
authored and
GlutenPerfBot
committed
[11771] [11772] Fix smj result mismatch issue
1 parent fa33e12 commit c67397e

File tree

5 files changed

+272
-113
lines changed

5 files changed

+272
-113
lines changed

velox/exec/MergeJoin.cpp

+102-104
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "velox/exec/Task.h"
1919
#include "velox/expression/FieldReference.h"
2020

21+
#include <iostream>
22+
2123
namespace facebook::velox::exec {
2224

2325
MergeJoin::MergeJoin(
@@ -92,7 +94,7 @@ void MergeJoin::initialize() {
9294
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
9395
joinTracker_ = JoinTracker(outputBatchSize_, pool());
9496
}
95-
} else if (joinNode_->isAntiJoin()) {
97+
} else if (joinNode_->isAntiJoin() || joinNode_->isFullJoin()) {
9698
// Anti join needs to track the left side rows that have no match on the
9799
// right.
98100
joinTracker_ = JoinTracker(outputBatchSize_, pool());
@@ -386,7 +388,8 @@ bool MergeJoin::tryAddOutputRow(
386388
const RowVectorPtr& leftBatch,
387389
vector_size_t leftRow,
388390
const RowVectorPtr& rightBatch,
389-
vector_size_t rightRow) {
391+
vector_size_t rightRow,
392+
bool isRightJoinForFullOuter) {
390393
if (outputSize_ == outputBatchSize_) {
391394
return false;
392395
}
@@ -420,12 +423,15 @@ bool MergeJoin::tryAddOutputRow(
420423
filterRightInputProjections_);
421424

422425
if (joinTracker_) {
423-
if (isRightJoin(joinType_)) {
426+
if (isRightJoin(joinType_) ||
427+
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
424428
// Record right-side row with a match on the left-side.
425-
joinTracker_->addMatch(rightBatch, rightRow, outputSize_);
429+
joinTracker_->addMatch(
430+
rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
426431
} else {
427432
// Record left-side row with a match on the right-side.
428-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
433+
joinTracker_->addMatch(
434+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
429435
}
430436
}
431437
}
@@ -435,7 +441,8 @@ bool MergeJoin::tryAddOutputRow(
435441
if (isAntiJoin(joinType_)) {
436442
VELOX_CHECK(joinTracker_.has_value());
437443
// Record left-side row with a match on the right-side.
438-
joinTracker_->addMatch(leftBatch, leftRow, outputSize_);
444+
joinTracker_->addMatch(
445+
leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
439446
}
440447

441448
++outputSize_;
@@ -454,14 +461,14 @@ bool MergeJoin::prepareOutput(
454461
return true;
455462
}
456463

457-
if (isRightJoin(joinType_) && right != currentRight_) {
458-
return true;
459-
}
460-
461464
// If there is a new right, we need to flatten the dictionary.
462465
if (!isRightFlattened_ && right && currentRight_ != right) {
463466
flattenRightProjections();
464467
}
468+
469+
if (right != currentRight_) {
470+
return true;
471+
}
465472
return false;
466473
}
467474

@@ -573,6 +580,39 @@ bool MergeJoin::prepareOutput(
573580
bool MergeJoin::addToOutput() {
574581
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
575582
return addToOutputForRightJoin();
583+
} else if (isFullJoin(joinType_) && filter_) {
584+
if (!leftForRightJoinMatch_) {
585+
leftForRightJoinMatch_ = leftMatch_;
586+
rightForRightJoinMatch_ = rightMatch_;
587+
}
588+
589+
if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
590+
auto left = addToOutputForLeftJoin();
591+
if (!leftMatch_) {
592+
leftJoinForFullFinished_ = true;
593+
}
594+
if (left) {
595+
if (!leftMatch_) {
596+
leftMatch_ = leftForRightJoinMatch_;
597+
rightMatch_ = rightForRightJoinMatch_;
598+
}
599+
600+
return true;
601+
}
602+
}
603+
604+
if (!leftMatch_ && !rightJoinForFullFinished_) {
605+
leftMatch_ = leftForRightJoinMatch_;
606+
rightMatch_ = rightForRightJoinMatch_;
607+
rightJoinForFullFinished_ = true;
608+
}
609+
610+
auto right = addToOutputForRightJoin();
611+
612+
leftForRightJoinMatch_ = leftMatch_;
613+
rightForRightJoinMatch_ = rightMatch_;
614+
615+
return right;
576616
} else {
577617
return addToOutputForLeftJoin();
578618
}
@@ -719,7 +759,13 @@ bool MergeJoin::addToOutputForRightJoin() {
719759
}
720760

721761
for (auto j = leftStartRow; j < leftEndRow; ++j) {
722-
if (!tryAddOutputRow(leftBatch, j, rightBatch, i)) {
762+
auto isRightJoinForFullOuter = false;
763+
if (isFullJoin(joinType_)) {
764+
isRightJoinForFullOuter = true;
765+
}
766+
767+
if (!tryAddOutputRow(
768+
leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
723769
// If we run out of space in the current output_, we will need to
724770
// produce a buffer and continue processing left later. In this
725771
// case, we cannot leave left as a lazy vector, since we cannot have
@@ -818,7 +864,7 @@ RowVectorPtr MergeJoin::getOutput() {
818864
continue;
819865
} else if (isAntiJoin(joinType_)) {
820866
output = filterOutputForAntiJoin(output);
821-
if (output) {
867+
if (output != nullptr && output->size() > 0) {
822868
return output;
823869
}
824870

@@ -904,6 +950,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
904950
// results from the current match.
905951
if (addToOutput()) {
906952
return std::move(output_);
953+
} else {
954+
previousLeftMatch_ = leftMatch_;
907955
}
908956
}
909957

@@ -968,6 +1016,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
9681016

9691017
if (addToOutput()) {
9701018
return std::move(output_);
1019+
} else {
1020+
previousLeftMatch_ = leftMatch_;
9711021
}
9721022
}
9731023

@@ -1107,7 +1157,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11071157
isFullJoin(joinType_)) {
11081158
// If output_ is currently wrapping a different buffer, return it
11091159
// first.
1110-
if (prepareOutput(input_, nullptr)) {
1160+
if (prepareOutput(input_, rightInput_)) {
11111161
output_->resize(outputSize_);
11121162
return std::move(output_);
11131163
}
@@ -1132,7 +1182,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
11321182
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
11331183
// If output_ is currently wrapping a different buffer, return it
11341184
// first.
1135-
if (prepareOutput(nullptr, rightInput_)) {
1185+
if (prepareOutput(input_, rightInput_)) {
11361186
output_->resize(outputSize_);
11371187
return std::move(output_);
11381188
}
@@ -1184,6 +1234,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
11841234
endRightRow < rightInput_->size(),
11851235
std::nullopt};
11861236

1237+
leftJoinForFullFinished_ = false;
1238+
rightJoinForFullFinished_ = false;
11871239
if (!leftMatch_->complete || !rightMatch_->complete) {
11881240
if (!leftMatch_->complete) {
11891241
// Need to continue looking for the end of match.
@@ -1212,6 +1264,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
12121264

12131265
if (addToOutput()) {
12141266
return std::move(output_);
1267+
} else {
1268+
previousLeftMatch_ = leftMatch_;
12151269
}
12161270

12171271
if (!rightInput_) {
@@ -1228,8 +1282,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
12281282
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12291283
const auto numRows = output->size();
12301284

1231-
RowVectorPtr fullOuterOutput = nullptr;
1232-
12331285
BufferPtr indices = allocateIndices(numRows, pool());
12341286
auto* rawIndices = indices->asMutable<vector_size_t>();
12351287
vector_size_t numPassed = 0;
@@ -1246,76 +1298,29 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
12461298

12471299
// If all matches for a given left-side row fail the filter, add a row to
12481300
// the output with nulls for the right-side columns.
1249-
const auto onMiss = [&](auto row) {
1250-
if (isAntiJoin(joinType_)) {
1251-
return;
1252-
}
1253-
rawIndices[numPassed++] = row;
1254-
1255-
if (isFullJoin(joinType_)) {
1256-
// For filtered rows, it is necessary to insert additional data
1257-
// to ensure the result set is complete. Specifically, we
1258-
// need to generate two records: one record containing the
1259-
// columns from the left table along with nulls for the
1260-
// right table, and another record containing the columns
1261-
// from the right table along with nulls for the left table.
1262-
// For instance, the current output is filtered based on the condition
1263-
// t > 1.
1264-
1265-
// 1, 1
1266-
// 2, 2
1267-
// 3, 3
1268-
1269-
// In this scenario, we need to additionally insert a record 1, 1.
1270-
// Subsequently, we will set the values of the columns on the left to
1271-
// null and the values of the columns on the right to null as well. By
1272-
// doing so, we will obtain the final result set.
1273-
1274-
// 1, null
1275-
// null, 1
1276-
// 2, 2
1277-
// 3, 3
1278-
fullOuterOutput = BaseVector::create<RowVector>(
1279-
output->type(), output->size() + 1, pool());
1280-
1281-
for (auto i = 0; i < row + 1; ++i) {
1282-
for (auto j = 0; j < output->type()->size(); ++j) {
1283-
fullOuterOutput->childAt(j)->copy(
1284-
output->childAt(j).get(), i, i, 1);
1301+
auto onMiss = [&](auto row, bool flag) {
1302+
if (!isLeftSemiFilterJoin(joinType_) &&
1303+
!isRightSemiFilterJoin(joinType_)) {
1304+
rawIndices[numPassed++] = row;
1305+
1306+
if (!isRightJoin(joinType_)) {
1307+
if (isFullJoin(joinType_) && flag) {
1308+
for (auto& projection : leftProjections_) {
1309+
auto target = output->childAt(projection.outputChannel);
1310+
target->setNull(row, true);
1311+
}
1312+
} else {
1313+
for (auto& projection : rightProjections_) {
1314+
auto target = output->childAt(projection.outputChannel);
1315+
target->setNull(row, true);
1316+
}
12851317
}
1286-
}
1287-
1288-
for (auto j = 0; j < output->type()->size(); ++j) {
1289-
fullOuterOutput->childAt(j)->copy(
1290-
output->childAt(j).get(), row + 1, row, 1);
1291-
}
1292-
1293-
for (auto i = row + 1; i < output->size(); ++i) {
1294-
for (auto j = 0; j < output->type()->size(); ++j) {
1295-
fullOuterOutput->childAt(j)->copy(
1296-
output->childAt(j).get(), i + 1, i, 1);
1318+
} else {
1319+
for (auto& projection : leftProjections_) {
1320+
auto target = output->childAt(projection.outputChannel);
1321+
target->setNull(row, true);
12971322
}
12981323
}
1299-
1300-
for (auto& projection : leftProjections_) {
1301-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1302-
target->setNull(row, true);
1303-
}
1304-
1305-
for (auto& projection : rightProjections_) {
1306-
auto& target = fullOuterOutput->childAt(projection.outputChannel);
1307-
target->setNull(row + 1, true);
1308-
}
1309-
} else if (!isRightJoin(joinType_)) {
1310-
for (auto& projection : rightProjections_) {
1311-
auto& target = output->childAt(projection.outputChannel);
1312-
target->setNull(row, true);
1313-
}
1314-
} else {
1315-
for (auto& projection : leftProjections_) {
1316-
auto& target = output->childAt(projection.outputChannel);
1317-
target->setNull(row, true);
1318-
}
13191324
}
13201325
};
13211326

@@ -1326,12 +1331,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13261331

13271332
joinTracker_->processFilterResult(i, passed, onMiss);
13281333

1329-
if (isAntiJoin(joinType_)) {
1330-
if (!passed) {
1331-
rawIndices[numPassed++] = i;
1332-
}
1333-
} else {
1334-
if (passed) {
1334+
if (!isAntiJoin(joinType_)) {
1335+
if (passed && !joinTracker_->isRightJoinForFullOuter(i)) {
13351336
rawIndices[numPassed++] = i;
13361337
}
13371338
}
@@ -1344,26 +1345,30 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13441345

13451346
// Every time we start a new left key match, `processFilterResult()` will
13461347
// check if at least one row from the previous match passed the filter. If
1347-
// none did, it calls onMiss to add a record with null right projections to
1348-
// the output.
1348+
// none did, it calls onMiss to add a record with null right projections
1349+
// to the output.
13491350
//
13501351
// Before we leave the current buffer, since we may not have seen the next
1351-
// left key match yet, the last key match may still be pending to produce a
1352-
// row (because `processFilterResult()` was not called yet).
1352+
// left key match yet, the last key match may still be pending to produce
1353+
// a row (because `processFilterResult()` was not called yet).
13531354
//
13541355
// To handle this, we need to call `noMoreFilterResults()` unless the
1355-
// same current left key match may continue in the next buffer. So there are
1356-
// two cases to check:
1356+
// same current left key match may continue in the next buffer. So there
1357+
// are two cases to check:
13571358
//
1358-
// 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1359-
// a different key match.
1359+
// 1. If leftMatch_ is nullopt, there for sure the next buffer will
1360+
// contain a different key match.
13601361
//
13611362
// 2. leftMatch_ may not be nullopt, but may be related to a different
13621363
// (subsequent) left key. So we check if the last row in the batch has the
13631364
// same left row number as the last key match.
13641365
if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) {
13651366
joinTracker_->noMoreFilterResults(onMiss);
13661367
}
1368+
1369+
if (isAntiJoin(joinType_) && leftMatch_ && !previousLeftMatch_) {
1370+
joinTracker_->noMoreFilterResults(onMiss);
1371+
}
13671372
} else {
13681373
filterRows_.resize(numRows);
13691374
filterRows_.setAll();
@@ -1385,17 +1390,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
13851390

13861391
if (numPassed == numRows) {
13871392
// All rows passed.
1388-
if (fullOuterOutput) {
1389-
return fullOuterOutput;
1390-
}
13911393
return output;
13921394
}
13931395

13941396
// Some, but not all rows passed.
1395-
if (fullOuterOutput) {
1396-
return wrap(numPassed, indices, fullOuterOutput);
1397-
}
1398-
13991397
return wrap(numPassed, indices, output);
14001398
}
14011399

0 commit comments

Comments
 (0)