diff --git a/ci/scripts/run-backfill-tests.sh b/ci/scripts/run-backfill-tests.sh index ea7a5f0b6ef94..8c719dafa85d0 100755 --- a/ci/scripts/run-backfill-tests.sh +++ b/ci/scripts/run-backfill-tests.sh @@ -422,6 +422,16 @@ test_cross_db_snapshot_backfill() { kill_cluster } +test_locality_backfill() { + echo "--- e2e, locality backfill test, $RUNTIME_CLUSTER_PROFILE" + + risedev ci-start $RUNTIME_CLUSTER_PROFILE + + sqllogictest -p 4566 -d dev 'e2e_test/backfill/locality_backfill/basic.slt' + + kill_cluster +} + main() { set -euo pipefail test_snapshot_and_upstream_read @@ -433,6 +443,7 @@ main() { test_scale_in test_cross_db_snapshot_backfill + test_locality_backfill # Only if profile is "ci-release", run it. if [[ ${profile:-} == "ci-release" ]]; then diff --git a/e2e_test/backfill/locality_backfill/basic.slt b/e2e_test/backfill/locality_backfill/basic.slt new file mode 100644 index 0000000000000..285e59b213a1a --- /dev/null +++ b/e2e_test/backfill/locality_backfill/basic.slt @@ -0,0 +1,35 @@ +statement ok +set enable_locality_backfill=true; + +statement ok +create table t1(a int, b int); + +statement ok +create table t2(a int, b int); + +statement ok +insert into t1 select i, 123 from generate_series(1, 1000, 1) i; + +statement ok +insert into t2 select i, 123 from generate_series(1, 1000, 1) i ; + +statement ok +flush; + +statement ok +create materialized view mv as select count(*) from t1 join t2 on t1.a = t2.a group by t1.b; + + +query ? +select * from mv; +---- +1000 + +statement ok +drop materialized view mv; + +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/e2e_test/batch/catalog/pg_settings.slt.part b/e2e_test/batch/catalog/pg_settings.slt.part index 587f34ebdda91..e32b89e8aa344 100644 --- a/e2e_test/batch/catalog/pg_settings.slt.part +++ b/e2e_test/batch/catalog/pg_settings.slt.part @@ -40,6 +40,7 @@ user disable_purify_definition user dml_rate_limit user enable_index_selection user enable_join_ordering +user enable_locality_backfill user enable_share_plan user enable_two_phase_agg user extra_float_digits diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 16fbdfc750aea..58bcf1f5d09eb 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -998,6 +998,15 @@ message UpstreamSinkUnionNode { repeated UpstreamSinkInfo init_upstreams = 1; } +message LocalityProviderNode { + // Column indices that define locality + repeated uint32 locality_columns = 1; + // State table for buffering input data + optional catalog.Table state_table = 2; + // Progress table for tracking backfill progress + optional catalog.Table progress_table = 3; +} + message StreamNode { // This field used to be a `bool append_only`. // Enum variants are ordered for backwards compatibility. @@ -1060,6 +1069,7 @@ message StreamNode { MaterializedExprsNode materialized_exprs = 149; VectorIndexWriteNode vector_index_write = 150; UpstreamSinkUnionNode upstream_sink_union = 151; + LocalityProviderNode locality_provider = 152; } // The id for the operator. This is local per mview. // TODO: should better be a uint32. diff --git a/src/common/src/catalog/mod.rs b/src/common/src/catalog/mod.rs index deda6f60712c3..974e3697caa9c 100644 --- a/src/common/src/catalog/mod.rs +++ b/src/common/src/catalog/mod.rs @@ -637,7 +637,8 @@ macro_rules! for_all_fragment_type_flags { CrossDbSnapshotBackfillStreamScan, StreamCdcScan, VectorIndexWrite, - UpstreamSinkUnion + UpstreamSinkUnion, + LocalityProvider }, {}, 0 @@ -892,6 +893,11 @@ mod tests { 65536, "UPSTREAM_SINK_UNION", ), + ( + LocalityProvider, + 131072, + "LOCALITY_PROVIDER", + ), ] "#]] .assert_debug_eq( diff --git a/src/common/src/session_config/mod.rs b/src/common/src/session_config/mod.rs index cc279e999c027..67f2f797bdedf 100644 --- a/src/common/src/session_config/mod.rs +++ b/src/common/src/session_config/mod.rs @@ -421,6 +421,10 @@ pub struct SessionConfig { /// Enable index selection for queries #[parameter(default = true)] enable_index_selection: bool, + + /// Enable locality backfill for streaming queries. Defaults to false. + #[parameter(default = false)] + enable_locality_backfill: bool, } fn check_iceberg_engine_connection(val: &str) -> Result<(), String> { diff --git a/src/common/src/util/stream_graph_visitor.rs b/src/common/src/util/stream_graph_visitor.rs index f901b47c49112..c4115e4e31b24 100644 --- a/src/common/src/util/stream_graph_visitor.rs +++ b/src/common/src/util/stream_graph_visitor.rs @@ -309,6 +309,10 @@ pub fn visit_stream_node_tables_inner( always!(node.table, "StreamVectorIndexWrite"); } + NodeBody::LocalityProvider(node) => { + always!(node.state_table, "LocalityProviderState"); + always!(node.progress_table, "LocalityProviderProgress"); + } _ => {} } }; diff --git a/src/frontend/planner_test/tests/testdata/input/locality_backfill.yaml b/src/frontend/planner_test/tests/testdata/input/locality_backfill.yaml new file mode 100644 index 0000000000000..6b197d70b7e9e --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/input/locality_backfill.yaml @@ -0,0 +1,19 @@ +- sql: | + set enable_locality_backfill = true; + create table t (a int, b int, c int); + select count(*) from t group by b; + expected_outputs: + - stream_plan +- sql: | + set enable_locality_backfill = true; + create table t1 (a int, b int, c int); + create table t2 (a int, b int, c int); + select count(*) from t1 join t2 on t1.a = t2.a group by t1.b; + expected_outputs: + - stream_plan +- sql: | + set enable_locality_backfill = true; + create table t (a int, b int, c int, primary key (b, a)); + select count(*) from t group by a, b; + expected_outputs: + - stream_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/locality_backfill.yaml b/src/frontend/planner_test/tests/testdata/output/locality_backfill.yaml new file mode 100644 index 0000000000000..2c72134309bec --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/output/locality_backfill.yaml @@ -0,0 +1,43 @@ +# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information. +- sql: | + set enable_locality_backfill = true; + create table t (a int, b int, c int); + select count(*) from t group by b; + stream_plan: |- + StreamMaterialize { columns: [count, t.b(hidden)], stream_key: [t.b], pk_columns: [t.b], pk_conflict: NoCheck } + └─StreamProject { exprs: [count, t.b] } + └─StreamHashAgg { group_key: [t.b], aggs: [count] } + └─StreamLocalityProvider { locality_columns: [t.b] } + └─StreamExchange { dist: HashShard(t.b) } + └─StreamTableScan { table: t, columns: [t.b, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) } +- sql: | + set enable_locality_backfill = true; + create table t1 (a int, b int, c int); + create table t2 (a int, b int, c int); + select count(*) from t1 join t2 on t1.a = t2.a group by t1.b; + stream_plan: |- + StreamMaterialize { columns: [count, t1.b(hidden)], stream_key: [t1.b], pk_columns: [t1.b], pk_conflict: NoCheck } + └─StreamProject { exprs: [count, t1.b] } + └─StreamHashAgg { group_key: [t1.b], aggs: [count] } + └─StreamLocalityProvider { locality_columns: [t1.b] } + └─StreamExchange { dist: HashShard(t1.b) } + └─StreamHashJoin { type: Inner, predicate: t1.a = t2.a, output: [t1.b, t1._row_id, t1.a, t2._row_id] } + ├─StreamExchange { dist: HashShard(t1.a) } + │ └─StreamLocalityProvider { locality_columns: [t1.a] } + │ └─StreamExchange { dist: HashShard(t1.a) } + │ └─StreamTableScan { table: t1, columns: [t1.a, t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard(t2.a) } + └─StreamLocalityProvider { locality_columns: [t2.a] } + └─StreamExchange { dist: HashShard(t2.a) } + └─StreamTableScan { table: t2, columns: [t2.a, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } +- sql: | + set enable_locality_backfill = true; + create table t (a int, b int, c int, primary key (b, a)); + select count(*) from t group by a, b; + stream_plan: |- + StreamMaterialize { columns: [count, t.a(hidden), t.b(hidden)], stream_key: [t.a, t.b], pk_columns: [t.a, t.b], pk_conflict: NoCheck } + └─StreamProject { exprs: [count, t.a, t.b] } + └─StreamHashAgg { group_key: [t.a, t.b], aggs: [count] } + └─StreamLocalityProvider { locality_columns: [t.a, t.b] } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(t.b, t.a) } + └─StreamTableScan { table: t, columns: [t.a, t.b], stream_scan_type: ArrangementBackfill, stream_key: [t.b, t.a], pk: [b, a], dist: UpstreamHashShard(t.b, t.a) } diff --git a/src/frontend/src/optimizer/plan_node/generic/locality_provider.rs b/src/frontend/src/optimizer/plan_node/generic/locality_provider.rs new file mode 100644 index 0000000000000..f1268ede8d886 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/generic/locality_provider.rs @@ -0,0 +1,78 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pretty_xmlish::Pretty; +use risingwave_common::catalog::{FieldDisplay, Schema}; + +use super::{GenericPlanNode, GenericPlanRef, impl_distill_unit_from_fields}; +use crate::expr::ExprRewriter; +use crate::optimizer::optimizer_context::OptimizerContextRef; +use crate::optimizer::property::FunctionalDependencySet; + +/// `LocalityProvider` provides locality for operators during backfilling. +/// It buffers input data into a state table using locality columns as primary key prefix. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LocalityProvider { + pub input: PlanRef, + /// Columns that define the locality + pub locality_columns: Vec, +} + +impl LocalityProvider { + pub fn new(input: PlanRef, locality_columns: Vec) -> Self { + Self { + input, + locality_columns, + } + } + + pub fn fields_pretty<'a>(&self) -> Vec<(&'a str, Pretty<'a>)> { + let locality_columns_display = self + .locality_columns + .iter() + .map(|&i| Pretty::display(&FieldDisplay(self.input.schema().fields.get(i).unwrap()))) + .collect(); + vec![("locality_columns", Pretty::Array(locality_columns_display))] + } +} + +impl GenericPlanNode for LocalityProvider { + fn schema(&self) -> Schema { + self.input.schema().clone() + } + + fn stream_key(&self) -> Option> { + Some(self.input.stream_key()?.to_vec()) + } + + fn ctx(&self) -> OptimizerContextRef { + self.input.ctx() + } + + fn functional_dependency(&self) -> FunctionalDependencySet { + self.input.functional_dependency().clone() + } +} + +impl LocalityProvider { + pub fn rewrite_exprs(&mut self, _r: &mut dyn ExprRewriter) { + // LocalityProvider doesn't contain expressions to rewrite + } + + pub fn visit_exprs(&self, _v: &mut dyn crate::expr::ExprVisitor) { + // LocalityProvider doesn't contain expressions to visit + } +} + +impl_distill_unit_from_fields!(LocalityProvider, GenericPlanRef); diff --git a/src/frontend/src/optimizer/plan_node/generic/mod.rs b/src/frontend/src/optimizer/plan_node/generic/mod.rs index 6394e8e9348b6..dbea1ae29b2ac 100644 --- a/src/frontend/src/optimizer/plan_node/generic/mod.rs +++ b/src/frontend/src/optimizer/plan_node/generic/mod.rs @@ -95,6 +95,9 @@ pub use postgres_query::*; mod mysql_query; pub use mysql_query::*; +mod locality_provider; +pub use locality_provider::*; + pub trait DistillUnit { fn distill_with_name<'a>(&self, name: impl Into>) -> XmlNode<'a>; } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 299425c8a446e..3238f71936577 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -1416,10 +1416,14 @@ impl ToStream for LogicalAgg { use super::stream::prelude::*; let eowc = ctx.emit_on_window_close(); - let input = self - .input() - .try_better_locality(&self.group_key().to_vec()) - .unwrap_or_else(|| self.input()); + let input = if self.group_key().is_empty() { + self.input() + } else { + self.input() + .try_better_locality(&self.group_key().to_vec()) + .unwrap_or_else(|| self.input()) + }; + let stream_input = input.to_stream(ctx)?; // Use Dedup operator, if possible. diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 48884018bb2f7..ad1b4c68d860d 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -28,9 +28,9 @@ use super::generic::{ }; use super::utils::{Distill, childless_record}; use super::{ - BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, PlanBase, - PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, StreamPlanRef, StreamProject, ToBatch, - ToStream, generic, + BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalLocalityProvider, + LogicalPlanRef as PlanRef, PlanBase, PlanTreeNodeBinary, PredicatePushdown, StreamHashJoin, + StreamPlanRef, StreamProject, ToBatch, ToStream, generic, }; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{CollectInputRef, Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef}; @@ -1541,6 +1541,26 @@ impl LogicalJoin { .into()), } } + + fn try_better_locality_inner(&self, columns: &[usize]) -> Option { + let mut ctx = ToStreamContext::new(false); + // only pass through the locality information if it can be converted to dynamic filter + if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) { + // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns. + let o2i_mapping = self.core.o2i_col_mapping(); + let left_input_columns = columns + .iter() + .map(|&col| o2i_mapping.try_map(col)) + .collect::>>()?; + if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) { + return Some( + self.clone_with_left_right(better_left_plan, self.right()) + .into(), + ); + } + } + None + } } impl ToBatch for LogicalJoin { @@ -1754,23 +1774,19 @@ impl ToStream for LogicalJoin { } fn try_better_locality(&self, columns: &[usize]) -> Option { - let mut ctx = ToStreamContext::new(false); - // only pass through the locality information if it can be converted to dynamic filter - if let Ok(Some(_)) = self.to_stream_dynamic_filter(self.on().clone(), &mut ctx) { - // since dynamic filter only supports left input ref in the output indices, we can safely use o2i mapping to convert the required columns. - let o2i_mapping = self.core.o2i_col_mapping(); - let left_input_columns = columns - .iter() - .map(|&col| o2i_mapping.try_map(col)) - .collect::>>()?; - if let Some(better_left_plan) = self.left().try_better_locality(&left_input_columns) { - return Some( - self.clone_with_left_right(better_left_plan, self.right()) - .into(), - ); - } + if let Some(better_plan) = self.try_better_locality_inner(columns) { + Some(better_plan) + } else if self.ctx().session_ctx().config().enable_locality_backfill() { + Some( + LogicalLocalityProvider::new( + self.clone_with_left_right(self.left(), self.right()).into(), + columns.to_owned(), + ) + .into(), + ) + } else { + None } - None } } diff --git a/src/frontend/src/optimizer/plan_node/logical_locality_provider.rs b/src/frontend/src/optimizer/plan_node/logical_locality_provider.rs new file mode 100644 index 0000000000000..b487e0917fd4f --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/logical_locality_provider.rs @@ -0,0 +1,181 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; + +use super::generic::GenericPlanRef; +use super::utils::impl_distill_by_unit; +use super::{ + BatchPlanRef, ColPrunable, ExprRewritable, Logical, LogicalPlanRef as PlanRef, LogicalProject, + PlanBase, PlanTreeNodeUnary, PredicatePushdown, StreamExchange, StreamPlanRef, ToBatch, + ToStream, generic, +}; +use crate::error::Result; +use crate::expr::{ExprRewriter, ExprVisitor}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::{ + ColumnPruningContext, PredicatePushdownContext, RewriteStreamContext, ToStreamContext, +}; +use crate::optimizer::property::RequiredDist; +use crate::utils::{ColIndexMapping, Condition}; + +/// `LogicalLocalityProvider` provides locality for operators during backfilling. +/// It buffers input data into a state table using locality columns as primary key prefix. +/// +/// The `LocalityProvider` has 2 states: +/// - One is used to buffer data during backfilling and provide data locality. +/// - The other one is a progress table like normal backfill operator to track the backfilling progress of itself. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LogicalLocalityProvider { + pub base: PlanBase, + core: generic::LocalityProvider, +} + +impl LogicalLocalityProvider { + pub fn new(input: PlanRef, locality_columns: Vec) -> Self { + assert!(!locality_columns.is_empty()); + let core = generic::LocalityProvider::new(input, locality_columns); + let base = PlanBase::new_logical_with_core(&core); + LogicalLocalityProvider { base, core } + } + + pub fn create(input: PlanRef, locality_columns: Vec) -> PlanRef { + LogicalLocalityProvider::new(input, locality_columns).into() + } + + pub fn locality_columns(&self) -> &[usize] { + &self.core.locality_columns + } +} + +impl PlanTreeNodeUnary for LogicalLocalityProvider { + fn input(&self) -> PlanRef { + self.core.input.clone() + } + + fn clone_with_input(&self, input: PlanRef) -> Self { + Self::new(input, self.locality_columns().to_vec()) + } + + fn rewrite_with_input( + &self, + input: PlanRef, + input_col_change: ColIndexMapping, + ) -> (Self, ColIndexMapping) { + let locality_columns = self + .locality_columns() + .iter() + .map(|&i| input_col_change.map(i)) + .collect(); + + (Self::new(input, locality_columns), input_col_change) + } +} + +impl_plan_tree_node_for_unary! { Logical, LogicalLocalityProvider} +impl_distill_by_unit!(LogicalLocalityProvider, core, "LogicalLocalityProvider"); + +impl ColPrunable for LogicalLocalityProvider { + fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef { + // No pruning. + let input_required_cols = (0..self.input().schema().len()).collect_vec(); + LogicalProject::with_out_col_idx( + self.clone_with_input(self.input().prune_col(&input_required_cols, ctx)) + .into(), + required_cols.iter().cloned(), + ) + .into() + } +} + +impl PredicatePushdown for LogicalLocalityProvider { + fn predicate_pushdown( + &self, + predicate: Condition, + ctx: &mut PredicatePushdownContext, + ) -> PlanRef { + let new_input = self.input().predicate_pushdown(predicate, ctx); + let new_provider = self.clone_with_input(new_input); + new_provider.into() + } +} + +impl ToBatch for LogicalLocalityProvider { + fn to_batch(&self) -> Result { + // LocalityProvider is a streaming-only operator + Err(crate::error::ErrorCode::NotSupported( + "LocalityProvider in batch mode".to_owned(), + "LocalityProvider is only supported in streaming mode for backfilling".to_owned(), + ) + .into()) + } +} + +impl ToStream for LogicalLocalityProvider { + fn to_stream(&self, ctx: &mut ToStreamContext) -> Result { + use super::StreamLocalityProvider; + + let input = self.input().to_stream(ctx)?; + let required_dist = + RequiredDist::shard_by_key(self.input().schema().len(), self.locality_columns()); + let input = required_dist.streaming_enforce_if_not_satisfies(input)?; + let input = if input.as_stream_exchange().is_none() { + // Force a no shuffle exchange to ensure locality provider is in its own fragment. + // This is important to ensure the backfill ordering can recognize and build + // the dependency graph among different backfill-needed fragments. + StreamExchange::new_no_shuffle(input).into() + } else { + input + }; + let stream_core = generic::LocalityProvider::new(input, self.locality_columns().to_vec()); + Ok(StreamLocalityProvider::new(stream_core).into()) + } + + fn logical_rewrite_for_stream( + &self, + ctx: &mut RewriteStreamContext, + ) -> Result<(PlanRef, ColIndexMapping)> { + let (input, input_col_change) = self.input().logical_rewrite_for_stream(ctx)?; + let (locality_provider, out_col_change) = self.rewrite_with_input(input, input_col_change); + Ok((locality_provider.into(), out_col_change)) + } +} + +impl ExprRewritable for LogicalLocalityProvider { + fn has_rewritable_expr(&self) -> bool { + false + } + + fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef { + self.clone().into() + } +} + +impl ExprVisitable for LogicalLocalityProvider { + fn visit_exprs(&self, _v: &mut dyn ExprVisitor) { + // No expressions to visit + } +} + +impl LogicalLocalityProvider { + pub fn try_better_locality(&self, columns: &[usize]) -> Option { + if columns == self.locality_columns() { + Some(self.clone().into()) + } else if let Some(better_input) = self.input().try_better_locality(columns) { + Some(better_input) + } else { + Some(Self::new(self.input(), columns.to_owned()).into()) + } + } +} diff --git a/src/frontend/src/optimizer/plan_node/logical_scan.rs b/src/frontend/src/optimizer/plan_node/logical_scan.rs index 1041986a4c5e5..06f542b9a696b 100644 --- a/src/frontend/src/optimizer/plan_node/logical_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_scan.rs @@ -26,8 +26,8 @@ use super::generic::{GenericPlanNode, GenericPlanRef}; use super::utils::{Distill, childless_record}; use super::{ BatchFilter, BatchPlanRef, BatchProject, ColPrunable, ExprRewritable, Logical, - LogicalPlanRef as PlanRef, PlanBase, PlanNodeId, PredicatePushdown, StreamTableScan, ToBatch, - ToStream, generic, + LogicalLocalityProvider, LogicalPlanRef as PlanRef, PlanBase, PlanNodeId, PredicatePushdown, + StreamTableScan, ToBatch, ToStream, generic, }; use crate::TableCatalog; use crate::binder::BoundBaseTable; @@ -565,6 +565,52 @@ impl LogicalScan { None } + + fn try_better_locality_inner(&self, columns: &[usize]) -> Option { + if !self + .core + .ctx() + .session_ctx() + .config() + .enable_index_selection() + { + return None; + } + if columns.is_empty() { + return None; + } + if self.table_indexes().is_empty() { + return None; + } + let orders = if columns.len() <= 3 { + OrderType::all() + } else { + // Limit the number of order type combinations to avoid explosion. + // For more than 3 columns, we only consider ascending nulls last and descending. + // Since by default, indexes are created with ascending nulls last. + // This is a heuristic to reduce the search space. + vec![OrderType::ascending_nulls_last(), OrderType::descending()] + }; + for order_type_combo in columns + .iter() + .map(|&col| orders.iter().map(move |ot| ColumnOrder::new(col, *ot))) + .multi_cartesian_product() + .take(256) + // limit the number of combinations + { + let required_order = Order { + column_orders: order_type_combo, + }; + + let order_satisfied_index = self.indexes_satisfy_order(&required_order); + for index in order_satisfied_index { + if let Some(index_scan) = self.to_index_scan_if_index_covered(index) { + return Some(index_scan.into()); + } + } + } + None + } } impl ToBatch for LogicalScan { @@ -680,48 +726,12 @@ impl ToStream for LogicalScan { } fn try_better_locality(&self, columns: &[usize]) -> Option { - if !self - .core - .ctx() - .session_ctx() - .config() - .enable_index_selection() - { - return None; - } - if columns.is_empty() { - return None; - } - if self.table_indexes().is_empty() { - return None; - } - let orders = if columns.len() <= 3 { - OrderType::all() + if let Some(better_plan) = self.try_better_locality_inner(columns) { + Some(better_plan) + } else if self.ctx().session_ctx().config().enable_locality_backfill() { + Some(LogicalLocalityProvider::new(self.clone().into(), columns.to_owned()).into()) } else { - // Limit the number of order type combinations to avoid explosion. - // For more than 3 columns, we only consider ascending nulls last and descending. - // Since by default, indexes are created with ascending nulls last. - // This is a heuristic to reduce the search space. - vec![OrderType::ascending_nulls_last(), OrderType::descending()] - }; - for order_type_combo in columns - .iter() - .map(|&col| orders.iter().map(move |ot| ColumnOrder::new(col, *ot))) - .multi_cartesian_product() - .take(256) - // limit the number of combinations - { - let required_order = Order { - column_orders: order_type_combo, - }; - - let order_satisfied_index = self.indexes_satisfy_order(&required_order); - for index in order_satisfied_index { - if let Some(index_scan) = self.to_index_scan_if_index_covered(index) { - return Some(index_scan.into()); - } - } + None } - None } } diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index bc09d7beece09..5dcda3df37ed6 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -1013,6 +1013,7 @@ mod logical_intersect; mod logical_join; mod logical_kafka_scan; mod logical_limit; +mod logical_locality_provider; mod logical_max_one_row; mod logical_multi_join; mod logical_now; @@ -1047,6 +1048,7 @@ mod stream_hash_join; mod stream_hop_window; mod stream_join_common; mod stream_local_approx_percentile; +mod stream_locality_provider; mod stream_materialize; mod stream_materialized_exprs; mod stream_now; @@ -1143,6 +1145,7 @@ pub use logical_intersect::LogicalIntersect; pub use logical_join::LogicalJoin; pub use logical_kafka_scan::LogicalKafkaScan; pub use logical_limit::LogicalLimit; +pub use logical_locality_provider::LogicalLocalityProvider; pub use logical_max_one_row::LogicalMaxOneRow; pub use logical_multi_join::{LogicalMultiJoin, LogicalMultiJoinBuilder}; pub use logical_mysql_query::LogicalMySqlQuery; @@ -1181,6 +1184,7 @@ pub use stream_hash_join::StreamHashJoin; pub use stream_hop_window::StreamHopWindow; use stream_join_common::StreamJoinCommon; pub use stream_local_approx_percentile::StreamLocalApproxPercentile; +pub use stream_locality_provider::StreamLocalityProvider; pub use stream_materialize::StreamMaterialize; pub use stream_materialized_exprs::StreamMaterializedExprs; pub use stream_now::StreamNow; @@ -1269,6 +1273,7 @@ macro_rules! for_all_plan_nodes { , { Logical, MySqlQuery } , { Logical, VectorSearch } , { Logical, GetChannelDeltaStats } + , { Logical, LocalityProvider } , { Batch, SimpleAgg } , { Batch, HashAgg } , { Batch, SortAgg } @@ -1346,6 +1351,7 @@ macro_rules! for_all_plan_nodes { , { Stream, MaterializedExprs } , { Stream, VectorIndexWrite } , { Stream, UpstreamSinkUnion } + , { Stream, LocalityProvider } $(,$rest)* } }; diff --git a/src/frontend/src/optimizer/plan_node/stream_locality_provider.rs b/src/frontend/src/optimizer/plan_node/stream_locality_provider.rs new file mode 100644 index 0000000000000..15db8ef9bc152 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream_locality_provider.rs @@ -0,0 +1,207 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use pretty_xmlish::XmlNode; +use risingwave_common::catalog::Field; +use risingwave_common::hash::VirtualNode; +use risingwave_common::types::DataType; +use risingwave_common::util::sort_util::OrderType; +use risingwave_pb::stream_plan::LocalityProviderNode; +use risingwave_pb::stream_plan::stream_node::PbNodeBody; + +use super::stream::prelude::*; +use super::utils::{Distill, TableCatalogBuilder, childless_record}; +use super::{ExprRewritable, PlanTreeNodeUnary, StreamNode, StreamPlanRef as PlanRef, generic}; +use crate::TableCatalog; +use crate::catalog::TableId; +use crate::expr::{ExprRewriter, ExprVisitor}; +use crate::optimizer::plan_node::PlanBase; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::property::Distribution; +use crate::stream_fragmenter::BuildFragmentGraphState; + +/// `StreamLocalityProvider` implements [`super::LogicalLocalityProvider`] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StreamLocalityProvider { + pub base: PlanBase, + core: generic::LocalityProvider, +} + +impl StreamLocalityProvider { + pub fn new(core: generic::LocalityProvider) -> Self { + let input = core.input.clone(); + + let dist = match input.distribution() { + Distribution::HashShard(keys) => { + // If the input is hash-distributed, we make it a UpstreamHashShard distribution + // just like a normal table scan. It is used to ensure locality provider is in its own fragment. + // This is important to ensure the backfill ordering can recognize and build + // the dependency graph among different backfill-needed fragments. + Distribution::UpstreamHashShard(keys.clone(), TableId::placeholder()) + } + Distribution::UpstreamHashShard(keys, table_id) => { + Distribution::UpstreamHashShard(keys.clone(), *table_id) + } + _ => { + panic!("LocalityProvider input must be hash-distributed"); + } + }; + + // LocalityProvider maintains the append-only behavior if input is append-only + let base = PlanBase::new_stream_with_core( + &core, + dist, + input.stream_kind(), + input.emit_on_window_close(), + input.watermark_columns().clone(), + input.columns_monotonicity().clone(), + ); + StreamLocalityProvider { base, core } + } + + pub fn locality_columns(&self) -> &[usize] { + &self.core.locality_columns + } +} + +impl PlanTreeNodeUnary for StreamLocalityProvider { + fn input(&self) -> PlanRef { + self.core.input.clone() + } + + fn clone_with_input(&self, input: PlanRef) -> Self { + let mut core = self.core.clone(); + core.input = input; + Self::new(core) + } +} + +impl_plan_tree_node_for_unary! { Stream, StreamLocalityProvider } + +impl Distill for StreamLocalityProvider { + fn distill<'a>(&self) -> XmlNode<'a> { + let vec = self.core.fields_pretty(); + childless_record("StreamLocalityProvider", vec) + } +} + +impl StreamNode for StreamLocalityProvider { + fn to_stream_prost_body(&self, state: &mut BuildFragmentGraphState) -> PbNodeBody { + let state_table = self.build_state_catalog(state); + let progress_table = self.build_progress_catalog(state); + + let locality_provider_node = LocalityProviderNode { + locality_columns: self.locality_columns().iter().map(|&i| i as u32).collect(), + // State table for buffering input data + state_table: Some(state_table.to_prost()), + // Progress table for tracking backfill progress + progress_table: Some(progress_table.to_prost()), + }; + + PbNodeBody::LocalityProvider(Box::new(locality_provider_node)) + } +} + +impl ExprRewritable for StreamLocalityProvider { + fn has_rewritable_expr(&self) -> bool { + false + } + + fn rewrite_exprs(&self, _r: &mut dyn ExprRewriter) -> PlanRef { + self.clone().into() + } +} + +impl ExprVisitable for StreamLocalityProvider { + fn visit_exprs(&self, _v: &mut dyn ExprVisitor) { + // No expressions to visit + } +} + +impl StreamLocalityProvider { + /// Build the state table catalog for buffering input data + /// Schema: same as input schema (locality handled by primary key ordering) + /// Key: `locality_columns` (vnode handled internally by `StateTable`) + fn build_state_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog { + let mut catalog_builder = TableCatalogBuilder::default(); + let input = self.input(); + let input_schema = input.schema(); + + // Add all input columns in original order + for field in &input_schema.fields { + catalog_builder.add_column(field); + } + + // Set locality columns as primary key. + for locality_col_idx in self.locality_columns() { + catalog_builder.add_order_column(*locality_col_idx, OrderType::ascending()); + } + // add streaming key of the input as the rest of the primary key + for &key_col_idx in input.expect_stream_key() { + catalog_builder.add_order_column(key_col_idx, OrderType::ascending()); + } + + catalog_builder.set_value_indices((0..input_schema.len()).collect()); + + catalog_builder + .build( + self.input().distribution().dist_column_indices().to_vec(), + 0, + ) + .with_id(state.gen_table_id_wrapped()) + } + + /// Build the progress table catalog for tracking backfill progress + /// Schema: | vnode | pk(locality columns + input stream keys) | `backfill_finished` | `row_count` | + /// Key: | vnode | pk(locality columns + input stream keys) | + fn build_progress_catalog(&self, state: &mut BuildFragmentGraphState) -> TableCatalog { + let mut catalog_builder = TableCatalogBuilder::default(); + let input = self.input(); + let input_schema = input.schema(); + + // Add vnode column as primary key + catalog_builder.add_column(&Field::with_name(VirtualNode::RW_TYPE, "vnode")); + catalog_builder.add_order_column(0, OrderType::ascending()); + + // Add locality columns as part of primary key + for &locality_col_idx in self.locality_columns() { + let field = &input_schema.fields[locality_col_idx]; + catalog_builder.add_column(field); + } + + // Add stream key columns as part of primary key (excluding those already added as locality columns) + for &key_col_idx in input.expect_stream_key() { + let field = &input_schema.fields[key_col_idx]; + catalog_builder.add_column(field); + } + + // Add backfill_finished column + catalog_builder.add_column(&Field::with_name(DataType::Boolean, "backfill_finished")); + + // Add row_count column + catalog_builder.add_column(&Field::with_name(DataType::Int64, "row_count")); + + // Set vnode column index and distribution key + catalog_builder.set_vnode_col_idx(0); + catalog_builder.set_dist_key_in_pk(vec![0]); + + let num_of_columns = catalog_builder.columns().len(); + catalog_builder.set_value_indices((0..num_of_columns).collect_vec()); + + catalog_builder + .build(vec![0], 1) + .with_id(state.gen_table_id_wrapped()) + } +} diff --git a/src/frontend/src/stream_fragmenter/mod.rs b/src/frontend/src/stream_fragmenter/mod.rs index acf85856c8c93..58055249ad2a2 100644 --- a/src/frontend/src/stream_fragmenter/mod.rs +++ b/src/frontend/src/stream_fragmenter/mod.rs @@ -482,6 +482,12 @@ fn build_fragment( .add(FragmentTypeFlag::UpstreamSinkUnion); } + NodeBody::LocalityProvider(_) => { + current_fragment + .fragment_type_mask + .add(FragmentTypeFlag::LocalityProvider); + } + _ => {} }; diff --git a/src/meta/src/barrier/backfill_order_control.rs b/src/meta/src/barrier/backfill_order_control.rs index a9c0ddc006f1f..09852b7babbbd 100644 --- a/src/meta/src/barrier/backfill_order_control.rs +++ b/src/meta/src/barrier/backfill_order_control.rs @@ -67,10 +67,11 @@ impl BackfillOrderState { let mut backfill_nodes: HashMap = HashMap::new(); for fragment in stream_job_fragments.fragments() { - if fragment - .fragment_type_mask - .contains_any([FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan]) - { + if fragment.fragment_type_mask.contains_any([ + FragmentTypeFlag::StreamScan, + FragmentTypeFlag::SourceScan, + FragmentTypeFlag::LocalityProvider, + ]) { let fragment_id = fragment.fragment_id; backfill_nodes.insert( fragment_id, diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index 76018c3e84eac..7a4ec26647b25 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -149,6 +149,10 @@ impl Progress { BackfillUpstreamType::Values => { // do not consider progress for values } + BackfillUpstreamType::LocalityProvider => { + // Track LocalityProvider progress similar to MView + self.mv_backfill_consumed_rows += new - old; + } } self.states.insert(actor, new_state); next_backfill_nodes @@ -183,6 +187,7 @@ impl Progress { BackfillUpstreamType::MView => mv_count += 1, BackfillUpstreamType::Source => source_count += 1, BackfillUpstreamType::Values => (), + BackfillUpstreamType::LocalityProvider => mv_count += 1, /* Count LocalityProvider as an MView for progress */ } } diff --git a/src/meta/src/model/stream.rs b/src/meta/src/model/stream.rs index e8735a0d708e4..e416cde7ebcba 100644 --- a/src/meta/src/model/stream.rs +++ b/src/meta/src/model/stream.rs @@ -503,6 +503,7 @@ impl StreamJobFragments { FragmentTypeFlag::Values, FragmentTypeFlag::StreamScan, FragmentTypeFlag::SourceScan, + FragmentTypeFlag::LocalityProvider, ]) { actor_ids.extend(fragment.actors.iter().map(|actor| { ( @@ -782,6 +783,7 @@ pub enum BackfillUpstreamType { MView, Values, Source, + LocalityProvider, } impl BackfillUpstreamType { @@ -789,12 +791,13 @@ impl BackfillUpstreamType { let is_mview = mask.contains(FragmentTypeFlag::StreamScan); let is_values = mask.contains(FragmentTypeFlag::Values); let is_source = mask.contains(FragmentTypeFlag::SourceScan); + let is_locality_provider = mask.contains(FragmentTypeFlag::LocalityProvider); // Note: in theory we can have multiple backfill executors in one fragment, but currently it's not possible. // See . debug_assert!( - is_mview as u8 + is_values as u8 + is_source as u8 == 1, - "a backfill fragment should either be mview, value or source, found {:?}", + is_mview as u8 + is_values as u8 + is_source as u8 + is_locality_provider as u8 == 1, + "a backfill fragment should either be mview, value, source, or locality provider, found {:?}", mask ); @@ -804,6 +807,8 @@ impl BackfillUpstreamType { BackfillUpstreamType::Values } else if is_source { BackfillUpstreamType::Source + } else if is_locality_provider { + BackfillUpstreamType::LocalityProvider } else { unreachable!("invalid fragment type mask: {:?}", mask); } diff --git a/src/meta/src/stream/stream_graph/fragment.rs b/src/meta/src/stream/stream_graph/fragment.rs index 5c218b1fe2ee1..3ba9103d1efd9 100644 --- a/src/meta/src/stream/stream_graph/fragment.rs +++ b/src/meta/src/stream/stream_graph/fragment.rs @@ -1020,6 +1020,8 @@ impl StreamFragmentGraph { pub fn create_fragment_backfill_ordering(&self) -> FragmentBackfillOrder { let mapping = self.collect_backfill_mapping(); let mut fragment_ordering: HashMap> = HashMap::new(); + + // 1. Add backfill dependencies for (rel_id, downstream_rel_ids) in &self.backfill_order.order { let fragment_ids = mapping.get(rel_id).unwrap(); for fragment_id in fragment_ids { @@ -1032,8 +1034,153 @@ impl StreamFragmentGraph { fragment_ordering.insert(*fragment_id, downstream_fragment_ids); } } + + // If no backfill order is specified, we still need to ensure that all backfill fragments + // run before LocalityProvider fragments. + if fragment_ordering.is_empty() { + for value in mapping.values() { + for &fragment_id in value { + fragment_ordering.entry(fragment_id).or_default(); + } + } + } + + // 2. Add dependencies: all backfill fragments should run before LocalityProvider fragments + let locality_provider_dependencies = self.find_locality_provider_dependencies(); + + let backfill_fragments: HashSet = mapping.values().flatten().copied().collect(); + + // Calculate LocalityProvider root fragments (zero indegree) + // Root fragments are those that appear as keys but never appear as downstream dependencies + let all_locality_provider_fragments: HashSet = + locality_provider_dependencies.keys().copied().collect(); + let downstream_locality_provider_fragments: HashSet = locality_provider_dependencies + .values() + .flatten() + .copied() + .collect(); + let locality_provider_root_fragments: Vec = all_locality_provider_fragments + .difference(&downstream_locality_provider_fragments) + .copied() + .collect(); + + // For each backfill fragment, add only the root LocalityProvider fragments as dependents + // This ensures backfill completes before any LocalityProvider starts, while minimizing dependencies + for &backfill_fragment_id in &backfill_fragments { + fragment_ordering + .entry(backfill_fragment_id) + .or_default() + .extend(locality_provider_root_fragments.iter().copied()); + } + + // 3. Add LocalityProvider internal dependencies + for (fragment_id, downstream_fragments) in locality_provider_dependencies { + fragment_ordering + .entry(fragment_id) + .or_default() + .extend(downstream_fragments); + } + fragment_ordering } + + /// Find dependency relationships among fragments containing `LocalityProvider` nodes. + /// Returns a mapping where each fragment ID maps to a list of fragment IDs that should be processed after it. + /// Following the same semantics as `FragmentBackfillOrder`: + /// `G[10] -> [1, 2, 11]` means `LocalityProvider` in fragment 10 should be processed + /// before `LocalityProviders` in fragments 1, 2, and 11. + /// + /// This method assumes each fragment contains at most one `LocalityProvider` node. + pub fn find_locality_provider_dependencies(&self) -> HashMap> { + let mut locality_provider_fragments = HashSet::new(); + let mut dependencies: HashMap> = HashMap::new(); + + // First, identify all fragments that contain LocalityProvider nodes + for (fragment_id, fragment) in &self.fragments { + let fragment_id = fragment_id.as_global_id(); + let has_locality_provider = self.fragment_has_locality_provider(fragment); + + if has_locality_provider { + locality_provider_fragments.insert(fragment_id); + dependencies.entry(fragment_id).or_default(); + } + } + + // Build dependency relationships between LocalityProvider fragments + // For each LocalityProvider fragment, find all downstream LocalityProvider fragments + // The upstream fragment should be processed before the downstream fragments + for &provider_fragment_id in &locality_provider_fragments { + let provider_fragment_global_id = GlobalFragmentId::new(provider_fragment_id); + + // Find all fragments downstream from this LocalityProvider fragment + let mut visited = HashSet::new(); + let mut downstream_locality_providers = Vec::new(); + + self.collect_downstream_locality_providers( + provider_fragment_global_id, + &locality_provider_fragments, + &mut visited, + &mut downstream_locality_providers, + ); + + // This fragment should be processed before all its downstream LocalityProvider fragments + dependencies + .entry(provider_fragment_id) + .or_default() + .extend(downstream_locality_providers); + } + + dependencies + } + + fn fragment_has_locality_provider(&self, fragment: &BuildingFragment) -> bool { + let mut has_locality_provider = false; + + if let Some(node) = fragment.node.as_ref() { + visit_stream_node_cont(node, |stream_node| { + if let Some(NodeBody::LocalityProvider(_)) = stream_node.node_body.as_ref() { + has_locality_provider = true; + false // Stop visiting once we find a LocalityProvider + } else { + true // Continue visiting + } + }); + } + + has_locality_provider + } + + /// Recursively collect downstream `LocalityProvider` fragments + fn collect_downstream_locality_providers( + &self, + current_fragment_id: GlobalFragmentId, + locality_provider_fragments: &HashSet, + visited: &mut HashSet, + downstream_providers: &mut Vec, + ) { + if visited.contains(¤t_fragment_id) { + return; + } + visited.insert(current_fragment_id); + + // Check all downstream fragments + for &downstream_id in self.get_downstreams(current_fragment_id).keys() { + let downstream_fragment_id = downstream_id.as_global_id(); + + // If the downstream fragment is a LocalityProvider, add it to results + if locality_provider_fragments.contains(&downstream_fragment_id) { + downstream_providers.push(downstream_fragment_id); + } + + // Recursively check further downstream + self.collect_downstream_locality_providers( + downstream_id, + locality_provider_fragments, + visited, + downstream_providers, + ); + } + } } /// Fill snapshot epoch for `StreamScanNode` of `SnapshotBackfill`. diff --git a/src/prost/build.rs b/src/prost/build.rs index babc61295e84f..0d818976d015b 100644 --- a/src/prost/build.rs +++ b/src/prost/build.rs @@ -170,6 +170,7 @@ fn main() -> Result<(), Box> { .boxed(".stream_plan.StreamNode.node_body.sync_log_store") .boxed(".stream_plan.StreamNode.node_body.materialized_exprs") .boxed(".stream_plan.StreamNode.node_body.vector_index_write") + .boxed(".stream_plan.StreamNode.node_body.locality_provider") // `Udf` is 248 bytes, while 2nd largest field is 32 bytes. .boxed(".expr.ExprNode.rex_node.udf") // Eq + Hash are for plan nodes to do common sub-plan detection. diff --git a/src/stream/src/executor/locality_provider.rs b/src/stream/src/executor/locality_provider.rs new file mode 100644 index 0000000000000..519c488832eb8 --- /dev/null +++ b/src/stream/src/executor/locality_provider.rs @@ -0,0 +1,817 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use either::Either; +use futures::stream::select_with_strategy; +use futures::{TryStreamExt, pin_mut, stream}; +use futures_async_stream::try_stream; +use itertools::Itertools; +use risingwave_common::array::{DataChunk, Op, StreamChunk}; +use risingwave_common::catalog::Schema; +use risingwave_common::hash::{VirtualNode, VnodeBitmapExt}; +use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::{Datum, ToOwnedDatum}; +use risingwave_common::util::chunk_coalesce::DataChunkBuilder; +use risingwave_common::util::sort_util::cmp_datum_iter; +use risingwave_common_rate_limit::RateLimit; +use risingwave_storage::StateStore; +use risingwave_storage::store::PrefetchOptions; + +use crate::common::table::state_table::StateTable; +use crate::executor::backfill::utils::create_builder; +use crate::executor::prelude::*; +use crate::task::{CreateMviewProgressReporter, FragmentId}; + +type Builders = HashMap; + +/// Progress state for tracking backfill per vnode +#[derive(Clone, Debug, PartialEq, Eq)] +enum LocalityBackfillProgress { + /// Backfill not started for this vnode + NotStarted, + /// Backfill in progress, tracking current position + InProgress { + /// Current position in the locality-ordered scan + current_pos: OwnedRow, + /// Number of rows processed for this vnode + processed_rows: u64, + }, + /// Backfill completed for this vnode + Completed { + /// Final position reached + final_pos: OwnedRow, + /// Total rows processed for this vnode + total_rows: u64, + }, +} + +/// State management for locality provider backfill process +#[derive(Clone, Debug)] +struct LocalityBackfillState { + /// Progress per vnode + per_vnode: HashMap, + /// Total snapshot rows read across all vnodes + total_snapshot_rows: u64, +} + +impl LocalityBackfillState { + fn new(vnodes: impl Iterator) -> Self { + let per_vnode = vnodes + .map(|vnode| (vnode, LocalityBackfillProgress::NotStarted)) + .collect(); + Self { + per_vnode, + total_snapshot_rows: 0, + } + } + + fn is_completed(&self) -> bool { + self.per_vnode + .values() + .all(|progress| matches!(progress, LocalityBackfillProgress::Completed { .. })) + } + + fn vnodes(&self) -> impl Iterator { + self.per_vnode + .iter() + .map(|(&vnode, progress)| (vnode, progress)) + } + + fn has_progress(&self) -> bool { + self.per_vnode + .values() + .any(|progress| matches!(progress, LocalityBackfillProgress::InProgress { .. })) + } + + fn update_progress(&mut self, vnode: VirtualNode, new_pos: OwnedRow, row_count_delta: u64) { + let progress = self.per_vnode.get_mut(&vnode).unwrap(); + match progress { + LocalityBackfillProgress::NotStarted => { + *progress = LocalityBackfillProgress::InProgress { + current_pos: new_pos, + processed_rows: row_count_delta, + }; + } + LocalityBackfillProgress::InProgress { processed_rows, .. } => { + *progress = LocalityBackfillProgress::InProgress { + current_pos: new_pos, + processed_rows: *processed_rows + row_count_delta, + }; + } + LocalityBackfillProgress::Completed { .. } => { + // Already completed, shouldn't update + } + } + self.total_snapshot_rows += row_count_delta; + } + + fn finish_vnode(&mut self, vnode: VirtualNode, pk_len: usize) { + let progress = self.per_vnode.get_mut(&vnode).unwrap(); + match progress { + LocalityBackfillProgress::NotStarted => { + // Create a final position with pk_len NULL values to indicate completion + let final_pos = OwnedRow::new(vec![None; pk_len]); + *progress = LocalityBackfillProgress::Completed { + final_pos, + total_rows: 0, + }; + } + LocalityBackfillProgress::InProgress { + current_pos, + processed_rows, + } => { + *progress = LocalityBackfillProgress::Completed { + final_pos: current_pos.clone(), + total_rows: *processed_rows, + }; + } + LocalityBackfillProgress::Completed { .. } => { + // Already completed + } + } + } + + fn get_progress(&self, vnode: &VirtualNode) -> &LocalityBackfillProgress { + self.per_vnode.get(vnode).unwrap() + } +} + +/// The `LocalityProviderExecutor` provides locality for operators during backfilling. +/// It buffers input data into a state table using locality columns as primary key prefix. +/// +/// The executor implements a proper backfill process similar to arrangement backfill: +/// 1. Backfill phase: Buffer incoming data and provide locality-ordered snapshot reads +/// 2. Forward phase: Once backfill is complete, forward upstream messages directly +/// +/// Key improvements over the original implementation: +/// - Removes arbitrary barrier buffer limit +/// - Implements proper upstream chunk tracking during backfill +/// - Uses per-vnode progress tracking for better state management +pub struct LocalityProviderExecutor { + /// Upstream input + upstream: Executor, + + /// Locality columns (indices in input schema) + #[allow(dead_code)] + locality_columns: Vec, + + /// State table for buffering input data + state_table: StateTable, + + /// Progress table for tracking backfill progress per vnode + progress_table: StateTable, + + input_schema: Schema, + + /// Progress reporter for materialized view creation + progress: CreateMviewProgressReporter, + + fragment_id: FragmentId, + + actor_id: ActorId, + + /// Metrics + metrics: Arc, + + /// Chunk size for output + chunk_size: usize, +} + +impl LocalityProviderExecutor { + #[allow(clippy::too_many_arguments)] + pub fn new( + upstream: Executor, + locality_columns: Vec, + state_table: StateTable, + progress_table: StateTable, + input_schema: Schema, + progress: CreateMviewProgressReporter, + metrics: Arc, + chunk_size: usize, + fragment_id: FragmentId, + ) -> Self { + Self { + upstream, + locality_columns, + state_table, + progress_table, + input_schema, + actor_id: progress.actor_id(), + progress, + metrics, + chunk_size, + fragment_id, + } + } + + /// Creates a snapshot stream that reads from state table in locality order + #[try_stream(ok = Option<(VirtualNode, OwnedRow)>, error = StreamExecutorError)] + async fn make_snapshot_stream<'a>( + state_table: &'a StateTable, + backfill_state: LocalityBackfillState, + ) { + // Read from state table per vnode in locality order + for vnode in state_table.vnodes().iter_vnodes() { + let progress = backfill_state.get_progress(&vnode); + + let current_pos = match progress { + LocalityBackfillProgress::NotStarted => None, + LocalityBackfillProgress::Completed { .. } => { + // Skip completed vnodes + continue; + } + LocalityBackfillProgress::InProgress { current_pos, .. } => { + Some(current_pos.clone()) + } + }; + + // Compute range bounds for iteration based on current position + let range_bounds = if let Some(ref pos) = current_pos { + let start_bound = std::ops::Bound::Excluded(pos.as_inner()); + (start_bound, std::ops::Bound::<&[Datum]>::Unbounded) + } else { + ( + std::ops::Bound::<&[Datum]>::Unbounded, + std::ops::Bound::<&[Datum]>::Unbounded, + ) + }; + + // Iterate over rows for this vnode + let iter = state_table + .iter_with_vnode( + vnode, + &range_bounds, + PrefetchOptions::prefetch_for_small_range_scan(), + ) + .await?; + pin_mut!(iter); + + while let Some(row) = iter.try_next().await? { + yield Some((vnode, row)); + } + } + + // Signal end of stream + yield None; + } + + /// Persist backfill state to progress table + async fn persist_backfill_state( + progress_table: &mut StateTable, + backfill_state: &LocalityBackfillState, + ) -> StreamExecutorResult<()> { + for (vnode, progress) in &backfill_state.per_vnode { + let (is_finished, current_pos, row_count) = match progress { + LocalityBackfillProgress::NotStarted => continue, // Don't persist NotStarted + LocalityBackfillProgress::InProgress { + current_pos, + processed_rows, + } => (false, current_pos.clone(), *processed_rows), + LocalityBackfillProgress::Completed { + final_pos, + total_rows, + } => (true, final_pos.clone(), *total_rows), + }; + + // Build progress row: vnode + current_pos + is_finished + row_count + let mut row_data = vec![Some(vnode.to_scalar().into())]; + row_data.extend(current_pos); + row_data.push(Some(risingwave_common::types::ScalarImpl::Bool( + is_finished, + ))); + row_data.push(Some(risingwave_common::types::ScalarImpl::Int64( + row_count as i64, + ))); + + let new_row = OwnedRow::new(row_data); + + // Check if there's an existing row for this vnode to determine insert vs update + // This ensures state operation consistency - update existing rows, insert new ones + let key_data = vec![Some(vnode.to_scalar().into())]; + let key = OwnedRow::new(key_data); + + if let Some(existing_row) = progress_table.get_row(&key).await? { + // Update existing state - ensures proper state transition for recovery + progress_table.update(existing_row, new_row); + } else { + // Insert new state - first time persisting for this vnode + progress_table.insert(new_row); + } + } + Ok(()) + } + + /// Load backfill state from progress table + async fn load_backfill_state( + progress_table: &StateTable, + ) -> StreamExecutorResult { + let mut backfill_state = LocalityBackfillState::new(progress_table.vnodes().iter_vnodes()); + let mut total_snapshot_rows = 0; + + // For each vnode, try to get its progress state + for vnode in progress_table.vnodes().iter_vnodes() { + // Build key: vnode + NULL values for locality columns (to match progress table schema) + let key_data = vec![Some(vnode.to_scalar().into())]; + + let key = OwnedRow::new(key_data); + + if let Some(row) = progress_table.get_row(&key).await? { + // Parse is_finished flag (second to last column) + let finished_col_idx = row.len() - 2; + let is_finished = row + .datum_at(finished_col_idx) + .map(|d| d.into_bool()) + .unwrap_or(false); + + // Parse row count (last column) + let row_count = row + .datum_at(row.len() - 1) + .map(|d| d.into_int64() as u64) + .unwrap_or(0); + + let current_pos_data: Vec = (1..finished_col_idx) + .map(|i| row.datum_at(i).to_owned_datum()) + .collect(); + let current_pos = OwnedRow::new(current_pos_data); + + // Set progress based on is_finished flag + let progress = if is_finished { + LocalityBackfillProgress::Completed { + final_pos: current_pos, + total_rows: row_count, + } + } else { + LocalityBackfillProgress::InProgress { + current_pos, + processed_rows: row_count, + } + }; + + backfill_state.per_vnode.insert(vnode, progress); + total_snapshot_rows += row_count; + } + // If no row found, keep the default NotStarted state + } + + backfill_state.total_snapshot_rows = total_snapshot_rows; + Ok(backfill_state) + } + + /// Mark chunk for forwarding based on backfill progress + fn mark_chunk( + chunk: StreamChunk, + backfill_state: &LocalityBackfillState, + state_table: &StateTable, + ) -> StreamExecutorResult { + let chunk = chunk.compact(); + let (data, ops) = chunk.into_parts(); + let mut new_visibility = risingwave_common::bitmap::BitmapBuilder::with_capacity(ops.len()); + + let pk_indices = state_table.pk_indices(); + let pk_order = state_table.pk_serde().get_order_types(); + + for row in data.rows() { + // Project to primary key columns for comparison + let pk = row.project(pk_indices); + let vnode = state_table.compute_vnode_by_pk(pk); + + let visible = match backfill_state.get_progress(&vnode) { + LocalityBackfillProgress::Completed { .. } => true, + LocalityBackfillProgress::NotStarted => false, + LocalityBackfillProgress::InProgress { current_pos, .. } => { + // Compare primary key with current position + cmp_datum_iter(pk.iter(), current_pos.iter(), pk_order.iter().copied()).is_le() + } + }; + + new_visibility.append(visible); + } + + let (columns, _) = data.into_parts(); + let chunk = StreamChunk::with_visibility(ops, columns, new_visibility.finish()); + Ok(chunk) + } + + fn handle_snapshot_chunk( + data_chunk: DataChunk, + vnode: VirtualNode, + pk_indices: &[usize], + backfill_state: &mut LocalityBackfillState, + cur_barrier_snapshot_processed_rows: &mut u64, + ) -> StreamExecutorResult { + let chunk = StreamChunk::from_parts(vec![Op::Insert; data_chunk.cardinality()], data_chunk); + let chunk_cardinality = chunk.cardinality() as u64; + + // Extract primary key from the last row to update progress + // As snapshot read streams are ordered by pk, we can use the last row to update current_pos + if let Some(last_row) = chunk.rows().last() { + let pk = last_row.1.project(pk_indices); + let pk_owned = pk.into_owned_row(); + backfill_state.update_progress(vnode, pk_owned, chunk_cardinality); + } + + *cur_barrier_snapshot_processed_rows += chunk_cardinality; + Ok(chunk) + } +} + +impl Execute for LocalityProviderExecutor { + fn execute(self: Box) -> BoxedMessageStream { + self.execute_inner().boxed() + } +} + +impl LocalityProviderExecutor { + #[try_stream(ok = Message, error = StreamExecutorError)] + async fn execute_inner(mut self) { + let mut upstream = self.upstream.execute(); + + // Wait for first barrier to initialize + let first_barrier = expect_first_barrier(&mut upstream).await?; + let first_epoch = first_barrier.epoch; + + // Propagate the first barrier + yield Message::Barrier(first_barrier); + + let mut state_table = self.state_table; + let mut progress_table = self.progress_table; + + // Initialize state tables + state_table.init_epoch(first_epoch).await?; + progress_table.init_epoch(first_epoch).await?; + + // Load backfill state from progress table + let mut backfill_state = Self::load_backfill_state(&progress_table).await?; + + // Get pk info from state table + let pk_indices = state_table.pk_indices().iter().cloned().collect_vec(); + + let need_backfill = !backfill_state.is_completed(); + + let need_buffering = backfill_state + .per_vnode + .values() + .all(|progress| matches!(progress, LocalityBackfillProgress::NotStarted)); + + // Initial buffering phase before backfill - wait for StartFragmentBackfill mutation (if needed) + if need_buffering { + // Enter buffering phase - buffer data until StartFragmentBackfill is received + let mut start_backfill = false; + + #[for_await] + for msg in upstream.by_ref() { + let msg = msg?; + + match msg { + Message::Watermark(_) => { + // Ignore watermarks during initial buffering + } + Message::Chunk(chunk) => { + state_table.write_chunk(chunk); + state_table.try_flush().await?; + } + Message::Barrier(barrier) => { + let epoch = barrier.epoch; + + // Check for StartFragmentBackfill mutation + if let Some(mutation) = barrier.mutation.as_deref() { + use crate::executor::Mutation; + if let Mutation::StartFragmentBackfill { fragment_ids } = mutation { + tracing::info!( + "Start backfill of locality provider with fragment id: {:?}", + &self.fragment_id + ); + if fragment_ids.contains(&self.fragment_id) { + start_backfill = true; + } + } + } + + // Commit state tables + let post_commit1 = state_table.commit(epoch).await?; + let post_commit2 = progress_table.commit(epoch).await?; + + yield Message::Barrier(barrier); + post_commit1.post_yield_barrier(None).await?; + post_commit2.post_yield_barrier(None).await?; + + // Start backfill when StartFragmentBackfill mutation is received + if start_backfill { + break; + } + } + } + } + } + + // Locality Provider Backfill Algorithm (adapted from Arrangement Backfill): + // + // backfill_stream + // / \ + // upstream snapshot (from state_table) + // + // We construct a backfill stream with upstream as its left input and locality-ordered + // snapshot read stream as its right input. When a chunk comes from upstream, we buffer it. + // + // When a barrier comes from upstream: + // - For each row of the upstream chunk buffer, compute vnode. + // - Get the `current_pos` corresponding to the vnode. Forward it to downstream if its + // locality key <= `current_pos`, otherwise ignore it. + // - Flush all buffered upstream_chunks to state table. + // - Persist backfill progress to progress table. + // - Reconstruct the whole backfill stream with upstream and new snapshot read stream. + // + // When a chunk comes from snapshot, we forward it to the downstream and raise + // `current_pos`. + // + // When we reach the end of the snapshot read stream, it means backfill has been + // finished. + // + // Once the backfill loop ends, we forward the upstream directly to the downstream. + + if need_backfill { + let mut upstream_chunk_buffer: Vec = vec![]; + let mut pending_barrier: Option = None; + + let metrics = self + .metrics + .new_backfill_metrics(state_table.table_id(), self.actor_id); + + // Create builders for snapshot data chunks + let snapshot_data_types = self.input_schema.data_types(); + let mut builders: Builders = state_table + .vnodes() + .iter_vnodes() + .map(|vnode| { + let builder = create_builder( + RateLimit::Disabled, + self.chunk_size, + snapshot_data_types.clone(), + ); + (vnode, builder) + }) + .collect(); + + 'backfill_loop: loop { + let mut cur_barrier_snapshot_processed_rows: u64 = 0; + let mut cur_barrier_upstream_processed_rows: u64 = 0; + + // Create the backfill stream with upstream and snapshot + { + let left_upstream = upstream.by_ref().map(Either::Left); + let right_snapshot = pin!( + Self::make_snapshot_stream(&state_table, backfill_state.clone(),) + .map(Either::Right) + ); + + // Prefer to select upstream, so we can stop snapshot stream as soon as the + // barrier comes. + let mut backfill_stream = + select_with_strategy(left_upstream, right_snapshot, |_: &mut ()| { + stream::PollNext::Left + }); + + #[for_await] + for either in &mut backfill_stream { + match either { + // Upstream + Either::Left(msg) => { + match msg? { + Message::Barrier(barrier) => { + // We have to process the barrier outside of the loop. + pending_barrier = Some(barrier); + break; + } + Message::Chunk(chunk) => { + // Buffer the upstream chunk. + upstream_chunk_buffer.push(chunk.compact()); + } + Message::Watermark(_) => { + // Ignore watermark during backfill. + } + } + } + // Snapshot read + Either::Right(msg) => { + match msg? { + None => { + // End of the snapshot read stream. + // Consume remaining rows in the builders. + for (vnode, builder) in &mut builders { + if let Some(data_chunk) = builder.consume_all() { + let chunk = Self::handle_snapshot_chunk( + data_chunk, + *vnode, + &pk_indices, + &mut backfill_state, + &mut cur_barrier_snapshot_processed_rows, + )?; + yield Message::Chunk(chunk); + } + } + + // Consume remaining rows in the upstream buffer. + for chunk in upstream_chunk_buffer.drain(..) { + let chunk_cardinality = chunk.cardinality() as u64; + cur_barrier_upstream_processed_rows += + chunk_cardinality; + yield Message::Chunk(chunk); + } + metrics + .backfill_snapshot_read_row_count + .inc_by(cur_barrier_snapshot_processed_rows); + metrics + .backfill_upstream_output_row_count + .inc_by(cur_barrier_upstream_processed_rows); + break 'backfill_loop; + } + Some((vnode, row)) => { + // Use builder to batch rows efficiently + let builder = builders.get_mut(&vnode).unwrap(); + if let Some(data_chunk) = builder.append_one_row(row) { + // Builder is full, handle the chunk + let chunk = Self::handle_snapshot_chunk( + data_chunk, + vnode, + &pk_indices, + &mut backfill_state, + &mut cur_barrier_snapshot_processed_rows, + )?; + yield Message::Chunk(chunk); + } + // If append_one_row returns None, row is buffered but no chunk is produced yet + // Progress will be updated when the builder is consumed later + } + } + } + } + } + } + + // Process barrier + let barrier = match pending_barrier.take() { + Some(barrier) => barrier, + None => break 'backfill_loop, // Reached end of backfill + }; + + // Consume remaining rows from builders at barrier + for (vnode, builder) in &mut builders { + if let Some(data_chunk) = builder.consume_all() { + let chunk = Self::handle_snapshot_chunk( + data_chunk, + *vnode, + &pk_indices, + &mut backfill_state, + &mut cur_barrier_snapshot_processed_rows, + )?; + yield Message::Chunk(chunk); + } + } + + // Process upstream buffer chunks with marking + for chunk in upstream_chunk_buffer.drain(..) { + cur_barrier_upstream_processed_rows += chunk.cardinality() as u64; + + // Mark chunk based on backfill progress + if backfill_state.has_progress() { + let marked_chunk = + Self::mark_chunk(chunk.clone(), &backfill_state, &state_table)?; + yield Message::Chunk(marked_chunk); + } + } + + // no-op commit state table + state_table + .commit_assert_no_update_vnode_bitmap(barrier.epoch) + .await?; + + // Update progress with current epoch and snapshot read count + let total_snapshot_processed_rows: u64 = backfill_state + .vnodes() + .map(|(_, progress)| match *progress { + LocalityBackfillProgress::InProgress { processed_rows, .. } => { + processed_rows + } + LocalityBackfillProgress::Completed { total_rows, .. } => total_rows, + LocalityBackfillProgress::NotStarted => 0, + }) + .sum(); + + self.progress.update( + barrier.epoch, + barrier.epoch.curr, // Use barrier epoch as snapshot read epoch + total_snapshot_processed_rows, + ); + + // Persist backfill progress + Self::persist_backfill_state(&mut progress_table, &backfill_state).await?; + let barrier_epoch = barrier.epoch; + let post_commit = progress_table.commit(barrier_epoch).await?; + + metrics + .backfill_snapshot_read_row_count + .inc_by(cur_barrier_snapshot_processed_rows); + metrics + .backfill_upstream_output_row_count + .inc_by(cur_barrier_upstream_processed_rows); + + yield Message::Barrier(barrier); + post_commit.post_yield_barrier(None).await?; + } + } + + tracing::debug!("Locality provider backfill finished, forwarding upstream directly"); + + // Wait for first barrier after backfill completion to mark progress as finished + if need_backfill && !backfill_state.is_completed() { + while let Some(Ok(msg)) = upstream.next().await { + match msg { + Message::Barrier(barrier) => { + // no-op commit state table + state_table + .commit_assert_no_update_vnode_bitmap(barrier.epoch) + .await?; + + // Mark all vnodes as completed + for vnode in state_table.vnodes().iter_vnodes() { + backfill_state.finish_vnode(vnode, pk_indices.len()); + } + + // Calculate final total processed rows + let total_snapshot_processed_rows: u64 = backfill_state + .vnodes() + .map(|(_, progress)| match *progress { + LocalityBackfillProgress::Completed { total_rows, .. } => { + total_rows + } + LocalityBackfillProgress::InProgress { processed_rows, .. } => { + processed_rows + } + LocalityBackfillProgress::NotStarted => 0, + }) + .sum(); + + // Finish progress reporting + self.progress + .finish(barrier.epoch, total_snapshot_processed_rows); + + // Persist final state + Self::persist_backfill_state(&mut progress_table, &backfill_state).await?; + let post_commit = progress_table.commit(barrier.epoch).await?; + + yield Message::Barrier(barrier); + post_commit.post_yield_barrier(None).await?; + break; // Exit the loop after processing the barrier + } + Message::Chunk(chunk) => { + // Forward chunks directly during completion phase + yield Message::Chunk(chunk); + } + Message::Watermark(watermark) => { + // Forward watermarks directly during completion phase + yield Message::Watermark(watermark); + } + } + } + } + + // TODO: truncate the state table after backfill. + + // After backfill completion, forward messages directly + #[for_await] + for msg in upstream { + let msg = msg?; + + match msg { + Message::Barrier(barrier) => { + // Commit state tables but don't modify them + state_table + .commit_assert_no_update_vnode_bitmap(barrier.epoch) + .await?; + progress_table + .commit_assert_no_update_vnode_bitmap(barrier.epoch) + .await?; + yield Message::Barrier(barrier); + } + _ => { + // Forward all other messages directly + yield msg; + } + } + } + } +} diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 28c8e77f34afd..ad519820c7114 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -93,6 +93,7 @@ mod filter; pub mod hash_join; mod hop_window; mod join; +pub mod locality_provider; mod lookup; mod lookup_union; mod merge; diff --git a/src/stream/src/from_proto/locality_provider.rs b/src/stream/src/from_proto/locality_provider.rs new file mode 100644 index 0000000000000..7472705f8d34e --- /dev/null +++ b/src/stream/src/from_proto/locality_provider.rs @@ -0,0 +1,86 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use risingwave_pb::stream_plan::LocalityProviderNode; +use risingwave_storage::StateStore; + +use super::*; +use crate::common::table::state_table::StateTableBuilder; +use crate::executor::Executor; +use crate::executor::locality_provider::LocalityProviderExecutor; + +impl ExecutorBuilder for LocalityProviderBuilder { + type Node = LocalityProviderNode; + + async fn new_boxed_executor( + params: ExecutorParams, + node: &Self::Node, + store: impl StateStore, + ) -> StreamResult { + let [input]: [_; 1] = params.input.try_into().unwrap(); + + let locality_columns = node + .locality_columns + .iter() + .map(|&i| i as usize) + .collect::>(); + + let input_schema = input.schema().clone(); + + let vnodes = Some(Arc::new( + params + .vnode_bitmap + .expect("vnodes not set for locality provider"), + )); + + // Create state table for buffering input data + let state_table = StateTableBuilder::new( + node.get_state_table().unwrap(), + store.clone(), + vnodes.clone(), + ) + .enable_preload_all_rows_by_config(¶ms.actor_context.streaming_config) + .build() + .await; + + // Create progress table for tracking backfill progress + let progress_table = + StateTableBuilder::new(node.get_progress_table().unwrap(), store, vnodes) + .enable_preload_all_rows_by_config(¶ms.actor_context.streaming_config) + .build() + .await; + + let progress = params + .local_barrier_manager + .register_create_mview_progress(params.actor_context.id); + + let exec = LocalityProviderExecutor::new( + input, + locality_columns, + state_table, + progress_table, + input_schema, + progress, + params.executor_stats.clone(), + params.env.config().developer.chunk_size, + params.actor_context.fragment_id, + ); + + Ok((params.info, exec).into()) + } +} + +pub struct LocalityProviderBuilder; diff --git a/src/stream/src/from_proto/mod.rs b/src/stream/src/from_proto/mod.rs index 90f2eeb3a6eee..25c0bfa8d5169 100644 --- a/src/stream/src/from_proto/mod.rs +++ b/src/stream/src/from_proto/mod.rs @@ -30,6 +30,7 @@ mod group_top_n; mod hash_agg; mod hash_join; mod hop_window; +mod locality_provider; mod lookup; mod lookup_union; mod materialized_exprs; @@ -85,6 +86,7 @@ use self::group_top_n::GroupTopNExecutorBuilder; use self::hash_agg::*; use self::hash_join::*; use self::hop_window::*; +use self::locality_provider::*; use self::lookup::*; use self::lookup_union::*; use self::materialized_exprs::MaterializedExprsExecutorBuilder; @@ -201,5 +203,6 @@ pub async fn create_executor( NodeBody::MaterializedExprs => MaterializedExprsExecutorBuilder, NodeBody::VectorIndexWrite => VectorIndexWriteExecutorBuilder, NodeBody::UpstreamSinkUnion => UpstreamSinkUnionExecutorBuilder, + NodeBody::LocalityProvider => LocalityProviderBuilder, } }