Skip to content

Commit 7c25ce5

Browse files
authored
GH-52: Make RangeEqualsVisitor of RunEndEncodedVector more efficient (#761)
## What's Changed Avoid doing a binary search on every step to make the RangeEqualsVisitor of RunEndEncodedVector more efficient. Closes #52 .
1 parent 6be33bb commit 7c25ce5

File tree

4 files changed

+181
-31
lines changed

4 files changed

+181
-31
lines changed

vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.apache.arrow.vector.complex.ListViewVector;
4444
import org.apache.arrow.vector.complex.NonNullableStructVector;
4545
import org.apache.arrow.vector.complex.RunEndEncodedVector;
46+
import org.apache.arrow.vector.complex.RunEndEncodedVector.RangeIterator;
4647
import org.apache.arrow.vector.complex.StructVector;
4748
import org.apache.arrow.vector.complex.UnionVector;
4849

@@ -270,42 +271,35 @@ protected boolean compareRunEndEncodedVectors(Range range) {
270271
RunEndEncodedVector leftVector = (RunEndEncodedVector) left;
271272
RunEndEncodedVector rightVector = (RunEndEncodedVector) right;
272273

273-
final int leftRangeEnd = range.getLeftStart() + range.getLength();
274-
final int rightRangeEnd = range.getRightStart() + range.getLength();
274+
final RunEndEncodedVector.RangeIterator leftIterator =
275+
new RunEndEncodedVector.RangeIterator(leftVector, range.getLeftStart(), range.getLength());
276+
final RunEndEncodedVector.RangeIterator rightIterator =
277+
new RunEndEncodedVector.RangeIterator(
278+
rightVector, range.getRightStart(), range.getLength());
275279

276280
FieldVector leftValuesVector = leftVector.getValuesVector();
277281
FieldVector rightValuesVector = rightVector.getValuesVector();
278282

279283
RangeEqualsVisitor innerVisitor = createInnerVisitor(leftValuesVector, rightValuesVector, null);
280284

281-
int leftLogicalIndex = range.getLeftStart();
282-
int rightLogicalIndex = range.getRightStart();
285+
while (nextRun(leftIterator, rightIterator)) {
286+
int leftPhysicalIndex = leftIterator.getRunIndex();
287+
int rightPhysicalIndex = rightIterator.getRunIndex();
283288

284-
while (leftLogicalIndex < leftRangeEnd) {
285-
// TODO: implement it more efficient
286-
// https://github.com/apache/arrow/issues/44157
287-
int leftPhysicalIndex = leftVector.getPhysicalIndex(leftLogicalIndex);
288-
int rightPhysicalIndex = rightVector.getPhysicalIndex(rightLogicalIndex);
289-
if (leftValuesVector.accept(
290-
innerVisitor, new Range(leftPhysicalIndex, rightPhysicalIndex, 1))) {
291-
int leftRunEnd = leftVector.getRunEnd(leftLogicalIndex);
292-
int rightRunEnd = rightVector.getRunEnd(rightLogicalIndex);
293-
294-
int leftRunLength = Math.min(leftRunEnd, leftRangeEnd) - leftLogicalIndex;
295-
int rightRunLength = Math.min(rightRunEnd, rightRangeEnd) - rightLogicalIndex;
296-
297-
if (leftRunLength != rightRunLength) {
298-
return false;
299-
} else {
300-
leftLogicalIndex = leftRunEnd;
301-
rightLogicalIndex = rightRunEnd;
302-
}
303-
} else {
289+
if (leftIterator.getRunLength() != rightIterator.getRunLength()
290+
|| !leftValuesVector.accept(
291+
innerVisitor, new Range(leftPhysicalIndex, rightPhysicalIndex, 1))) {
304292
return false;
305293
}
306294
}
307295

308-
return true;
296+
return leftIterator.isEnd() && rightIterator.isEnd();
297+
}
298+
299+
private static boolean nextRun(RangeIterator leftIterator, RangeIterator rightIterator) {
300+
boolean left = leftIterator.nextRun();
301+
boolean right = rightIterator.nextRun();
302+
return left && right;
309303
}
310304

311305
protected RangeEqualsVisitor createInnerVisitor(

vector/src/main/java/org/apache/arrow/vector/complex/RunEndEncodedVector.java

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.apache.arrow.memory.OutOfMemoryException;
2929
import org.apache.arrow.memory.util.ByteFunctionHelpers;
3030
import org.apache.arrow.memory.util.hash.ArrowBufHasher;
31+
import org.apache.arrow.util.Preconditions;
3132
import org.apache.arrow.vector.BaseIntVector;
3233
import org.apache.arrow.vector.BaseValueVector;
3334
import org.apache.arrow.vector.BigIntVector;
@@ -820,4 +821,101 @@ static int getPhysicalIndex(FieldVector runEndVector, int logicalIndex) {
820821

821822
return result;
822823
}
824+
825+
public static class RangeIterator {
826+
827+
private final RunEndEncodedVector runEndEncodedVector;
828+
private final int rangeEnd;
829+
private int runIndex;
830+
private int runEnd;
831+
private int logicalPos;
832+
833+
/**
834+
* Constructs a new RangeIterator for iterating over a range of values in a RunEndEncodedVector.
835+
*
836+
* @param runEndEncodedVector The vector to iterate over
837+
* @param startIndex The logical start index of the range (inclusive)
838+
* @param length The number of values to include in the range
839+
* @throws IllegalArgumentException if startIndex is negative or (startIndex + length) exceeds
840+
* vector bounds
841+
*/
842+
public RangeIterator(RunEndEncodedVector runEndEncodedVector, int startIndex, int length) {
843+
int rangeEnd = startIndex + length;
844+
Preconditions.checkArgument(
845+
startIndex >= 0, "startIndex %s must be non negative.", startIndex);
846+
Preconditions.checkArgument(
847+
rangeEnd <= runEndEncodedVector.getValueCount(),
848+
"(startIndex + length) %s out of range[0, %s].",
849+
rangeEnd,
850+
runEndEncodedVector.getValueCount());
851+
852+
this.rangeEnd = rangeEnd;
853+
this.runEndEncodedVector = runEndEncodedVector;
854+
this.runIndex = runEndEncodedVector.getPhysicalIndex(startIndex) - 1;
855+
this.runEnd = startIndex;
856+
this.logicalPos = -1;
857+
}
858+
859+
/**
860+
* Advances to the next run in the range.
861+
*
862+
* @return true if there is another run available, false if iteration has completed
863+
*/
864+
public boolean nextRun() {
865+
logicalPos = runEnd;
866+
if (logicalPos >= rangeEnd) {
867+
return false;
868+
}
869+
updateRun();
870+
return true;
871+
}
872+
873+
private void updateRun() {
874+
runIndex++;
875+
runEnd = (int) ((BaseIntVector) runEndEncodedVector.runEndsVector).getValueAsLong(runIndex);
876+
}
877+
878+
/**
879+
* Advances to the next value in the range.
880+
*
881+
* @return true if there is another value available, false if iteration has completed
882+
*/
883+
public boolean nextValue() {
884+
logicalPos++;
885+
if (logicalPos >= rangeEnd) {
886+
return false;
887+
}
888+
if (logicalPos == runEnd) {
889+
updateRun();
890+
}
891+
return true;
892+
}
893+
894+
/**
895+
* Gets the current run index (physical position in the run-ends vector).
896+
*
897+
* @return the current run index
898+
*/
899+
public int getRunIndex() {
900+
return runIndex;
901+
}
902+
903+
/**
904+
* Gets the length of the current run within the iterator's range.
905+
*
906+
* @return the number of remaining values in current run within the iterator's range
907+
*/
908+
public int getRunLength() {
909+
return Math.min(runEnd, rangeEnd) - logicalPos;
910+
}
911+
912+
/**
913+
* Checks if iteration has completed.
914+
*
915+
* @return true if all values in the range have been processed, false otherwise
916+
*/
917+
public boolean isEnd() {
918+
return logicalPos >= rangeEnd;
919+
}
920+
}
823921
}

vector/src/test/java/org/apache/arrow/vector/TestRunEndEncodedVector.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,18 @@ public void testRangeCompare() {
148148
assertTrue(
149149
constantVector.accept(
150150
new RangeEqualsVisitor(constantVector, constantVector), new Range(1, 2, 13)));
151-
assertFalse(
152-
constantVector.accept(
153-
new RangeEqualsVisitor(constantVector, constantVector), new Range(1, 10, 10)));
154-
assertFalse(
155-
constantVector.accept(
156-
new RangeEqualsVisitor(constantVector, constantVector), new Range(10, 1, 10)));
151+
152+
// throws exception if the range end is out the bound of the vector
153+
assertThrows(
154+
IllegalArgumentException.class,
155+
() ->
156+
constantVector.accept(
157+
new RangeEqualsVisitor(constantVector, constantVector), new Range(1, 10, 10)));
158+
assertThrows(
159+
IllegalArgumentException.class,
160+
() ->
161+
constantVector.accept(
162+
new RangeEqualsVisitor(constantVector, constantVector), new Range(10, 1, 10)));
157163

158164
// Create REE vector representing: [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5].
159165
RunEndEncodedVector reeVector =

vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import java.nio.charset.Charset;
2424
import java.util.Arrays;
25+
import java.util.List;
2526
import org.apache.arrow.memory.BufferAllocator;
2627
import org.apache.arrow.memory.RootAllocator;
2728
import org.apache.arrow.vector.BigIntVector;
@@ -39,6 +40,7 @@
3940
import org.apache.arrow.vector.complex.LargeListViewVector;
4041
import org.apache.arrow.vector.complex.ListVector;
4142
import org.apache.arrow.vector.complex.ListViewVector;
43+
import org.apache.arrow.vector.complex.RunEndEncodedVector;
4244
import org.apache.arrow.vector.complex.StructVector;
4345
import org.apache.arrow.vector.complex.UnionVector;
4446
import org.apache.arrow.vector.complex.impl.NullableStructWriter;
@@ -53,7 +55,9 @@
5355
import org.apache.arrow.vector.holders.NullableUInt4Holder;
5456
import org.apache.arrow.vector.types.FloatingPointPrecision;
5557
import org.apache.arrow.vector.types.Types;
58+
import org.apache.arrow.vector.types.Types.MinorType;
5659
import org.apache.arrow.vector.types.pojo.ArrowType;
60+
import org.apache.arrow.vector.types.pojo.ArrowType.RunEndEncoded;
5761
import org.apache.arrow.vector.types.pojo.Field;
5862
import org.apache.arrow.vector.types.pojo.FieldType;
5963
import org.junit.jupiter.api.AfterEach;
@@ -1003,6 +1007,54 @@ public void testLargeListViewVectorApproxEquals() {
10031007
}
10041008
}
10051009

1010+
@Test
1011+
public void testRunEndEncodedFloat8ApproxEquals() {
1012+
try (final Float8Vector vector1 = new Float8Vector("float", allocator);
1013+
final Float8Vector vector2 = new Float8Vector("float", allocator);
1014+
final Float8Vector vector3 = new Float8Vector("float", allocator);
1015+
final IntVector reeVector = new IntVector("ree", allocator)) {
1016+
1017+
final float epsilon = 1.0E-6f;
1018+
setVector(vector1, 1.1, 2.2);
1019+
setVector(vector2, 1.1 + epsilon / 2, 2.2 + epsilon / 2);
1020+
setVector(vector3, 1.1 + epsilon * 2, 2.2 + epsilon * 2);
1021+
setVector(reeVector, 1, 3);
1022+
1023+
ArrowType type = MinorType.FLOAT8.getType();
1024+
final FieldType valueType = FieldType.notNullable(type);
1025+
final FieldType runEndType = FieldType.notNullable(MinorType.INT.getType());
1026+
1027+
final Field valueField = new Field("value", valueType, null);
1028+
final Field runEndField = new Field("ree", runEndType, null);
1029+
1030+
Field field =
1031+
new Field(
1032+
"ree_float",
1033+
FieldType.notNullable(RunEndEncoded.INSTANCE),
1034+
List.of(runEndField, valueField));
1035+
1036+
try (final RunEndEncodedVector encodedVector1 =
1037+
new RunEndEncodedVector(field, allocator, reeVector, vector1, null);
1038+
final RunEndEncodedVector encodedVector2 =
1039+
new RunEndEncodedVector(field, allocator, reeVector, vector2, null);
1040+
final RunEndEncodedVector encodedVector3 =
1041+
new RunEndEncodedVector(field, allocator, reeVector, vector3, null)) {
1042+
1043+
encodedVector1.setValueCount(3);
1044+
encodedVector2.setValueCount(3);
1045+
encodedVector3.setValueCount(3);
1046+
1047+
Range range = new Range(0, 0, encodedVector1.getValueCount());
1048+
assertTrue(
1049+
new ApproxEqualsVisitor(encodedVector1, encodedVector2, epsilon, epsilon)
1050+
.rangeEquals(range));
1051+
assertFalse(
1052+
new ApproxEqualsVisitor(encodedVector1, encodedVector3, epsilon, epsilon)
1053+
.rangeEquals(range));
1054+
}
1055+
}
1056+
}
1057+
10061058
private void writeStructVector(NullableStructWriter writer, int value1, long value2) {
10071059
writer.start();
10081060
writer.integer("f0").writeInt(value1);

0 commit comments

Comments
 (0)