Skip to content

Commit c4ceb71

Browse files
committed
Support decimal compare functions on different presicion and scales (6207)
1 parent f1f275d commit c4ceb71

File tree

4 files changed

+333
-0
lines changed

4 files changed

+333
-0
lines changed

velox/docs/functions/spark/comparison.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,27 @@ Comparison Functions
7878
Returns true if x is not equal to y. Supports all scalar types. The types
7979
of x and y must be the same. Corresponds to Spark's operator ``!=``.
8080

81+
.. spark:function:: decimal_lt(x, y) -> boolean
8182
83+
Returns true if x is less than y. Supports decimal types with different precison and scales.
84+
Corresponds to Spark's operator ``<``.
8285

86+
.. spark:function:: decimal_lte(x, y) -> boolean
87+
88+
Returns true if x is less than y or x is equal to y. Supports decimal types with different precison and scales.
89+
Corresponds to Spark's operator ``<=``.
90+
91+
.. spark:function:: decimal_eq(x, y) -> boolean
92+
93+
Returns true if x is equal to y. Supports decimal types with different precison and scales.
94+
Corresponds to Spark's operator ``==``.
95+
96+
.. spark:function:: decimal_gt(x, y) -> boolean
97+
98+
Returns true if x is greater than y. Supports decimal types with different precison and scales.
99+
Corresponds to Spark's operator ``>``.
100+
101+
.. spark:function:: decimal_gte(x, y) -> boolean
102+
103+
Returns true if x is greater than y or x is equal to y. Supports decimal types with different precison and scales.
104+
Corresponds to Spark's operator ``>=``.

velox/functions/sparksql/DecimalVectorFunctions.cpp

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,284 @@ class DecimalRoundFunction : public exec::VectorFunction {
247247
}
248248
}
249249
};
250+
251+
// Rescale two inputs as the same scale and compare. Returns 0 when a is equal
252+
// with b. Returns -1 when a is less than b. Returns 1 when a is greater than b.
253+
template <typename T>
254+
int32_t rescaleAndCompare(T a, T b, int32_t deltaScale) {
255+
T aScaled = a;
256+
T bScaled = b;
257+
if (deltaScale < 0) {
258+
aScaled = a * velox::DecimalUtil::kPowersOfTen[-deltaScale];
259+
} else if (deltaScale > 0) {
260+
bScaled = b * velox::DecimalUtil::kPowersOfTen[deltaScale];
261+
}
262+
if (aScaled == bScaled) {
263+
return 0;
264+
} else if (aScaled < bScaled) {
265+
return -1;
266+
} else {
267+
return 1;
268+
}
269+
}
270+
271+
// Compare two decimals. Rescale one of them When they are of different scales.
272+
int32_t decimalCompare(
273+
int128_t a,
274+
uint8_t aPrecision,
275+
uint8_t aScale,
276+
int128_t b,
277+
uint8_t bPrecision,
278+
uint8_t bScale) {
279+
int32_t deltaScale = aScale - bScale;
280+
// Check if we need 256-bits after adjusting the scale.
281+
bool need256 = (deltaScale < 0 &&
282+
aPrecision - deltaScale > LongDecimalType::kMaxPrecision) ||
283+
(bPrecision + deltaScale > LongDecimalType::kMaxPrecision);
284+
if (need256) {
285+
return rescaleAndCompare<int256_t>(
286+
static_cast<int256_t>(a), static_cast<int256_t>(b), deltaScale);
287+
}
288+
return rescaleAndCompare<int128_t>(a, b, deltaScale);
289+
}
290+
291+
// GreaterThan decimal compare function.
292+
class Gt {
293+
public:
294+
inline static bool apply(
295+
int128_t a,
296+
uint8_t aPrecision,
297+
uint8_t aScale,
298+
int128_t b,
299+
uint8_t bPrecision,
300+
uint8_t bScale) {
301+
return decimalCompare(a, aPrecision, aScale, b, bPrecision, bScale) > 0;
302+
}
303+
};
304+
305+
// GreaterThanOrEqual decimal compare function.
306+
class Gte {
307+
public:
308+
inline static bool apply(
309+
int128_t a,
310+
uint8_t aPrecision,
311+
uint8_t aScale,
312+
int128_t b,
313+
uint8_t bPrecision,
314+
uint8_t bScale) {
315+
return decimalCompare(a, aPrecision, aScale, b, bPrecision, bScale) >= 0;
316+
}
317+
};
318+
319+
// LessThan decimal compare function.
320+
class Lt {
321+
public:
322+
inline static bool apply(
323+
int128_t a,
324+
uint8_t aPrecision,
325+
uint8_t aScale,
326+
int128_t b,
327+
uint8_t bPrecision,
328+
uint8_t bScale) {
329+
return decimalCompare(a, aPrecision, aScale, b, bPrecision, bScale) < 0;
330+
}
331+
};
332+
333+
// LessThanOrEqual decimal compare function.
334+
class Lte {
335+
public:
336+
inline static bool apply(
337+
int128_t a,
338+
uint8_t aPrecision,
339+
uint8_t aScale,
340+
int128_t b,
341+
uint8_t bPrecision,
342+
uint8_t bScale) {
343+
return decimalCompare(a, aPrecision, aScale, b, bPrecision, bScale) <= 0;
344+
}
345+
};
346+
347+
// Equal decimal compare function.
348+
class Eq {
349+
public:
350+
inline static bool apply(
351+
int128_t a,
352+
uint8_t aPrecision,
353+
uint8_t aScale,
354+
int128_t b,
355+
uint8_t bPrecision,
356+
uint8_t bScale) {
357+
return decimalCompare(a, aPrecision, aScale, b, bPrecision, bScale) == 0;
358+
}
359+
};
360+
361+
// Class for decimal compare operations.
362+
template <typename A, typename B, typename Operation /* Arithmetic operation */>
363+
class DecimalCompareFunction : public exec::VectorFunction {
364+
public:
365+
DecimalCompareFunction(
366+
uint8_t aPrecision,
367+
uint8_t aScale,
368+
uint8_t bPrecision,
369+
uint8_t bScale)
370+
: aPrecision_(aPrecision),
371+
aScale_(aScale),
372+
bPrecision_(bPrecision),
373+
bScale_(bScale) {}
374+
375+
void apply(
376+
const SelectivityVector& rows,
377+
std::vector<VectorPtr>& args,
378+
const TypePtr& resultType,
379+
exec::EvalCtx& context,
380+
VectorPtr& result) const override {
381+
prepareResults(rows, resultType, context, result);
382+
383+
// Fast path when the first argument is a flat vector.
384+
if (args[0]->isFlatEncoding()) {
385+
auto rawA = args[0]->asUnchecked<FlatVector<A>>()->mutableRawValues();
386+
387+
if (args[1]->isConstantEncoding()) {
388+
auto constantB = args[1]->asUnchecked<SimpleVector<B>>()->valueAt(0);
389+
context.applyToSelectedNoThrow(rows, [&](auto row) {
390+
result->asUnchecked<FlatVector<bool>>()->set(
391+
row,
392+
Operation::apply(
393+
(int128_t)rawA[row],
394+
aPrecision_,
395+
aScale_,
396+
(int128_t)constantB,
397+
bPrecision_,
398+
bScale_));
399+
});
400+
return;
401+
}
402+
403+
if (args[1]->isFlatEncoding()) {
404+
auto rawB = args[1]->asUnchecked<FlatVector<B>>()->mutableRawValues();
405+
context.applyToSelectedNoThrow(rows, [&](auto row) {
406+
result->asUnchecked<FlatVector<bool>>()->set(
407+
row,
408+
Operation::apply(
409+
(int128_t)rawA[row],
410+
aPrecision_,
411+
aScale_,
412+
(int128_t)rawB[row],
413+
bPrecision_,
414+
bScale_));
415+
});
416+
return;
417+
}
418+
} else {
419+
// Fast path when the first argument is encoded but the second is
420+
// constant.
421+
exec::DecodedArgs decodedArgs(rows, args, context);
422+
auto aDecoded = decodedArgs.at(0);
423+
auto aDecodedData = aDecoded->data<A>();
424+
425+
if (args[1]->isConstantEncoding()) {
426+
auto constantB = args[1]->asUnchecked<SimpleVector<B>>()->valueAt(0);
427+
context.applyToSelectedNoThrow(rows, [&](auto row) {
428+
auto value = aDecodedData[aDecoded->index(row)];
429+
result->asUnchecked<FlatVector<bool>>()->set(
430+
row,
431+
Operation::apply(
432+
(int128_t)value,
433+
aPrecision_,
434+
aScale_,
435+
(int128_t)constantB,
436+
bPrecision_,
437+
bScale_));
438+
});
439+
return;
440+
}
441+
}
442+
443+
// Decode the input in all other cases.
444+
exec::DecodedArgs decodedArgs(rows, args, context);
445+
auto aDecoded = decodedArgs.at(0);
446+
auto bDecoded = decodedArgs.at(1);
447+
448+
auto aDecodedData = aDecoded->data<A>();
449+
auto bDecodedData = bDecoded->data<B>();
450+
451+
context.applyToSelectedNoThrow(rows, [&](auto row) {
452+
auto aValue = aDecodedData[aDecoded->index(row)];
453+
auto bValue = bDecodedData[bDecoded->index(row)];
454+
result->asUnchecked<FlatVector<bool>>()->set(
455+
row,
456+
Operation::apply(
457+
(int128_t)aValue,
458+
aPrecision_,
459+
aScale_,
460+
(int128_t)bValue,
461+
bPrecision_,
462+
bScale_));
463+
});
464+
}
465+
466+
private:
467+
void prepareResults(
468+
const SelectivityVector& rows,
469+
const TypePtr& resultType,
470+
exec::EvalCtx& context,
471+
VectorPtr& result) const {
472+
context.ensureWritable(rows, resultType, result);
473+
result->clearNulls(rows);
474+
}
475+
476+
const uint8_t aPrecision_;
477+
const uint8_t aScale_;
478+
const uint8_t bPrecision_;
479+
const uint8_t bScale_;
480+
};
481+
482+
template <typename Operation>
483+
std::shared_ptr<exec::VectorFunction> createDecimalCompareFunction(
484+
const std::string& name,
485+
const std::vector<exec::VectorFunctionArg>& inputArgs,
486+
const core::QueryConfig& /*config*/) {
487+
const auto& aType = inputArgs[0].type;
488+
const auto& bType = inputArgs[1].type;
489+
auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
490+
auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
491+
if (aType->isShortDecimal()) {
492+
if (bType->isShortDecimal()) {
493+
return std::make_shared<
494+
DecimalCompareFunction<int64_t, int64_t, Operation>>(
495+
aPrecision, aScale, bPrecision, bScale);
496+
} else if (bType->isLongDecimal()) {
497+
return std::make_shared<
498+
DecimalCompareFunction<int64_t, int128_t, Operation>>(
499+
aPrecision, aScale, bPrecision, bScale);
500+
}
501+
}
502+
if (aType->isLongDecimal()) {
503+
if (bType->isShortDecimal()) {
504+
return std::make_shared<
505+
DecimalCompareFunction<int128_t, int64_t, Operation>>(
506+
aPrecision, aScale, bPrecision, bScale);
507+
} else if (bType->isLongDecimal()) {
508+
return std::make_shared<
509+
DecimalCompareFunction<int128_t, int128_t, Operation>>(
510+
aPrecision, aScale, bPrecision, bScale);
511+
}
512+
}
513+
VELOX_UNREACHABLE();
514+
}
515+
516+
std::vector<std::shared_ptr<exec::FunctionSignature>>
517+
decimalCompareSignature() {
518+
return {exec::FunctionSignatureBuilder()
519+
.integerVariable("a_precision")
520+
.integerVariable("a_scale")
521+
.integerVariable("b_precision")
522+
.integerVariable("b_scale")
523+
.returnType("boolean")
524+
.argumentType("DECIMAL(a_precision, a_scale)")
525+
.argumentType("DECIMAL(b_precision, b_scale)")
526+
.build()};
527+
}
250528
} // namespace
251529

252530
std::vector<std::shared_ptr<exec::FunctionSignature>>
@@ -308,4 +586,29 @@ VELOX_DECLARE_VECTOR_FUNCTION(
308586
udf_decimal_round,
309587
std::vector<std::shared_ptr<exec::FunctionSignature>>{},
310588
std::make_unique<DecimalRoundFunction>());
589+
590+
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
591+
udf_decimal_gt,
592+
decimalCompareSignature(),
593+
createDecimalCompareFunction<Gt>);
594+
595+
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
596+
udf_decimal_gte,
597+
decimalCompareSignature(),
598+
createDecimalCompareFunction<Gte>);
599+
600+
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
601+
udf_decimal_lt,
602+
decimalCompareSignature(),
603+
createDecimalCompareFunction<Lt>);
604+
605+
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
606+
udf_decimal_lte,
607+
decimalCompareSignature(),
608+
createDecimalCompareFunction<Lte>);
609+
610+
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
611+
udf_decimal_eq,
612+
decimalCompareSignature(),
613+
createDecimalCompareFunction<Eq>);
311614
} // namespace facebook::velox::functions::sparksql

velox/functions/sparksql/RegisterCompare.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ void registerCompareFunctions(const std::string& prefix) {
5050
{prefix + "between"});
5151
registerFunction<BetweenFunction, bool, double, double, double>(
5252
{prefix + "between"});
53+
// Decimal comapre functions.
54+
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_gt, "decimal_gt");
55+
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_gte, "decimal_gte");
56+
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_lt, "decimal_lt");
57+
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_lte, "decimal_lte");
58+
VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_eq, "decimal_eq");
5359
}
5460

5561
} // namespace facebook::velox::functions::sparksql

velox/type/Type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#pragma once
1717

18+
#include <boost/multiprecision/cpp_int.hpp>
1819
#include <fmt/core.h>
1920
#include <fmt/format.h>
2021
#include <folly/Format.h>
@@ -44,6 +45,7 @@
4445
namespace facebook::velox {
4546

4647
using int128_t = __int128_t;
48+
using int256_t = boost::multiprecision::int256_t;
4749

4850
/// Velox type system supports a small set of SQL-compatible composeable types:
4951
/// BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, HUGEINT, REAL, DOUBLE, VARCHAR,

0 commit comments

Comments
 (0)