Skip to content
137 changes: 134 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::utils::split_conjunction;
use datafusion_expr::{
Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType,
Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame,
WindowFrameBound, WriteOp,
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan,
WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::{Column, Literal};
Expand All @@ -90,6 +91,7 @@ use datafusion_physical_expr::{
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::empty::EmptyExec;
use datafusion_physical_plan::execution_plan::InvariantLevel;
use datafusion_physical_plan::joins::PiecewiseMergeJoinExec;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::recursive_query::RecursiveQueryExec;
use datafusion_physical_plan::unnest::ListUnnest;
Expand Down Expand Up @@ -1089,8 +1091,42 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<join_utils::JoinOn>>()?;

// TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`)
let mut num_range_filters = 0;
let mut range_filters: Vec<Expr> = Vec::new();
let mut total_filters = 0;

let join_filter = match filter {
Some(expr) => {
let split_expr = split_conjunction(expr);
for expr in split_expr.iter() {
match *expr {
Expr::BinaryExpr(BinaryExpr {
left: _,
right: _,
op,
}) => {
if matches!(
op,
Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
) {
range_filters.push((**expr).clone());
num_range_filters += 1;
}
total_filters += 1;
}
// TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates
// which is why it is not dealt with in PWMJ
// Expr::Between(_) => {},
_ => {
total_filters += 1;
}
}
}

// Extract columns from filter expression and saved in a HashSet
let cols = expr.column_refs();

Expand Down Expand Up @@ -1146,6 +1182,7 @@ impl DefaultPhysicalPlanner {
)?;
let filter_schema =
Schema::new_with_metadata(filter_fields, metadata);

let filter_expr = create_physical_expr(
expr,
&filter_df_schema,
Expand All @@ -1168,10 +1205,104 @@ impl DefaultPhysicalPlanner {
let prefer_hash_join =
session_state.config_options().optimizer.prefer_hash_join;

let cfg = session_state.config();

let can_run_single =
cfg.target_partitions() == 1 || !cfg.repartition_joins();

// TODO: Allow PWMJ to deal with residual equijoin conditions
let join: Arc<dyn ExecutionPlan> = if join_on.is_empty() {
if join_filter.is_none() && matches!(join_type, JoinType::Inner) {
// cross join if there is no join conditions and no join filter set
Arc::new(CrossJoinExec::new(physical_left, physical_right))
} else if num_range_filters == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to refactor this in another pull request, just a refactor but it should be quite simple to do. Just wanted to get this version in first.

&& total_filters == 1
&& can_run_single
&& !matches!(
join_type,
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
)
{
let Expr::BinaryExpr(be) = &range_filters[0] else {
return plan_err!(
"Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`"
);
};

let mut op = be.op;
if !matches!(
op,
Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq
) {
return plan_err!(
"Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=",
op
);
}

fn reverse_ineq(op: Operator) -> Operator {
match op {
Operator::Lt => Operator::Gt,
Operator::LtEq => Operator::GtEq,
Operator::Gt => Operator::Lt,
Operator::GtEq => Operator::LtEq,
_ => op,
}
}

let side_of = |e: &Expr| -> Result<&'static str> {
let cols = e.column_refs();
let in_left = cols
.iter()
.all(|c| left_df_schema.index_of_column(c).is_ok());
let in_right = cols
.iter()
.all(|c| right_df_schema.index_of_column(c).is_ok());
match (in_left, in_right) {
(true, false) => Ok("left"),
(false, true) => Ok("right"),
_ => unreachable!(),
}
};

let mut lhs_logical = &be.left;
let mut rhs_logical = &be.right;

let left_side = side_of(lhs_logical)?;
let right_side = side_of(rhs_logical)?;
if left_side == "right" && right_side == "left" {
std::mem::swap(&mut lhs_logical, &mut rhs_logical);
op = reverse_ineq(op);
} else if !(left_side == "left" && right_side == "right") {
return plan_err!(
"Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=",
op
);
}

let on_left = create_physical_expr(
lhs_logical,
left_df_schema,
session_state.execution_props(),
)?;
let on_right = create_physical_expr(
rhs_logical,
right_df_schema,
session_state.execution_props(),
)?;

Arc::new(PiecewiseMergeJoinExec::try_new(
physical_left,
physical_right,
(on_left, on_right),
op,
*join_type,
)?)
} else {
// there is no equal join condition, use the nested loop join
Arc::new(NestedLoopJoinExec::try_new(
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ impl HashJoinStream {
let (left_side, right_side) = get_final_indices_from_shared_bitmap(
build_side.left_data.visited_indices_bitmap(),
self.join_type,
true,
);
let empty_right_batch = RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ pub use hash_join::HashJoinExec;
pub use nested_loop_join::NestedLoopJoinExec;
use parking_lot::Mutex;
// Note: SortMergeJoin is not used in plans yet
pub use piecewise_merge_join::PiecewiseMergeJoinExec;
pub use sort_merge_join::SortMergeJoinExec;
pub use symmetric_hash_join::SymmetricHashJoinExec;
mod cross_join;
mod hash_join;
mod nested_loop_join;
mod piecewise_merge_join;
mod sort_merge_join;
mod stream_join_utils;
mod symmetric_hash_join;
Expand Down
Loading