@@ -247,6 +247,284 @@ class DecimalRoundFunction : public exec::VectorFunction {
247
247
}
248
248
}
249
249
};
250
+
251
+ // Rescale two inputs as the same scale and compare. Returns 0 when a is equal
252
+ // with b. Returns -1 when a is less than b. Returns 1 when a is greater than b.
253
+ template <typename T>
254
+ int32_t rescaleAndCompare (T a, T b, int32_t deltaScale) {
255
+ T aScaled = a;
256
+ T bScaled = b;
257
+ if (deltaScale < 0 ) {
258
+ aScaled = a * velox::DecimalUtil::kPowersOfTen [-deltaScale];
259
+ } else if (deltaScale > 0 ) {
260
+ bScaled = b * velox::DecimalUtil::kPowersOfTen [deltaScale];
261
+ }
262
+ if (aScaled == bScaled) {
263
+ return 0 ;
264
+ } else if (aScaled < bScaled) {
265
+ return -1 ;
266
+ } else {
267
+ return 1 ;
268
+ }
269
+ }
270
+
271
+ // Compare two decimals. Rescale one of them When they are of different scales.
272
+ int32_t decimalCompare (
273
+ int128_t a,
274
+ uint8_t aPrecision,
275
+ uint8_t aScale,
276
+ int128_t b,
277
+ uint8_t bPrecision,
278
+ uint8_t bScale) {
279
+ int32_t deltaScale = aScale - bScale;
280
+ // Check if we need 256-bits after adjusting the scale.
281
+ bool need256 = (deltaScale < 0 &&
282
+ aPrecision - deltaScale > LongDecimalType::kMaxPrecision ) ||
283
+ (bPrecision + deltaScale > LongDecimalType::kMaxPrecision );
284
+ if (need256) {
285
+ return rescaleAndCompare<int256_t >(
286
+ static_cast <int256_t >(a), static_cast <int256_t >(b), deltaScale);
287
+ }
288
+ return rescaleAndCompare<int128_t >(a, b, deltaScale);
289
+ }
290
+
291
+ // GreaterThan decimal compare function.
292
+ class Gt {
293
+ public:
294
+ inline static bool apply (
295
+ int128_t a,
296
+ uint8_t aPrecision,
297
+ uint8_t aScale,
298
+ int128_t b,
299
+ uint8_t bPrecision,
300
+ uint8_t bScale) {
301
+ return decimalCompare (a, aPrecision, aScale, b, bPrecision, bScale) > 0 ;
302
+ }
303
+ };
304
+
305
+ // GreaterThanOrEqual decimal compare function.
306
+ class Gte {
307
+ public:
308
+ inline static bool apply (
309
+ int128_t a,
310
+ uint8_t aPrecision,
311
+ uint8_t aScale,
312
+ int128_t b,
313
+ uint8_t bPrecision,
314
+ uint8_t bScale) {
315
+ return decimalCompare (a, aPrecision, aScale, b, bPrecision, bScale) >= 0 ;
316
+ }
317
+ };
318
+
319
+ // LessThan decimal compare function.
320
+ class Lt {
321
+ public:
322
+ inline static bool apply (
323
+ int128_t a,
324
+ uint8_t aPrecision,
325
+ uint8_t aScale,
326
+ int128_t b,
327
+ uint8_t bPrecision,
328
+ uint8_t bScale) {
329
+ return decimalCompare (a, aPrecision, aScale, b, bPrecision, bScale) < 0 ;
330
+ }
331
+ };
332
+
333
+ // LessThanOrEqual decimal compare function.
334
+ class Lte {
335
+ public:
336
+ inline static bool apply (
337
+ int128_t a,
338
+ uint8_t aPrecision,
339
+ uint8_t aScale,
340
+ int128_t b,
341
+ uint8_t bPrecision,
342
+ uint8_t bScale) {
343
+ return decimalCompare (a, aPrecision, aScale, b, bPrecision, bScale) <= 0 ;
344
+ }
345
+ };
346
+
347
+ // Equal decimal compare function.
348
+ class Eq {
349
+ public:
350
+ inline static bool apply (
351
+ int128_t a,
352
+ uint8_t aPrecision,
353
+ uint8_t aScale,
354
+ int128_t b,
355
+ uint8_t bPrecision,
356
+ uint8_t bScale) {
357
+ return decimalCompare (a, aPrecision, aScale, b, bPrecision, bScale) == 0 ;
358
+ }
359
+ };
360
+
361
+ // Class for decimal compare operations.
362
+ template <typename A, typename B, typename Operation /* Arithmetic operation */ >
363
+ class DecimalCompareFunction : public exec ::VectorFunction {
364
+ public:
365
+ DecimalCompareFunction (
366
+ uint8_t aPrecision,
367
+ uint8_t aScale,
368
+ uint8_t bPrecision,
369
+ uint8_t bScale)
370
+ : aPrecision_(aPrecision),
371
+ aScale_ (aScale),
372
+ bPrecision_(bPrecision),
373
+ bScale_(bScale) {}
374
+
375
+ void apply (
376
+ const SelectivityVector& rows,
377
+ std::vector<VectorPtr>& args,
378
+ const TypePtr& resultType,
379
+ exec::EvalCtx& context,
380
+ VectorPtr& result) const override {
381
+ prepareResults (rows, resultType, context, result);
382
+
383
+ // Fast path when the first argument is a flat vector.
384
+ if (args[0 ]->isFlatEncoding ()) {
385
+ auto rawA = args[0 ]->asUnchecked <FlatVector<A>>()->mutableRawValues ();
386
+
387
+ if (args[1 ]->isConstantEncoding ()) {
388
+ auto constantB = args[1 ]->asUnchecked <SimpleVector<B>>()->valueAt (0 );
389
+ context.applyToSelectedNoThrow (rows, [&](auto row) {
390
+ result->asUnchecked <FlatVector<bool >>()->set (
391
+ row,
392
+ Operation::apply (
393
+ (int128_t )rawA[row],
394
+ aPrecision_,
395
+ aScale_,
396
+ (int128_t )constantB,
397
+ bPrecision_,
398
+ bScale_));
399
+ });
400
+ return ;
401
+ }
402
+
403
+ if (args[1 ]->isFlatEncoding ()) {
404
+ auto rawB = args[1 ]->asUnchecked <FlatVector<B>>()->mutableRawValues ();
405
+ context.applyToSelectedNoThrow (rows, [&](auto row) {
406
+ result->asUnchecked <FlatVector<bool >>()->set (
407
+ row,
408
+ Operation::apply (
409
+ (int128_t )rawA[row],
410
+ aPrecision_,
411
+ aScale_,
412
+ (int128_t )rawB[row],
413
+ bPrecision_,
414
+ bScale_));
415
+ });
416
+ return ;
417
+ }
418
+ } else {
419
+ // Fast path when the first argument is encoded but the second is
420
+ // constant.
421
+ exec::DecodedArgs decodedArgs (rows, args, context);
422
+ auto aDecoded = decodedArgs.at (0 );
423
+ auto aDecodedData = aDecoded->data <A>();
424
+
425
+ if (args[1 ]->isConstantEncoding ()) {
426
+ auto constantB = args[1 ]->asUnchecked <SimpleVector<B>>()->valueAt (0 );
427
+ context.applyToSelectedNoThrow (rows, [&](auto row) {
428
+ auto value = aDecodedData[aDecoded->index (row)];
429
+ result->asUnchecked <FlatVector<bool >>()->set (
430
+ row,
431
+ Operation::apply (
432
+ (int128_t )value,
433
+ aPrecision_,
434
+ aScale_,
435
+ (int128_t )constantB,
436
+ bPrecision_,
437
+ bScale_));
438
+ });
439
+ return ;
440
+ }
441
+ }
442
+
443
+ // Decode the input in all other cases.
444
+ exec::DecodedArgs decodedArgs (rows, args, context);
445
+ auto aDecoded = decodedArgs.at (0 );
446
+ auto bDecoded = decodedArgs.at (1 );
447
+
448
+ auto aDecodedData = aDecoded->data <A>();
449
+ auto bDecodedData = bDecoded->data <B>();
450
+
451
+ context.applyToSelectedNoThrow (rows, [&](auto row) {
452
+ auto aValue = aDecodedData[aDecoded->index (row)];
453
+ auto bValue = bDecodedData[bDecoded->index (row)];
454
+ result->asUnchecked <FlatVector<bool >>()->set (
455
+ row,
456
+ Operation::apply (
457
+ (int128_t )aValue,
458
+ aPrecision_,
459
+ aScale_,
460
+ (int128_t )bValue,
461
+ bPrecision_,
462
+ bScale_));
463
+ });
464
+ }
465
+
466
+ private:
467
+ void prepareResults (
468
+ const SelectivityVector& rows,
469
+ const TypePtr& resultType,
470
+ exec::EvalCtx& context,
471
+ VectorPtr& result) const {
472
+ context.ensureWritable (rows, resultType, result);
473
+ result->clearNulls (rows);
474
+ }
475
+
476
+ const uint8_t aPrecision_;
477
+ const uint8_t aScale_;
478
+ const uint8_t bPrecision_;
479
+ const uint8_t bScale_;
480
+ };
481
+
482
+ template <typename Operation>
483
+ std::shared_ptr<exec::VectorFunction> createDecimalCompareFunction (
484
+ const std::string& name,
485
+ const std::vector<exec::VectorFunctionArg>& inputArgs,
486
+ const core::QueryConfig& /* config*/ ) {
487
+ const auto & aType = inputArgs[0 ].type ;
488
+ const auto & bType = inputArgs[1 ].type ;
489
+ auto [aPrecision, aScale] = getDecimalPrecisionScale (*aType);
490
+ auto [bPrecision, bScale] = getDecimalPrecisionScale (*bType);
491
+ if (aType->isShortDecimal ()) {
492
+ if (bType->isShortDecimal ()) {
493
+ return std::make_shared<
494
+ DecimalCompareFunction<int64_t , int64_t , Operation>>(
495
+ aPrecision, aScale, bPrecision, bScale);
496
+ } else if (bType->isLongDecimal ()) {
497
+ return std::make_shared<
498
+ DecimalCompareFunction<int64_t , int128_t , Operation>>(
499
+ aPrecision, aScale, bPrecision, bScale);
500
+ }
501
+ }
502
+ if (aType->isLongDecimal ()) {
503
+ if (bType->isShortDecimal ()) {
504
+ return std::make_shared<
505
+ DecimalCompareFunction<int128_t , int64_t , Operation>>(
506
+ aPrecision, aScale, bPrecision, bScale);
507
+ } else if (bType->isLongDecimal ()) {
508
+ return std::make_shared<
509
+ DecimalCompareFunction<int128_t , int128_t , Operation>>(
510
+ aPrecision, aScale, bPrecision, bScale);
511
+ }
512
+ }
513
+ VELOX_UNREACHABLE ();
514
+ }
515
+
516
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
517
+ decimalCompareSignature () {
518
+ return {exec::FunctionSignatureBuilder ()
519
+ .integerVariable (" a_precision" )
520
+ .integerVariable (" a_scale" )
521
+ .integerVariable (" b_precision" )
522
+ .integerVariable (" b_scale" )
523
+ .returnType (" boolean" )
524
+ .argumentType (" DECIMAL(a_precision, a_scale)" )
525
+ .argumentType (" DECIMAL(b_precision, b_scale)" )
526
+ .build ()};
527
+ }
250
528
} // namespace
251
529
252
530
std::vector<std::shared_ptr<exec::FunctionSignature>>
@@ -308,4 +586,29 @@ VELOX_DECLARE_VECTOR_FUNCTION(
308
586
udf_decimal_round,
309
587
std::vector<std::shared_ptr<exec::FunctionSignature>>{},
310
588
std::make_unique<DecimalRoundFunction>());
589
+
590
+ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
591
+ udf_decimal_gt,
592
+ decimalCompareSignature (),
593
+ createDecimalCompareFunction<Gt>);
594
+
595
+ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
596
+ udf_decimal_gte,
597
+ decimalCompareSignature (),
598
+ createDecimalCompareFunction<Gte>);
599
+
600
+ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
601
+ udf_decimal_lt,
602
+ decimalCompareSignature (),
603
+ createDecimalCompareFunction<Lt>);
604
+
605
+ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
606
+ udf_decimal_lte,
607
+ decimalCompareSignature (),
608
+ createDecimalCompareFunction<Lte>);
609
+
610
+ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION (
611
+ udf_decimal_eq,
612
+ decimalCompareSignature (),
613
+ createDecimalCompareFunction<Eq>);
311
614
} // namespace facebook::velox::functions::sparksql
0 commit comments