18
18
use crate :: signature:: TypeSignature ;
19
19
use arrow:: datatypes:: {
20
20
DataType , FieldRef , TimeUnit , DECIMAL128_MAX_PRECISION , DECIMAL128_MAX_SCALE ,
21
- DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE ,
21
+ DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE , DECIMAL32_MAX_PRECISION ,
22
+ DECIMAL64_MAX_PRECISION ,
22
23
} ;
23
24
24
25
use datafusion_common:: { internal_err, plan_err, Result } ;
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
150
151
DataType :: Int64 => Ok ( DataType :: Int64 ) ,
151
152
DataType :: UInt64 => Ok ( DataType :: UInt64 ) ,
152
153
DataType :: Float64 => Ok ( DataType :: Float64 ) ,
154
+ DataType :: Decimal32 ( precision, scale) => {
155
+ // in the spark, the result type is DECIMAL(min(38,precision+10), s)
156
+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
157
+ let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
158
+ Ok ( DataType :: Decimal128 ( new_precision, * scale) )
159
+ }
160
+ DataType :: Decimal64 ( precision, scale) => {
161
+ // in the spark, the result type is DECIMAL(min(38,precision+10), s)
162
+ // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
163
+ let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
164
+ Ok ( DataType :: Decimal128 ( new_precision, * scale) )
165
+ }
153
166
DataType :: Decimal128 ( precision, scale) => {
154
167
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
155
168
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -222,6 +235,16 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
222
235
/// Internal sum type of an average
223
236
pub fn avg_sum_type ( arg_type : & DataType ) -> Result < DataType > {
224
237
match arg_type {
238
+ DataType :: Decimal32 ( precision, scale) => {
239
+ // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
240
+ let new_precision = DECIMAL32_MAX_PRECISION . min ( * precision + 10 ) ;
241
+ Ok ( DataType :: Decimal32 ( new_precision, * scale) )
242
+ }
243
+ DataType :: Decimal64 ( precision, scale) => {
244
+ // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
245
+ let new_precision = DECIMAL64_MAX_PRECISION . min ( * precision + 10 ) ;
246
+ Ok ( DataType :: Decimal64 ( new_precision, * scale) )
247
+ }
225
248
DataType :: Decimal128 ( precision, scale) => {
226
249
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
227
250
let new_precision = DECIMAL128_MAX_PRECISION . min ( * precision + 10 ) ;
@@ -249,7 +272,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
249
272
_ => matches ! (
250
273
arg_type,
251
274
arg_type if NUMERICS . contains( arg_type)
252
- || matches!( arg_type, DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
275
+ || matches!( arg_type, DataType :: Decimal32 ( _ , _ ) | DataType :: Decimal64 ( _ , _ ) | DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
253
276
) ,
254
277
}
255
278
}
@@ -262,7 +285,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
262
285
_ => matches ! (
263
286
arg_type,
264
287
arg_type if NUMERICS . contains( arg_type)
265
- || matches!( arg_type, DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
288
+ || matches!( arg_type, DataType :: Decimal32 ( _ , _ ) | DataType :: Decimal64 ( _ , _ ) | DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) )
266
289
) ,
267
290
}
268
291
}
@@ -297,6 +320,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
297
320
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
298
321
fn coerced_type ( func_name : & str , data_type : & DataType ) -> Result < DataType > {
299
322
match & data_type {
323
+ DataType :: Decimal32 ( p, s) => Ok ( DataType :: Decimal32 ( * p, * s) ) ,
324
+ DataType :: Decimal64 ( p, s) => Ok ( DataType :: Decimal64 ( * p, * s) ) ,
300
325
DataType :: Decimal128 ( p, s) => Ok ( DataType :: Decimal128 ( * p, * s) ) ,
301
326
DataType :: Decimal256 ( p, s) => Ok ( DataType :: Decimal256 ( * p, * s) ) ,
302
327
d if d. is_numeric ( ) => Ok ( DataType :: Float64 ) ,
0 commit comments