18
18
#include " velox/exec/Task.h"
19
19
#include " velox/expression/FieldReference.h"
20
20
21
+ #include < iostream>
22
+
21
23
namespace facebook ::velox::exec {
22
24
23
25
MergeJoin::MergeJoin (
@@ -92,7 +94,7 @@ void MergeJoin::initialize() {
92
94
joinNode_->isRightJoin () || joinNode_->isFullJoin ()) {
93
95
joinTracker_ = JoinTracker (outputBatchSize_, pool ());
94
96
}
95
- } else if (joinNode_->isAntiJoin ()) {
97
+ } else if (joinNode_->isAntiJoin () || joinNode_-> isFullJoin () ) {
96
98
// Anti join needs to track the left side rows that have no match on the
97
99
// right.
98
100
joinTracker_ = JoinTracker (outputBatchSize_, pool ());
@@ -386,7 +388,8 @@ bool MergeJoin::tryAddOutputRow(
386
388
const RowVectorPtr& leftBatch,
387
389
vector_size_t leftRow,
388
390
const RowVectorPtr& rightBatch,
389
- vector_size_t rightRow) {
391
+ vector_size_t rightRow,
392
+ bool isRightJoinForFullOuter) {
390
393
if (outputSize_ == outputBatchSize_) {
391
394
return false ;
392
395
}
@@ -420,12 +423,15 @@ bool MergeJoin::tryAddOutputRow(
420
423
filterRightInputProjections_);
421
424
422
425
if (joinTracker_) {
423
- if (isRightJoin (joinType_)) {
426
+ if (isRightJoin (joinType_) ||
427
+ (isFullJoin (joinType_) && isRightJoinForFullOuter)) {
424
428
// Record right-side row with a match on the left-side.
425
- joinTracker_->addMatch (rightBatch, rightRow, outputSize_);
429
+ joinTracker_->addMatch (
430
+ rightBatch, rightRow, outputSize_, isRightJoinForFullOuter);
426
431
} else {
427
432
// Record left-side row with a match on the right-side.
428
- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
433
+ joinTracker_->addMatch (
434
+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
429
435
}
430
436
}
431
437
}
@@ -435,7 +441,8 @@ bool MergeJoin::tryAddOutputRow(
435
441
if (isAntiJoin (joinType_)) {
436
442
VELOX_CHECK (joinTracker_.has_value ());
437
443
// Record left-side row with a match on the right-side.
438
- joinTracker_->addMatch (leftBatch, leftRow, outputSize_);
444
+ joinTracker_->addMatch (
445
+ leftBatch, leftRow, outputSize_, isRightJoinForFullOuter);
439
446
}
440
447
441
448
++outputSize_;
@@ -454,14 +461,14 @@ bool MergeJoin::prepareOutput(
454
461
return true ;
455
462
}
456
463
457
- if (isRightJoin (joinType_) && right != currentRight_) {
458
- return true ;
459
- }
460
-
461
464
// If there is a new right, we need to flatten the dictionary.
462
465
if (!isRightFlattened_ && right && currentRight_ != right) {
463
466
flattenRightProjections ();
464
467
}
468
+
469
+ if (right != currentRight_) {
470
+ return true ;
471
+ }
465
472
return false ;
466
473
}
467
474
@@ -573,6 +580,39 @@ bool MergeJoin::prepareOutput(
573
580
bool MergeJoin::addToOutput () {
574
581
if (isRightJoin (joinType_) || isRightSemiFilterJoin (joinType_)) {
575
582
return addToOutputForRightJoin ();
583
+ } else if (isFullJoin (joinType_) && filter_) {
584
+ if (!leftForRightJoinMatch_) {
585
+ leftForRightJoinMatch_ = leftMatch_;
586
+ rightForRightJoinMatch_ = rightMatch_;
587
+ }
588
+
589
+ if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
590
+ auto left = addToOutputForLeftJoin ();
591
+ if (!leftMatch_) {
592
+ leftJoinForFullFinished_ = true ;
593
+ }
594
+ if (left) {
595
+ if (!leftMatch_) {
596
+ leftMatch_ = leftForRightJoinMatch_;
597
+ rightMatch_ = rightForRightJoinMatch_;
598
+ }
599
+
600
+ return true ;
601
+ }
602
+ }
603
+
604
+ if (!leftMatch_ && !rightJoinForFullFinished_) {
605
+ leftMatch_ = leftForRightJoinMatch_;
606
+ rightMatch_ = rightForRightJoinMatch_;
607
+ rightJoinForFullFinished_ = true ;
608
+ }
609
+
610
+ auto right = addToOutputForRightJoin ();
611
+
612
+ leftForRightJoinMatch_ = leftMatch_;
613
+ rightForRightJoinMatch_ = rightMatch_;
614
+
615
+ return right;
576
616
} else {
577
617
return addToOutputForLeftJoin ();
578
618
}
@@ -719,7 +759,13 @@ bool MergeJoin::addToOutputForRightJoin() {
719
759
}
720
760
721
761
for (auto j = leftStartRow; j < leftEndRow; ++j) {
722
- if (!tryAddOutputRow (leftBatch, j, rightBatch, i)) {
762
+ auto isRightJoinForFullOuter = false ;
763
+ if (isFullJoin (joinType_)) {
764
+ isRightJoinForFullOuter = true ;
765
+ }
766
+
767
+ if (!tryAddOutputRow (
768
+ leftBatch, j, rightBatch, i, isRightJoinForFullOuter)) {
723
769
// If we run out of space in the current output_, we will need to
724
770
// produce a buffer and continue processing left later. In this
725
771
// case, we cannot leave left as a lazy vector, since we cannot have
@@ -818,7 +864,7 @@ RowVectorPtr MergeJoin::getOutput() {
818
864
continue ;
819
865
} else if (isAntiJoin (joinType_)) {
820
866
output = filterOutputForAntiJoin (output);
821
- if (output) {
867
+ if (output != nullptr && output-> size () > 0 ) {
822
868
return output;
823
869
}
824
870
@@ -904,6 +950,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
904
950
// results from the current match.
905
951
if (addToOutput ()) {
906
952
return std::move (output_);
953
+ } else {
954
+ previousLeftMatch_ = leftMatch_;
907
955
}
908
956
}
909
957
@@ -968,6 +1016,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
968
1016
969
1017
if (addToOutput ()) {
970
1018
return std::move (output_);
1019
+ } else {
1020
+ previousLeftMatch_ = leftMatch_;
971
1021
}
972
1022
}
973
1023
@@ -1107,7 +1157,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
1107
1157
isFullJoin (joinType_)) {
1108
1158
// If output_ is currently wrapping a different buffer, return it
1109
1159
// first.
1110
- if (prepareOutput (input_, nullptr )) {
1160
+ if (prepareOutput (input_, rightInput_ )) {
1111
1161
output_->resize (outputSize_);
1112
1162
return std::move (output_);
1113
1163
}
@@ -1132,7 +1182,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
1132
1182
if (isRightJoin (joinType_) || isFullJoin (joinType_)) {
1133
1183
// If output_ is currently wrapping a different buffer, return it
1134
1184
// first.
1135
- if (prepareOutput (nullptr , rightInput_)) {
1185
+ if (prepareOutput (input_ , rightInput_)) {
1136
1186
output_->resize (outputSize_);
1137
1187
return std::move (output_);
1138
1188
}
@@ -1184,6 +1234,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
1184
1234
endRightRow < rightInput_->size (),
1185
1235
std::nullopt};
1186
1236
1237
+ leftJoinForFullFinished_ = false ;
1238
+ rightJoinForFullFinished_ = false ;
1187
1239
if (!leftMatch_->complete || !rightMatch_->complete ) {
1188
1240
if (!leftMatch_->complete ) {
1189
1241
// Need to continue looking for the end of match.
@@ -1212,6 +1264,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
1212
1264
1213
1265
if (addToOutput ()) {
1214
1266
return std::move (output_);
1267
+ } else {
1268
+ previousLeftMatch_ = leftMatch_;
1215
1269
}
1216
1270
1217
1271
if (!rightInput_) {
@@ -1228,8 +1282,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
1228
1282
RowVectorPtr MergeJoin::applyFilter (const RowVectorPtr& output) {
1229
1283
const auto numRows = output->size ();
1230
1284
1231
- RowVectorPtr fullOuterOutput = nullptr ;
1232
-
1233
1285
BufferPtr indices = allocateIndices (numRows, pool ());
1234
1286
auto * rawIndices = indices->asMutable <vector_size_t >();
1235
1287
vector_size_t numPassed = 0 ;
@@ -1246,76 +1298,29 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1246
1298
1247
1299
// If all matches for a given left-side row fail the filter, add a row to
1248
1300
// the output with nulls for the right-side columns.
1249
- const auto onMiss = [&](auto row) {
1250
- if (isAntiJoin (joinType_)) {
1251
- return ;
1252
- }
1253
- rawIndices[numPassed++] = row;
1254
-
1255
- if (isFullJoin (joinType_)) {
1256
- // For filtered rows, it is necessary to insert additional data
1257
- // to ensure the result set is complete. Specifically, we
1258
- // need to generate two records: one record containing the
1259
- // columns from the left table along with nulls for the
1260
- // right table, and another record containing the columns
1261
- // from the right table along with nulls for the left table.
1262
- // For instance, the current output is filtered based on the condition
1263
- // t > 1.
1264
-
1265
- // 1, 1
1266
- // 2, 2
1267
- // 3, 3
1268
-
1269
- // In this scenario, we need to additionally insert a record 1, 1.
1270
- // Subsequently, we will set the values of the columns on the left to
1271
- // null and the values of the columns on the right to null as well. By
1272
- // doing so, we will obtain the final result set.
1273
-
1274
- // 1, null
1275
- // null, 1
1276
- // 2, 2
1277
- // 3, 3
1278
- fullOuterOutput = BaseVector::create<RowVector>(
1279
- output->type (), output->size () + 1 , pool ());
1280
-
1281
- for (auto i = 0 ; i < row + 1 ; ++i) {
1282
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1283
- fullOuterOutput->childAt (j)->copy (
1284
- output->childAt (j).get (), i, i, 1 );
1301
+ auto onMiss = [&](auto row, bool flag) {
1302
+ if (!isLeftSemiFilterJoin (joinType_) &&
1303
+ !isRightSemiFilterJoin (joinType_)) {
1304
+ rawIndices[numPassed++] = row;
1305
+
1306
+ if (!isRightJoin (joinType_)) {
1307
+ if (isFullJoin (joinType_) && flag) {
1308
+ for (auto & projection : leftProjections_) {
1309
+ auto target = output->childAt (projection.outputChannel );
1310
+ target->setNull (row, true );
1311
+ }
1312
+ } else {
1313
+ for (auto & projection : rightProjections_) {
1314
+ auto target = output->childAt (projection.outputChannel );
1315
+ target->setNull (row, true );
1316
+ }
1285
1317
}
1286
- }
1287
-
1288
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1289
- fullOuterOutput->childAt (j)->copy (
1290
- output->childAt (j).get (), row + 1 , row, 1 );
1291
- }
1292
-
1293
- for (auto i = row + 1 ; i < output->size (); ++i) {
1294
- for (auto j = 0 ; j < output->type ()->size (); ++j) {
1295
- fullOuterOutput->childAt (j)->copy (
1296
- output->childAt (j).get (), i + 1 , i, 1 );
1318
+ } else {
1319
+ for (auto & projection : leftProjections_) {
1320
+ auto target = output->childAt (projection.outputChannel );
1321
+ target->setNull (row, true );
1297
1322
}
1298
1323
}
1299
-
1300
- for (auto & projection : leftProjections_) {
1301
- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1302
- target->setNull (row, true );
1303
- }
1304
-
1305
- for (auto & projection : rightProjections_) {
1306
- auto & target = fullOuterOutput->childAt (projection.outputChannel );
1307
- target->setNull (row + 1 , true );
1308
- }
1309
- } else if (!isRightJoin (joinType_)) {
1310
- for (auto & projection : rightProjections_) {
1311
- auto & target = output->childAt (projection.outputChannel );
1312
- target->setNull (row, true );
1313
- }
1314
- } else {
1315
- for (auto & projection : leftProjections_) {
1316
- auto & target = output->childAt (projection.outputChannel );
1317
- target->setNull (row, true );
1318
- }
1319
1324
}
1320
1325
};
1321
1326
@@ -1326,12 +1331,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1326
1331
1327
1332
joinTracker_->processFilterResult (i, passed, onMiss);
1328
1333
1329
- if (isAntiJoin (joinType_)) {
1330
- if (!passed) {
1331
- rawIndices[numPassed++] = i;
1332
- }
1333
- } else {
1334
- if (passed) {
1334
+ if (!isAntiJoin (joinType_)) {
1335
+ if (passed && !joinTracker_->isRightJoinForFullOuter (i)) {
1335
1336
rawIndices[numPassed++] = i;
1336
1337
}
1337
1338
}
@@ -1344,26 +1345,30 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1344
1345
1345
1346
// Every time we start a new left key match, `processFilterResult()` will
1346
1347
// check if at least one row from the previous match passed the filter. If
1347
- // none did, it calls onMiss to add a record with null right projections to
1348
- // the output.
1348
+ // none did, it calls onMiss to add a record with null right projections
1349
+ // to the output.
1349
1350
//
1350
1351
// Before we leave the current buffer, since we may not have seen the next
1351
- // left key match yet, the last key match may still be pending to produce a
1352
- // row (because `processFilterResult()` was not called yet).
1352
+ // left key match yet, the last key match may still be pending to produce
1353
+ // a row (because `processFilterResult()` was not called yet).
1353
1354
//
1354
1355
// To handle this, we need to call `noMoreFilterResults()` unless the
1355
- // same current left key match may continue in the next buffer. So there are
1356
- // two cases to check:
1356
+ // same current left key match may continue in the next buffer. So there
1357
+ // are two cases to check:
1357
1358
//
1358
- // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
1359
- // a different key match.
1359
+ // 1. If leftMatch_ is nullopt, there for sure the next buffer will
1360
+ // contain a different key match.
1360
1361
//
1361
1362
// 2. leftMatch_ may not be nullopt, but may be related to a different
1362
1363
// (subsequent) left key. So we check if the last row in the batch has the
1363
1364
// same left row number as the last key match.
1364
1365
if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch (numRows - 1 )) {
1365
1366
joinTracker_->noMoreFilterResults (onMiss);
1366
1367
}
1368
+
1369
+ if (isAntiJoin (joinType_) && leftMatch_ && !previousLeftMatch_) {
1370
+ joinTracker_->noMoreFilterResults (onMiss);
1371
+ }
1367
1372
} else {
1368
1373
filterRows_.resize (numRows);
1369
1374
filterRows_.setAll ();
@@ -1385,17 +1390,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
1385
1390
1386
1391
if (numPassed == numRows) {
1387
1392
// All rows passed.
1388
- if (fullOuterOutput) {
1389
- return fullOuterOutput;
1390
- }
1391
1393
return output;
1392
1394
}
1393
1395
1394
1396
// Some, but not all rows passed.
1395
- if (fullOuterOutput) {
1396
- return wrap (numPassed, indices, fullOuterOutput);
1397
- }
1398
-
1399
1397
return wrap (numPassed, indices, output);
1400
1398
}
1401
1399
0 commit comments