diff --git a/e2e_test/vector_search/vector_nearest.slt.part b/e2e_test/vector_search/vector_nearest.slt.part index c20e006a76343..d3f992214bc68 100644 --- a/e2e_test/vector_search/vector_nearest.slt.part +++ b/e2e_test/vector_search/vector_nearest.slt.part @@ -88,6 +88,34 @@ select * from query_view order by distance; statement ok drop view query_view; +# test vector index lookup join on raw embedding column +# query T +# with input as (select '[3,2,1]'::vector(3) as embedding) select array(select row(id, text) from no_index_mv order by input.embedding <-> no_index_mv.embedding limit 2) as related_info from input; +# ---- +# {"(1,first)","(2,second)"} + +statement ok +create view query_view as with input as (select '[3,2,1]'::vector(3) as embedding) select array(select row(id, text) from items order by input.embedding <-> items.embedding limit 2) as related_info from input; + +# ensure that vector index is used +query T +explain(verbose) select * from query_view; +---- +BatchExchange { order: [], dist: Single } +└─BatchProject { exprs: [vector_info] } + └─BatchVectorSearch { top_n: 2, distance_type: L2Sqr, index_name: "i", vector: '[3,2,1]':Vector(3), lookup_output: [("items.id", Int32), ("text", Varchar)], include_distance: false } + └─BatchValues { rows: [['[3,2,1]':Vector(3)]] } + + + +query T +select * from query_view; +---- +{"(1,first)","(2,second)"} + +statement ok +drop view query_view; + statement ok drop index i; @@ -153,6 +181,33 @@ select * from query_view order by distance; statement ok drop view query_view; +# test vector index lookup join on functional embedding column +# query T +# with input as (select '[3,2,1]'::vector(3) as embedding) select array(select row(id, text) from no_index_mv order by input.embedding <-> get_embedding(no_index_mv.text) limit 2) as related_info from input; +# ---- +# {"(1,first)","(2,second)"} + +statement ok +create view query_view as with input as (select '[3,2,1]'::vector(3) as embedding) select array(select row(id, text) from items order by input.embedding <-> get_embedding(items.text) limit 2) as related_info from input; + +# ensure that vector index is used +query T +explain(verbose) select * from query_view; +---- +BatchExchange { order: [], dist: Single } +└─BatchProject { exprs: [vector_info] } + └─BatchVectorSearch { top_n: 2, distance_type: L2Sqr, index_name: "i", vector: '[3,2,1]':Vector(3), lookup_output: [("items.id", Int32), ("text", Varchar)], include_distance: false } + └─BatchValues { rows: [['[3,2,1]':Vector(3)]] } + + +query T +select * from query_view; +---- +{"(1,first)","(2,second)"} + +statement ok +drop view query_view; + statement ok drop index i; @@ -162,7 +217,7 @@ drop materialized view no_index_mv; statement ok drop table items; -# test flat index +# test hnsw index statement ok create table items (id int primary key, extra string, text string, embedding vector(3)) append only; diff --git a/src/common/src/catalog/schema.rs b/src/common/src/catalog/schema.rs index 4ea4aaeb059d5..bebe42439a288 100644 --- a/src/common/src/catalog/schema.rs +++ b/src/common/src/catalog/schema.rs @@ -198,7 +198,7 @@ impl Schema { } for (a, b) in self.fields.iter().zip_eq_fast(other.fields.iter()) { - if a.data_type != b.data_type { + if !a.data_type.equals_datatype(&b.data_type) { return false; } } diff --git a/src/frontend/planner_test/tests/testdata/input/vector_search.yaml b/src/frontend/planner_test/tests/testdata/input/vector_search.yaml index 9d75154d28b0e..c762b6c7d5506 100644 --- a/src/frontend/planner_test/tests/testdata/input/vector_search.yaml +++ b/src/frontend/planner_test/tests/testdata/input/vector_search.yaml @@ -143,6 +143,92 @@ - create_table_and_function_index sql: | SELECT id, name FROM items order by openai_embedding('{"model": "model"}'::jsonb, text)::vector(3) <#> '[3,1,2]' limit 5; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch + - batch_plan +- id: create_correlated_tables + sql: | + create table items (id int primary key, name string, embedding vector(3)) append only; + create table events (event_id int primary key, time timestamp, embedding vector(3)); + expected_outputs: [] +- before: + - create_correlated_tables + id: correlated_read_without_embedding + sql: | + select + event_id, array( + select row(id, name) + from items + order by events.embedding <=> items.embedding + limit 3 + ) as related_info, + time + from events; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch +- before: + - create_correlated_tables + id: correlated_read_with_embedding + sql: | + select + event_id, time, embedding, array( + select row(id, name) + from items + order by items.embedding <=> events.embedding + limit 3 + ) + as related_info from events; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch +- before: + - create_correlated_tables + id: correlated_read_with_distance + sql: | + select + event_id, array( + select row(id, distance, name) + from (select id, name, events.embedding <=> items.embedding as distance from items order by distance limit 3) + ) as related_info, + time + from events; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch +- id: create_correlated_tables_with_column_value_index + sql: | + create table items (id int primary key, name string, embedding vector(3)) append only; + create table events (event_id int primary key, time timestamp, embedding vector(3)); + create index i on items using flat (embedding) with (distance_type = 'l2'); + expected_outputs: [] +- before: + - create_correlated_tables_with_column_value_index + id: correlated_read_without_embedding + sql: | + select + event_id, array( + select row(name) + from (select name from items order by events.embedding <-> items.embedding limit 3) + ) as related_info, + time + from events; + expected_outputs: + - logical_plan + - optimized_logical_plan_for_batch + - batch_plan +- before: + - create_correlated_tables_with_column_value_index + id: correlated_read_without_embedding + sql: | + select + event_id, array( + select row(id, name, distance) + from (select id, name, events.embedding <-> items.embedding as distance from items order by distance limit 3) + ) as related_info, + time + from events; expected_outputs: - logical_plan - optimized_logical_plan_for_batch diff --git a/src/frontend/planner_test/tests/testdata/output/vector_search.yaml b/src/frontend/planner_test/tests/testdata/output/vector_search.yaml index 2bc2953d02792..1aa29ef2d40ec 100644 --- a/src/frontend/planner_test/tests/testdata/output/vector_search.yaml +++ b/src/frontend/planner_test/tests/testdata/output/vector_search.yaml @@ -362,3 +362,126 @@ └─BatchProjectSet { select_list: [Unnest($1)] } └─BatchVectorSearch { top_n: 5, distance_type: InnerProduct, index_name: "vector_index", vector: query_vector, lookup_output: [("name", Varchar), ("items.id", Int32)], include_distance: true } └─BatchValues { rows: [['[3,1,2]':Vector(3)]] } +- id: create_correlated_tables + sql: | + create table items (id int primary key, name string, embedding vector(3)) append only; + create table events (event_id int primary key, time timestamp, embedding vector(3)); +- id: correlated_read_without_embedding + before: + - create_correlated_tables + sql: "select \n event_id, array(\n select row(id, name)\n from items\n order by events.embedding <=> items.embedding\n limit 3\n ) as related_info, \n time\nfrom events;\n" + logical_plan: |- + LogicalProject { exprs: [events.event_id, $expr3, events.time] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] } + └─LogicalProject { exprs: [Coalesce(array_agg($expr1 order_by($expr2 ASC)), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Varchar)] }))) as $expr3] } + └─LogicalAgg { aggs: [array_agg($expr1 order_by($expr2 ASC))] } + └─LogicalTopN { order: [$expr2 ASC], limit: 3, offset: 0 } + └─LogicalProject { exprs: [Row(items.id, items.name) as $expr1, CosineDistance(CorrelatedInputRef { index: 2, correlated_id: 1 }, items.embedding) as $expr2] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + optimized_logical_plan_for_batch: |- + LogicalProject { exprs: [events.event_id, array, events.time] } + └─LogicalVectorSearchLookupJoin { distance_type: Cosine, top_n: 3, input_vector: events.embedding:Vector(3), lookup_vector: items.embedding, lookup_output_columns: [items.id:Int32, items.name:Varchar], include_distance: false } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } +- id: correlated_read_with_embedding + before: + - create_correlated_tables + sql: "select \n event_id, time, embedding, array(\n select row(id, name)\n from items\n order by items.embedding <=> events.embedding \n limit 3\n )\nas related_info from events;\n" + logical_plan: |- + LogicalProject { exprs: [events.event_id, events.time, events.embedding, $expr3] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] } + └─LogicalProject { exprs: [Coalesce(array_agg($expr1 order_by($expr2 ASC)), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Varchar)] }))) as $expr3] } + └─LogicalAgg { aggs: [array_agg($expr1 order_by($expr2 ASC))] } + └─LogicalTopN { order: [$expr2 ASC], limit: 3, offset: 0 } + └─LogicalProject { exprs: [Row(items.id, items.name) as $expr1, CosineDistance(items.embedding, CorrelatedInputRef { index: 2, correlated_id: 1 }) as $expr2] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + optimized_logical_plan_for_batch: |- + LogicalProject { exprs: [events.event_id, events.time, events.embedding, array] } + └─LogicalVectorSearchLookupJoin { distance_type: Cosine, top_n: 3, input_vector: events.embedding:Vector(3), lookup_vector: items.embedding, lookup_output_columns: [items.id:Int32, items.name:Varchar], include_distance: false } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } +- id: correlated_read_with_distance + before: + - create_correlated_tables + sql: "select \n event_id, array(\n select row(id, distance, name)\n from (select id, name, events.embedding <=> items.embedding as distance from items order by distance limit 3)\n ) as related_info, \n time\nfrom events;\n" + logical_plan: |- + LogicalProject { exprs: [events.event_id, $expr3, events.time] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] } + └─LogicalProject { exprs: [Coalesce(array_agg($expr2), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Float64), ("f3", Varchar)] }))) as $expr3] } + └─LogicalAgg { aggs: [array_agg($expr2)] } + └─LogicalProject { exprs: [Row(items.id, $expr1, items.name) as $expr2] } + └─LogicalTopN { order: [$expr1 ASC], limit: 3, offset: 0 } + └─LogicalProject { exprs: [items.id, items.name, CosineDistance(CorrelatedInputRef { index: 2, correlated_id: 1 }, items.embedding) as $expr1] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + optimized_logical_plan_for_batch: |- + LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(events.embedding, events.embedding), output: [events.event_id, $expr3, events.time] } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] } + └─LogicalProject { exprs: [events.embedding, Coalesce(array_agg($expr2) filter(IsNotNull(1:Int32)), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Float64), ("f3", Varchar)] }))) as $expr3] } + └─LogicalAgg { group_key: [events.embedding], aggs: [array_agg($expr2) filter(IsNotNull(1:Int32))] } + └─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(events.embedding, events.embedding), output: [events.embedding, $expr2, 1:Int32] } + ├─LogicalAgg { group_key: [events.embedding], aggs: [] } + │ └─LogicalScan { table: events, columns: [events.embedding] } + └─LogicalProject { exprs: [events.embedding, Row(items.id, $expr1, items.name) as $expr2, 1:Int32] } + └─LogicalTopN { order: [$expr1 ASC], limit: 3, offset: 0, group_key: [events.embedding] } + └─LogicalProject { exprs: [events.embedding, items.id, items.name, CosineDistance(events.embedding, items.embedding) as $expr1] } + └─LogicalJoin { type: Inner, on: true, output: all } + ├─LogicalAgg { group_key: [events.embedding], aggs: [] } + │ └─LogicalScan { table: events, columns: [events.embedding] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding] } +- id: create_correlated_tables_with_column_value_index + sql: | + create table items (id int primary key, name string, embedding vector(3)) append only; + create table events (event_id int primary key, time timestamp, embedding vector(3)); + create index i on items using flat (embedding) with (distance_type = 'l2'); +- id: correlated_read_without_embedding + before: + - create_correlated_tables_with_column_value_index + sql: "select \n event_id, array(\n select row(name)\n from (select name from items order by events.embedding <-> items.embedding limit 3)\n ) as related_info, \n time\nfrom events;\n" + logical_plan: |- + LogicalProject { exprs: [events.event_id, $expr3, events.time] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] } + └─LogicalProject { exprs: [Coalesce(array_agg($expr2), ARRAY[]:List(Struct(StructType { fields: [("f1", Varchar)] }))) as $expr3] } + └─LogicalAgg { aggs: [array_agg($expr2)] } + └─LogicalProject { exprs: [Row(items.name) as $expr2] } + └─LogicalProject { exprs: [items.name] } + └─LogicalTopN { order: [$expr1 ASC], limit: 3, offset: 0 } + └─LogicalProject { exprs: [items.name, L2Distance(CorrelatedInputRef { index: 2, correlated_id: 1 }, items.embedding) as $expr1] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + optimized_logical_plan_for_batch: |- + LogicalProject { exprs: [events.event_id, array, events.time] } + └─LogicalVectorSearchLookupJoin { distance_type: L2Sqr, top_n: 3, input_vector: events.embedding:Vector(3), lookup_vector: items.embedding, lookup_output_columns: [items.name:Varchar], include_distance: false } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [events.event_id, vector_info, events.time] } + └─BatchVectorSearch { schema: [events.event_id:Int32, events.time:Timestamp, events.embedding:Vector(3), vector_info:List(Struct(StructType { fields: [("name", Varchar)] }))], top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding } + └─BatchScan { table: events, columns: [events.event_id, events.time, events.embedding], distribution: Single } +- id: correlated_read_without_embedding + before: + - create_correlated_tables_with_column_value_index + sql: "select \n event_id, array(\n select row(id, name, distance)\n from (select id, name, events.embedding <-> items.embedding as distance from items order by distance limit 3)\n ) as related_info, \n time\nfrom events;\n" + logical_plan: |- + LogicalProject { exprs: [events.event_id, $expr3, events.time] } + └─LogicalApply { type: LeftOuter, on: true, correlated_id: 1, max_one_row: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding, events._rw_timestamp] } + └─LogicalProject { exprs: [Coalesce(array_agg($expr2), ARRAY[]:List(Struct(StructType { fields: [("f1", Int32), ("f2", Varchar), ("f3", Float64)] }))) as $expr3] } + └─LogicalAgg { aggs: [array_agg($expr2)] } + └─LogicalProject { exprs: [Row(items.id, items.name, $expr1) as $expr2] } + └─LogicalTopN { order: [$expr1 ASC], limit: 3, offset: 0 } + └─LogicalProject { exprs: [items.id, items.name, L2Distance(CorrelatedInputRef { index: 2, correlated_id: 1 }, items.embedding) as $expr1] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + optimized_logical_plan_for_batch: |- + LogicalProject { exprs: [events.event_id, array, events.time] } + └─LogicalVectorSearchLookupJoin { distance_type: L2Sqr, top_n: 3, input_vector: events.embedding:Vector(3), lookup_vector: items.embedding, lookup_output_columns: [items.id:Int32, items.name:Varchar], include_distance: true } + ├─LogicalScan { table: events, columns: [events.event_id, events.time, events.embedding] } + └─LogicalScan { table: items, columns: [items.id, items.name, items.embedding, items._rw_timestamp] } + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [events.event_id, vector_info, events.time] } + └─BatchVectorSearch { schema: [events.event_id:Int32, events.time:Timestamp, events.embedding:Vector(3), vector_info:List(Struct(StructType { fields: [("id", Int32), ("name", Varchar), ("__distance", Float64)] }))], top_n: 3, distance_type: L2Sqr, index_name: "i", vector: events.embedding } + └─BatchScan { table: events, columns: [events.event_id, events.time, events.embedding], distribution: Single } diff --git a/src/frontend/src/handler/create_index.rs b/src/frontend/src/handler/create_index.rs index e4d420a03c22c..3f4f985494a4c 100644 --- a/src/frontend/src/handler/create_index.rs +++ b/src/frontend/src/handler/create_index.rs @@ -22,7 +22,6 @@ use itertools::Itertools; use pgwire::pg_response::{PgResponse, StatementType}; use risingwave_common::catalog::{IndexId, TableId}; use risingwave_common::types::DataType; -use risingwave_common::util::recursive::{Recurse, tracker}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_pb::catalog::{ CreateType, PbFlatIndexConfig, PbHnswFlatIndexConfig, PbIndex, PbIndexColumnProperties, @@ -39,7 +38,7 @@ use crate::binder::Binder; use crate::catalog::root_catalog::SchemaPath; use crate::catalog::{DatabaseId, SchemaId}; use crate::error::{ErrorCode, Result}; -use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, is_impure_func_call}; +use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef}; use crate::handler::HandlerArgs; use crate::optimizer::plan_expr_rewriter::ConstEvalRewriter; use crate::optimizer::plan_node::utils::plan_can_use_background_ddl; @@ -50,8 +49,8 @@ use crate::optimizer::property::{Distribution, Order, RequiredDist}; use crate::optimizer::{LogicalPlanRoot, OptimizerContext, OptimizerContextRef, PlanRoot}; use crate::scheduler::streaming_manager::CreatingStreamingJobInfo; use crate::session::SessionImpl; -use crate::session::current::notice_to_user; use crate::stream_fragmenter::{GraphJobType, build_graph}; +use crate::utils::IndexColumnExprValidator; pub(crate) fn resolve_index_schema( session: &SessionImpl, @@ -73,66 +72,6 @@ pub(crate) fn resolve_index_schema( Ok((schema_name.to_owned(), table.clone(), index_table_name)) } -struct IndexColumnExprValidator { - allow_impure: bool, - result: Result<()>, -} - -impl IndexColumnExprValidator { - fn unsupported_expr_err(expr: &ExprImpl) -> ErrorCode { - ErrorCode::NotSupported( - format!("unsupported index column expression type: {:?}", expr), - "use columns or expressions instead".into(), - ) - } - - fn validate(expr: &ExprImpl, allow_impure: bool) -> Result<()> { - match expr { - ExprImpl::InputRef(_) | ExprImpl::FunctionCall(_) => {} - other_expr => { - return Err(Self::unsupported_expr_err(other_expr).into()); - } - } - let mut visitor = Self { - allow_impure, - result: Ok(()), - }; - visitor.visit_expr(expr); - visitor.result - } -} - -impl ExprVisitor for IndexColumnExprValidator { - fn visit_expr(&mut self, expr: &ExprImpl) { - if self.result.is_err() { - return; - } - tracker!().recurse(|t| { - if t.depth_reaches(crate::expr::EXPR_DEPTH_THRESHOLD) { - notice_to_user(crate::expr::EXPR_TOO_DEEP_NOTICE); - } - - match expr { - ExprImpl::InputRef(_) | ExprImpl::Literal(_) => {} - ExprImpl::FunctionCall(inner) => { - if !self.allow_impure && is_impure_func_call(inner) { - self.result = Err(ErrorCode::NotSupported( - "this expression is impure".into(), - "use a pure expression instead".into(), - ) - .into()); - return; - } - self.visit_function_call(inner) - } - other_expr => { - self.result = Err(Self::unsupported_expr_err(other_expr).into()); - } - } - }) - } -} - pub(crate) fn gen_create_index_plan( session: &SessionImpl, context: OptimizerContextRef, diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index a930f9c2e8f48..d0aa5f5bc3eec 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -524,6 +524,14 @@ static TOP_N_TO_VECTOR_SEARCH: LazyLock = LazyLock::new(|| { ) }); +static CORRELATED_TOP_N_TO_VECTOR_SEARCH: LazyLock = LazyLock::new(|| { + OptimizationStage::new( + "Correlated TopN to Vector Search", + vec![CorrelatedTopNToVectorSearchRule::create()], + ApplyOrder::BottomUp, + ) +}); + impl LogicalOptimizer { pub fn predicate_pushdown( plan: LogicalPlanRef, @@ -784,7 +792,7 @@ impl LogicalOptimizer { // In order to unnest a table function, we need to convert it into a `project_set` first. plan = plan.optimize_by_rules(&TABLE_FUNCTION_CONVERT)?; - plan = plan.optimize_by_rules(&TOP_N_TO_VECTOR_SEARCH)?; + plan = plan.optimize_by_rules(&CORRELATED_TOP_N_TO_VECTOR_SEARCH)?; plan = Self::subquery_unnesting(plan, false, explain_trace, &ctx)?; @@ -834,6 +842,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&JOIN_COMMUTE)?; + plan = plan.optimize_by_rules(&TOP_N_TO_VECTOR_SEARCH)?; + // Do a final column pruning and predicate pushing down to clean up the plan. plan = Self::column_pruning(plan, explain_trace, &ctx); if last_total_rule_applied_before_predicate_pushdown != ctx.total_rule_applied() { diff --git a/src/frontend/src/optimizer/plan_node/logical_vector_search.rs b/src/frontend/src/optimizer/plan_node/logical_vector_search.rs index 20a3c9a81b7ad..28b7e084f5a2d 100644 --- a/src/frontend/src/optimizer/plan_node/logical_vector_search.rs +++ b/src/frontend/src/optimizer/plan_node/logical_vector_search.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use itertools::Itertools; use pretty_xmlish::{Pretty, XmlNode}; use risingwave_common::array::VECTOR_DISTANCE_TYPE; @@ -26,6 +28,7 @@ use risingwave_pb::common::PbDistanceType; use risingwave_pb::plan_common::JoinType; use crate::OptimizerContextRef; +use crate::catalog::index_catalog::VectorIndex; use crate::error::ErrorCode; use crate::expr::{ Expr, ExprImpl, ExprRewriter, ExprType, ExprVisitor, FunctionCall, InputRef, Literal, @@ -124,9 +127,9 @@ impl LogicalVectorSearch { distance_type: PbDistanceType, left: ExprImpl, right: ExprImpl, - output_indices: Vec, input: PlanRef, ) -> Self { + let output_indices = (0..input.schema().len()).collect(); let core = VectorSearchCore { top_n, distance_type, @@ -142,10 +145,6 @@ impl LogicalVectorSearch { let base = PlanBase::new_logical_with_core(&core); Self { base, core } } - - pub(crate) fn i2o_mapping(&self) -> ColIndexMapping { - self.core.i2o_mapping() - } } impl_plan_tree_node_for_unary! { Logical, LogicalVectorSearch } @@ -360,17 +359,24 @@ impl LogicalVectorSearch { } } -impl ToBatch for LogicalVectorSearch { - fn to_batch(&self) -> crate::error::Result { - if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan() - && !scan.vector_indexes().is_empty() - && self - .core - .ctx() - .session_ctx() - .config() - .enable_index_selection() - { +impl LogicalVectorSearch { + #[expect(clippy::type_complexity)] + pub fn resolve_vector_index_lookup<'a>( + scan: &'a LogicalScan, + vector_column_expr: &ExprImpl, + distance_type: PbDistanceType, + output_indices: &[usize], + ) -> Option<( + &'a Arc, + Vec, + Vec, + Vec<(bool, usize)>, + )> { + if !scan.vector_indexes().is_empty() { + let primary_table_cols_idx = output_indices + .iter() + .map(|input_idx| scan.output_col_idx()[*input_idx]) + .collect_vec(); for index in scan.vector_indexes() { if !Self::is_matched_vector_column_expr( &index.vector_expr, @@ -379,16 +385,10 @@ impl ToBatch for LogicalVectorSearch { ) { continue; } - if index.vector_index_info.distance_type() != self.core.distance_type { + if index.vector_index_info.distance_type() != distance_type { continue; } - let primary_table_cols_idx = self - .core - .output_indices - .iter() - .map(|input_idx| scan.output_col_idx()[*input_idx]) - .collect_vec(); let mut covered_table_cols_idx = Vec::new(); let mut non_covered_table_cols_idx = Vec::new(); let mut primary_table_col_in_output = @@ -405,6 +405,40 @@ impl ToBatch for LogicalVectorSearch { non_covered_table_cols_idx.push(*table_col_idx); } } + return Some(( + index, + covered_table_cols_idx, + non_covered_table_cols_idx, + primary_table_col_in_output, + )); + } + } + None + } +} + +impl ToBatch for LogicalVectorSearch { + fn to_batch(&self) -> Result { + if let Some((scan, vector_expr, vector_column_expr)) = self.as_vector_table_scan() + && self + .core + .ctx() + .session_ctx() + .config() + .enable_index_selection() + && let Some(( + index, + covered_table_cols_idx, + non_covered_table_cols_idx, + primary_table_col_in_output, + )) = Self::resolve_vector_index_lookup( + scan, + vector_column_expr, + self.core.distance_type, + &self.core.output_indices, + ) + { + { let vector_data_type = vector_expr.return_type(); let literal_vector_input = BatchValues::new(LogicalValues::new( vec![vec![vector_expr]], diff --git a/src/frontend/src/optimizer/plan_node/logical_vector_search_lookup_join.rs b/src/frontend/src/optimizer/plan_node/logical_vector_search_lookup_join.rs new file mode 100644 index 0000000000000..e9dfa6e472f00 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/logical_vector_search_lookup_join.rs @@ -0,0 +1,344 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pretty_xmlish::{Pretty, XmlNode}; +use risingwave_common::array::VECTOR_DISTANCE_TYPE; +use risingwave_common::bail; +use risingwave_common::catalog::{Field, Schema}; +use risingwave_common::types::{DataType, StructType}; +use risingwave_common::util::column_index_mapping::ColIndexMapping; +use risingwave_pb::catalog::vector_index_info; +use risingwave_pb::common::PbDistanceType; + +use crate::OptimizerContextRef; +use crate::expr::{ExprDisplay, ExprImpl}; +use crate::optimizer::plan_node::expr_visitable::ExprVisitable; +use crate::optimizer::plan_node::generic::{ + GenericPlanNode, GenericPlanRef, VectorIndexLookupJoin, ensure_sorted_required_cols, +}; +use crate::optimizer::plan_node::utils::{Distill, childless_record}; +use crate::optimizer::plan_node::{LogicalPlanRef as PlanRef, *}; +use crate::optimizer::property::FunctionalDependencySet; +use crate::utils::Condition; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct VectorSearchLookupJoinCore { + top_n: u64, + distance_type: PbDistanceType, + + input: PlanRef, + input_vector_col_idx: usize, + lookup: PlanRef, + lookup_vector: ExprImpl, + + /// The indices of lookup that will be included in the output. + /// The index of distance column is `lookup_output_indices.len()` + lookup_output_indices: Vec, + include_distance: bool, +} + +impl VectorSearchLookupJoinCore { + pub(crate) fn clone_with_input(&self, input: PlanRef, lookup: PlanRef) -> Self { + Self { + top_n: self.top_n, + distance_type: self.distance_type, + input, + input_vector_col_idx: self.input_vector_col_idx, + lookup, + lookup_vector: self.lookup_vector.clone(), + lookup_output_indices: self.lookup_output_indices.clone(), + include_distance: self.include_distance, + } + } +} + +impl GenericPlanNode for VectorSearchLookupJoinCore { + fn functional_dependency(&self) -> FunctionalDependencySet { + // TODO: include dependency of array_agg column + FunctionalDependencySet::new(self.input.schema().len() + 1) + } + + fn schema(&self) -> Schema { + let fields = self + .input + .schema() + .fields + .iter() + .cloned() + .chain([Field::new( + "array", + DataType::Struct(StructType::new( + self.lookup_output_indices + .iter() + .map(|i| { + let field = &self.lookup.schema().fields[*i]; + (field.name.clone(), field.data_type.clone()) + }) + .chain( + self.include_distance + .then(|| ("vector_distance".to_owned(), VECTOR_DISTANCE_TYPE)), + ), + )) + .list(), + )]) + .collect(); + + Schema { fields } + } + + fn stream_key(&self) -> Option> { + self.input.stream_key().map(|key| key.to_vec()) + } + + fn ctx(&self) -> OptimizerContextRef { + self.input.ctx() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct LogicalVectorSearchLookupJoin { + pub base: PlanBase, + core: VectorSearchLookupJoinCore, +} + +impl LogicalVectorSearchLookupJoin { + pub(crate) fn new( + top_n: u64, + distance_type: PbDistanceType, + input: PlanRef, + input_vector_col_idx: usize, + lookup: PlanRef, + lookup_vector: ExprImpl, + lookup_output_indices: Vec, + include_distance: bool, + ) -> Self { + let core = VectorSearchLookupJoinCore { + top_n, + distance_type, + input, + input_vector_col_idx, + lookup, + lookup_vector, + lookup_output_indices, + include_distance, + }; + Self::with_core(core) + } + + fn with_core(core: VectorSearchLookupJoinCore) -> Self { + let base = PlanBase::new_logical_with_core(&core); + Self { base, core } + } +} + +impl_plan_tree_node_for_binary! { Logical, LogicalVectorSearchLookupJoin } + +impl PlanTreeNodeBinary for LogicalVectorSearchLookupJoin { + fn left(&self) -> PlanRef { + self.core.input.clone() + } + + fn right(&self) -> PlanRef { + self.core.lookup.clone() + } + + fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self { + let core = self.core.clone_with_input(left, right); + Self::with_core(core) + } +} + +impl Distill for LogicalVectorSearchLookupJoin { + fn distill<'a>(&self) -> XmlNode<'a> { + let verbose = self.base.ctx().is_explain_verbose(); + let mut vec = Vec::with_capacity(if verbose { 4 } else { 6 }); + vec.push(("distance_type", Pretty::debug(&self.core.distance_type))); + vec.push(("top_n", Pretty::debug(&self.core.top_n))); + vec.push(( + "input_vector", + Pretty::debug(&self.core.input.schema()[self.core.input_vector_col_idx]), + )); + + vec.push(( + "lookup_vector", + Pretty::debug(&ExprDisplay { + expr: &self.core.lookup_vector, + input_schema: self.core.lookup.schema(), + }), + )); + + if verbose { + vec.push(( + "lookup_output_columns", + Pretty::Array( + self.core + .lookup_output_indices + .iter() + .map(|input_idx| { + Pretty::debug(&self.core.lookup.schema().fields()[*input_idx]) + }) + .collect(), + ), + )); + vec.push(( + "include_distance", + Pretty::debug(&self.core.include_distance), + )); + } + + childless_record("LogicalVectorSearchLookupJoin", vec) + } +} + +impl ColPrunable for LogicalVectorSearchLookupJoin { + fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef { + let (project_exprs, mut required_cols) = + ensure_sorted_required_cols(required_cols, self.base.schema()); + assert!(required_cols.is_sorted()); + if let Some(last_col) = required_cols.last() + && *last_col == self.core.input.schema().len() + { + // pop the array_agg column, since we only prune base input + required_cols.pop(); + let output_vector = required_cols.contains(&self.core.input_vector_col_idx); + if !output_vector { + // include vector column in the input + required_cols.push(self.core.input_vector_col_idx); + } + + let new_input = self.core.input.prune_col(&required_cols, ctx); + let mut core = self + .core + .clone_with_input(new_input, self.core.lookup.clone()); + + core.input_vector_col_idx = ColIndexMapping::with_remaining_columns( + &required_cols, + self.core.input.schema().len(), + ) + .map(self.core.input_vector_col_idx); + let vector_search = Self::with_core(core).into(); + let input = if output_vector { + vector_search + } else { + // prune the vector column in the end of input, and include the array_agg column + LogicalProject::with_out_col_idx( + vector_search, + (0..required_cols.len() - 1).chain([required_cols.len()]), + ) + .into() + }; + + LogicalProject::create(input, project_exprs) + } else { + // the array_agg column is pruned, no need to lookup + let input = self.core.input.prune_col(&required_cols, ctx); + LogicalProject::create(input, project_exprs) + } + } +} + +impl ExprRewritable for LogicalVectorSearchLookupJoin {} + +impl ExprVisitable for LogicalVectorSearchLookupJoin {} + +impl PredicatePushdown for LogicalVectorSearchLookupJoin { + fn predicate_pushdown( + &self, + predicate: Condition, + ctx: &mut PredicatePushdownContext, + ) -> PlanRef { + // TODO: push down to input when possible + let input = self + .core + .input + .predicate_pushdown(Condition::true_cond(), ctx); + let lookup = self + .core + .lookup + .predicate_pushdown(Condition::true_cond(), ctx); + let core = self.core.clone_with_input(input, lookup); + LogicalFilter::create(Self::with_core(core).into(), predicate) + } +} + +impl ToStream for LogicalVectorSearchLookupJoin { + fn logical_rewrite_for_stream( + &self, + _ctx: &mut RewriteStreamContext, + ) -> crate::error::Result<(PlanRef, ColIndexMapping)> { + bail!("LogicalVectorSearch can only for batch plan, not stream plan"); + } + + fn to_stream(&self, _ctx: &mut ToStreamContext) -> crate::error::Result { + bail!("LogicalVectorSearch can only for batch plan, not stream plan"); + } +} + +impl ToBatch for LogicalVectorSearchLookupJoin { + fn to_batch(&self) -> Result { + if let Some(scan) = self.core.lookup.as_logical_scan() + && let Some(( + index, + _covered_table_cols_idx, + non_covered_table_cols_idx, + primary_table_col_in_output, + )) = LogicalVectorSearch::resolve_vector_index_lookup( + scan, + &self.core.lookup_vector, + self.core.distance_type, + &self.core.lookup_output_indices, + ) + && non_covered_table_cols_idx.is_empty() + { + let hnsw_ef_search = match index.vector_index_info.config.as_ref().unwrap() { + vector_index_info::Config::Flat(_) => None, + vector_index_info::Config::HnswFlat(_) => Some( + self.core + .ctx() + .session_ctx() + .config() + .batch_hnsw_ef_search(), + ), + }; + let info_output_indices = primary_table_col_in_output + .iter() + .map(|(covered, idx_in_index_info_columns)| { + assert!(*covered); + *idx_in_index_info_columns + }) + .collect(); + let core = VectorIndexLookupJoin { + input: self.core.input.to_batch()?, + top_n: self.core.top_n, + distance_type: self.core.distance_type, + index_name: index.index_table.name.clone(), + index_table_id: index.index_table.id, + info_column_desc: index.info_column_desc(), + info_output_indices, + include_distance: self.core.include_distance, + vector_column_idx: self.core.input_vector_col_idx, + hnsw_ef_search, + ctx: self.core.ctx(), + }; + return Ok(BatchVectorSearch::with_core(core).into()); + } + let todo = 0; + Ok(BatchValues::new(LogicalValues::new( + vec![], + self.base.schema().clone(), + self.base.ctx(), + )) + .into()) + } +} diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index 5dcda3df37ed6..eab5e7cb97cfe 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -1083,6 +1083,7 @@ mod logical_postgres_query; mod batch_vector_search; mod logical_mysql_query; mod logical_vector_search; +mod logical_vector_search_lookup_join; mod stream_cdc_table_scan; mod stream_share; mod stream_temporal_join; @@ -1165,6 +1166,7 @@ pub use logical_union::LogicalUnion; pub use logical_update::LogicalUpdate; pub use logical_values::LogicalValues; pub use logical_vector_search::LogicalVectorSearch; +pub use logical_vector_search_lookup_join::LogicalVectorSearchLookupJoin; pub use stream_asof_join::StreamAsOfJoin; pub use stream_cdc_table_scan::StreamCdcTableScan; pub use stream_changelog::StreamChangeLog; @@ -1274,6 +1276,7 @@ macro_rules! for_all_plan_nodes { , { Logical, VectorSearch } , { Logical, GetChannelDeltaStats } , { Logical, LocalityProvider } + , { Logical, VectorSearchLookupJoin } , { Batch, SimpleAgg } , { Batch, HashAgg } , { Batch, SortAgg } diff --git a/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs b/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs index d223df67f6d23..37389d1791b1b 100644 --- a/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs +++ b/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs @@ -28,7 +28,10 @@ struct ExprVis<'a> { impl ExprVisitor for ExprVis<'_> { fn visit_input_ref(&mut self, input_ref: &crate::expr::InputRef) { - if input_ref.data_type != self.schema[input_ref.index].data_type { + if !input_ref + .data_type + .equals_datatype(&self.schema[input_ref.index].data_type) + { self.string.replace(format!( "InputRef#{} has type {}, but its type is {} in the input schema", input_ref.index, input_ref.data_type, self.schema[input_ref.index].data_type diff --git a/src/frontend/src/optimizer/rule/mod.rs b/src/frontend/src/optimizer/rule/mod.rs index 4d6e6ca2f11d5..c069414088a1c 100644 --- a/src/frontend/src/optimizer/rule/mod.rs +++ b/src/frontend/src/optimizer/rule/mod.rs @@ -368,6 +368,7 @@ macro_rules! for_all_rules { , { AddLogstoreRule } , { EmptyAggRemoveRule } , { TopNToVectorSearchRule } + , { CorrelatedTopNToVectorSearchRule } } }; } diff --git a/src/frontend/src/optimizer/rule/top_n_to_vector_search_rule.rs b/src/frontend/src/optimizer/rule/top_n_to_vector_search_rule.rs index 74018716a501c..e66a5f2b798f3 100644 --- a/src/frontend/src/optimizer/rule/top_n_to_vector_search_rule.rs +++ b/src/frontend/src/optimizer/rule/top_n_to_vector_search_rule.rs @@ -14,18 +14,22 @@ use std::assert_matches::assert_matches; -use risingwave_common::types::DataType; +use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_common::util::sort_util::ColumnOrder; +use risingwave_expr::aggregate::AggType; use risingwave_pb::common::PbDistanceType; +use risingwave_pb::plan_common::JoinType; -use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprType, InputRef}; +use crate::expr::{Expr, ExprImpl, ExprType, InputRef}; use crate::optimizer::LogicalPlanRef; -use crate::optimizer::plan_node::generic::TopNLimit; +use crate::optimizer::plan_node::generic::{GenericPlanRef, TopNLimit}; use crate::optimizer::plan_node::{ - LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, LogicalVectorSearch, PlanTreeNodeUnary, + LogicalPlanNodeType, LogicalPlanRef as PlanRef, LogicalProject, LogicalTopN, + LogicalVectorSearch, LogicalVectorSearchLookupJoin, PlanTreeNodeBinary, PlanTreeNodeUnary, }; use crate::optimizer::rule::prelude::*; -use crate::optimizer::rule::{BoxedRule, ProjectMergeRule, Rule}; +use crate::optimizer::rule::{BoxedRule, PbAggKind, ProjectMergeRule, Rule}; +use crate::utils::IndexColumnExprValidator; pub struct TopNToVectorSearchRule; @@ -47,7 +51,13 @@ fn merge_consecutive_projections(input: LogicalPlanRef) -> Option<(Vec } impl TopNToVectorSearchRule { - fn resolve_vector_search(top_n: &LogicalTopN) -> Option<(LogicalVectorSearch, Vec)> { + #[expect(clippy::type_complexity)] + fn resolve_vector_search( + top_n: &LogicalTopN, + ) -> Option<( + (u64, PbDistanceType, ExprImpl, ExprImpl, PlanRef), + Vec, + )> { if !top_n.group_key().is_empty() { // vector search applies for only singleton top n return None; @@ -109,18 +119,9 @@ impl TopNToVectorSearchRule { assert_matches!(left.return_type(), DataType::Vector(_)); assert_matches!(right.return_type(), DataType::Vector(_)); - let vector_search = LogicalVectorSearch::new( - limit, - distance_type, - left.clone(), - right.clone(), - (0..projection_input.schema().len()).collect(), - projection_input.clone(), - ); - let mut i2o_mapping = vector_search.i2o_mapping(); let mut output_exprs = Vec::with_capacity(exprs.len()); for expr in &exprs[0..order.column_index] { - output_exprs.push(i2o_mapping.rewrite_expr(expr.clone())); + output_exprs.push(expr.clone()); } output_exprs.push(ExprImpl::InputRef( InputRef { @@ -130,9 +131,18 @@ impl TopNToVectorSearchRule { .into(), )); for expr in &exprs[order.column_index + 1..exprs.len()] { - output_exprs.push(i2o_mapping.rewrite_expr(expr.clone())); + output_exprs.push(expr.clone()); } - Some((vector_search, output_exprs)) + Some(( + ( + limit, + distance_type, + left.clone(), + right.clone(), + projection_input, + ), + output_exprs, + )) } } @@ -149,7 +159,163 @@ impl TopNToVectorSearchRule { impl Rule for TopNToVectorSearchRule { fn apply(&self, plan: PlanRef) -> Option { let top_n = plan.as_logical_top_n()?; - let (vector_search, project_exprs) = Self::resolve_vector_search(top_n)?; + let ((top_n, distance_type, left, right, input), project_exprs) = + TopNToVectorSearchRule::resolve_vector_search(top_n)?; + let vector_search = LogicalVectorSearch::new(top_n, distance_type, left, right, input); Some(LogicalProject::create(vector_search.into(), project_exprs)) } } + +pub struct CorrelatedTopNToVectorSearchRule; + +impl CorrelatedTopNToVectorSearchRule { + pub fn create() -> BoxedRule { + Box::new(CorrelatedTopNToVectorSearchRule) + } +} + +impl Rule for CorrelatedTopNToVectorSearchRule { + fn apply(&self, plan: PlanRef) -> Option { + let apply = plan.as_logical_apply()?; + if apply.join_type() != JoinType::LeftOuter { + return None; + } + if !apply.max_one_row() { + return None; + } + let correlated_id = apply.correlated_id(); + let input = apply.left(); + + // match pattern LogicalProject { exprs: [[Coalesce(array_agg($expr1 order_by($expr2 ASC)), ARRAY[]) as $expr3] } + let right = apply.right(); + let project = right.as_logical_project()?; + let Ok(expr) = project.exprs().as_slice().try_into() else { + return None; + }; + let [expr]: &[_; 1] = expr; + let func_call = expr.as_function_call()?; + if func_call.func_type() != ExprType::Coalesce { + return None; + } + let Ok(inputs) = func_call.inputs().try_into() else { + return None; + }; + let [first, second]: &[_; 2] = inputs; + let empty_array = second.as_literal()?; + let Some(ScalarImpl::List(empty_list)) = empty_array.get_data() else { + return None; + }; + if !empty_list.is_empty() { + return None; + } + + // match pattern of LogicalAgg { aggs: [array_agg($expr1 order_by($expr2 ASC))] } + let array_agg_input = first.as_input_ref()?; + let project_input = project.input(); + let agg = project_input.as_logical_agg()?; + if !agg.group_key().is_empty() { + return None; + } + let Ok(array_agg) = agg.agg_calls().as_slice().try_into() else { + return None; + }; + let [array_agg]: &[_; 1] = array_agg; + if array_agg.agg_type != AggType::Builtin(PbAggKind::ArrayAgg) { + return None; + } + assert_eq!(array_agg_input.index, 0); + let Ok(array_agg_input) = array_agg.inputs.as_slice().try_into() else { + return None; + }; + let [array_agg_input]: &[_; 1] = array_agg_input; + + let ((top_n, distance_type, left, right, lookup_input), project_exprs) = { + let mut prev_proj_exprs: Option> = None; + let mut input = agg.input(); + loop { + match input.node_type() { + LogicalPlanNodeType::LogicalProject => { + let proj = input.as_logical_project().expect("checked node type"); + prev_proj_exprs = Some(if let Some(prev_proj_exprs) = prev_proj_exprs { + ProjectMergeRule::merge_project_exprs( + prev_proj_exprs.as_slice(), + proj.exprs(), + false, + )? + } else { + proj.exprs().clone() + }); + input = proj.input(); + } + LogicalPlanNodeType::LogicalTopN => { + let (resolved_info, mut project_exprs) = + TopNToVectorSearchRule::resolve_vector_search( + input.as_logical_top_n().expect("checked node type"), + )?; + if let Some(prev_proj_exprs) = prev_proj_exprs { + project_exprs = ProjectMergeRule::merge_project_exprs( + prev_proj_exprs.as_slice(), + &project_exprs, + false, + )?; + } + break (resolved_info, project_exprs); + } + _ => { + return None; + } + } + } + }; + + let (input_vector_idx, lookup_expr) = match (left, right) { + (ExprImpl::CorrelatedInputRef(correlated), lookup_expr) + | (lookup_expr, ExprImpl::CorrelatedInputRef(correlated)) + if correlated.correlated_id() == correlated_id + && IndexColumnExprValidator::validate(&lookup_expr, true).is_ok() => + { + (correlated.index(), lookup_expr) + } + _ => { + return None; + } + }; + + // match pattern Row(lookup.col1, lookup.col2, ..) + let array_agg_input_expr = &project_exprs[array_agg_input.index]; + let row_input_func = array_agg_input_expr.as_function_call()?; + if row_input_func.func_type() != ExprType::Row { + return None; + } + let mut row_input_indices = vec![]; + let mut include_distance = false; + for (idx, row_input) in row_input_func.inputs().iter().enumerate() { + let input_index = row_input.as_input_ref()?.index; + if input_index == lookup_input.schema().len() { + // distance column included in the row output + if idx != row_input_func.inputs().len() - 1 { + // for simplicity, we require that distance column should be the last column in the row + return None; + } else { + include_distance = true; + } + } else { + row_input_indices.push(input_index); + } + } + + Some( + LogicalVectorSearchLookupJoin::new( + top_n, + distance_type, + input, + input_vector_idx, + lookup_input, + lookup_expr, + row_input_indices, + include_distance, + ) + .into(), + ) + } +} diff --git a/src/frontend/src/utils/mod.rs b/src/frontend/src/utils/mod.rs index e7f35af54ae85..1bab34bc1a65a 100644 --- a/src/frontend/src/utils/mod.rs +++ b/src/frontend/src/utils/mod.rs @@ -38,8 +38,11 @@ pub(crate) mod group_by; pub mod overwrite_options; pub use group_by::*; pub use overwrite_options::*; +use risingwave_common::util::recursive::{Recurse, tracker}; -use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef}; +use crate::error::ErrorCode; +use crate::expr::{Expr, ExprImpl, ExprRewriter, ExprVisitor, InputRef, is_impure_func_call}; +use crate::session::current::notice_to_user; pub static FRONTEND_RUNTIME: LazyLock = LazyLock::new(|| { tokio::runtime::Builder::new_multi_thread() @@ -56,12 +59,15 @@ pub struct Substitute { impl ExprRewriter for Substitute { fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl { - assert_eq!( - input_ref.return_type(), - self.mapping[input_ref.index()].return_type(), - "Type mismatch when substituting {:?} with {:?}", + assert!( + input_ref + .return_type() + .equals_datatype(&self.mapping[input_ref.index()].return_type()), + "Type mismatch when substituting {:?} of {:?} with {:?} of {:?}", input_ref, + input_ref.return_type(), self.mapping[input_ref.index()], + self.mapping[input_ref.index()].return_type() ); self.mapping[input_ref.index()].clone() } @@ -213,3 +219,63 @@ pub fn ordinal(i: usize) -> String { }; s + suffix } + +pub(crate) struct IndexColumnExprValidator { + allow_impure: bool, + result: crate::error::Result<()>, +} + +impl IndexColumnExprValidator { + fn unsupported_expr_err(expr: &ExprImpl) -> ErrorCode { + ErrorCode::NotSupported( + format!("unsupported index column expression type: {:?}", expr), + "use columns or expressions instead".into(), + ) + } + + pub(crate) fn validate(expr: &ExprImpl, allow_impure: bool) -> crate::error::Result<()> { + match expr { + ExprImpl::InputRef(_) | ExprImpl::FunctionCall(_) => {} + other_expr => { + return Err(Self::unsupported_expr_err(other_expr).into()); + } + } + let mut visitor = Self { + allow_impure, + result: Ok(()), + }; + visitor.visit_expr(expr); + visitor.result + } +} + +impl ExprVisitor for IndexColumnExprValidator { + fn visit_expr(&mut self, expr: &ExprImpl) { + if self.result.is_err() { + return; + } + tracker!().recurse(|t| { + if t.depth_reaches(crate::expr::EXPR_DEPTH_THRESHOLD) { + notice_to_user(crate::expr::EXPR_TOO_DEEP_NOTICE); + } + + match expr { + ExprImpl::InputRef(_) | ExprImpl::Literal(_) => {} + ExprImpl::FunctionCall(inner) => { + if !self.allow_impure && is_impure_func_call(inner) { + self.result = Err(ErrorCode::NotSupported( + "this expression is impure".into(), + "use a pure expression instead".into(), + ) + .into()); + return; + } + self.visit_function_call(inner) + } + other_expr => { + self.result = Err(Self::unsupported_expr_err(other_expr).into()); + } + } + }) + } +}