diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6618d9495d78..4ce845bbbdb2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -77,10 +77,11 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, + WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; @@ -90,6 +91,7 @@ use datafusion_physical_expr::{ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; +use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; @@ -1089,8 +1091,42 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; + // TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`) + let mut num_range_filters = 0; + let mut range_filters: Vec = Vec::new(); + let mut total_filters = 0; + let join_filter = match filter { Some(expr) => { + let split_expr = split_conjunction(expr); + for expr in split_expr.iter() { + match *expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + right: _, + op, + }) => { + if matches!( + op, + Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) { + range_filters.push((**expr).clone()); + num_range_filters += 1; + } + total_filters += 1; + } + // TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates + // which is why it is not dealt with in PWMJ + // Expr::Between(_) => {}, + _ => { + total_filters += 1; + } + } + } + // Extract columns from filter expression and saved in a HashSet let cols = expr.column_refs(); @@ -1146,6 +1182,7 @@ impl DefaultPhysicalPlanner { )?; let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1168,10 +1205,104 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + let cfg = session_state.config(); + + let can_run_single = + cfg.target_partitions() == 1 || !cfg.repartition_joins(); + + // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { if join_filter.is_none() && matches!(join_type, JoinType::Inner) { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else if num_range_filters == 1 + && total_filters == 1 + && can_run_single + && !matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) + { + let Expr::BinaryExpr(be) = &range_filters[0] else { + return plan_err!( + "Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`" + ); + }; + + let mut op = be.op; + if !matches!( + op, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + fn reverse_ineq(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::LtEq => Operator::GtEq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + _ => op, + } + } + + let side_of = |e: &Expr| -> Result<&'static str> { + let cols = e.column_refs(); + let in_left = cols + .iter() + .all(|c| left_df_schema.index_of_column(c).is_ok()); + let in_right = cols + .iter() + .all(|c| right_df_schema.index_of_column(c).is_ok()); + match (in_left, in_right) { + (true, false) => Ok("left"), + (false, true) => Ok("right"), + _ => unreachable!(), + } + }; + + let mut lhs_logical = &be.left; + let mut rhs_logical = &be.right; + + let left_side = side_of(lhs_logical)?; + let right_side = side_of(rhs_logical)?; + if left_side == "right" && right_side == "left" { + std::mem::swap(&mut lhs_logical, &mut rhs_logical); + op = reverse_ineq(op); + } else if !(left_side == "left" && right_side == "right") { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + let on_left = create_physical_expr( + lhs_logical, + left_df_schema, + session_state.execution_props(), + )?; + let on_right = create_physical_expr( + rhs_logical, + right_df_schema, + session_state.execution_props(), + )?; + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index d368a9cf8ee2..0adb9b7a69cb 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -589,6 +589,7 @@ impl HashJoinStream { let (left_side, right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, + true, ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434..b0c28cf994f7 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -24,11 +24,13 @@ pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet +pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod nested_loop_join; +mod piecewise_merge_join; mod sort_merge_join; mod stream_join_utils; mod symmetric_hash_join; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs new file mode 100644 index 000000000000..55c8245b4507 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -0,0 +1,1471 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) + +use arrow::array::{ + new_null_array, Array, PrimitiveArray, PrimitiveBuilder, RecordBatchOptions, +}; +use arrow::compute::take; +use arrow::datatypes::{UInt32Type, UInt64Type}; +use arrow::{ + array::{ + ArrayRef, RecordBatch, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, + }, + compute::{sort_to_indices, take_record_batch}, +}; +use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::handle_state; +use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; +use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; +use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; +pub(super) enum PiecewiseMergeJoinStreamState { + WaitBufferedSide, + FetchStreamBatch, + ProcessStreamBatch(StreamedBatch), + ExhaustedStreamSide, + Completed, +} + +impl PiecewiseMergeJoinStreamState { + // Grab mutable reference to the current stream batch + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut StreamedBatch> { + match self { + PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + _ => internal_err!("Expected streamed batch in StreamBatch"), + } + } +} + +pub(super) struct StreamedBatch { + pub batch: RecordBatch, + values: Vec, +} + +impl StreamedBatch { + #[allow(dead_code)] + fn new(batch: RecordBatch, values: Vec) -> Self { + Self { batch, values } + } + + fn values(&self) -> &Vec { + &self.values + } +} + +pub(super) struct ClassicPWMJStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Streamed schema + streamed_schema: SchemaRef, + // Buffered side data + buffered_side: BufferedSide, + // Tracks the state of the `PiecewiseMergeJoin` + state: PiecewiseMergeJoinStreamState, + // Sort option for buffered and streamed side (specifies whether + // the sort is ascending or descending) + sort_option: SortOptions, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, + // Creates batch size + batch_size: usize, +} + +impl RecordBatchStream for ClassicPWMJStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ProcessStreamBatch`, `ExhaustedStreamSide` and `Completed`. +// +// Classic Joins +// 1. `WaitBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to +// `ExhaustedStreamBatch` once stream batches are exhausted. +// 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. +// 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as +// `Completed` however for Full and Right we will need to process the matched/unmatched rows. +impl ClassicPWMJStream { + // Creates a new `PiecewiseMergeJoinStream` instance + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: PiecewiseMergeJoinStreamState, + sort_option: SortOptions, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + let streamed_schema = streamed.schema(); + Self { + schema, + on_streamed, + join_type, + operator, + streamed_schema, + streamed, + buffered_side, + state, + sort_option, + join_metrics, + batch_process_state: BatchProcessState::new(), + batch_size, + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + PiecewiseMergeJoinStreamState::WaitBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + PiecewiseMergeJoinStreamState::FetchStreamBatch => { + handle_state!(ready!(self.fetch_stream_batch(cx))) + } + PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + handle_state!(self.process_stream_batch()) + } + PiecewiseMergeJoinStreamState::ExhaustedStreamSide => { + handle_state!(self.process_unmatched_buffered_batch()) + } + PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + build_timer.done(); + + // We will start fetching stream batches for classic joins + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Fetches incoming stream batches + fn fetch_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + // Sort stream values and change the streamed record batch accordingly + let indices = sort_to_indices( + stream_values.as_ref(), + Some(self.sort_option), + None, + )?; + let stream_batch = take_record_batch(&batch, &indices)?; + let stream_values = take(stream_values.as_ref(), &indices, None)?; + + self.state = + PiecewiseMergeJoinStreamState::ProcessStreamBatch(StreamedBatch { + batch: stream_batch, + values: vec![stream_values], + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Only classic join will call. This function will process stream batches and evaluate against + // the buffered side data. + fn process_stream_batch( + &mut self, + ) -> Result>> { + let buffered_side = self.buffered_side.try_as_ready_mut()?; + let stream_batch = self.state.try_as_process_stream_batch_mut()?; + + let batch = resolve_classic_join( + buffered_side, + stream_batch, + Arc::clone(&self.schema), + self.operator, + self.sort_option, + self.join_type, + &mut self.batch_process_state, + self.batch_size, + )?; + + if self.batch_process_state.continue_process { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + Ok(StatefulStreamResult::Ready(Some(batch))) + } + + // Process remaining unmatched rows + fn process_unmatched_buffered_batch( + &mut self, + ) -> Result>> { + // Return early for `JoinType::Right` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Right | JoinType::Inner) { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + + let timer = self.join_metrics.join_time.timer(); + + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + + // Check if the same batch needs to be checked for values again + if let Some(start_idx) = self.batch_process_state.process_rest { + if let Some(buffered_indices) = &self.batch_process_state.buffered_indices { + let remaining = buffered_indices.len() - start_idx; + + // Branch into this and return value if there are more rows to deal with + if remaining > self.batch_size { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = + buffered_indices.slice(start_idx, self.batch_size); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state + .set_process_rest(Some(start_idx + self.batch_size)); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = buffered_indices.slice(start_idx, remaining); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state.reset(); + + timer.done(); + self.join_metrics.output_batches.add(1); + self.state = PiecewiseMergeJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + return exec_err!("Batch process state should hold buffered indices"); + } + + let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( + &buffered_data.visited_indices_bitmap, + self.join_type, + true, + ); + + // If the output indices is larger than the limit for the incremental batching then + // proceed to outputting all matches up to that index, return batch, and the matching + // will start next on the updated index (`process_rest`) + if buffered_indices.len() > self.batch_size { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let indices_chunk_ref = buffered_indices + .slice(self.batch_process_state.start_idx, self.batch_size); + + let indices_chunk = indices_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + indices_chunk.clone(), + )?; + + self.batch_process_state.buffered_indices = Some(buffered_indices); + self.batch_process_state + .set_process_rest(Some(self.batch_size)); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + buffered_indices, + )?; + + timer.done(); + self.join_metrics.output_batches.add(1); + self.state = PiecewiseMergeJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(batch))) + } +} + +// Holds all information for processing incremental output +struct BatchProcessState { + // Used to pick up from the last index on the stream side + start_idx: usize, + // Used to pick up from the last index on the buffered side + pivot: usize, + // Tracks the number of rows processed; default starts at 0 + num_rows: usize, + // Processes the rest of the batch + process_rest: Option, + // Used to skip fully processing the row + not_found: bool, + // Signals whether to call `ProcessStreamBatch` again + continue_process: bool, + // Holding the buffered indices when processing the remaining marked rows. + buffered_indices: Option>, +} + +impl BatchProcessState { + pub fn new() -> Self { + Self { + start_idx: 0, + num_rows: 0, + pivot: 0, + process_rest: None, + not_found: false, + continue_process: false, + buffered_indices: None, + } + } + + fn reset(&mut self) { + self.start_idx = 0; + self.num_rows = 0; + self.pivot = 0; + self.process_rest = None; + self.not_found = false; + self.continue_process = false; + self.buffered_indices = None; + } + + fn pivot(&self) -> usize { + self.pivot + } + + fn set_pivot(&mut self, pivot: usize) { + self.pivot = pivot; + } + + fn set_start_idx(&mut self, start_idx: usize) { + self.start_idx = start_idx; + } + + fn set_rows(&mut self, num_rows: usize) { + self.num_rows = num_rows; + } + + fn set_process_rest(&mut self, process_rest: Option) { + self.process_rest = process_rest; + } +} + +impl Stream for ClassicPWMJStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +#[allow(clippy::too_many_arguments)] +fn resolve_classic_join( + buffered_side: &mut BufferedSideReadyState, + stream_batch: &StreamedBatch, + join_schema: Arc, + operator: Operator, + sort_options: SortOptions, + join_type: JoinType, + batch_process_state: &mut BatchProcessState, + batch_size: usize, +) -> Result { + let buffered_values = buffered_side.buffered_data.values(); + let buffered_len = buffered_values.len(); + let stream_values = stream_batch.values(); + + let mut buffered_indices = UInt64Builder::default(); + let mut stream_indices = UInt32Builder::default(); + + // Our pivot variable allows us to start probing on the buffered side where we last matched + // in the previous stream row. + let mut pivot = batch_process_state.pivot(); + for row_idx in batch_process_state.start_idx..stream_values[0].len() { + let mut found = false; + + // Check once to see if it is a redo of a null value if not we do not try to process the batch + if !batch_process_state.not_found { + while pivot < buffered_values.len() + || batch_process_state.process_rest.is_some() + { + // If there is still data left in the batch to process, use the index and output + if let Some(start_idx) = batch_process_state.process_rest { + let count = buffered_values.len() - start_idx; + if count >= batch_size { + let stream_repeated = vec![row_idx as u32; batch_size]; + batch_process_state + .set_process_rest(Some(start_idx + batch_size)); + batch_process_state + .set_rows(batch_process_state.num_rows + batch_size); + let buffered_range: Vec = (start_idx as u64 + ..((start_idx as u64) + (batch_size as u64))) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + batch_process_state.set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (start_idx as u64..buffered_len as u64).collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + batch_process_state.process_rest = None; + + found = true; + + break; + } + + let compare = compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + pivot, + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count >= batch_size { + let process_batch_size = + batch_size - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.continue_process = true; + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count >= batch_size { + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + + let process_batch_size = + batch_size - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + // Increment pivot after every row + pivot += 1; + } + } + + // If not found we append a null value for `JoinType::Right` and `JoinType::Full` + if (!found || batch_process_state.not_found) + && matches!(join_type, JoinType::Right | JoinType::Full) + { + let remaining = batch_size.saturating_sub(batch_process_state.num_rows); + if remaining == 0 { + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.not_found = true; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Append right side value + null value for left + stream_indices.append_value(row_idx as u32); + buffered_indices.append_null(); + batch_process_state.set_rows(batch_process_state.num_rows + 1); + batch_process_state.not_found = false; + } + } + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Resets batch process state for processing `Left` + `Full` join + batch_process_state.reset(); + + Ok(batch) +} + +fn process_batch( + buffered_indices: &mut PrimitiveBuilder, + stream_indices: &mut PrimitiveBuilder, + stream_batch: &StreamedBatch, + buffered_side: &mut BufferedSideReadyState, + join_type: JoinType, + join_schema: Arc, +) -> Result { + let stream_indices_array = stream_indices.finish(); + let buffered_indices_array = buffered_indices.finish(); + + // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Left` + if need_produce_result_in_final(join_type) { + let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); + + buffered_indices_array.iter().flatten().for_each(|i| { + bitmap.set_bit(i as usize, true); + }); + } + + let batch = build_matched_indices( + join_schema, + &stream_batch.batch, + &buffered_side.buffered_data.batch, + stream_indices_array, + buffered_indices_array, + )?; + + Ok(batch) +} + +fn build_matched_indices( + schema: Arc, + streamed_batch: &RecordBatch, + buffered_batch: &RecordBatch, + streamed_indices: UInt32Array, + buffered_indices: UInt64Array, +) -> Result { + if schema.fields().is_empty() { + // Build an “empty” RecordBatch with just row‐count metadata + let options = RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(streamed_indices.len())); + return Ok(RecordBatch::try_new_with_options( + Arc::new((*schema).clone()), + vec![], + &options, + )?); + } + + // Gather stream columns after applying filter specified with stream indices + let streamed_columns = streamed_batch + .columns() + .iter() + .map(|column_array| { + if column_array.is_empty() + || streamed_indices.null_count() == streamed_indices.len() + { + assert_eq!(streamed_indices.null_count(), streamed_indices.len()); + Ok(new_null_array( + column_array.data_type(), + streamed_indices.len(), + )) + } else { + take(column_array, &streamed_indices, None) + } + }) + .collect::, ArrowError>>()?; + + let mut buffered_columns = buffered_batch + .columns() + .iter() + .map(|column_array| take(column_array, &buffered_indices, None)) + .collect::, ArrowError>>()?; + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::new((*schema).clone()), + buffered_columns, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + joins::PiecewiseMergeJoinExec, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Date64Array}; + use arrow_schema::{DataType, Field}; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 2 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![2, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 2 | 7 | 30 | 1 | 90 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_left() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // (empty) + // +----+----+----+ + let left = build_table( + ("a1", &Vec::::new()), + ("b1", &Vec::::new()), + ("c1", &Vec::::new()), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 1 | 1 | 1 | + // | 2 | 2 | 2 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c2", &vec![1, 2]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_full_greater_than_equal_to() -> Result<()> { + // +----+----+-----+ + // | a1 | b1 | c1 | + // +----+----+-----+ + // | 1 | 1 | 100 | + // | 2 | 2 | 200 | + // +----+----+-----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![100, 200]), + ); + + // +----+----+-----+ + // | a2 | b1 | c2 | + // +----+----+-----+ + // | 10 | 3 | 300 | + // | 20 | 2 | 400 | + // +----+----+-----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 2]), + ("c2", &vec![300, 400]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | | | | 10 | 3 | 300 | + | 1 | 1 | 100 | | | | + +----+----+-----+----+----+-----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 10 | 3 | 70 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 5 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 2 | 90 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 20 | 3 | 80 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 3, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 1 | 9 | 30 | 5 | 90 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date32_inner_less_than() -> Result<()> { + // +----+-------+----+ + // | a1 | b1 | c1 | + // +----+-------+----+ + // | 1 | 19107 | 7 | + // | 2 | 19107 | 8 | + // | 3 | 19105 | 9 | + // +----+-------+----+ + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19107, 19105]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+-------+----+ + // | a2 | b1 | c2 | + // +----+-------+----+ + // | 10 | 19105 | 70 | + // | 20 | 19103 | 80 | + // | 30 | 19107 | 90 | + // +----+-------+----+ + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19105, 19103, 19107]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_inner_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 70 | + // | 20 | 1650503441000 | 80 | + // | 30 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs new file mode 100644 index 000000000000..4bcd1ffa6f80 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -0,0 +1,729 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::{ + array::{ArrayRef, BooleanBufferBuilder, RecordBatch}, + compute::concat_batches, + util::bit_util, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::not_impl_err; +use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + SendableRecordBatchStream, +}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{ + LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::TryStreamExt; +use parking_lot::Mutex; +use std::fmt::Formatter; +use std::sync::Arc; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; + +use crate::joins::piecewise_merge_join::classic_join::{ + ClassicPWMJStream, PiecewiseMergeJoinStreamState, +}; +use crate::joins::piecewise_merge_join::utils::{ + build_visited_indices_map, is_existence_join, is_right_existence_join, +}; +use crate::joins::utils::symmetric_join_output_partitioning; +use crate::{ + joins::{ + utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, + SharedBitmapBuilder, + }, + metrics::ExecutionPlanMetricsSet, + spill::get_record_batch_memory_size, + ExecutionPlan, PlanProperties, +}; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; + +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. +/// +/// The physical planner will choose to evaluate this join when there is only one range predicate. This +/// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and +/// [`Operator::GtEq`].: +/// Examples: +/// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` +/// +/// Since the join only support range predicates, equijoins are not supported in `PiecewiseMergeJoinExec`, +/// however you can first evaluate another join and run `PiecewiseMergeJoinExec` if left with one range +/// predicate. +/// +/// # Execution Plan Inputs +/// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the +/// 'buffered' side. +/// +/// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record +/// batches during processing. Sorted input must specifically be ascending/descending based on the operator. +/// +/// # Algorithms +/// Classic joins are processed differently compared to existence joins. +/// +/// ## Classic Joins (Inner, Full, Left, Right) +/// For classic joins we buffer the right side (buffered), and incrementally process the left side (streamed). +/// Every streamed batch is sorted so we can perform a sort merge algorithm. For the buffered side we want to +/// have it already sorted either ascending or descending based on the operator as this allows us to emit all +/// the rows from a given point to the end as matches. Sorting the streamed side allows us to start the pointer +/// from the previous row's match on the buffered side. +/// +/// For `Lt` (`<`) + `LtEq` (`<=`) operations both inputs are to be sorted in descending order and sorted in +/// ascending order for `Gt` (`>`) + `GtEq` (`>=`) than (`>`) operations. `SortExec` is used to enforce sorting +/// on the buffered side and streamed side is sorted in memory. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// for stream_row in stream_batch: +/// for buffer_row in buffer_batch: +/// if compare(stream_row, probe_row): +/// output stream_row X buffer_batch[buffer_row:] +/// else: +/// continue +/// ``` +/// +/// The algorithm uses the streamed side to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. +/// +/// Here is an example: +/// +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). For each +/// row on the streamed side we move a pointer on the buffered until it matches the condition. Once we reach +/// the row which matches (in this case with row 1 on streamed will have its first match on row 2 on +/// buffered; 100 < 200 is true), we can emit all rows after that match. We can emit the rows like this because +/// if the batch is sorted in ascending order, every subsequent row will also satisfy the condition as they will +/// all be larger values. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Processing Row 1: +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row +/// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) +/// └──────────────────┘ +/// +/// Processing Row 2: +/// By sorting the streamed side we know +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 │ +/// └──────────────────┘ +/// +/// ``` +/// +/// ## Existence Joins (Semi, Anti, Mark) +/// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find +/// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting +/// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For less than operations (`<`) both inputs are to be sorted in descending order and vice versa for greater +/// than (`>`) operations. `SortExec` is used to enforce sorting on the buffered side and streamed side does not +/// need to be sorted due to only needing to find the min/max. +/// +/// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// // Using the example of a less than `<` operation +/// let max = max_batch(streamed_batch) +/// +/// for buffer_row in buffer_batch: +/// if buffer_row < max: +/// output buffer_batch[buffer_row:] +/// ``` +/// +/// Only need to find the min/max value and iterate through the buffered side once. +/// +/// Here is an example: +/// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because +/// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. +/// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies +/// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of +/// those rows as matched. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (500), (200), (300)) AS streamed(a) +/// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Sorted Buffered Side Unsorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 500 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 3 │ 200 │ 3 │ 300 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 +/// 5 │ 400 │ ─┘ on the buffered side. +/// └──────────────────┘ +/// min value: 200 +/// ``` +/// +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or +/// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). +/// +/// # Performance Explanation (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: +/// +/// ## Piecewise Merge Join (PWMJ) +/// # Classic Join: +/// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match +/// is found. +/// Complexity: `O(sort(S) + |S| * scan(R))`. +/// +/// # Mark Join: +/// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only +/// within that range. +/// Complexity: `O(|S| + scan(R[range]))`. +/// +/// ## Nested Loop Join +/// Compares every row from `S` with every row from `R`. +/// Complexity: `O(|S| * |R|)`. +/// +/// ## Nested Loop Join +/// Always going to be probe (O(N) * O(N)). +/// +/// # Further Reference Material +/// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) +#[derive(Debug)] +pub struct PiecewiseMergeJoinExec { + /// Left buffered execution plan + pub buffered: Arc, + /// Right streamed execution plan + pub streamed: Arc, + /// The two expressions being compared + pub on: (Arc, Arc), + /// Comparison operator in the range predicate + pub operator: Operator, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Buffered data + buffered_fut: OnceAsync, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + left_sort_exprs: LexOrdering, + /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + /// Unsorted for mark joins + right_sort_exprs: LexOrdering, + /// Sort options of join columns used in sorting the stream and buffered execution plans + sort_options: SortOptions, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl PiecewiseMergeJoinExec { + pub fn try_new( + buffered: Arc, + streamed: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + // TODO: Implement existence joins for PiecewiseMergeJoin + if is_existence_join(join_type) { + return not_impl_err!( + "Existence Joins are currently not supported for PiecewiseMergeJoin" + ); + } + + // Take the operator and enforce a sort order on the streamed + buffered side based on + // the operator type. + let sort_options = match operator { + Operator::Lt | Operator::LtEq => { + // For left existence joins the inputs will be swapped so the sort + // options are switched + if is_right_existence_join(join_type) { + SortOptions::new(false, false) + } else { + SortOptions::new(true, false) + } + } + Operator::Gt | Operator::GtEq => { + if is_right_existence_join(join_type) { + SortOptions::new(true, false) + } else { + SortOptions::new(false, false) + } + } + _ => { + return internal_err!( + "Cannot contain non-range operator in PiecewiseMergeJoinExec" + ) + } + }; + + // Give the same `sort_option for comparison later` + let left_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let buffered_schema = buffered.schema(); + let streamed_schema = streamed.schema(); + + // Create output schema for the join + let schema = + Arc::new(build_join_schema(&buffered_schema, &streamed_schema, &join_type).0); + let cache = Self::compute_properties( + &buffered, + &streamed, + Arc::clone(&schema), + join_type, + &on, + )?; + + Ok(Self { + streamed, + buffered, + on, + operator, + join_type, + schema, + buffered_fut: Default::default(), + metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + sort_options, + cache, + }) + } + + /// Reference to buffered side execution plan + pub fn buffered(&self) -> &Arc { + &self.buffered + } + + /// Reference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Reference to sort options + pub fn sort_options(&self) -> &SortOptions { + &self.sort_options + } + + /// Get probe side (streamed side) for the PiecewiseMergeJoin + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::Right + | JoinType::Inner + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + pub fn compute_properties( + buffered: &Arc, + streamed: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: &(PhysicalExprRef, PhysicalExprRef), + ) -> Result { + let eq_properties = join_equivalence_properties( + buffered.equivalence_properties().clone(), + streamed.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + std::slice::from_ref(join_on), + )?; + + let output_partitioning = + symmetric_join_output_partitioning(buffered, streamed, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([buffered, streamed]), + )) + } + + // TODO: Add input order + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + // The existence side is expected to come in sorted + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + vec![false, false] + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + vec![false, false] + } + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during + // execution for `PiecewiseMergeJoin` + _ => vec![false, false], + } + } + + // TODO + pub fn swap_inputs(&self) -> Result> { + todo!() + } +} + +impl ExecutionPlan for PiecewiseMergeJoinExec { + fn name(&self) -> &str { + "PiecewiseMergeJoinExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.buffered, &self.streamed] + } + + fn required_input_ordering(&self) -> Vec> { + // Existence joins don't need to be sorted on one side. + if is_right_existence_join(self.join_type) { + // Right side needs to be sorted because this will be swapped to the + // buffered side + vec![ + None, + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } else { + // Sort the right side in memory, so we do not need to enforce any sorting + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + None, + ] + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.operator, + self.join_type, + )?)), + _ => internal_err!( + "PiecewiseMergeJoin should have 2 children, found {}", + children.len() + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_buffered = Arc::clone(&self.on.0); + let on_streamed = Arc::clone(&self.on.1); + + // If the join type is either RightSemi, RightAnti, or RightMark we will swap the inputs + // and sort ordering because we want the mark side to be the buffered side. + let (buffered, streamed, on_buffered, on_streamed, operator) = + if is_right_existence_join(self.join_type) { + ( + Arc::clone(&self.streamed), + Arc::clone(&self.buffered), + on_streamed, + on_buffered, + self.operator.swap().unwrap(), + ) + } else { + ( + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + on_buffered, + on_streamed, + self.operator, + ) + }; + + let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); + let buffered_fut = self.buffered_fut.try_once(|| { + let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") + .register(context.memory_pool()); + let buffered_stream = buffered.execute(partition, Arc::clone(&context))?; + Ok(build_buffered_data( + buffered_stream, + Arc::clone(&on_buffered), + metrics.clone(), + reservation, + build_visited_indices_map(self.join_type), + )) + })?; + + let streamed = streamed.execute(partition, Arc::clone(&context))?; + + let batch_size = context.session_config().batch_size(); + + // TODO: Add existence joins + this is guarded at physical planner + if is_existence_join(self.join_type()) { + unreachable!() + } else { + Ok(Box::pin(ClassicPWMJStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + PiecewiseMergeJoinStreamState::WaitBufferedSide, + self.sort_options, + metrics, + batch_size, + ))) + } + } +} + +impl DisplayAs for PiecewiseMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let on_str = format!( + "({} {} {})", + fmt_sql(self.on.0.as_ref()), + self.operator, + fmt_sql(self.on.1.as_ref()) + ); + + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "PiecewiseMergeJoin: operator={:?}, join_type={:?}, on={}", + self.operator, self.join_type, on_str + ) + } + + DisplayFormatType::TreeRender => { + writeln!(f, "operator={:?}", self.operator)?; + if self.join_type != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on_str}") + } + } + } +} + +async fn build_buffered_data( + buffered: SendableRecordBatchStream, + on_buffered: PhysicalExprRef, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + build_map: bool, +) -> Result { + let schema = buffered.schema(); + + // Combine batches and record number of rows + let initial = (Vec::new(), 0, metrics, reservation); + let (batches, num_rows, metrics, mut reservation) = buffered + .try_fold(initial, |mut acc, batch| async { + let batch_size = get_record_batch_memory_size(&batch); + acc.3.try_grow(batch_size)?; + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update row count + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let batches_iter = batches.iter().rev(); + let single_batch = concat_batches(&schema, batches_iter)?; + + // Evaluate physical expression on the buffered side. + let buffered_values = on_buffered + .evaluate(&single_batch)? + .into_array(single_batch.num_rows())?; + + // We add the single batch size + the memory of the join keys + // size of the size estimation + let size_estimation = get_record_batch_memory_size(&single_batch) + + buffered_values.get_array_memory_size(); + reservation.try_grow(size_estimation)?; + metrics.build_mem_used.add(size_estimation); + + // Created visited indices bitmap only if the join type requires it + let visited_indices_bitmap = if build_map { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let buffered_data = BufferedSideData::new( + single_batch, + buffered_values, + Mutex::new(visited_indices_bitmap), + reservation, + ); + + Ok(buffered_data) +} + +pub(super) struct BufferedSideData { + pub(super) batch: RecordBatch, + values: ArrayRef, + pub(super) visited_indices_bitmap: SharedBitmapBuilder, + _reservation: MemoryReservation, +} + +impl BufferedSideData { + pub(super) fn new( + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + values, + visited_indices_bitmap, + _reservation: reservation, + } + } + + pub(super) fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub(super) fn values(&self) -> &ArrayRef { + &self.values + } +} + +pub(super) enum BufferedSide { + /// Indicates that build-side not collected yet + Initial(BufferedSideInitialState), + /// Indicates that build-side data has been collected + Ready(BufferedSideReadyState), +} + +impl BufferedSide { + // Takes a mutable state of the buffered row batches + pub(super) fn try_as_initial_mut(&mut self) -> Result<&mut BufferedSideInitialState> { + match self { + BufferedSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + pub(super) fn try_as_ready(&self) -> Result<&BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => { + internal_err!("Expected build side in ready state") + } + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + pub(super) fn try_as_ready_mut(&mut self) -> Result<&mut BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +pub(super) struct BufferedSideInitialState { + pub(crate) buffered_fut: OnceFut, +} + +pub(super) struct BufferedSideReadyState { + /// Collected build-side data + pub(super) buffered_data: Arc, +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs new file mode 100644 index 000000000000..f66de0ddab43 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub use exec::PiecewiseMergeJoinExec; + +mod classic_join; +mod exec; +mod utils; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs new file mode 100644 index 000000000000..5bbb496322b5 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::JoinType; + +// Returns boolean for whether the join is a right existence join +pub(super) fn is_right_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) +} + +// Returns boolean for whether the join is an existence join +pub(super) fn is_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +// Returns boolean to check if the join type needs to record +// buffered side matches for classic joins +pub(super) fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!(join_type, JoinType::Full | JoinType::Left) +} + +// Returns boolean for whether or not we need to build the buffered side +// bitmap for marking matched rows on the buffered side. +pub(super) fn build_visited_indices_map(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Left + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index d28a9bad17ec..27fd0c4c2121 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -34,7 +34,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::JoinFilter; +use crate::joins::utils::{compare_join_arrays, JoinFilter}; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -1852,99 +1852,6 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = match null_equality { - NullEquality::NullEqualsNothing => Ordering::Less, - NullEquality::NullEqualsNull => Ordering::Equal, - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - /// A faster version of compare_join_arrays() that only output whether /// the given two rows are equal fn is_join_arrays_equal( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d392650f88dd..b41a3e0514cf 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -43,7 +43,13 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, +}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; use arrow::compute::{self, and, take, FilterBuilder}; @@ -51,12 +57,13 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, + not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, + SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; @@ -284,7 +291,7 @@ pub fn build_join_schema( JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), JoinType::LeftMark => { let right_field = once(( - Field::new("mark", arrow::datatypes::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -295,7 +302,7 @@ pub fn build_join_schema( JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), JoinType::RightMark => { let left_field = once(( - Field::new("mark", arrow_schema::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -817,9 +824,10 @@ pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { pub(crate) fn get_final_indices_from_shared_bitmap( shared_bitmap: &SharedBitmapBuilder, join_type: JoinType, + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) + get_final_indices_from_bit_map(&bitmap, join_type, piecewise) } /// In the end of join execution, need to use bit map of the matched @@ -834,16 +842,22 @@ pub(crate) fn get_final_indices_from_shared_bitmap( pub(crate) fn get_final_indices_from_bit_map( left_bit_map: &BooleanBufferBuilder, join_type: JoinType, + // We add a flag for whether this is being passed from the `PiecewiseMergeJoin` + // because the bitmap can be for left + right `JoinType`s + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { + if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise) + { let left_indices = (0..left_size as u64).collect::(); let right_indices = (0..left_size) .map(|idx| left_bit_map.get_bit(idx).then_some(0)) .collect::(); return (left_indices, right_indices); } - let left_indices = if join_type == JoinType::LeftSemi { + let left_indices = if join_type == JoinType::LeftSemi + || (join_type == JoinType::RightSemi && piecewise) + { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) .collect::() @@ -1753,6 +1767,99 @@ fn eq_dyn_null( } } +/// Get comparison result of two rows of join arrays +pub fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ad21bdac6d2d..c9bbf3cf5734 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4148,10 +4148,11 @@ logical_plan 03)----TableScan: left_table projection=[a, b, c] 04)----TableScan: right_table projection=[x, y, z] physical_plan -01)SortExec: expr=[x@3 ASC NULLS LAST], preserve_partitioning=[false] -02)--NestedLoopJoinExec: join_type=Inner, filter=a@0 < x@1 -03)----DataSourceExec: partitions=1, partition_sizes=[0] -04)----DataSourceExec: partitions=1, partition_sizes=[0] +01)SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(a < x) +03)----SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)----DataSourceExec: partitions=1, partition_sizes=[0] query TT EXPLAIN SELECT * FROM left_table JOIN right_table ON left_table.a= t1.c2 LIMIT 20; +---- +01)GlobalLimitExec: skip=0, fetch=20 +01)Limit: skip=0, fetch=20 +02)--Full Join: Filter: t0.c2 >= t1.c2 +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) +03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----TableScan: t0 projection=[c1, c2] +04)------DataSourceExec: partitions=1, partition_sizes=[2] +04)----TableScan: t1 projection=[c1, c2, c3] +05)----DataSourceExec: partitions=1, partition_sizes=[2] +logical_plan +physical_plan + query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; @@ -4212,6 +4237,9 @@ SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; 4 4 3 3 false 4 4 3 3 true +statement ok +set datafusion.execution.batch_size = 3; + query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 20; @@ -4238,7 +4266,7 @@ physical_plan 03)----DataSourceExec: partitions=1, partition_sizes=[2] 04)----DataSourceExec: partitions=1, partition_sizes=[2] -## Test join.on.is_empty() && join.filter.is_some() +## Test join.on.is_empty() && join.filter.is_some() -> single filter now a PWMJ query TT EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; ---- @@ -4249,9 +4277,10 @@ logical_plan 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=2 -02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) +03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[2] +05)----DataSourceExec: partitions=1, partition_sizes=[2] ## Test !join.on.is_empty() && join.filter.is_some() query TT @@ -5161,6 +5190,44 @@ WHERE k1 < 0 ---- +# PiecewiseMergeJoin Test +statement ok +set datafusion.execution.batch_size = 8192; + +# TODO: partitioned PWMJ execution +statement ok +set datafusion.execution.target_partitions = 1; + +query II +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +22 11 +33 11 +44 11 + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 > 10 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt new file mode 100644 index 000000000000..ee9622e6bb1b --- /dev/null +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE join_t1 (t1_id INT); + +statement ok +CREATE TABLE join_t2 (t2_id INT, t2_name TEXT, t2_int INT); + +statement ok +INSERT INTO join_t1 VALUES (11), (22), (33), (44); + +statement ok +INSERT INTO join_t2 VALUES + (11, 'z', 3), + (22, 'y', 1), + (44, 'x', 3), + (55, 'w', 3); + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +22 11 +33 11 +44 11 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 > 10 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +22 11 +33 11 +44 11 +44 44 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id >= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id >= Int32(22) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int = Int32(3) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 >= 22 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 = 3, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +11 55 +11 44 +22 55 +22 44 +33 55 +33 44 +44 55 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id < t2.t2_id +03)----SubqueryAlias: t1 +04)------TableScan: join_t1 projection=[t1_id] +05)----SubqueryAlias: t2 +06)------Projection: join_t2.t2_id +07)--------Filter: join_t2.t2_int >= Int32(3) +08)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) +03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)----CoalesceBatchesExec: target_batch_size=8192 +06)------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] + + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +11 55 +11 44 +11 11 +44 55 +44 44 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id <= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id = Int32(11) OR join_t1.t1_id = Int32(44) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_name != Utf8View("y") +09)----------TableScan: join_t2 projection=[t2_id, t2_name] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) +03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_name@1 != y, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1]