Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 227 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ impl DefaultPhysicalPlanner {
// We pass the filters and let the provider handle the projection
let filters = extract_dml_filters(input, table_name)?;
// Extract assignments from the projection in input plan
let assignments = extract_update_assignments(input)?;
let assignments = extract_update_assignments(input, table_name)?;
provider
.table_provider
.update(session_state, assignments, filters)
Expand Down Expand Up @@ -2235,7 +2235,10 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
/// over the source table. This function extracts column name and expression pairs
/// from the projection. Column qualifiers are stripped from the expressions.
///
fn extract_update_assignments(input: &Arc<LogicalPlan>) -> Result<Vec<(String, Expr)>> {
fn extract_update_assignments(
input: &Arc<LogicalPlan>,
_target_table: &TableReference,
) -> Result<Vec<(String, Expr)>> {
// The UPDATE input plan structure is:
// Projection(updated columns as expressions with aliases)
// Filter(optional WHERE clause)
Expand Down Expand Up @@ -3115,7 +3118,8 @@ mod tests {

use crate::execution::session_state::SessionStateBuilder;
use arrow::array::{ArrayRef, DictionaryArray, Int32Array};
use arrow::datatypes::{DataType, Field, Int32Type};
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
use arrow::record_batch::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{
Expand All @@ -3125,13 +3129,79 @@ mod tests {
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::builder::subquery_alias;
use datafusion_expr::{
LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit,
DmlStatement, LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore,
WriteOp, col, lit,
};
use datafusion_functions_aggregate::count::count_all;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

async fn make_update_from_plan(
sql: &str,
) -> Result<(Arc<LogicalPlan>, TableReference)> {
let ctx = SessionContext::new();
let t1_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
Field::new("d", DataType::Int32, true),
]));
let t2_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
Field::new("d", DataType::Int32, true),
]));
let t1 = MemTable::try_new(
Arc::clone(&t1_schema),
vec![vec![RecordBatch::new_empty(Arc::clone(&t1_schema))]],
)?;
let t2 = MemTable::try_new(
Arc::clone(&t2_schema),
vec![vec![RecordBatch::new_empty(Arc::clone(&t2_schema))]],
)?;
ctx.register_table("t1", Arc::new(t1))?;
ctx.register_table("t2", Arc::new(t2))?;

let plan = ctx.sql(sql).await?.into_unoptimized_plan();
match plan {
LogicalPlan::Dml(DmlStatement {
table_name,
op: WriteOp::Update,
input,
..
}) => Ok((input, table_name)),
other => internal_err!("Expected UPDATE DML plan, got: {other}"),
}
}

async fn make_delete_plan(sql: &str) -> Result<(Arc<LogicalPlan>, TableReference)> {
let ctx = SessionContext::new();
let t1_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
Field::new("d", DataType::Int32, true),
]));
let t1 = MemTable::try_new(
Arc::clone(&t1_schema),
vec![vec![RecordBatch::new_empty(Arc::clone(&t1_schema))]],
)?;
ctx.register_table("t1", Arc::new(t1))?;

let plan = ctx.sql(sql).await?.into_unoptimized_plan();
match plan {
LogicalPlan::Dml(DmlStatement {
table_name,
op: WriteOp::Delete,
input,
..
}) => Ok((input, table_name)),
other => internal_err!("Expected DELETE DML plan, got: {other}"),
}
}

fn make_session_state() -> SessionState {
let runtime = Arc::new(RuntimeEnv::default());
let config = SessionConfig::new().with_target_partitions(4);
Expand Down Expand Up @@ -3418,6 +3488,159 @@ mod tests {
Ok(())
}

#[tokio::test]
#[ignore = "TODO(19950): enable once the implementation PR lands"]
async fn test_extract_update_assignments_preserves_source_qualifiers_for_update_from()
-> Result<()> {
// TODO(19950): enable once the implementation PR lands.
let (input, table_name) = make_update_from_plan(
"UPDATE t1 AS dst \
SET b = src.b, d = src.d \
FROM t2 AS src \
WHERE dst.a = src.a",
)
.await?;

let assignments = extract_update_assignments(&input, &table_name)?;
let b_expr = assignments
.iter()
.find(|(name, _)| name == "b")
.map(|(_, expr)| expr.to_string())
.ok_or_else(|| {
internal_datafusion_err!("Expected assignment for target column b")
})?;
let d_expr = assignments
.iter()
.find(|(name, _)| name == "d")
.map(|(_, expr)| expr.to_string())
.ok_or_else(|| {
internal_datafusion_err!("Expected assignment for target column d")
})?;

assert!(
b_expr.contains("src.b"),
"Unexpected b assignment: {b_expr}"
);
assert!(
d_expr.contains("src.d"),
"Unexpected d assignment: {d_expr}"
);
assert!(
assignments.iter().all(|(name, _)| name != "a"),
"Identity target columns should not be extracted as assignments"
);

Ok(())
}

#[tokio::test]
#[ignore = "TODO(19950): enable once the implementation PR lands"]
async fn test_extract_update_assignments_strips_target_qualifiers_single_table()
-> Result<()> {
// TODO(19950): enable once the implementation PR lands.
let (input, table_name) =
make_update_from_plan("UPDATE t1 AS dst SET d = dst.d + 1 WHERE dst.a > 0")
.await?;

let assignments = extract_update_assignments(&input, &table_name)?;
let d_expr = assignments
.iter()
.find(|(name, _)| name == "d")
.map(|(_, expr)| expr.to_string())
.ok_or_else(|| {
internal_datafusion_err!("Expected assignment for target column d")
})?;

assert!(
!d_expr.contains("dst."),
"Single-table assignment should not keep target qualifiers: {d_expr}"
);
assert!(
d_expr.contains("d"),
"Unexpected assignment expression: {d_expr}"
);

Ok(())
}

#[tokio::test]
#[ignore = "TODO(19950): enable once the implementation PR lands"]
async fn test_extract_update_assignments_preserves_self_join_source_aliases()
-> Result<()> {
// TODO(19950): enable once the implementation PR lands.
let (input, table_name) = make_update_from_plan(
"UPDATE t1 AS dst \
SET b = src.b, d = src.d \
FROM t1 AS src \
WHERE dst.a = src.a + 1",
)
.await?;

let assignments = extract_update_assignments(&input, &table_name)?;
let b_expr = assignments
.iter()
.find(|(name, _)| name == "b")
.map(|(_, expr)| expr.to_string())
.ok_or_else(|| {
internal_datafusion_err!("Expected assignment for target column b")
})?;
let d_expr = assignments
.iter()
.find(|(name, _)| name == "d")
.map(|(_, expr)| expr.to_string())
.ok_or_else(|| {
internal_datafusion_err!("Expected assignment for target column d")
})?;

assert!(
b_expr.contains("src.b"),
"Self-join source alias should be preserved: {b_expr}"
);
assert!(
d_expr.contains("src.d"),
"Self-join source alias should be preserved: {d_expr}"
);

Ok(())
}

#[tokio::test]
#[ignore = "TODO(19950): enable once the implementation PR lands"]
async fn test_extract_dml_filters_delete_limit_without_where() -> Result<()> {
// TODO(19950): enable once the implementation PR lands.
let (input, table_name) = make_delete_plan("DELETE FROM t1 LIMIT 10").await?;

let filters = extract_dml_filters(&input, &table_name)?;
assert!(
filters.is_empty(),
"DELETE ... LIMIT without WHERE should not synthesize filters: {filters:?}"
);

Ok(())
}

#[tokio::test]
#[ignore = "TODO(19950): enable once the implementation PR lands"]
async fn test_extract_dml_filters_delete_where_limit() -> Result<()> {
// TODO(19950): enable once the implementation PR lands.
let (input, table_name) =
make_delete_plan("DELETE FROM t1 WHERE a > 1 LIMIT 10").await?;

let filters = extract_dml_filters(&input, &table_name)?;
assert_eq!(
filters.len(),
1,
"Expected one target predicate from DELETE WHERE ... LIMIT"
);
let rendered = filters[0].to_string();
assert!(
rendered.starts_with("a > Int") && rendered.ends_with("(1)"),
"Unexpected rendered filter for DELETE WHERE ... LIMIT: {rendered}"
);

Ok(())
}

#[tokio::test]
async fn test_create_not() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
Expand Down
Loading