Skip to content

Commit c7837d9

Browse files
PHILO-HEJkSelf
authored andcommitted
Add registerCompanionFunctions and overwrite as parameters in agg registration (7110)
1 parent b437c4e commit c7837d9

18 files changed

+193
-72
lines changed

velox/functions/lib/aggregates/BitwiseAggregateBase.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ class BitwiseAggregateBase : public SimpleNumericAggregate<T, T, T> {
7070
};
7171

7272
template <template <typename U> class T>
73-
exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
73+
exec::AggregateRegistrationResult registerBitwise(
74+
const std::string& name,
75+
bool registerCompanionFunctions,
76+
bool overwrite) {
7477
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
7578
for (const auto& inputType : {"tinyint", "smallint", "integer", "bigint"}) {
7679
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
@@ -106,7 +109,9 @@ exec::AggregateRegistrationResult registerBitwise(const std::string& name) {
106109
name,
107110
inputType->kindName());
108111
}
109-
});
112+
},
113+
registerCompanionFunctions,
114+
overwrite);
110115
}
111116

112117
} // namespace facebook::velox::functions::aggregate

velox/functions/prestosql/aggregates/BitwiseAggregates.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,14 @@ class BitwiseAndAggregate : public BitwiseAggregateBase<T> {
101101

102102
} // namespace
103103

104-
void registerBitwiseAggregates(const std::string& prefix) {
105-
registerBitwise<BitwiseOrAggregate>(prefix + kBitwiseOr);
106-
registerBitwise<BitwiseAndAggregate>(prefix + kBitwiseAnd);
104+
void registerBitwiseAggregates(
105+
const std::string& prefix,
106+
bool registerCompanionFunctions,
107+
bool overwrite) {
108+
registerBitwise<BitwiseOrAggregate>(
109+
prefix + kBitwiseOr, registerCompanionFunctions, overwrite);
110+
registerBitwise<BitwiseAndAggregate>(
111+
prefix + kBitwiseAnd, registerCompanionFunctions, overwrite);
107112
}
108113

109114
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/CountAggregate.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ class CountAggregate : public SimpleNumericAggregate<bool, int64_t, int64_t> {
151151
} // namespace
152152

153153
exec::AggregateRegistrationResult registerCountAggregate(
154-
const std::string& prefix) {
154+
const std::string& prefix,
155+
bool registerCompanionFunctions,
156+
bool overwrite) {
155157
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
156158
exec::AggregateFunctionSignatureBuilder()
157159
.returnType("bigint")
@@ -178,7 +180,9 @@ exec::AggregateRegistrationResult registerCountAggregate(
178180
VELOX_CHECK_LE(
179181
argTypes.size(), 1, "{} takes at most one argument", name);
180182
return std::make_unique<CountAggregate>();
181-
});
183+
},
184+
registerCompanionFunctions,
185+
overwrite);
182186
}
183187

184188
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/CovarianceAggregates.cpp

+21-8
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,10 @@ template <
558558
typename TIntermediateInput,
559559
typename TIntermediateResult,
560560
typename TResultAccessor>
561-
exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
561+
exec::AggregateRegistrationResult registerCovariance(
562+
const std::string& name,
563+
bool registerCompanionFunctions,
564+
bool overwrite) {
562565
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures = {
563566
// (double, double) -> double
564567
exec::AggregateFunctionSignatureBuilder()
@@ -607,37 +610,47 @@ exec::AggregateRegistrationResult registerCovariance(const std::string& name) {
607610
"Unsupported raw input type: {}. Expected DOUBLE or REAL.",
608611
rawInputType->toString())
609612
}
610-
});
613+
},
614+
registerCompanionFunctions,
615+
overwrite);
611616
}
612617

613618
} // namespace
614619

615-
void registerCovarianceAggregates(const std::string& prefix) {
620+
void registerCovarianceAggregates(
621+
const std::string& prefix,
622+
bool registerCompanionFunctions,
623+
bool overwrite) {
616624
registerCovariance<
617625
CovarAccumulator,
618626
CovarIntermediateInput,
619627
CovarIntermediateResult,
620-
CovarPopResultAccessor>(prefix + kCovarPop);
628+
CovarPopResultAccessor>(
629+
prefix + kCovarPop, registerCompanionFunctions, overwrite);
621630
registerCovariance<
622631
CovarAccumulator,
623632
CovarIntermediateInput,
624633
CovarIntermediateResult,
625-
CovarSampResultAccessor>(prefix + kCovarSamp);
634+
CovarSampResultAccessor>(
635+
prefix + kCovarSamp, registerCompanionFunctions, overwrite);
626636
registerCovariance<
627637
CorrAccumulator,
628638
CorrIntermediateInput,
629639
CorrIntermediateResult,
630-
CorrResultAccessor>(prefix + kCorr);
640+
CorrResultAccessor>(
641+
prefix + kCorr, registerCompanionFunctions, overwrite);
631642
registerCovariance<
632643
RegrAccumulator,
633644
RegrIntermediateInput,
634645
RegrIntermediateResult,
635-
RegrInterceptResultAccessor>(prefix + kRegrIntercept);
646+
RegrInterceptResultAccessor>(
647+
prefix + kRegrIntercept, registerCompanionFunctions, overwrite);
636648
registerCovariance<
637649
RegrAccumulator,
638650
RegrIntermediateInput,
639651
RegrIntermediateResult,
640-
RegrSlopeResultAccessor>(prefix + kRegrSlop);
652+
RegrSlopeResultAccessor>(
653+
prefix + kRegrSlop, registerCompanionFunctions, overwrite);
641654
}
642655

643656
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/MinMaxAggregates.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,10 @@ template <
904904
typename TNonNumeric,
905905
template <typename T>
906906
class TNumericN>
907-
exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
907+
exec::AggregateRegistrationResult registerMinMax(
908+
const std::string& name,
909+
bool registerCompanionFunctions,
910+
bool overwrite) {
908911
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
909912
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
910913
.orderableTypeVariable("T")
@@ -1008,16 +1011,21 @@ exec::AggregateRegistrationResult registerMinMax(const std::string& name) {
10081011
inputType->kindName());
10091012
}
10101013
}
1011-
});
1014+
},
1015+
registerCompanionFunctions,
1016+
overwrite);
10121017
}
10131018

10141019
} // namespace
10151020

1016-
void registerMinMaxAggregates(const std::string& prefix) {
1021+
void registerMinMaxAggregates(
1022+
const std::string& prefix,
1023+
bool registerCompanionFunctions,
1024+
bool overwrite) {
10171025
registerMinMax<MinAggregate, NonNumericMinAggregate, MinNAggregate>(
1018-
prefix + kMin);
1026+
prefix + kMin, registerCompanionFunctions, overwrite);
10191027
registerMinMax<MaxAggregate, NonNumericMaxAggregate, MaxNAggregate>(
1020-
prefix + kMax);
1028+
prefix + kMax, registerCompanionFunctions, overwrite);
10211029
}
10221030

10231031
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/RegisterAggregateFunctions.cpp

+28-11
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ extern exec::AggregateRegistrationResult registerBitwiseXorAggregate(
3333
extern exec::AggregateRegistrationResult registerChecksumAggregate(
3434
const std::string& prefix);
3535
extern exec::AggregateRegistrationResult registerCountAggregate(
36-
const std::string& prefix);
36+
const std::string& prefix,
37+
bool registerCompanionFunctions,
38+
bool overwrite);
3739
extern exec::AggregateRegistrationResult registerCountIfAggregate(
3840
const std::string& prefix);
3941
extern exec::AggregateRegistrationResult registerEntropyAggregate(
@@ -62,30 +64,45 @@ extern exec::AggregateRegistrationResult registerSetUnionAggregate(
6264
const std::string& prefix);
6365

6466
extern void registerApproxDistinctAggregates(const std::string& prefix);
65-
extern void registerBitwiseAggregates(const std::string& prefix);
67+
extern void registerBitwiseAggregates(
68+
const std::string& prefix,
69+
bool registerCompanionFunctions,
70+
bool overwrite);
6671
extern void registerBoolAggregates(const std::string& prefix);
6772
extern void registerCentralMomentsAggregates(const std::string& prefix);
68-
extern void registerCovarianceAggregates(const std::string& prefix);
69-
extern void registerMinMaxAggregates(const std::string& prefix);
73+
extern void registerCovarianceAggregates(
74+
const std::string& prefix,
75+
bool registerCompanionFunctions,
76+
bool overwrite);
77+
extern void registerMinMaxAggregates(
78+
const std::string& prefix,
79+
bool registerCompanionFunctions,
80+
bool overwrite);
7081
extern void registerMinMaxByAggregates(const std::string& prefix);
7182
extern void registerSumAggregate(const std::string& prefix);
72-
extern void registerVarianceAggregates(const std::string& prefix);
83+
extern void registerVarianceAggregates(
84+
const std::string& prefix,
85+
bool registerCompanionFunctions,
86+
bool overwrite);
7387

74-
void registerAllAggregateFunctions(const std::string& prefix) {
88+
void registerAllAggregateFunctions(
89+
const std::string& prefix,
90+
bool registerCompanionFunctions,
91+
bool overwrite) {
7592
registerApproxDistinctAggregates(prefix);
7693
registerApproxMostFrequentAggregate(prefix);
7794
registerApproxPercentileAggregate(prefix);
7895
registerArbitraryAggregate(prefix);
7996
registerArrayAggAggregate(prefix);
8097
registerAverageAggregate(prefix);
81-
registerBitwiseAggregates(prefix);
98+
registerBitwiseAggregates(prefix, registerCompanionFunctions, overwrite);
8299
registerBitwiseXorAggregate(prefix);
83100
registerBoolAggregates(prefix);
84101
registerCentralMomentsAggregates(prefix);
85102
registerChecksumAggregate(prefix);
86-
registerCountAggregate(prefix);
103+
registerCountAggregate(prefix, registerCompanionFunctions, overwrite);
87104
registerCountIfAggregate(prefix);
88-
registerCovarianceAggregates(prefix);
105+
registerCovarianceAggregates(prefix, registerCompanionFunctions, overwrite);
89106
registerEntropyAggregate(prefix);
90107
registerGeometricMeanAggregate(prefix);
91108
registerHistogramAggregate(prefix);
@@ -95,13 +112,13 @@ void registerAllAggregateFunctions(const std::string& prefix) {
95112
registerMaxDataSizeForStatsAggregate(prefix);
96113
registerMultiMapAggAggregate(prefix);
97114
registerSumDataSizeForStatsAggregate(prefix);
98-
registerMinMaxAggregates(prefix);
115+
registerMinMaxAggregates(prefix, registerCompanionFunctions, overwrite);
99116
registerMinMaxByAggregates(prefix);
100117
registerReduceAgg(prefix);
101118
registerSetAggAggregate(prefix);
102119
registerSetUnionAggregate(prefix);
103120
registerSumAggregate(prefix);
104-
registerVarianceAggregates(prefix);
121+
registerVarianceAggregates(prefix, registerCompanionFunctions, overwrite);
105122
}
106123

107124
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
namespace facebook::velox::aggregate::prestosql {
2121

22-
void registerAllAggregateFunctions(const std::string& prefix = "");
22+
void registerAllAggregateFunctions(
23+
const std::string& prefix = "",
24+
bool registerCompanionFunctions = false,
25+
bool overwrite = false);
2326

2427
} // namespace facebook::velox::aggregate::prestosql

velox/functions/prestosql/aggregates/VarianceAggregates.cpp

+23-9
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,10 @@ void checkSumCountRowType(
459459
}
460460

461461
template <template <typename TInput> class TClass>
462-
exec::AggregateRegistrationResult registerVariance(const std::string& name) {
462+
exec::AggregateRegistrationResult registerVariance(
463+
const std::string& name,
464+
bool registerCompanionFunctions,
465+
bool overwrite) {
463466
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
464467
std::vector<std::string> inputTypes = {
465468
"smallint", "integer", "bigint", "real", "double"};
@@ -508,18 +511,29 @@ exec::AggregateRegistrationResult registerVariance(const std::string& name) {
508511
"(count:bigint, mean:double, m2:double) struct");
509512
return std::make_unique<TClass<int64_t>>(resultType);
510513
}
511-
});
514+
},
515+
registerCompanionFunctions,
516+
overwrite);
512517
}
513518

514519
} // namespace
515520

516-
void registerVarianceAggregates(const std::string& prefix) {
517-
registerVariance<StdDevSampAggregate>(prefix + kStdDev);
518-
registerVariance<StdDevPopAggregate>(prefix + kStdDevPop);
519-
registerVariance<StdDevSampAggregate>(prefix + kStdDevSamp);
520-
registerVariance<VarSampAggregate>(prefix + kVariance);
521-
registerVariance<VarPopAggregate>(prefix + kVarPop);
522-
registerVariance<VarSampAggregate>(prefix + kVarSamp);
521+
void registerVarianceAggregates(
522+
const std::string& prefix,
523+
bool registerCompanionFunctions,
524+
bool overwrite) {
525+
registerVariance<StdDevSampAggregate>(
526+
prefix + kStdDev, registerCompanionFunctions, overwrite);
527+
registerVariance<StdDevPopAggregate>(
528+
prefix + kStdDevPop, registerCompanionFunctions, overwrite);
529+
registerVariance<StdDevSampAggregate>(
530+
prefix + kStdDevSamp, registerCompanionFunctions, overwrite);
531+
registerVariance<VarSampAggregate>(
532+
prefix + kVariance, registerCompanionFunctions, overwrite);
533+
registerVariance<VarPopAggregate>(
534+
prefix + kVarPop, registerCompanionFunctions, overwrite);
535+
registerVariance<VarSampAggregate>(
536+
prefix + kVarSamp, registerCompanionFunctions, overwrite);
523537
}
524538

525539
} // namespace facebook::velox::aggregate::prestosql

velox/functions/sparksql/aggregates/AverageAggregate.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,10 @@ class DecimalAverageAggregate : public DecimalAggregate<TInputType> {
362362
/// REAL | DOUBLE | DOUBLE
363363
/// ALL INTs | DOUBLE | DOUBLE
364364
/// DECIMAL | DECIMAL | DECIMAL
365-
exec::AggregateRegistrationResult registerAverage(const std::string& name) {
365+
exec::AggregateRegistrationResult registerAverage(
366+
const std::string& name,
367+
bool registerCompanionFunctions,
368+
bool overwrite) {
366369
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
367370

368371
for (const auto& inputType :
@@ -495,7 +498,8 @@ exec::AggregateRegistrationResult registerAverage(const std::string& name) {
495498
}
496499
}
497500
},
498-
/*registerCompanionFunctions*/ true);
501+
registerCompanionFunctions,
502+
overwrite);
499503
}
500504

501505
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/AverageAggregate.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
namespace facebook::velox::functions::aggregate::sparksql {
2424

25-
exec::AggregateRegistrationResult registerAverage(const std::string& name);
25+
exec::AggregateRegistrationResult registerAverage(
26+
const std::string& name,
27+
bool registerCompanionFunctions,
28+
bool overwrite);
2629

2730
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/BitwiseXorAggregate.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ class BitwiseXorAggregate : public BitwiseAggregateBase<T> {
6666
} // namespace
6767

6868
exec::AggregateRegistrationResult registerBitwiseXorAggregate(
69-
const std::string& prefix) {
69+
const std::string& prefix,
70+
bool registerCompanionFunctions,
71+
bool overwrite) {
7072
return functions::aggregate::registerBitwise<BitwiseXorAggregate>(
71-
prefix + "bit_xor");
73+
prefix + "bit_xor", registerCompanionFunctions, overwrite);
7274
}
7375

7476
} // namespace facebook::velox::functions::aggregate::sparksql

velox/functions/sparksql/aggregates/BitwiseXorAggregate.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
namespace facebook::velox::functions::aggregate::sparksql {
2424

2525
exec::AggregateRegistrationResult registerBitwiseXorAggregate(
26-
const std::string& name);
26+
const std::string& name,
27+
bool registerCompanionFunctions,
28+
bool overwrite);
2729

2830
} // namespace facebook::velox::functions::aggregate::sparksql

0 commit comments

Comments
 (0)