diff --git a/native/core/src/execution/joins/exec.rs b/native/core/src/execution/joins/exec.rs new file mode 100644 index 0000000000..f88bf50677 --- /dev/null +++ b/native/core/src/execution/joins/exec.rs @@ -0,0 +1,431 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for semi/anti sort-merge joins. +//! +//! Ported from Apache DataFusion. + +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; + +use super::stream::SemiAntiSortMergeJoinStream; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::expressions::PhysicalSortExpr; +use datafusion::physical_plan::joins::utils::{ + build_join_schema, check_join_is_valid, JoinFilter, JoinOn, JoinOnRef, +}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, +}; + +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use datafusion::common::{ + assert_eq_or_internal_err, plan_err, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::equivalence::join_equivalence_properties; +use datafusion::physical_expr_common::physical_expr::fmt_sql; +use datafusion::physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; + +/// Sort-merge join operator specialized for semi/anti joins. +/// +/// # Motivation +/// +/// The general-purpose `SortMergeJoinExec` handles semi/anti joins by +/// materializing `(outer, inner)` row pairs, applying a filter, then using a +/// "corrected filter mask" to deduplicate. Semi/anti joins only need a boolean +/// per outer row (does a match exist?), not pairs. The pair-based approach +/// incurs unnecessary memory allocation and intermediate batches. +/// +/// This operator instead tracks matches with a per-outer-batch bitset, +/// avoiding all pair materialization. +/// +/// Supports: `LeftSemi`, `LeftAnti`, `RightSemi`, `RightAnti`. +/// +/// # Algorithm +/// +/// Both inputs must be sorted by the join keys. The stream performs a merge +/// scan across the two sorted inputs. At each step: +/// +/// - **outer < inner**: Skip the outer key group (no match exists). +/// - **outer > inner**: Skip the inner key group. +/// - **outer == inner**: Process the match. +/// +/// **Without filter**: All outer rows in the key group are marked as matched. +/// +/// **With filter**: The inner key group is buffered. For each buffered inner +/// row, the filter is evaluated against the outer key group as a batch. +/// Results are OR'd into the matched bitset. Short-circuits when all outer +/// rows in the group are matched. +/// +/// On emit: +/// Semi -> filter_record_batch(outer_batch, &matched) +/// Anti -> filter_record_batch(outer_batch, &NOT(matched)) +#[derive(Debug, Clone)] +pub struct SemiAntiSortMergeJoinExec { + pub left: Arc, + pub right: Arc, + pub on: JoinOn, + pub filter: Option, + pub join_type: JoinType, + schema: SchemaRef, + metrics: ExecutionPlanMetricsSet, + left_sort_exprs: LexOrdering, + right_sort_exprs: LexOrdering, + pub sort_options: Vec, + pub null_equality: NullEquality, + cache: PlanProperties, +} + +impl SemiAntiSortMergeJoinExec { + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + ) -> Result { + if !matches!( + join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti + ) { + return plan_err!( + "SemiAntiSortMergeJoinExec only supports semi/anti joins, got {:?}", + join_type + ); + } + + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + if sort_options.len() != on.len() { + return plan_err!( + "Expected number of sort options: {}, actual: {}", + on.len(), + sort_options.len() + ); + } + + let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on + .iter() + .zip(sort_options.iter()) + .map(|((l, r), sort_op)| { + let left = PhysicalSortExpr { + expr: Arc::clone(l), + options: *sort_op, + }; + let right = PhysicalSortExpr { + expr: Arc::clone(r), + options: *sort_op, + }; + (left, right) + }) + .unzip(); + + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return plan_err!( + "SemiAntiSortMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return plan_err!( + "SemiAntiSortMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?; + + Ok(Self { + left, + right, + on, + filter, + join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + sort_options, + null_equality, + cache, + }) + } + + /// The outer (probe) side: Left for LeftSemi/LeftAnti, Right for RightSemi/RightAnti. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::RightSemi | JoinType::RightAnti => JoinSide::Right, + _ => JoinSide::Left, + } + } + + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::RightSemi | JoinType::RightAnti => vec![false, true], + _ => vec![false, false], + } + } + + fn compute_properties( + left: &Arc, + right: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: JoinOnRef, + ) -> Result { + let eq_properties = join_equivalence_properties( + left.equivalence_properties().clone(), + right.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + join_on, + )?; + let output_partitioning = symmetric_join_output_partitioning(left, right, &join_type)?; + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([left, right]), + )) + } +} + +/// Inlined from `datafusion_physical_plan::execution_plan::boundedness_from_children` +/// which is `pub(crate)` in DF 52.2.0. +fn boundedness_from_children<'a>( + children: impl IntoIterator>, +) -> Boundedness { + let mut unbounded_with_finite_mem = false; + for child in children { + match child.boundedness() { + Boundedness::Unbounded { + requires_infinite_memory: true, + } => { + return Boundedness::Unbounded { + requires_infinite_memory: true, + }; + } + Boundedness::Unbounded { + requires_infinite_memory: false, + } => { + unbounded_with_finite_mem = true; + } + Boundedness::Bounded => {} + } + } + if unbounded_with_finite_mem { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + } +} + +/// Inlined from `datafusion_physical_plan::joins::utils::symmetric_join_output_partitioning` +/// which is `pub(crate)` in DF 52.2.0. +fn symmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Result { + let left_partitioning = left.output_partitioning(); + let right_partitioning = right.output_partitioning(); + let result = match join_type { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left_partitioning.clone() + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_partitioning.clone() + } + _ => Partitioning::UnknownPartitioning(right_partitioning.partition_count()), + }; + Ok(result) +} + +impl DisplayAs for SemiAntiSortMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({c1}, {c2})")) + .collect::>() + .join(", "); + write!( + f, + "{}: join_type={:?}, on=[{}]{}", + Self::static_name(), + self.join_type, + on, + self.filter.as_ref().map_or_else( + || "".to_string(), + |filt| format!(", filter={}", filt.expression()) + ), + ) + } + DisplayFormatType::TreeRender => { + let on = self + .on + .iter() + .map(|(c1, c2)| { + format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref())) + }) + .collect::>() + .join(", "); + + writeln!(f, "join_type={:?}", self.join_type)?; + writeln!(f, "on={on}")?; + Ok(()) + } + } + } +} + +impl ExecutionPlan for SemiAntiSortMergeJoinExec { + fn name(&self) -> &'static str { + "SemiAntiSortMergeJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + fn required_input_ordering(&self) -> Vec> { + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } + + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(Self::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?)), + _ => datafusion::common::internal_err!( + "SemiAntiSortMergeJoinExec wrong number of children" + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + assert_eq_or_internal_err!( + left_partitions, + right_partitions, + "Invalid SemiAntiSortMergeJoinExec, partition count mismatch \ + {left_partitions}!={right_partitions}" + ); + + let (on_left, on_right): (Vec<_>, Vec<_>) = self.on.iter().cloned().unzip(); + + let (outer, inner, on_outer, on_inner) = + if Self::probe_side(&self.join_type) == JoinSide::Left { + ( + Arc::clone(&self.left), + Arc::clone(&self.right), + on_left, + on_right, + ) + } else { + ( + Arc::clone(&self.right), + Arc::clone(&self.left), + on_right, + on_left, + ) + }; + + let outer = outer.execute(partition, Arc::clone(&context))?; + let inner = inner.execute(partition, Arc::clone(&context))?; + let batch_size = context.session_config().batch_size(); + + Ok(Box::pin(SemiAntiSortMergeJoinStream::try_new( + Arc::clone(&self.schema), + self.sort_options.clone(), + self.null_equality, + outer, + inner, + on_outer, + on_inner, + self.filter.clone(), + self.join_type, + batch_size, + partition, + &self.metrics, + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema)) + } +} diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs new file mode 100644 index 0000000000..dcf86050b8 --- /dev/null +++ b/native/core/src/execution/joins/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized Sort Merge Join for Semi/Anti joins. +//! +//! Ported from Apache DataFusion. See [`SemiAntiSortMergeJoinExec`] for +//! algorithm details and motivation. + +pub use exec::SemiAntiSortMergeJoinExec; + +mod exec; +mod stream; diff --git a/native/core/src/execution/joins/stream.rs b/native/core/src/execution/joins/stream.rs new file mode 100644 index 0000000000..63b950f8dc --- /dev/null +++ b/native/core/src/execution/joins/stream.rs @@ -0,0 +1,918 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for semi/anti sort-merge joins. +//! +//! Ported from Apache DataFusion. + +use std::cmp::Ordering; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{Array, ArrayRef, BooleanArray, BooleanBufferBuilder, RecordBatch}; +use arrow::compute::{filter_record_batch, not, BatchCoalescer, SortOptions}; +use arrow::datatypes::SchemaRef; +use arrow::util::bit_chunk_iterator::UnalignedBitChunk; +use arrow::util::bit_util::apply_bitwise_binary_op; +use datafusion::common::{internal_err, JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_expr_common::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::joins::utils::{compare_join_arrays, JoinFilter}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, +}; +use datafusion::physical_plan::RecordBatchStream; + +use futures::{ready, Stream, StreamExt}; + +/// Evaluates join key expressions against a batch, returning one array per key. +fn evaluate_join_keys(batch: &RecordBatch, on: &[PhysicalExprRef]) -> Result> { + on.iter() + .map(|expr: &PhysicalExprRef| { + let num_rows = batch.num_rows(); + let val = expr.evaluate(batch)?; + val.into_array(num_rows) + }) + .collect() +} + +/// Find the first index in `key_arrays` starting from `from` where the key +/// differs from the key at `from`. Uses binary search with `compare_join_arrays`. +fn find_key_group_end( + key_arrays: &[ArrayRef], + from: usize, + len: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut lo = from + 1; + let mut hi = len; + while lo < hi { + let mid = lo + (hi - lo) / 2; + if compare_join_arrays( + key_arrays, + from, + key_arrays, + mid, + sort_options, + null_equality, + )? == Ordering::Equal + { + lo = mid + 1; + } else { + hi = mid; + } + } + Ok(lo) +} + +/// Tracks whether we're mid-key-group when `poll_next_outer_batch` returns +/// `Poll::Pending` inside the Equal branch's boundary loop. +#[derive(Debug)] +enum BoundaryState { + /// Normal processing — not inside a boundary poll. + Normal, + /// The no-filter boundary loop's `poll_next_outer_batch` returned + /// Pending. Carries the key arrays and index from the last emitted + /// batch so we can compare with the next batch's first key. + NoFilterPending { + saved_keys: Vec, + saved_idx: usize, + }, + /// The filtered boundary loop's `poll_next_outer_batch` returned + /// Pending. The `inner_key_buffer` field already holds the buffered + /// inner rows needed to resume filter evaluation. + FilteredPending, +} + +pub(super) struct SemiAntiSortMergeJoinStream { + /// true for semi (emit matched), false for anti (emit unmatched) + is_semi: bool, + + // Input streams + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + + // Current batches and cursor positions + outer_batch: Option, + outer_offset: usize, + outer_key_arrays: Vec, + inner_batch: Option, + inner_offset: usize, + inner_key_arrays: Vec, + + // Per-outer-batch match tracking (bit-packed) + matched: BooleanBufferBuilder, + + // Inner key group buffer for filtered joins + inner_key_buffer: Vec, + + // Tracks partial buffering across Pending re-entries + buffering_inner_pending: bool, + + // Boundary re-entry state + boundary_state: BoundaryState, + + // Join condition + on_outer: Vec, + on_inner: Vec, + filter: Option, + sort_options: Vec, + null_equality: NullEquality, + outer_is_left: bool, + + // Output + coalescer: BatchCoalescer, + schema: SchemaRef, + + // Metrics + join_time: datafusion::physical_plan::metrics::Time, + input_batches: Count, + input_rows: Count, + baseline_metrics: BaselineMetrics, + + // Guards against double-emit on Pending re-entry + batch_emitted: bool, +} + +impl SemiAntiSortMergeJoinStream { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: SchemaRef, + sort_options: Vec, + null_equality: NullEquality, + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + on_outer: Vec, + on_inner: Vec, + filter: Option, + join_type: JoinType, + batch_size: usize, + partition: usize, + metrics: &ExecutionPlanMetricsSet, + ) -> Result { + let is_semi = matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi); + let outer_is_left = matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti); + + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let baseline_metrics = BaselineMetrics::new(metrics, partition); + + Ok(Self { + is_semi, + outer, + inner, + outer_batch: None, + outer_offset: 0, + outer_key_arrays: vec![], + inner_batch: None, + inner_offset: 0, + inner_key_arrays: vec![], + matched: BooleanBufferBuilder::new(0), + inner_key_buffer: vec![], + buffering_inner_pending: false, + boundary_state: BoundaryState::Normal, + on_outer, + on_inner, + filter, + sort_options, + null_equality, + outer_is_left, + coalescer: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Some(batch_size / 2)), + schema, + join_time, + input_batches, + input_rows, + baseline_metrics, + batch_emitted: false, + }) + } + + /// Poll for the next outer batch. Returns true if a batch was loaded. + fn poll_next_outer_batch(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.outer.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(false)), + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + self.input_batches.add(1); + self.input_rows.add(batch.num_rows()); + if batch.num_rows() == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_outer)?; + let num_rows = batch.num_rows(); + self.outer_batch = Some(batch); + self.outer_offset = 0; + self.outer_key_arrays = keys; + self.batch_emitted = false; + self.matched = BooleanBufferBuilder::new(num_rows); + self.matched.append_n(num_rows, false); + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Poll for the next inner batch. Returns true if a batch was loaded. + fn poll_next_inner_batch(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.inner.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(false)), + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + self.input_batches.add(1); + self.input_rows.add(batch.num_rows()); + if batch.num_rows() == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_inner)?; + self.inner_batch = Some(batch); + self.inner_offset = 0; + self.inner_key_arrays = keys; + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Emit the current outer batch through the coalescer, applying the + /// matched bitset as a selection mask. + fn emit_outer_batch(&mut self) -> Result<()> { + if self.batch_emitted { + return Ok(()); + } + self.batch_emitted = true; + + let batch = self.outer_batch.as_ref().unwrap(); + + let selection = BooleanArray::new(self.matched.finish(), None); + + let selection = if self.is_semi { + selection + } else { + not(&selection)? + }; + + let filtered = filter_record_batch(batch, &selection)?; + if filtered.num_rows() > 0 { + self.coalescer.push_batch(filtered)?; + } + Ok(()) + } + + /// Process a key match between outer and inner sides (no filter). + fn process_key_match_no_filter(&mut self) -> Result<()> { + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + let outer_group_end = find_key_group_end( + &self.outer_key_arrays, + self.outer_offset, + num_outer, + &self.sort_options, + self.null_equality, + )?; + + for i in self.outer_offset..outer_group_end { + self.matched.set_bit(i, true); + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Advance inner past the current key group. Returns Ok(true) if inner + /// is exhausted. + fn advance_inner_past_key_group(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => return Poll::Ready(Ok(true)), + }; + let num_inner = inner_batch.num_rows(); + + let group_end = find_key_group_end( + &self.inner_key_arrays, + self.inner_offset, + num_inner, + &self.sort_options, + self.null_equality, + )?; + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + + // Key group extends to end of batch — need to check next batch + let last_key_idx = num_inner - 1; + let saved_inner_keys = self.inner_key_arrays.clone(); + + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + return Poll::Ready(Ok(true)); + } + Ok(true) => { + if keys_match( + &saved_inner_keys, + last_key_idx, + &self.inner_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Buffer inner key group for filter evaluation. Collects all inner rows + /// with the current key across batch boundaries. + fn buffer_inner_key_group(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut resume_from_poll = false; + if self.buffering_inner_pending { + self.buffering_inner_pending = false; + resume_from_poll = true; + } else { + self.inner_key_buffer.clear(); + } + + loop { + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => return Poll::Ready(Ok(true)), + }; + let num_inner = inner_batch.num_rows(); + let group_end = find_key_group_end( + &self.inner_key_arrays, + self.inner_offset, + num_inner, + &self.sort_options, + self.null_equality, + )?; + + if !resume_from_poll { + let slice = inner_batch.slice(self.inner_offset, group_end - self.inner_offset); + self.inner_key_buffer.push(slice); + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + } + resume_from_poll = false; + + // Key group extends to end of batch — check next + let last_key_idx = num_inner - 1; + let saved_inner_keys = self.inner_key_arrays.clone(); + + self.buffering_inner_pending = true; + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => { + self.buffering_inner_pending = false; + return Poll::Ready(Err(e)); + } + Ok(false) => { + self.buffering_inner_pending = false; + return Poll::Ready(Ok(true)); + } + Ok(true) => { + self.buffering_inner_pending = false; + if keys_match( + &saved_inner_keys, + last_key_idx, + &self.inner_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Process a key match with a filter. For each inner row in the buffered + /// key group, evaluates the filter against the outer key group and ORs + /// the results into the matched bitset. + fn process_key_match_with_filter(&mut self) -> Result<()> { + let filter = self.filter.as_ref().unwrap(); + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + debug_assert!( + !self.inner_key_buffer.is_empty(), + "process_key_match_with_filter called with empty inner_key_buffer" + ); + debug_assert!( + self.outer_offset < num_outer, + "outer_offset must be within the current batch" + ); + debug_assert!( + self.matched.len() == num_outer, + "matched vector must be sized for the current outer batch" + ); + + let outer_group_end = find_key_group_end( + &self.outer_key_arrays, + self.outer_offset, + num_outer, + &self.sort_options, + self.null_equality, + )?; + let outer_group_len = outer_group_end - self.outer_offset; + let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len); + + let mut matched_count = + UnalignedBitChunk::new(self.matched.as_slice(), self.outer_offset, outer_group_len) + .count_ones(); + + 'outer: for inner_slice in &self.inner_key_buffer { + for inner_row in 0..inner_slice.num_rows() { + if matched_count == outer_group_len { + break 'outer; + } + + let filter_result = evaluate_filter_for_inner_row( + self.outer_is_left, + filter, + &outer_slice, + inner_slice, + inner_row, + )?; + + let filter_buf = filter_result.values(); + apply_bitwise_binary_op( + self.matched.as_slice_mut(), + self.outer_offset, + filter_buf.inner().as_slice(), + filter_buf.offset(), + outer_group_len, + |a, b| a | b, + ); + + matched_count = UnalignedBitChunk::new( + self.matched.as_slice(), + self.outer_offset, + outer_group_len, + ) + .count_ones(); + } + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Main loop: drive the merge-scan to produce output batches. + #[allow(clippy::panicking_unwrap)] + fn poll_join(&mut self, cx: &mut Context<'_>) -> Poll>> { + let join_time = self.join_time.clone(); + let _timer = join_time.timer(); + + loop { + // 1. Ensure we have an outer batch + if self.outer_batch.is_none() { + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + // Outer exhausted — flush coalescer + self.boundary_state = BoundaryState::Normal; + self.coalescer.finish_buffered_batch()?; + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + return Poll::Ready(Ok(None)); + } + Ok(true) => { + match std::mem::replace(&mut self.boundary_state, BoundaryState::Normal) { + BoundaryState::NoFilterPending { + saved_keys, + saved_idx, + } => { + let same_key = keys_match( + &saved_keys, + saved_idx, + &self.outer_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )?; + if same_key { + self.process_key_match_no_filter()?; + let num_outer = self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + let new_saved = self.outer_key_arrays.clone(); + let new_idx = num_outer - 1; + self.boundary_state = BoundaryState::NoFilterPending { + saved_keys: new_saved, + saved_idx: new_idx, + }; + self.emit_outer_batch()?; + self.outer_batch = None; + continue; + } + } + } + BoundaryState::FilteredPending => { + if !self.inner_key_buffer.is_empty() { + let first_inner = &self.inner_key_buffer[0]; + let inner_keys = + evaluate_join_keys(first_inner, &self.on_inner)?; + let same_key = keys_match( + &self.outer_key_arrays, + 0, + &inner_keys, + 0, + &self.sort_options, + self.null_equality, + )?; + if same_key { + self.process_key_match_with_filter()?; + let num_outer = + self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + self.boundary_state = BoundaryState::FilteredPending; + self.emit_outer_batch()?; + self.outer_batch = None; + continue; + } + } + } + self.inner_key_buffer.clear(); + } + BoundaryState::Normal => {} + } + } + } + } + + // 2. Ensure we have an inner batch (unless inner is exhausted) + if self.inner_batch.is_none() && matches!(self.boundary_state, BoundaryState::Normal) { + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + // Inner exhausted — emit remaining outer batches + self.emit_outer_batch()?; + self.outer_batch = None; + + loop { + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => break, + Ok(true) => { + self.emit_outer_batch()?; + self.outer_batch = None; + } + } + } + + self.coalescer.finish_buffered_batch()?; + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + return Poll::Ready(Ok(None)); + } + Ok(true) => {} + } + } + + // 3. Main merge-scan loop + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + if self.outer_offset >= num_outer { + self.emit_outer_batch()?; + self.outer_batch = None; + + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + continue; + } + + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => { + self.emit_outer_batch()?; + self.outer_batch = None; + continue; + } + }; + let num_inner = inner_batch.num_rows(); + + if self.inner_offset >= num_inner { + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.inner_batch = None; + continue; + } + Ok(true) => continue, + } + } + + // 4. Compare keys at current positions + let cmp = compare_join_arrays( + &self.outer_key_arrays, + self.outer_offset, + &self.inner_key_arrays, + self.inner_offset, + &self.sort_options, + self.null_equality, + )?; + + match cmp { + Ordering::Less => { + let group_end = find_key_group_end( + &self.outer_key_arrays, + self.outer_offset, + num_outer, + &self.sort_options, + self.null_equality, + )?; + self.outer_offset = group_end; + } + Ordering::Greater => { + let group_end = find_key_group_end( + &self.inner_key_arrays, + self.inner_offset, + num_inner, + &self.sort_options, + self.null_equality, + )?; + if group_end >= num_inner { + let saved_keys = self.inner_key_arrays.clone(); + let saved_idx = num_inner - 1; + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.inner_batch = None; + continue; + } + Ok(true) => { + if keys_match( + &saved_keys, + saved_idx, + &self.inner_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )? { + match ready!(self.advance_inner_past_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_) => continue, + } + } + continue; + } + } + } else { + self.inner_offset = group_end; + } + } + Ordering::Equal => { + if self.filter.is_some() { + // Buffer inner key group (may span batches) + match ready!(self.buffer_inner_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_inner_exhausted) => {} + } + + // Process outer rows against buffered inner group + loop { + self.process_key_match_with_filter()?; + + let outer_batch = self.outer_batch.as_ref().unwrap(); + if self.outer_offset >= outer_batch.num_rows() { + self.emit_outer_batch()?; + self.boundary_state = BoundaryState::FilteredPending; + + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.boundary_state = BoundaryState::Normal; + self.outer_batch = None; + break; + } + Ok(true) => { + self.boundary_state = BoundaryState::Normal; + if !self.inner_key_buffer.is_empty() { + let first_inner = &self.inner_key_buffer[0]; + let inner_keys = + evaluate_join_keys(first_inner, &self.on_inner)?; + let same = keys_match( + &self.outer_key_arrays, + 0, + &inner_keys, + 0, + &self.sort_options, + self.null_equality, + )?; + if same { + continue; + } + } + break; + } + } + } else { + break; + } + } + + self.inner_key_buffer.clear(); + } else { + // No filter: advance inner past key group, then + // mark all outer rows with this key as matched. + match ready!(self.advance_inner_past_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_inner_exhausted) => {} + } + + loop { + self.process_key_match_no_filter()?; + + let num_outer = self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + let saved_keys = self.outer_key_arrays.clone(); + let saved_idx = num_outer - 1; + + self.emit_outer_batch()?; + self.boundary_state = BoundaryState::NoFilterPending { + saved_keys, + saved_idx, + }; + + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.boundary_state = BoundaryState::Normal; + self.outer_batch = None; + break; + } + Ok(true) => { + // Recover saved_keys from boundary state + let BoundaryState::NoFilterPending { + saved_keys, + saved_idx, + } = std::mem::replace( + &mut self.boundary_state, + BoundaryState::Normal, + ) + else { + unreachable!() + }; + let same_key = keys_match( + &saved_keys, + saved_idx, + &self.outer_key_arrays, + 0, + &self.sort_options, + self.null_equality, + )?; + if same_key { + continue; + } + break; + } + } + } else { + break; + } + } + } + } + } + + // Check for completed coalescer batch + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + } + } +} + +/// Compare two key rows for equality. +fn keys_match( + left_arrays: &[ArrayRef], + left_idx: usize, + right_arrays: &[ArrayRef], + right_idx: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let cmp = compare_join_arrays( + left_arrays, + left_idx, + right_arrays, + right_idx, + sort_options, + null_equality, + )?; + Ok(cmp == Ordering::Equal) +} + +/// Evaluate the join filter for one inner row against a slice of outer rows. +fn evaluate_filter_for_inner_row( + outer_is_left: bool, + filter: &JoinFilter, + outer_slice: &RecordBatch, + inner_batch: &RecordBatch, + inner_idx: usize, +) -> Result { + let num_outer_rows = outer_slice.num_rows(); + + let mut columns: Vec = Vec::with_capacity(filter.column_indices().len()); + for col_idx in filter.column_indices() { + let (side_batch, side_idx) = if outer_is_left { + match col_idx.side { + JoinSide::Left => (outer_slice, None), + JoinSide::Right => (inner_batch, Some(inner_idx)), + JoinSide::None => { + return internal_err!("Unexpected JoinSide::None in filter"); + } + } + } else { + match col_idx.side { + JoinSide::Left => (inner_batch, Some(inner_idx)), + JoinSide::Right => (outer_slice, None), + JoinSide::None => { + return internal_err!("Unexpected JoinSide::None in filter"); + } + } + }; + + match side_idx { + None => { + columns.push(Arc::clone(side_batch.column(col_idx.index))); + } + Some(idx) => { + let scalar = + ScalarValue::try_from_array(side_batch.column(col_idx.index).as_ref(), idx)?; + columns.push(scalar.to_array_of_size(num_outer_rows)?); + } + } + } + + let filter_batch = RecordBatch::try_new(Arc::clone(filter.schema()), columns)?; + let result = filter + .expression() + .evaluate(&filter_batch)? + .into_array(num_outer_rows)?; + let bool_arr = result + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal( + "Filter expression did not return BooleanArray".to_string(), + ) + })?; + // Treat nulls as false + if bool_arr.null_count() > 0 { + Ok(arrow::compute::prep_null_mask_filter(bool_arr)) + } else { + Ok(bool_arr.clone()) + } +} + +impl Stream for SemiAntiSortMergeJoinStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = self.poll_join(cx).map(|result| result.transpose()); + self.baseline_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for SemiAntiSortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index 85fc672461..f3d62be809 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -19,6 +19,7 @@ pub mod columnar_to_row; pub mod expressions; pub mod jni_api; +pub(crate) mod joins; pub(crate) mod metrics; pub mod operators; pub(crate) mod planner; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b79b43f6c9..100716a0ac 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -75,6 +75,7 @@ use datafusion_comet_spark_expr::{ }; use iceberg::expr::Bind; +use crate::execution::joins::SemiAntiSortMergeJoinExec; use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::{CometPartitioning, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -1571,53 +1572,87 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); - let join = Arc::new(SortMergeJoinExec::try_new( - Arc::clone(&left), - Arc::clone(&right), - join_params.join_on, - join_params.join_filter, + let is_semi_anti = matches!( join_params.join_type, - sort_options, - // null doesn't equal to null in Spark join key. If the join key is - // `EqualNullSafe`, Spark will rewrite it during planning. - NullEquality::NullEqualsNothing, - )?); + DFJoinType::LeftSemi + | DFJoinType::LeftAnti + | DFJoinType::RightSemi + | DFJoinType::RightAnti + ); - if join.filter.is_some() { - // SMJ with join filter produces lots of tiny batches - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new( - Arc::::clone(&join), - self.session_ctx - .state() - .config_options() - .execution - .batch_size, - )); - Ok(( - scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - coalesce_batches, - vec![ - Arc::clone(&join_params.left), - Arc::clone(&join_params.right), - ], - vec![join], - )), - )) - } else { + if is_semi_anti { + let join_exec: Arc = + Arc::new(SemiAntiSortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?); + // SemiAntiSortMergeJoinExec has an internal BatchCoalescer, + // so no need for CoalesceBatchesExec wrapping. Ok(( scans, Arc::new(SparkPlan::new( spark_plan.plan_id, - join, + join_exec, vec![ Arc::clone(&join_params.left), Arc::clone(&join_params.right), ], )), )) + } else { + let join = Arc::new(SortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + // null doesn't equal to null in Spark join key. If the join key is + // `EqualNullSafe`, Spark will rewrite it during planning. + NullEquality::NullEqualsNothing, + )?); + + if join.filter.is_some() { + // SMJ with join filter produces lots of tiny batches + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new( + Arc::::clone(&join), + self.session_ctx + .state() + .config_options() + .execution + .batch_size, + )); + Ok(( + scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + coalesce_batches, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + vec![join], + )), + )) + } else { + Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + join, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + )), + )) + } } } OpStruct::HashJoin(join) => { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..7ed57d1cc2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1963,7 +1963,12 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] { } } - if (join.condition.isDefined && + val isSemiAnti = join.joinType match { + case LeftSemi | LeftAnti => true + case _ => false + } + + if (join.condition.isDefined && !isSemiAnti && !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED .get(join.conf)) { withInfo(