Skip to content

Commit e8cf5e2

Browse files
committed
Support Decimal32/64 types
1 parent 5bb476f commit e8cf5e2

File tree

9 files changed

+560
-67
lines changed

9 files changed

+560
-67
lines changed

datafusion/common/src/cast.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
2323
use crate::{downcast_value, Result};
2424
use arrow::array::{
25-
BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray,
26-
DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array,
27-
LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array,
25+
BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray,
26+
DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array,
27+
Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray,
28+
UInt16Array,
2829
};
2930
use arrow::{
3031
array::{
@@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) -> Result<&UInt64Array> {
9798
Ok(downcast_value!(array, UInt64Array))
9899
}
99100

101+
// Downcast Array to Decimal32Array
102+
pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> {
103+
Ok(downcast_value!(array, Decimal32Array))
104+
}
105+
106+
// Downcast Array to Decimal64Array
107+
pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> {
108+
Ok(downcast_value!(array, Decimal64Array))
109+
}
110+
100111
// Downcast Array to Decimal128Array
101112
pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> {
102113
Ok(downcast_value!(array, Decimal128Array))

datafusion/common/src/scalar/mod.rs

Lines changed: 338 additions & 36 deletions
Large diffs are not rendered by default.

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
use crate::signature::TypeSignature;
1919
use arrow::datatypes::{
2020
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
21-
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
21+
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
22+
DECIMAL64_MAX_PRECISION,
2223
};
2324

2425
use datafusion_common::{internal_err, plan_err, Result};
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
150151
DataType::Int64 => Ok(DataType::Int64),
151152
DataType::UInt64 => Ok(DataType::UInt64),
152153
DataType::Float64 => Ok(DataType::Float64),
154+
DataType::Decimal32(precision, scale) => {
155+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
156+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
158+
Ok(DataType::Decimal128(new_precision, *scale))
159+
}
160+
DataType::Decimal64(precision, scale) => {
161+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
162+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
164+
Ok(DataType::Decimal128(new_precision, *scale))
165+
}
153166
DataType::Decimal128(precision, scale) => {
154167
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
155168
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -222,6 +235,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
222235
/// Internal sum type of an average
223236
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
224237
match arg_type {
238+
DataType::Decimal32(precision, scale) => {
239+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
240+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
241+
Ok(DataType::Decimal32(new_precision, *scale))
242+
}
243+
DataType::Decimal64(precision, scale) => {
244+
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
245+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
246+
Ok(DataType::Decimal64(new_precision, *scale))
247+
}
225248
DataType::Decimal128(precision, scale) => {
226249
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
227250
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
@@ -249,7 +272,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
249272
_ => matches!(
250273
arg_type,
251274
arg_type if NUMERICS.contains(arg_type)
252-
|| matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
275+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
253276
),
254277
}
255278
}
@@ -262,7 +285,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
262285
_ => matches!(
263286
arg_type,
264287
arg_type if NUMERICS.contains(arg_type)
265-
|| matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _))
288+
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
266289
),
267290
}
268291
}
@@ -297,6 +320,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
297320
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
298321
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
299322
match &data_type {
323+
DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
324+
DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
300325
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
301326
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
302327
d if d.is_numeric() => Ok(DataType::Float64),

datafusion/expr/src/type_coercion/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
5151
| DataType::Float16
5252
| DataType::Float32
5353
| DataType::Float64
54+
| DataType::Decimal32(_, _)
55+
| DataType::Decimal64(_, _)
5456
| DataType::Decimal128(_, _)
5557
| DataType::Decimal256(_, _),
5658
)
@@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool {
8991

9092
/// Determine whether the given data type `dt` is a `Decimal`.
9193
pub fn is_decimal(dt: &DataType) -> bool {
92-
matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
94+
matches!(
95+
dt,
96+
DataType::Decimal32(_, _)
97+
| DataType::Decimal64(_, _)
98+
| DataType::Decimal128(_, _)
99+
| DataType::Decimal256(_, _)
100+
)
93101
}

datafusion/functions-aggregate/src/average.rs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ use arrow::array::{
2424

2525
use arrow::compute::sum;
2626
use arrow::datatypes::{
27-
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
28-
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
29-
DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
27+
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type,
28+
Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType,
29+
DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit,
30+
UInt64Type,
3031
};
3132
use datafusion_common::{
3233
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
@@ -128,6 +129,28 @@ impl AggregateUDFImpl for Avg {
128129
} else {
129130
match (&data_type, acc_args.return_field.data_type()) {
130131
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
132+
(
133+
Decimal32(sum_precision, sum_scale),
134+
Decimal32(target_precision, target_scale),
135+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
136+
sum: None,
137+
count: 0,
138+
sum_scale: *sum_scale,
139+
sum_precision: *sum_precision,
140+
target_precision: *target_precision,
141+
target_scale: *target_scale,
142+
})),
143+
(
144+
Decimal64(sum_precision, sum_scale),
145+
Decimal64(target_precision, target_scale),
146+
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
147+
sum: None,
148+
count: 0,
149+
sum_scale: *sum_scale,
150+
sum_precision: *sum_precision,
151+
target_precision: *target_precision,
152+
target_scale: *target_scale,
153+
})),
131154
(
132155
Decimal128(sum_precision, sum_scale),
133156
Decimal128(target_precision, target_scale),
@@ -202,7 +225,11 @@ impl AggregateUDFImpl for Avg {
202225
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
203226
matches!(
204227
args.return_field.data_type(),
205-
DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_)
228+
DataType::Float64
229+
| DataType::Decimal32(_, _)
230+
| DataType::Decimal64(_, _)
231+
| DataType::Decimal128(_, _)
232+
| DataType::Duration(_)
206233
) && !args.is_distinct
207234
}
208235

@@ -222,6 +249,44 @@ impl AggregateUDFImpl for Avg {
222249
|sum: f64, count: u64| Ok(sum / count as f64),
223250
)))
224251
}
252+
(
253+
Decimal32(_sum_precision, sum_scale),
254+
Decimal32(target_precision, target_scale),
255+
) => {
256+
let decimal_averager = DecimalAverager::<Decimal32Type>::try_new(
257+
*sum_scale,
258+
*target_precision,
259+
*target_scale,
260+
)?;
261+
262+
let avg_fn =
263+
move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32);
264+
265+
Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
266+
&data_type,
267+
args.return_field.data_type(),
268+
avg_fn,
269+
)))
270+
}
271+
(
272+
Decimal64(_sum_precision, sum_scale),
273+
Decimal64(target_precision, target_scale),
274+
) => {
275+
let decimal_averager = DecimalAverager::<Decimal64Type>::try_new(
276+
*sum_scale,
277+
*target_precision,
278+
*target_scale,
279+
)?;
280+
281+
let avg_fn =
282+
move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64);
283+
284+
Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
285+
&data_type,
286+
args.return_field.data_type(),
287+
avg_fn,
288+
)))
289+
}
225290
(
226291
Decimal128(_sum_precision, sum_scale),
227292
Decimal128(target_precision, target_scale),

datafusion/functions-aggregate/src/sum.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
1919
2020
use ahash::RandomState;
21+
use arrow::datatypes::DECIMAL32_MAX_PRECISION;
22+
use arrow::datatypes::DECIMAL64_MAX_PRECISION;
2123
use datafusion_expr::utils::AggregateOrderSensitivity;
2224
use std::any::Any;
2325
use std::mem::size_of_val;
@@ -27,8 +29,8 @@ use arrow::array::ArrowNativeTypeOp;
2729
use arrow::array::{ArrowNumericType, AsArray};
2830
use arrow::datatypes::{ArrowNativeType, FieldRef};
2931
use arrow::datatypes::{
30-
DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type,
31-
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
32+
DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, Float64Type,
33+
Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
3234
};
3335
use arrow::{array::ArrayRef, datatypes::Field};
3436
use datafusion_common::{
@@ -71,6 +73,12 @@ macro_rules! downcast_sum {
7173
DataType::Float64 => {
7274
$helper!(Float64Type, $args.return_field.data_type().clone())
7375
}
76+
DataType::Decimal32(_, _) => {
77+
$helper!(Decimal32Type, $args.return_field.data_type().clone())
78+
}
79+
DataType::Decimal64(_, _) => {
80+
$helper!(Decimal64Type, $args.return_field.data_type().clone())
81+
}
7482
DataType::Decimal128(_, _) => {
7583
$helper!(Decimal128Type, $args.return_field.data_type().clone())
7684
}
@@ -145,9 +153,10 @@ impl AggregateUDFImpl for Sum {
145153
DataType::Dictionary(_, v) => coerced_type(v),
146154
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
147155
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
148-
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
149-
Ok(data_type.clone())
150-
}
156+
DataType::Decimal32(_, _)
157+
| DataType::Decimal64(_, _)
158+
| DataType::Decimal128(_, _)
159+
| DataType::Decimal256(_, _) => Ok(data_type.clone()),
151160
dt if dt.is_signed_integer() => Ok(DataType::Int64),
152161
dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
153162
dt if dt.is_floating() => Ok(DataType::Float64),
@@ -163,6 +172,18 @@ impl AggregateUDFImpl for Sum {
163172
DataType::Int64 => Ok(DataType::Int64),
164173
DataType::UInt64 => Ok(DataType::UInt64),
165174
DataType::Float64 => Ok(DataType::Float64),
175+
DataType::Decimal32(precision, scale) => {
176+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
177+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
178+
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
179+
Ok(DataType::Decimal128(new_precision, *scale))
180+
}
181+
DataType::Decimal64(precision, scale) => {
182+
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
183+
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
184+
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
185+
Ok(DataType::Decimal128(new_precision, *scale))
186+
}
166187
DataType::Decimal128(precision, scale) => {
167188
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
168189
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66

datafusion/proto-common/src/from_proto/mod.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use datafusion_common::{
3737
TableParquetOptions,
3838
},
3939
file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions},
40-
not_impl_err,
4140
parsers::CompressionTypeVariant,
4241
plan_datafusion_err,
4342
stats::Precision,
@@ -478,13 +477,13 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
478477
let null_type: DataType = v.try_into()?;
479478
null_type.try_into().map_err(Error::DataFusionError)?
480479
}
481-
Value::Decimal32Value(_val) => {
482-
return not_impl_err!("Decimal32 protobuf deserialization")
483-
.map_err(Error::DataFusionError)
480+
Value::Decimal32Value(val) => {
481+
let array = vec_to_array(val.value.clone());
482+
Self::Decimal32(Some(i32::from_be_bytes(array)), val.p as u8, val.s as i8)
484483
}
485-
Value::Decimal64Value(_val) => {
486-
return not_impl_err!("Decimal64 protobuf deserialization")
487-
.map_err(Error::DataFusionError)
484+
Value::Decimal64Value(val) => {
485+
let array = vec_to_array(val.value.clone());
486+
Self::Decimal64(Some(i64::from_be_bytes(array)), val.p as u8, val.s as i8)
488487
}
489488
Value::Decimal128Value(val) => {
490489
let array = vec_to_array(val.value.clone());

datafusion/proto-common/src/to_proto/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
405405
})
406406
})
407407
}
408+
ScalarValue::Decimal32(val, p, s) => match *val {
409+
Some(v) => {
410+
let array = v.to_be_bytes();
411+
let vec_val: Vec<u8> = array.to_vec();
412+
Ok(protobuf::ScalarValue {
413+
value: Some(Value::Decimal32Value(protobuf::Decimal32 {
414+
value: vec_val,
415+
p: *p as i64,
416+
s: *s as i64,
417+
})),
418+
})
419+
}
420+
None => Ok(protobuf::ScalarValue {
421+
value: Some(protobuf::scalar_value::Value::NullValue(
422+
(&data_type).try_into()?,
423+
)),
424+
}),
425+
},
426+
ScalarValue::Decimal64(val, p, s) => match *val {
427+
Some(v) => {
428+
let array = v.to_be_bytes();
429+
let vec_val: Vec<u8> = array.to_vec();
430+
Ok(protobuf::ScalarValue {
431+
value: Some(Value::Decimal64Value(protobuf::Decimal64 {
432+
value: vec_val,
433+
p: *p as i64,
434+
s: *s as i64,
435+
})),
436+
})
437+
}
438+
None => Ok(protobuf::ScalarValue {
439+
value: Some(protobuf::scalar_value::Value::NullValue(
440+
(&data_type).try_into()?,
441+
)),
442+
}),
443+
},
408444
ScalarValue::Decimal128(val, p, s) => match *val {
409445
Some(v) => {
410446
let array = v.to_be_bytes();

0 commit comments

Comments
 (0)