@@ -267,10 +267,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
267
267
const core::QueryConfig& config)
268
268
-> std::unique_ptr<Aggregate> {
269
269
if (auto func = getAggregateFunctionEntry (name)) {
270
+ core::AggregationNode::Step usedStep{
271
+ core::AggregationNode::Step::kPartial };
270
272
if (!exec::isRawInput (step)) {
271
- step = core::AggregationNode::Step::kIntermediate ;
273
+ usedStep = core::AggregationNode::Step::kIntermediate ;
272
274
}
273
- auto fn = func->factory (step, argTypes, resultType, config);
275
+ auto fn =
276
+ func->factory (usedStep, argTypes, resultType, config);
274
277
VELOX_CHECK_NOT_NULL (fn);
275
278
return std::make_unique<
276
279
AggregateCompanionAdapter::PartialFunction>(
@@ -409,26 +412,51 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
409
412
const std::vector<AggregateFunctionSignaturePtr>& signatures,
410
413
const AggregateFunctionMetadata& metadata,
411
414
bool overwrite) {
415
+ bool registered = false ;
412
416
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures (
413
417
signatures)) {
414
- return registerMergeExtractFunctionWithSuffix (
415
- name, signatures, metadata, overwrite);
418
+ registered |=
419
+ registerMergeExtractFunctionWithSuffix ( name, signatures, metadata, overwrite);
416
420
}
417
421
418
422
auto mergeExtractSignatures =
419
423
CompanionSignatures::mergeExtractFunctionSignatures (signatures);
420
424
if (mergeExtractSignatures.empty ()) {
421
- return false ;
425
+ return registered ;
422
426
}
423
427
424
428
auto mergeExtractFunctionName =
425
429
CompanionSignatures::mergeExtractFunctionName (name);
426
- return registerMergeExtractFunctionInternal (
427
- name,
428
- mergeExtractFunctionName,
429
- std::move (mergeExtractSignatures),
430
- metadata,
431
- overwrite);
430
+ registered |=
431
+ exec::registerAggregateFunction (
432
+ mergeExtractFunctionName,
433
+ std::move (mergeExtractSignatures),
434
+ [name, mergeExtractFunctionName](
435
+ core::AggregationNode::Step /* step*/ ,
436
+ const std::vector<TypePtr>& argTypes,
437
+ const TypePtr& resultType,
438
+ const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
439
+ if (auto func = getAggregateFunctionEntry (name)) {
440
+ auto fn = func->factory (
441
+ core::AggregationNode::Step::kFinal ,
442
+ argTypes,
443
+ resultType,
444
+ config);
445
+ VELOX_CHECK_NOT_NULL (fn);
446
+ return std::make_unique<
447
+ AggregateCompanionAdapter::MergeExtractFunction>(
448
+ std::move (fn), resultType);
449
+ }
450
+ VELOX_FAIL (
451
+ " Original aggregation function {} not found: {}" ,
452
+ name,
453
+ mergeExtractFunctionName);
454
+ },
455
+ metadata,
456
+ /* registerCompanionFunctions*/ false ,
457
+ overwrite)
458
+ .mainFunction ;
459
+ return registered;
432
460
}
433
461
434
462
VectorFunctionFactory getVectorFunctionFactory (
0 commit comments