diff --git a/crates/storage-query-datafusion/src/remote_query_scanner_client.rs b/crates/storage-query-datafusion/src/remote_query_scanner_client.rs index 4bc2b28252..11b804117f 100644 --- a/crates/storage-query-datafusion/src/remote_query_scanner_client.rs +++ b/crates/storage-query-datafusion/src/remote_query_scanner_client.rs @@ -16,6 +16,7 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_expr_common::physical_expr::snapshot_generation; use datafusion::physical_plan::PhysicalExpr; use datafusion::physical_plan::stream::RecordBatchReceiverStream; use tracing::debug; @@ -46,7 +47,10 @@ impl RemoteScanner { } } - async fn next_batch(&self) -> Result { + async fn next_batch( + &self, + next_predicate: Option, + ) -> Result { let Some(ref connection) = self.connection else { return Err(DataFusionError::Internal( "connection used after forget()".to_string(), @@ -67,6 +71,7 @@ impl RemoteScanner { .send_rpc( RemoteQueryScannerNext { scanner_id: self.scanner_id, + next_predicate, }, None, ) @@ -150,6 +155,9 @@ pub fn remote_scan_as_datafusion_stream( let tx = builder.tx(); let task = async move { + // get a snapshot of the initial predicate + let mut predicate_generation = predicate.as_ref().map(snapshot_generation).unwrap_or(0); + let initial_predicate = match &predicate { Some(predicate) => Some(RemoteQueryScannerPredicate { serialized_physical_expression: encode_expr(predicate)?, @@ -176,7 +184,26 @@ pub fn remote_scan_as_datafusion_stream( // loop while we have record_batch coming in // loop { - let batch = match remote_scanner.next_batch().await { + let next_predicate = if predicate_generation != 0 { + // generation 0 means the predicate is static (or we never had one) + let predicate = predicate + .as_ref() + .expect("must have a predicate if generation != 0"); + let current_predicate_generation = snapshot_generation(predicate); + + if current_predicate_generation != predicate_generation { + predicate_generation = current_predicate_generation; + Some(RemoteQueryScannerPredicate { + serialized_physical_expression: encode_expr(predicate)?, + }) + } else { + None + } + } else { + None + }; + + let batch = match remote_scanner.next_batch(next_predicate).await { Err(e) => { return Err(e); } diff --git a/crates/storage-query-datafusion/src/remote_query_scanner_server.rs b/crates/storage-query-datafusion/src/remote_query_scanner_server.rs index 095d98bd5a..45242c12c7 100644 --- a/crates/storage-query-datafusion/src/remote_query_scanner_server.rs +++ b/crates/storage-query-datafusion/src/remote_query_scanner_server.rs @@ -135,7 +135,10 @@ impl RemoteQueryScannerServer { // do that again here. If we do, we might end up dead-locking the map because we are holding a // reference into it (scanner). if let Err(mpsc::error::SendError(request)) = - scanner.send(super::scanner_task::NextRequest { reciprocal }) + scanner.send(super::scanner_task::NextRequest { + reciprocal, + next_predicate: req.next_predicate, + }) { tracing::info!( "No such scanner {}. This could be an expired scanner due to a slow scan with no activity.", diff --git a/crates/storage-query-datafusion/src/scanner_task.rs b/crates/storage-query-datafusion/src/scanner_task.rs index f519978ecb..2e682e1a69 100644 --- a/crates/storage-query-datafusion/src/scanner_task.rs +++ b/crates/storage-query-datafusion/src/scanner_task.rs @@ -12,7 +12,8 @@ use std::sync::{Arc, Weak}; use std::time::Duration; use anyhow::Context; -use datafusion::execution::SendableRecordBatchStream; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::PhysicalExpr; use datafusion::prelude::SessionContext; use tokio::sync::mpsc; @@ -23,7 +24,8 @@ use restate_core::network::{Oneshot, Reciprocal}; use restate_core::{TaskCenter, TaskKind}; use restate_types::GenerationalNodeId; use restate_types::net::remote_query_scanner::{ - RemoteQueryScannerNextResult, RemoteQueryScannerOpen, ScannerBatch, ScannerFailure, ScannerId, + RemoteQueryScannerNextResult, RemoteQueryScannerOpen, RemoteQueryScannerPredicate, + ScannerBatch, ScannerFailure, ScannerId, }; use crate::remote_query_scanner_manager::RemoteScannerManager; @@ -34,6 +36,7 @@ const SCANNER_EXPIRATION: Duration = Duration::from_secs(60); pub(crate) struct NextRequest { pub reciprocal: Reciprocal>, + pub next_predicate: Option, } pub(crate) type ScannerHandle = mpsc::UnboundedSender; @@ -45,6 +48,8 @@ pub(crate) struct ScannerTask { stream: SendableRecordBatchStream, rx: mpsc::UnboundedReceiver, scanners: Weak, + ctx: Arc, + schema: SchemaRef, predicate: Option>, } @@ -88,6 +93,8 @@ impl ScannerTask { stream, rx, scanners: Arc::downgrade(scanners), + ctx: SessionContext::new().task_ctx(), + schema, predicate, }; @@ -133,6 +140,21 @@ impl ScannerTask { } }; + if let Some(next_predicate) = request.next_predicate { + match decode_expr( + &self.ctx, + &self.schema, + &next_predicate.serialized_physical_expression, + ) { + // for now, we are not updating the predicate being passed to ScanPartition, + // so we rely on the filtering below to apply dynamic filters + Ok(next_predicate) => self.predicate = Some(next_predicate), + Err(e) => { + warn!("Failed to decode next predicate: {e}") + } + } + } + let record_batch = loop { // connection/request has been closed, don't bother with driving the stream. // The scanner will be dropped because we want to make sure that we don't get supurious diff --git a/crates/storage-query-datafusion/src/table_providers.rs b/crates/storage-query-datafusion/src/table_providers.rs index 4a7c518aa2..7f273313ed 100644 --- a/crates/storage-query-datafusion/src/table_providers.rs +++ b/crates/storage-query-datafusion/src/table_providers.rs @@ -20,6 +20,9 @@ use datafusion::execution::context::TaskContext; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::filter_pushdown::{ + FilterPushdownPhase, FilterPushdownPropagation, PushedDown, +}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties, @@ -242,6 +245,10 @@ where ) -> datafusion::common::Result> { let res = filters .iter() + // if we set this to exact, we might be able to remove a FilterExec higher up the plan. + // however, it means that fields we filter on won't end up in our projection, meaning we + // have to manage a projected schema and a filter schema - defer this complexity for + // future optimization. .map(|_| TableProviderFilterPushDown::Inexact) .collect(); @@ -347,6 +354,46 @@ where sequential_scanners_stream, ))) } + + fn handle_child_pushdown_result( + &self, + phase: datafusion::physical_plan::filter_pushdown::FilterPushdownPhase, + child_pushdown_result: datafusion::physical_plan::filter_pushdown::ChildPushdownResult, + _config: &datafusion::config::ConfigOptions, + ) -> datafusion::error::Result< + datafusion::physical_plan::filter_pushdown::FilterPushdownPropagation< + Arc, + >, + > { + if !matches!(phase, FilterPushdownPhase::Post) { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + let filters = child_pushdown_result + .parent_filters + .iter() + .map(|f| f.filter.clone()); + + let predicate = match &self.predicate { + Some(predicate) => datafusion::physical_expr::conjunction( + std::iter::once(predicate.clone()).chain(filters), + ), + None => datafusion::physical_expr::conjunction(filters), + }; + + let mut plan = self.clone(); + plan.predicate = Some(predicate); + + Ok(FilterPushdownPropagation { + // we report all filters as unsupported as we don't guarantee to apply them exactly as there can be a delay before new filters are used + filters: child_pushdown_result + .parent_filters + .iter() + .map(|_| PushedDown::No) + .collect(), + updated_node: Some(Arc::new(plan)), + }) + } } impl DisplayAs for PartitionedExecutionPlan diff --git a/crates/types/src/net/remote_query_scanner.rs b/crates/types/src/net/remote_query_scanner.rs index 2a60b9402c..269e6a5640 100644 --- a/crates/types/src/net/remote_query_scanner.rs +++ b/crates/types/src/net/remote_query_scanner.rs @@ -95,6 +95,9 @@ pub enum RemoteQueryScannerOpened { pub struct RemoteQueryScannerNext { #[bilrost(1)] pub scanner_id: ScannerId, + #[bilrost(tag(2))] + #[serde(default)] + pub next_predicate: Option, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, bilrost::Message)]