diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 4840b56f55fff..6682098340337 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -235,6 +235,14 @@ pub(super) struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// Actual amount tracked in the memory reservation for this batch. + /// + /// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata) + /// - `Spilled`: equals join_arrays memory if `try_grow` succeeded after spill, else 0 + /// + /// Invariant: `free_reservation()` shrinks by exactly this amount, so we never + /// shrink by more than we grew. + pub reserved_amount: usize, /// Tracks filter outcomes for buffered rows in full outer joins. /// Indexed by absolute row position within the batch. See [`FilterState`]. pub join_filter_status: Vec, @@ -274,10 +282,20 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, + reserved_amount: 0, join_filter_status: vec![FilterState::Unvisited; num_rows], num_rows, } } + + /// Memory footprint of join key arrays that remain in memory even after + /// the main batch is spilled to disk + fn join_arrays_mem(&self) -> usize { + self.join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum() + } } // TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429) @@ -948,10 +966,9 @@ impl MaterializingSortMergeJoinStream { } fn free_reservation(&mut self, buffered_batch: &BufferedBatch) -> Result<()> { - // Shrink memory usage for in-memory batches only - if let BufferedBatchState::InMemory(_) = buffered_batch.batch { + if buffered_batch.reserved_amount > 0 { self.reservation - .try_shrink(buffered_batch.size_estimation)?; + .try_shrink(buffered_batch.reserved_amount)?; } Ok(()) } @@ -959,6 +976,7 @@ impl MaterializingSortMergeJoinStream { fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { match self.reservation.try_grow(buffered_batch.size_estimation) { Ok(_) => { + buffered_batch.reserved_amount = buffered_batch.size_estimation; self.join_metrics .peak_mem_used() .set_max(self.reservation.size()); @@ -978,6 +996,20 @@ impl MaterializingSortMergeJoinStream { .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled buffered_batch.batch = BufferedBatchState::Spilled(spill_file); + + // Track remaining in-memory data (join key arrays) that + // stay in memory even after the batch is spilled. This is + // much smaller than the full batch, so try_grow should + // usually succeed. If it fails, reserved_amount stays 0 - + // best-effort tracking, free_reservation will safely be a no-op. + let join_arrays_mem = buffered_batch.join_arrays_mem(); + if self.reservation.try_grow(join_arrays_mem).is_ok() { + buffered_batch.reserved_amount = join_arrays_mem; + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + } + Ok(()) } _ => internal_err!("Buffered batch has empty body"), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 5d70530528728..8d3a7374a0d47 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2487,6 +2487,114 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } +/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path. +/// +/// Uses a memory limit smaller than a single batch's `size_estimation` so that +/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit. +/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only +/// called in the `Ok` arm. After the fix, the spill path calls +/// `try_grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`. +#[tokio::test] +async fn spill_join_arrays_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); + + // Memory limit: too small for a full batch, large enough for join_arrays. + // Every batch hits the Err arm → spills → try_grow(join_arrays_mem). + let memory_limit = (size_estimation + join_arrays_mem) / 2; + assert!( + memory_limit < size_estimation && memory_limit > join_arrays_mem, + "limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \ + and size_estimation {size_estimation}" + ); + + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let _result = common::collect(stream).await.unwrap(); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Before the fix, peak_mem_used was 0 here because set_max was only + // called in the Ok arm of allocate_reservation, which is never reached + // when every batch spills. After the fix, the spill path tracks + // join_arrays via try_grow + set_max. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem > 0, + "peak_mem_used should reflect join_arrays tracked on spill path" + ); + + Ok(()) +} + /// Build a c1 < c2 filter on the third column of each side. fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { JoinFilter::new(