diff --git a/velox/experimental/cudf/exec/CudfHashAggregation.cpp b/velox/experimental/cudf/exec/CudfHashAggregation.cpp index dfc03a1b2580..b0feb7820b88 100644 --- a/velox/experimental/cudf/exec/CudfHashAggregation.cpp +++ b/velox/experimental/cudf/exec/CudfHashAggregation.cpp @@ -392,8 +392,6 @@ std::unique_ptr createAggregator( uint32_t inputIndex, VectorPtr constant, bool isGlobal) { - // Companion function may be count_merge_extract or count_partial or others, - // so use this to map if (kind.rfind("sum", 0) == 0) { return std::make_unique( step, inputIndex, constant, isGlobal); @@ -414,6 +412,39 @@ std::unique_ptr createAggregator( } } +static const std::unordered_map + companionStep = { + {"_partial", core::AggregationNode::Step::kPartial}, + {"_merge", core::AggregationNode::Step::kIntermediate}, + {"_merge_extract", core::AggregationNode::Step::kFinal}}; + +/// \brief Convert companion function to step for the aggregation function +/// +/// Companion functions are functions that are registered in velox along with +/// their main aggregation functions. These are designed to always function +/// with a fixed `step`. This is to allow spark style planNodes where `step` is +/// the property of the aggregation function rather than the planNode. +/// Companion functions allow us to override the planNode's step and use +/// aggregations of different steps in the same planNode +core::AggregationNode::Step getCompanionStep( + std::string const& kind, + core::AggregationNode::Step step) { + for (const auto& [k, v] : companionStep) { + if (folly::StringPiece(kind).endsWith(k)) { + step = v; + break; + } + } + return step; +} + +bool hasFinalAggs( + std::vector const& aggregates) { + return std::any_of(aggregates.begin(), aggregates.end(), [](auto const& agg) { + return folly::StringPiece(agg.call->name()).endsWith("_merge_extract"); + }); +} + auto toAggregators( core::AggregationNode const& aggregationNode, exec::OperatorCtx const& operatorCtx) { @@ -453,8 +484,9 @@ auto toAggregators( auto const kind = aggregate.call->name(); auto const inputIndex = aggInputs[0]; auto const constant = aggConstants.empty() ? nullptr : aggConstants[0]; + auto const companionStep = getCompanionStep(kind, step); aggregators.push_back( - createAggregator(step, kind, inputIndex, constant, isGlobal)); + createAggregator(companionStep, kind, inputIndex, constant, isGlobal)); } return aggregators; } @@ -505,7 +537,9 @@ CudfHashAggregation::CudfHashAggregation( operatorId, fmt::format("[{}]", aggregationNode->id())), aggregationNode_(aggregationNode), - isPartialOutput_(exec::isPartialOutput(aggregationNode->step())), + isPartialOutput_( + exec::isPartialOutput(aggregationNode->step()) && + !hasFinalAggs(aggregationNode->aggregates())), isGlobal_(aggregationNode->groupingKeys().empty()), isDistinct_(!isGlobal_ && aggregationNode->aggregates().empty()), maxPartialAggregationMemoryUsage_( diff --git a/velox/experimental/cudf/tests/AggregationTest.cpp b/velox/experimental/cudf/tests/AggregationTest.cpp index 2f405b672e2b..a55b7256aa2b 100644 --- a/velox/experimental/cudf/tests/AggregationTest.cpp +++ b/velox/experimental/cudf/tests/AggregationTest.cpp @@ -479,6 +479,33 @@ TEST_F(AggregationTest, countPartialFinalGlobal) { assertQuery(op, "SELECT count(*) FROM tmp"); } +/// Tests the spark scenario of having different types of aggs in the same +/// planNode Specific example being tested is +/// https://github.com/facebookincubator/velox/issues/12830#issuecomment-2783340233 +TEST_F(AggregationTest, CompanionAggs) { + std::vector keys0{1, 1, 1, 2, 1, 1, 2, 2}; + std::vector keys1{1, 2, 1, 2, 1, 2, 1, 2}; + std::vector values{1, 2, 3, 4, 5, 6, 7, 8}; + auto rowVector = makeRowVector( + {makeFlatVector(keys0), + makeFlatVector(keys1), + makeFlatVector(values)}); + + createDuckDbTable({rowVector}); + + auto op = + PlanBuilder() + .values({rowVector}) + .singleAggregation({"c2", "c0"}, {"count_partial(c1)"}) + .localPartition({"c2", "c0"}) + .singleAggregation({"c0"}, {"count_merge(a0)", "count_partial(c2)"}) + .localPartition({"c0"}) + .singleAggregation({"c0"}, {"count_merge(a0)", "count_merge(a1)"}) + .planNode(); + assertQuery( + op, "SELECT c0, count(c1), count(distinct c2) FROM tmp GROUP BY c0"); +} + TEST_F(AggregationTest, partialAggregationMemoryLimit) { auto vectors = { makeRowVector({makeFlatVector(