Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions proto/stream_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,11 @@ message VectorIndexWriteNode {
catalog.Table table = 1;
}

message VectorIndexLookupJoinNode {
plan_common.VectorIndexReaderDesc reader_desc = 1;
uint32 vector_column_idx = 2;
}

message UpstreamSinkUnionNode {
// It is always empty in the persisted metadata, and get filled before we spawn the actors.
// The actual upstream info may be added and removed dynamically at runtime.
Expand Down Expand Up @@ -1070,6 +1075,7 @@ message StreamNode {
VectorIndexWriteNode vector_index_write = 150;
UpstreamSinkUnionNode upstream_sink_union = 151;
LocalityProviderNode locality_provider = 152;
VectorIndexLookupJoinNode vector_index_lookup_join = 153;
}
// The id for the operator. This is local per mview.
// TODO: should better be a uint32.
Expand Down
19 changes: 17 additions & 2 deletions src/frontend/planner_test/tests/testdata/input/vector_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
- id: create_correlated_tables_with_column_value_index
sql: |
create table items (id int primary key, name string, embedding vector(3)) append only;
create table events (event_id int primary key, time timestamp, embedding vector(3));
create table events (event_id int primary key, time timestamp, embedding vector(3)) append only;
create index i on items using flat (embedding) with (distance_type = 'l2');
expected_outputs: []
- before:
Expand Down Expand Up @@ -232,4 +232,19 @@
expected_outputs:
- logical_plan
- optimized_logical_plan_for_batch
- batch_plan
- batch_plan
- before:
- create_correlated_tables_with_column_value_index
id: correlated_read_without_embedding
sql: |
select
event_id, array(
select row(id, name, distance)
from (select id, name, events.embedding <-> items.embedding as distance from items order by distance limit 3)
) as related_info,
time
from events;
expected_outputs:
- logical_plan
- optimized_logical_plan_for_stream
- stream_plan
30 changes: 27 additions & 3 deletions src/frontend/planner_test/tests/testdata/output/vector_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@
- id: create_correlated_tables_with_column_value_index
sql: |
create table items (id int primary key, name string, embedding vector(3)) append only;
create table events (event_id int primary key, time timestamp, embedding vector(3));
create table events (event_id int primary key, time timestamp, embedding vector(3)) append only;
create index i on items using flat (embedding) with (distance_type = 'l2');
- id: correlated_read_without_embedding
before:
Expand All @@ -459,7 +459,7 @@
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [events.event_id, vector_info, events.time] }
└─BatchVectorSearch { schema: [events.event_id:Int32, events.time:Timestamp, events.embedding:Vector(3), vector_info:List(Struct(StructType { fields: [("name", Varchar)] }))], top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding }
└─BatchVectorSearch { top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding, lookup_output: [("name", Varchar)], include_distance: false }
└─BatchScan { table: events, columns: [events.event_id, events.time, events.embedding], distribution: Single }
- id: correlated_read_without_embedding
before:
Expand All @@ -483,5 +483,29 @@
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [events.event_id, vector_info, events.time] }
└─BatchVectorSearch { schema: [events.event_id:Int32, events.time:Timestamp, events.embedding:Vector(3), vector_info:List(Struct(StructType { fields: [("id", Int32), ("name", Varchar), ("__distance", Float64)] }))], top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding }
└─BatchVectorSearch { top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding, lookup_output: [("id", Int32), ("name", Varchar)], include_distance: true }
└─BatchScan { table: events, columns: [events.event_id, events.time, events.embedding], distribution: Single }
- id: correlated_read_without_embedding
before:
- create_correlated_tables_with_column_value_index
sql: "select \n event_id, array(\n select row(id, name, distance)\n from (select id, name, events.embedding <-> items.embedding as distance from items order by distance limit 3)\n ) as related_info, \n time\nfrom events;\n"
logical_plan: |-
LogicalProject { exprs: [events.event_id, $expr3, events.time] }
└─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true }
├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] }
└─LogicalProject { exprs: [Coalesce(array_agg($expr2), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Varchar), ("f3", Float64)] }))) as $expr3] }
└─LogicalAgg { aggs: [array_agg($expr2)] }
└─LogicalProject { exprs: [Row(items.id, items.name, $expr1) as $expr2] }
└─LogicalTopN { order: [$expr1 ASC], limit: 3, offset: 0 }
└─LogicalProject { exprs: [items.id, items.name, L2Distance(CorrelatedInputRef { index: 2, correlated_id: 1 }, items.embedding) as $expr1] }
└─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] }
optimized_logical_plan_for_stream: |-
LogicalProject { exprs: [events.event_id, array, events.time] }
└─LogicalVectorSearchLookupJoin { distance_type: L2Sqr, top_n: 3, input_vector: events.embedding:Vector(3), lookup_vector: items.embedding, lookup_output_columns: [items.id:Int32, items.name:Varchar], include_distance: true }
├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] }
└─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] }
stream_plan: |-
StreamMaterialize { columns: [event_id, related_info, time], stream_key: [event_id], pk_columns: [event_id], pk_conflict: NoCheck }
└─StreamProject { exprs: [events.event_id, vector_info, events.time] }
└─StreamVectorIndexLookupJoin { top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding, lookup_output: [("id", Int32), ("name", Varchar)], include_distance: true }
└─StreamTableScan { table: events, columns: [events.event_id, events.time, events.embedding], stream_scan_type: ArrangementBackfill, stream_key: [events.event_id], pk: [event_id], dist: UpstreamHashShard(events.event_id) }
2 changes: 2 additions & 0 deletions src/frontend/src/optimizer/logical_optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ impl LogicalOptimizer {
// In order to unnest a table function, we need to convert it into a `project_set` first.
plan = plan.optimize_by_rules(&TABLE_FUNCTION_CONVERT)?;

plan = plan.optimize_by_rules(&CORRELATED_TOP_N_TO_VECTOR_SEARCH)?;

plan = Self::subquery_unnesting(plan, enable_share_plan, explain_trace, &ctx)?;
if has_logical_max_one_row(plan.clone()) {
// `MaxOneRow` is currently only used for the runtime check of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,39 @@ impl PredicatePushdown for LogicalVectorSearchLookupJoin {
impl ToStream for LogicalVectorSearchLookupJoin {
fn logical_rewrite_for_stream(
&self,
_ctx: &mut RewriteStreamContext,
ctx: &mut RewriteStreamContext,
) -> crate::error::Result<(PlanRef, ColIndexMapping)> {
bail!("LogicalVectorSearch can only for batch plan, not stream plan");
if !self
.core
.input
.logical_rewrite_for_stream(ctx)?
.1
.is_identity()
{
// TODO: support it
bail!(
"LogicalVectorSearchLookupJoin does not support input that can possibly be rewritten"
)
}
Ok((
self.clone().into(),
ColIndexMapping::identity(self.base.schema().len()),
))
}

fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
bail!("LogicalVectorSearch can only for batch plan, not stream plan");
fn to_stream(&self, ctx: &mut ToStreamContext) -> crate::error::Result<StreamPlanRef> {
if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_stream(ctx))? {
return Ok(StreamVectorIndexLookupJoin::new(core)?.into());
}
bail!("LogicalVectorSearchLookupJoin should use proper vector index in streaming job")
}
}

impl ToBatch for LogicalVectorSearchLookupJoin {
fn to_batch(&self) -> Result<BatchPlanRef> {
impl LogicalVectorSearchLookupJoin {
fn to_vector_index_lookup_join<PlanRef>(
&self,
gen_input: impl FnOnce(&LogicalPlanRef) -> Result<PlanRef>,
) -> Result<Option<VectorIndexLookupJoin<PlanRef>>> {
if let Some(scan) = self.core.lookup.as_logical_scan()
&& let Some((
index,
Expand Down Expand Up @@ -319,7 +340,7 @@ impl ToBatch for LogicalVectorSearchLookupJoin {
})
.collect();
let core = VectorIndexLookupJoin {
input: self.core.input.to_batch()?,
input: gen_input(&self.core.input)?,
top_n: self.core.top_n,
distance_type: self.core.distance_type,
index_name: index.index_table.name.clone(),
Expand All @@ -331,6 +352,15 @@ impl ToBatch for LogicalVectorSearchLookupJoin {
hnsw_ef_search,
ctx: self.core.ctx(),
};
return Ok(Some(core));
}
Ok(None)
}
}

impl ToBatch for LogicalVectorSearchLookupJoin {
fn to_batch(&self) -> Result<BatchPlanRef> {
if let Some(core) = self.to_vector_index_lookup_join(|plan| plan.to_batch())? {
return Ok(BatchVectorSearch::with_core(core).into());
}
let todo = 0;
Expand Down
3 changes: 3 additions & 0 deletions src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ mod stream_share;
mod stream_temporal_join;
mod stream_union;
mod stream_upstream_sink_union;
mod stream_vector_index_lookup_join;
mod stream_vector_index_write;
pub mod utils;

Expand Down Expand Up @@ -1209,6 +1210,7 @@ pub use stream_topn::StreamTopN;
pub use stream_union::StreamUnion;
pub use stream_upstream_sink_union::StreamUpstreamSinkUnion;
pub use stream_values::StreamValues;
pub use stream_vector_index_lookup_join::StreamVectorIndexLookupJoin;
pub use stream_vector_index_write::StreamVectorIndexWrite;
pub use stream_watermark_filter::StreamWatermarkFilter;

Expand Down Expand Up @@ -1353,6 +1355,7 @@ macro_rules! for_all_plan_nodes {
, { Stream, SyncLogStore }
, { Stream, MaterializedExprs }
, { Stream, VectorIndexWrite }
, { Stream, VectorIndexLookupJoin }
, { Stream, UpstreamSinkUnion }
, { Stream, LocalityProvider }
$(,$rest)*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2025 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use pretty_xmlish::XmlNode;
use risingwave_common::bail;
use risingwave_pb::plan_common::PbVectorIndexReaderDesc;
use risingwave_pb::stream_plan::PbVectorIndexLookupJoinNode;

use crate::optimizer::plan_node::expr_visitable::ExprVisitable;
use crate::optimizer::plan_node::generic::{PhysicalPlanRef, VectorIndexLookupJoin};
use crate::optimizer::plan_node::stream::StreamPlanNodeMetadata;
use crate::optimizer::plan_node::utils::{Distill, childless_record};
use crate::optimizer::plan_node::{
ExprRewritable, PlanBase, PlanTreeNodeUnary, Stream, StreamNode, StreamPlanRef,
};
use crate::optimizer::property::StreamKind;
use crate::stream_fragmenter::BuildFragmentGraphState;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StreamVectorIndexLookupJoin {
pub base: PlanBase<Stream>,
pub core: VectorIndexLookupJoin<StreamPlanRef>,
}

impl StreamVectorIndexLookupJoin {
pub fn new(core: VectorIndexLookupJoin<StreamPlanRef>) -> crate::error::Result<Self> {
if core.input.stream_kind() != StreamKind::AppendOnly {
bail!("StreamVectorIndexLookupJoin only support append only input")
}
Ok(Self::with_core(core))
}

fn with_core(core: VectorIndexLookupJoin<StreamPlanRef>) -> Self {
assert_eq!(core.input.stream_kind(), StreamKind::AppendOnly);
let base = PlanBase::new_stream_with_core(
&core,
core.input.distribution().clone(),
core.input.stream_kind(),
core.input.emit_on_window_close(),
core.input.watermark_columns().clone(),
core.input.columns_monotonicity().clone(),
);
Self { base, core }
}
}

impl Distill for StreamVectorIndexLookupJoin {
fn distill<'a>(&self) -> XmlNode<'a> {
let fields = self.core.distill();
childless_record("StreamVectorIndexLookupJoin", fields)
}
}

impl PlanTreeNodeUnary<Stream> for StreamVectorIndexLookupJoin {
fn input(&self) -> crate::PlanRef<Stream> {
self.core.input.clone()
}

fn clone_with_input(&self, input: crate::PlanRef<Stream>) -> Self {
let mut core = self.core.clone();
core.input = input;
Self::with_core(core)
}
}

impl_plan_tree_node_for_unary!(Stream, StreamVectorIndexLookupJoin);

impl StreamNode for StreamVectorIndexLookupJoin {
fn to_stream_prost_body(
&self,
_state: &mut BuildFragmentGraphState,
) -> risingwave_pb::stream_plan::stream_node::NodeBody {
risingwave_pb::stream_plan::stream_node::NodeBody::VectorIndexLookupJoin(
PbVectorIndexLookupJoinNode {
reader_desc: Some(PbVectorIndexReaderDesc {
table_id: self.core.index_table_id.table_id,
info_column_desc: self
.core
.info_column_desc
.iter()
.map(|col| col.to_protobuf())
.collect(),
top_n: self.core.top_n as _,
distance_type: self.core.distance_type as _,
hnsw_ef_search: self.core.hnsw_ef_search.unwrap_or(0) as _,
info_output_indices: self
.core
.info_output_indices
.iter()
.map(|&idx| idx as _)
.collect(),
include_distance: self.core.include_distance,
}),
vector_column_idx: self.core.vector_column_idx as _,
}
.into(),
)
}
}

impl ExprVisitable for StreamVectorIndexLookupJoin {}

impl ExprRewritable<Stream> for StreamVectorIndexLookupJoin {}
1 change: 1 addition & 0 deletions src/prost/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.boxed(".stream_plan.StreamNode.node_body.materialized_exprs")
.boxed(".stream_plan.StreamNode.node_body.vector_index_write")
.boxed(".stream_plan.StreamNode.node_body.locality_provider")
.boxed(".stream_plan.StreamNode.node_body.vector_index_lookup_join")
// `Udf` is 248 bytes, while 2nd largest field is 32 bytes.
.boxed(".expr.ExprNode.rex_node.udf")
// Eq + Hash are for plan nodes to do common sub-plan detection.
Expand Down
4 changes: 2 additions & 2 deletions src/stream/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ mod sync_kv_log_store;
#[cfg(any(test, feature = "test"))]
pub mod test_utils;
mod utils;
mod vector_index;
mod vector;

pub use actor::{Actor, ActorContext, ActorContextRef};
use anyhow::Context;
Expand Down Expand Up @@ -181,7 +181,7 @@ pub use union::UnionExecutor;
pub use upstream_sink_union::{UpstreamFragmentInfo, UpstreamSinkUnionExecutor};
pub use utils::DummyExecutor;
pub use values::ValuesExecutor;
pub use vector_index::VectorIndexWriteExecutor;
pub use vector::*;
pub use watermark_filter::WatermarkFilterExecutor;
pub use wrapper::WrapperExecutor;

Expand Down
Loading