From 23f941f4559fc667d1ce31802a6cf7863774e426 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Thu, 30 Apr 2026 20:29:26 +0530 Subject: [PATCH 1/5] fix: track join_arrays memory in reservation after SMJ spill --- .../sort_merge_join/materializing_stream.rs | 36 +- .../src/joins/sort_merge_join/tests.rs | 440 ++++++++++++++++++ 2 files changed, 473 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 4840b56f55fff..e7e1f3139bef2 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -58,6 +58,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use futures::{Stream, StreamExt}; +use itertools::join; /// State of SMJ stream #[derive(Debug, PartialEq, Eq)] @@ -235,6 +236,14 @@ pub(super) struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// Actual amount tracked in the memory reservation for this batch. + /// + /// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata) + /// - `Spilled`: equals join_arrays memory if `try_grow` succeeded after spill, else 0 + /// + /// Invariant: `free_reservation()` shrinks by exactly this amount, so we never + /// shrink by more than we grew. + pub reserved_amount: usize, /// Tracks filter outcomes for buffered rows in full outer joins. /// Indexed by absolute row position within the batch. See [`FilterState`]. pub join_filter_status: Vec, @@ -274,10 +283,20 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, + reserved_amount: 0, // set by allocate_reservation() join_filter_status: vec![FilterState::Unvisited; num_rows], num_rows, } } + + /// Memory footprint of join key arrays that remain in memory even after + /// the main batch is spilled to disk + fn join_arrays_mem(&self) -> usize { + self.join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum() + } } // TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429) @@ -948,10 +967,9 @@ impl MaterializingSortMergeJoinStream { } fn free_reservation(&mut self, buffered_batch: &BufferedBatch) -> Result<()> { - // Shrink memory usage for in-memory batches only - if let BufferedBatchState::InMemory(_) = buffered_batch.batch { + if buffered_batch.reserved_amount > 0 { self.reservation - .try_shrink(buffered_batch.size_estimation)?; + .try_shrink(buffered_batch.reserved_amount)?; } Ok(()) } @@ -959,6 +977,7 @@ impl MaterializingSortMergeJoinStream { fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { match self.reservation.try_grow(buffered_batch.size_estimation) { Ok(_) => { + buffered_batch.reserved_amount = buffered_batch.size_estimation; self.join_metrics .peak_mem_used() .set_max(self.reservation.size()); @@ -978,6 +997,17 @@ impl MaterializingSortMergeJoinStream { .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled buffered_batch.batch = BufferedBatchState::Spilled(spill_file); + + // Track remaining in-memory data (join key arrays) that + // stay in memory even after the batch is spilled. This is + // much smaller than the full batch, so try_grow should + // usually succeed. If it fails, reserved_amount stays 0 - + // best-effort tracking, free_reservation will safely be a no-op. + let join_arrays_mem = buffered_batch.join_arrays_mem(); + if self.reservation.try_grow(join_arrays_mem).is_ok() { + buffered_batch.reserved_amount = join_arrays_mem; + } + Ok(()) } _ => internal_err!("Buffered batch has empty body"), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 5d70530528728..a9289e3180871 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2487,6 +2487,446 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } +/// Test spilling with many buffered batches sharing the same join key. +/// +/// This exercises the memory accounting fix where `join_arrays` (evaluated join +/// key columns) remain in memory after the main batch is spilled to disk. +/// With many spilled batches for the same key, the untracked join_arrays +/// memory can add up significantly. +#[tokio::test] +async fn spill_many_batches_same_key() -> Result<()> { + // 10 left batches, all with b1=1 (same join key) + let left_batches: Vec = (0..10) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i * 2, 101 + i * 2]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + // 5 right batches, all with b2=1 + let right_batches: Vec = (0..5) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i * 2, 201 + i * 2]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested separately in bitwise spill tests. + Inner, Left, Right, Full, + ]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run without spilling for baseline comparison + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) +} + +/// Test spilling with string (Utf8) join keys. +/// +/// String join keys produce larger `join_arrays` than integer keys. +/// This test verifies correctness when spilled batches retain +/// string-valued join_arrays in memory. +#[tokio::test] +async fn spill_string_join_keys() -> Result<()> { + use arrow::array::StringArray; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("key", DataType::Utf8, false), + Field::new("val", DataType::Int32, false), + ])); + + // 5 left batches, all with key="same_key_value" + let left_batches: Vec = (0..5) + .map(|i| { + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![i * 2, i * 2 + 1])), + Arc::new(StringArray::from(vec!["same_key_value", "same_key_value"])), + Arc::new(Int32Array::from(vec![100 + i, 200 + i])), + ], + ) + .unwrap() + }) + .collect(); + + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id2", DataType::Int32, false), + Field::new("key2", DataType::Utf8, false), + Field::new("val2", DataType::Int32, false), + ])); + + // 3 right batches, all with key2="same_key_value" + let right_batches: Vec = (0..3) + .map(|i| { + RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![i * 2 + 50, i * 2 + 51])), + Arc::new(StringArray::from(vec!["same_key_value", "same_key_value"])), + Arc::new(Int32Array::from(vec![300 + i, 400 + i])), + ], + ) + .unwrap() + }) + .collect(); + + let left: Arc = + TestMemoryExec::try_new_exec(&[left_batches], Arc::clone(&schema), None).unwrap(); + let right: Arc = + TestMemoryExec::try_new_exec(&[right_batches], Arc::clone(&right_schema), None) + .unwrap(); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("key", &schema)?) as _, + Arc::new(Column::new_with_schema("key2", &right_schema)?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let join_types = [ + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested separately in bitwise spill tests. + Inner, Left, Right, Full, + ]; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Baseline without spilling + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) +} + +/// Test spilling with multiple distinct keys where only some match. +/// +/// Exercises partial spilling — not all key groups trigger spill, and some +/// keys exist only on one side (testing outer join NULL rows from spilled batches). +#[tokio::test] +async fn spill_mixed_keys_some_match() -> Result<()> { + // Left: keys 1,1, 2,2, 3,3 across 3 batches + let left_batch_1 = build_table_i32( + ("a1", &vec![10, 11]), + ("b1", &vec![1, 1]), + ("c1", &vec![100, 101]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![20, 21]), + ("b1", &vec![2, 2]), + ("c1", &vec![200, 201]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![30, 31]), + ("b1", &vec![3, 3]), + ("c1", &vec![300, 301]), + ); + + // Right: keys 1,1, 2,2, 4,4 — key=3 has no match, key=4 has no match on left + let right_batch_1 = build_table_i32( + ("a2", &vec![40, 41]), + ("b2", &vec![1, 1]), + ("c2", &vec![400, 401]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![50, 51]), + ("b2", &vec![2, 2]), + ("c2", &vec![500, 501]), + ); + let right_batch_3 = build_table_i32( + ("a2", &vec![60, 61]), + ("b2", &vec![4, 4]), + ("c2", &vec![600, 601]), + ); + + let left = build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [ + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested separately in bitwise spill tests. + Inner, Left, Right, Full, + ]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(300, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Baseline + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) +} + +/// Test that memory accounting is correct after spilling — reservation +/// properly tracks join_arrays and is fully released after the join completes. +/// +/// This verifies the fix for the join_arrays memory leak (TODO at +/// materializing_stream.rs:283). Before the fix, spilled batches' join_arrays +/// were invisible to the memory pool. After the fix, `reserved_amount` tracks +/// them and `free_reservation()` releases them symmetrically. +#[tokio::test] +async fn spill_join_arrays_memory_accounting() -> Result<()> { + // Many left batches with same key → forces heavy spilling on buffered side + let left_batches: Vec = (0..8) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i * 2, 101 + i * 2]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i * 2, 201 + i * 2]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let _result = common::collect(stream).await.unwrap(); + + let metrics = join.metrics().unwrap(); + + // Verify spilling happened + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // After the stream is fully consumed and dropped, the memory pool + // reservation must be fully released. This verifies that + // free_reservation() correctly shrinks by reserved_amount for both + // in-memory and spilled batches. + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "Memory pool should be fully released after join completes" + ); + + // peak_mem_used should be > 0, confirming that the reservation + // tracked memory during execution (including join_arrays of spilled + // batches after the fix). + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem > 0, + "peak_mem_used should reflect tracked memory during join execution" + ); + + Ok(()) +} + /// Build a c1 < c2 filter on the third column of each side. fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { JoinFilter::new( From a0e19e32b848166d181af2c3f4266c35846f4dbe Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Thu, 30 Apr 2026 20:30:34 +0530 Subject: [PATCH 2/5] Remove unused import --- .../src/joins/sort_merge_join/materializing_stream.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index e7e1f3139bef2..e17ddc80ea78f 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -58,7 +58,6 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; use futures::{Stream, StreamExt}; -use itertools::join; /// State of SMJ stream #[derive(Debug, PartialEq, Eq)] @@ -1007,7 +1006,7 @@ impl MaterializingSortMergeJoinStream { if self.reservation.try_grow(join_arrays_mem).is_ok() { buffered_batch.reserved_amount = join_arrays_mem; } - + Ok(()) } _ => internal_err!("Buffered batch has empty body"), From 4ef0ce83cfe741b72801cee6a4076a1511adcebb Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Thu, 30 Apr 2026 21:03:57 +0530 Subject: [PATCH 3/5] Fix lint --- datafusion/physical-plan/src/joins/sort_merge_join/tests.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index a9289e3180871..99a8203ba59e3 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2615,7 +2615,7 @@ async fn spill_string_join_keys() -> Result<()> { Arc::new(Int32Array::from(vec![100 + i, 200 + i])), ], ) - .unwrap() + .unwrap() }) .collect(); @@ -2636,7 +2636,7 @@ async fn spill_string_join_keys() -> Result<()> { Arc::new(Int32Array::from(vec![300 + i, 400 + i])), ], ) - .unwrap() + .unwrap() }) .collect(); From 89af1972b7810004ab88af15a10af91d0f862c8f Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Thu, 30 Apr 2026 22:51:25 +0530 Subject: [PATCH 4/5] Adds max mem usage metrics --- .../sort_merge_join/materializing_stream.rs | 5 +- .../src/joins/sort_merge_join/tests.rs | 440 ------------------ 2 files changed, 4 insertions(+), 441 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index e17ddc80ea78f..6682098340337 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -282,7 +282,7 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, - reserved_amount: 0, // set by allocate_reservation() + reserved_amount: 0, join_filter_status: vec![FilterState::Unvisited; num_rows], num_rows, } @@ -1005,6 +1005,9 @@ impl MaterializingSortMergeJoinStream { let join_arrays_mem = buffered_batch.join_arrays_mem(); if self.reservation.try_grow(join_arrays_mem).is_ok() { buffered_batch.reserved_amount = join_arrays_mem; + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); } Ok(()) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 99a8203ba59e3..5d70530528728 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2487,446 +2487,6 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } -/// Test spilling with many buffered batches sharing the same join key. -/// -/// This exercises the memory accounting fix where `join_arrays` (evaluated join -/// key columns) remain in memory after the main batch is spilled to disk. -/// With many spilled batches for the same key, the untracked join_arrays -/// memory can add up significantly. -#[tokio::test] -async fn spill_many_batches_same_key() -> Result<()> { - // 10 left batches, all with b1=1 (same join key) - let left_batches: Vec = (0..10) - .map(|i| { - build_table_i32( - ("a1", &vec![i * 2, i * 2 + 1]), - ("b1", &vec![1, 1]), - ("c1", &vec![100 + i * 2, 101 + i * 2]), - ) - }) - .collect(); - let left = build_table_from_batches(left_batches); - - // 5 right batches, all with b2=1 - let right_batches: Vec = (0..5) - .map(|i| { - build_table_i32( - ("a2", &vec![i * 2, i * 2 + 1]), - ("b2", &vec![1, 1]), - ("c2", &vec![200 + i * 2, 201 + i * 2]), - ) - }) - .collect(); - let right = build_table_from_batches(right_batches); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = [ - // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks - // inner key buffer memory; tested separately in bitwise spill tests. - Inner, Left, Right, Full, - ]; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(500, 1.0) - .with_disk_manager_builder( - DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), - ) - .build_arc()?; - - for batch_size in [1, 50] { - let session_config = SessionConfig::default().with_batch_size(batch_size); - - for join_type in &join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Run without spilling for baseline comparison - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - assert_eq!(spilled_join_result, no_spilled_join_result); - } - } - - Ok(()) -} - -/// Test spilling with string (Utf8) join keys. -/// -/// String join keys produce larger `join_arrays` than integer keys. -/// This test verifies correctness when spilled batches retain -/// string-valued join_arrays in memory. -#[tokio::test] -async fn spill_string_join_keys() -> Result<()> { - use arrow::array::StringArray; - - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("key", DataType::Utf8, false), - Field::new("val", DataType::Int32, false), - ])); - - // 5 left batches, all with key="same_key_value" - let left_batches: Vec = (0..5) - .map(|i| { - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![i * 2, i * 2 + 1])), - Arc::new(StringArray::from(vec!["same_key_value", "same_key_value"])), - Arc::new(Int32Array::from(vec![100 + i, 200 + i])), - ], - ) - .unwrap() - }) - .collect(); - - let right_schema = Arc::new(Schema::new(vec![ - Field::new("id2", DataType::Int32, false), - Field::new("key2", DataType::Utf8, false), - Field::new("val2", DataType::Int32, false), - ])); - - // 3 right batches, all with key2="same_key_value" - let right_batches: Vec = (0..3) - .map(|i| { - RecordBatch::try_new( - Arc::clone(&right_schema), - vec![ - Arc::new(Int32Array::from(vec![i * 2 + 50, i * 2 + 51])), - Arc::new(StringArray::from(vec!["same_key_value", "same_key_value"])), - Arc::new(Int32Array::from(vec![300 + i, 400 + i])), - ], - ) - .unwrap() - }) - .collect(); - - let left: Arc = - TestMemoryExec::try_new_exec(&[left_batches], Arc::clone(&schema), None).unwrap(); - let right: Arc = - TestMemoryExec::try_new_exec(&[right_batches], Arc::clone(&right_schema), None) - .unwrap(); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("key", &schema)?) as _, - Arc::new(Column::new_with_schema("key2", &right_schema)?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(500, 1.0) - .with_disk_manager_builder( - DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), - ) - .build_arc()?; - - let join_types = [ - // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks - // inner key buffer memory; tested separately in bitwise spill tests. - Inner, Left, Right, Full, - ]; - - for batch_size in [1, 50] { - let session_config = SessionConfig::default().with_batch_size(batch_size); - - for join_type in &join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Baseline without spilling - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - assert_eq!(spilled_join_result, no_spilled_join_result); - } - } - - Ok(()) -} - -/// Test spilling with multiple distinct keys where only some match. -/// -/// Exercises partial spilling — not all key groups trigger spill, and some -/// keys exist only on one side (testing outer join NULL rows from spilled batches). -#[tokio::test] -async fn spill_mixed_keys_some_match() -> Result<()> { - // Left: keys 1,1, 2,2, 3,3 across 3 batches - let left_batch_1 = build_table_i32( - ("a1", &vec![10, 11]), - ("b1", &vec![1, 1]), - ("c1", &vec![100, 101]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![20, 21]), - ("b1", &vec![2, 2]), - ("c1", &vec![200, 201]), - ); - let left_batch_3 = build_table_i32( - ("a1", &vec![30, 31]), - ("b1", &vec![3, 3]), - ("c1", &vec![300, 301]), - ); - - // Right: keys 1,1, 2,2, 4,4 — key=3 has no match, key=4 has no match on left - let right_batch_1 = build_table_i32( - ("a2", &vec![40, 41]), - ("b2", &vec![1, 1]), - ("c2", &vec![400, 401]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![50, 51]), - ("b2", &vec![2, 2]), - ("c2", &vec![500, 501]), - ); - let right_batch_3 = build_table_i32( - ("a2", &vec![60, 61]), - ("b2", &vec![4, 4]), - ("c2", &vec![600, 601]), - ); - - let left = build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); - let right = - build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let join_types = [ - // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks - // inner key buffer memory; tested separately in bitwise spill tests. - Inner, Left, Right, Full, - ]; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(300, 1.0) - .with_disk_manager_builder( - DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), - ) - .build_arc()?; - - for batch_size in [1, 50] { - let session_config = SessionConfig::default().with_batch_size(batch_size); - - for join_type in &join_types { - let task_ctx = TaskContext::default() - .with_session_config(session_config.clone()) - .with_runtime(Arc::clone(&runtime)); - let task_ctx = Arc::new(task_ctx); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx)?; - let spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert!(join.metrics().unwrap().spill_count().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); - assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); - - // Baseline - let task_ctx_no_spill = - TaskContext::default().with_session_config(session_config.clone()); - let task_ctx_no_spill = Arc::new(task_ctx_no_spill); - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - *join_type, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - let stream = join.execute(0, task_ctx_no_spill)?; - let no_spilled_join_result = common::collect(stream).await.unwrap(); - - assert!(join.metrics().is_some()); - assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); - assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); - assert_eq!(spilled_join_result, no_spilled_join_result); - } - } - - Ok(()) -} - -/// Test that memory accounting is correct after spilling — reservation -/// properly tracks join_arrays and is fully released after the join completes. -/// -/// This verifies the fix for the join_arrays memory leak (TODO at -/// materializing_stream.rs:283). Before the fix, spilled batches' join_arrays -/// were invisible to the memory pool. After the fix, `reserved_amount` tracks -/// them and `free_reservation()` releases them symmetrically. -#[tokio::test] -async fn spill_join_arrays_memory_accounting() -> Result<()> { - // Many left batches with same key → forces heavy spilling on buffered side - let left_batches: Vec = (0..8) - .map(|i| { - build_table_i32( - ("a1", &vec![i * 2, i * 2 + 1]), - ("b1", &vec![1, 1]), - ("c1", &vec![100 + i * 2, 101 + i * 2]), - ) - }) - .collect(); - let left = build_table_from_batches(left_batches); - - let right_batches: Vec = (0..4) - .map(|i| { - build_table_i32( - ("a2", &vec![i * 2, i * 2 + 1]), - ("b2", &vec![1, 1]), - ("c2", &vec![200 + i * 2, 201 + i * 2]), - ) - }) - .collect(); - let right = build_table_from_batches(right_batches); - - let on = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, - Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, - )]; - let sort_options = vec![SortOptions::default(); on.len()]; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(500, 1.0) - .with_disk_manager_builder( - DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), - ) - .build_arc()?; - - let session_config = SessionConfig::default().with_batch_size(50); - let task_ctx = Arc::new( - TaskContext::default() - .with_session_config(session_config) - .with_runtime(Arc::clone(&runtime)), - ); - - let join = join_with_options( - Arc::clone(&left), - Arc::clone(&right), - on.clone(), - Inner, - sort_options.clone(), - NullEquality::NullEqualsNothing, - )?; - - let stream = join.execute(0, task_ctx)?; - let _result = common::collect(stream).await.unwrap(); - - let metrics = join.metrics().unwrap(); - - // Verify spilling happened - assert!( - metrics.spill_count().unwrap() > 0, - "Expected spilling to occur" - ); - - // After the stream is fully consumed and dropped, the memory pool - // reservation must be fully released. This verifies that - // free_reservation() correctly shrinks by reserved_amount for both - // in-memory and spilled batches. - assert_eq!( - runtime.memory_pool.reserved(), - 0, - "Memory pool should be fully released after join completes" - ); - - // peak_mem_used should be > 0, confirming that the reservation - // tracked memory during execution (including join_arrays of spilled - // batches after the fix). - let peak_mem = metrics - .sum_by_name("peak_mem_used") - .map(|m| m.as_usize()) - .unwrap_or(0); - assert!( - peak_mem > 0, - "peak_mem_used should reflect tracked memory during join execution" - ); - - Ok(()) -} - /// Build a c1 < c2 filter on the third column of each side. fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { JoinFilter::new( From ccefc2108e28e7ca796afb9c6b825b700ac5b467 Mon Sep 17 00:00:00 2001 From: Subham Singhal Date: Thu, 30 Apr 2026 23:23:06 +0530 Subject: [PATCH 5/5] Adds UT --- .../src/joins/sort_merge_join/tests.rs | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index 5d70530528728..8d3a7374a0d47 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2487,6 +2487,114 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } +/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path. +/// +/// Uses a memory limit smaller than a single batch's `size_estimation` so that +/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit. +/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only +/// called in the `Ok` arm. After the fix, the spill path calls +/// `try_grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`. +#[tokio::test] +async fn spill_join_arrays_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); + + // Memory limit: too small for a full batch, large enough for join_arrays. + // Every batch hits the Err arm → spills → try_grow(join_arrays_mem). + let memory_limit = (size_estimation + join_arrays_mem) / 2; + assert!( + memory_limit < size_estimation && memory_limit > join_arrays_mem, + "limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \ + and size_estimation {size_estimation}" + ); + + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let _result = common::collect(stream).await.unwrap(); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Before the fix, peak_mem_used was 0 here because set_max was only + // called in the Ok arm of allocate_reservation, which is never reached + // when every batch spills. After the fix, the spill path tracks + // join_arrays via try_grow + set_max. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem > 0, + "peak_mem_used should reflect join_arrays tracked on spill path" + ); + + Ok(()) +} + /// Build a c1 < c2 filter on the third column of each side. fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { JoinFilter::new(