diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index ea3c64c5c9ea..f77400e952d0 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -267,10 +267,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction( const core::QueryConfig& config) -> std::unique_ptr { if (auto func = getAggregateFunctionEntry(name)) { + core::AggregationNode::Step usedStep{ + core::AggregationNode::Step::kPartial}; if (!exec::isRawInput(step)) { - step = core::AggregationNode::Step::kIntermediate; + usedStep = core::AggregationNode::Step::kIntermediate; } - auto fn = func->factory(step, argTypes, resultType, config); + auto fn = + func->factory(usedStep, argTypes, resultType, config); VELOX_CHECK_NOT_NULL(fn); return std::make_unique< AggregateCompanionAdapter::PartialFunction>( @@ -409,26 +412,51 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( const std::vector& signatures, const AggregateFunctionMetadata& metadata, bool overwrite) { + bool registered = false; if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( signatures)) { - return registerMergeExtractFunctionWithSuffix( - name, signatures, metadata, overwrite); + registered |= + registerMergeExtractFunctionWithSuffix(name, signatures, metadata, overwrite); } auto mergeExtractSignatures = CompanionSignatures::mergeExtractFunctionSignatures(signatures); if (mergeExtractSignatures.empty()) { - return false; + return registered; } auto mergeExtractFunctionName = CompanionSignatures::mergeExtractFunctionName(name); - return registerMergeExtractFunctionInternal( - name, - mergeExtractFunctionName, - std::move(mergeExtractSignatures), - metadata, - overwrite); + registered |= + exec::registerAggregateFunction( + mergeExtractFunctionName, + std::move(mergeExtractSignatures), + [name, mergeExtractFunctionName]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& config) -> std::unique_ptr { + if (auto func = getAggregateFunctionEntry(name)) { + auto fn = func->factory( + core::AggregationNode::Step::kFinal, + argTypes, + resultType, + config); + VELOX_CHECK_NOT_NULL(fn); + return std::make_unique< + AggregateCompanionAdapter::MergeExtractFunction>( + std::move(fn), resultType); + } + VELOX_FAIL( + "Original aggregation function {} not found: {}", + name, + mergeExtractFunctionName); + }, + metadata, + /*registerCompanionFunctions*/ false, + overwrite) + .mainFunction; + return registered; } VectorFunctionFactory getVectorFunctionFactory(