Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 93 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::logical_plan::dml::UPDATE_FROM_OLD_COLUMN_PREFIX;
use datafusion_expr::utils::{expr_to_columns, split_conjunction};
use datafusion_expr::{
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
Expand Down Expand Up @@ -781,6 +782,12 @@ impl DefaultPhysicalPlanner {
if let Some(provider) =
target.as_any().downcast_ref::<DefaultTableSource>()
{
if has_update_from_old_row_projection(input)? {
return not_impl_err!(
"UPDATE ... FROM execution is not yet supported"
);
}

// For UPDATE, the assignments are encoded in the projection of input
// We pass the filters and let the provider handle the projection
let filters = extract_dml_filters(input, table_name)?;
Expand Down Expand Up @@ -2212,6 +2219,24 @@ fn predicate_is_on_target_multi(
}))
}

fn has_update_from_old_row_projection(input: &Arc<LogicalPlan>) -> Result<bool> {
let mut has_old_row_projection = false;
input.apply(|node| {
if let LogicalPlan::Projection(projection) = node
&& projection.expr.iter().any(|expr| {
matches!(expr, Expr::Alias(alias) if alias.name.starts_with(UPDATE_FROM_OLD_COLUMN_PREFIX))
})
{
has_old_row_projection = true;
return Ok(TreeNodeRecursion::Stop);
}

Ok(TreeNodeRecursion::Continue)
})?;

Ok(has_old_row_projection)
}

/// Strip table qualifiers from column references in an expression.
/// This is needed because DML filter expressions contain qualified column names
/// (e.g., "table.column") but the TableProvider's schema only has simple names.
Expand Down Expand Up @@ -2248,6 +2273,11 @@ fn extract_update_assignments(input: &Arc<LogicalPlan>) -> Result<Vec<(String, E
if let LogicalPlan::Projection(projection) = input.as_ref() {
for expr in &projection.expr {
if let Expr::Alias(alias) = expr {
// Hidden old-row aliases are planner metadata for UPDATE ... FROM,
// not provider-visible assignments.
if alias.name.starts_with(UPDATE_FROM_OLD_COLUMN_PREFIX) {
continue;
}
// The alias name is the column name being updated
// The inner expression is the new value
let column_name = alias.name.clone();
Expand All @@ -2266,6 +2296,9 @@ fn extract_update_assignments(input: &Arc<LogicalPlan>) -> Result<Vec<(String, E
if let LogicalPlan::Projection(projection) = node {
for expr in &projection.expr {
if let Expr::Alias(alias) = expr {
if alias.name.starts_with(UPDATE_FROM_OLD_COLUMN_PREFIX) {
continue;
}
let column_name = alias.name.clone();
if !is_identity_assignment(&alias.expr, &column_name) {
let stripped_expr =
Expand Down Expand Up @@ -3114,8 +3147,8 @@ mod tests {
use crate::test_util::{scan_empty, scan_empty_with_partitions};

use crate::execution::session_state::SessionStateBuilder;
use arrow::array::{ArrayRef, DictionaryArray, Int32Array};
use arrow::datatypes::{DataType, Field, Int32Type};
use arrow::array::{ArrayRef, DictionaryArray, Int32Array, RecordBatch};
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::{
Expand All @@ -3125,13 +3158,44 @@ mod tests {
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::builder::subquery_alias;
use datafusion_expr::{
LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit,
DmlStatement, LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore,
WriteOp, col, lit,
};
use datafusion_functions_aggregate::count::count_all;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

async fn make_update_from_input(sql: &str) -> Result<Arc<LogicalPlan>> {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Float64, true),
Field::new("d", DataType::Int32, true),
]));
let t1 = MemTable::try_new(
Arc::clone(&schema),
vec![vec![RecordBatch::new_empty(Arc::clone(&schema))]],
)?;
let t2 = MemTable::try_new(
Arc::clone(&schema),
vec![vec![RecordBatch::new_empty(Arc::clone(&schema))]],
)?;
ctx.register_table("t1", Arc::new(t1))?;
ctx.register_table("t2", Arc::new(t2))?;

let plan = ctx.sql(sql).await?.into_unoptimized_plan();
match plan {
LogicalPlan::Dml(DmlStatement {
op: WriteOp::Update,
input,
..
}) => Ok(input),
other => internal_err!("Expected UPDATE DML plan, got: {other}"),
}
}

fn make_session_state() -> SessionState {
let runtime = Arc::new(RuntimeEnv::default());
let config = SessionConfig::new().with_target_partitions(4);
Expand Down Expand Up @@ -3751,6 +3815,32 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_extract_update_assignments_skips_old_row_aliases() -> Result<()> {
let input = make_update_from_input(
"UPDATE t1 AS dst \
SET c = src.a + dst.a, d = src.d \
FROM t2 AS src \
WHERE dst.a = src.a",
)
.await?;

let assignments = extract_update_assignments(&input)?;

assert!(
assignments
.iter()
.all(|(name, _)| !name.starts_with(UPDATE_FROM_OLD_COLUMN_PREFIX)),
"Hidden old-row aliases should not be emitted as assignments: {assignments:?}"
);
assert!(
assignments.iter().any(|(name, _)| name == "c"),
"Expected assignment for c"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_count_all_with_alias() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down
16 changes: 8 additions & 8 deletions datafusion/core/tests/custom_sources_cases/dml_planning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ async fn test_delete_target_table_scoping() -> Result<()> {

#[tokio::test]
async fn test_update_from_drops_non_target_predicates() -> Result<()> {
// UPDATE ... FROM is currently not working
// UPDATE ... FROM should plan successfully but fail at execution time.
// TODO fix https://github.com/apache/datafusion/issues/19950
let target_provider = Arc::new(CaptureUpdateProvider::new_with_filter_pushdown(
test_schema(),
Expand All @@ -743,20 +743,20 @@ async fn test_update_from_drops_non_target_predicates() -> Result<()> {
let source_table = datafusion::datasource::empty::EmptyTable::new(source_schema);
ctx.register_table("t2", Arc::new(source_table))?;

let result = ctx
let df = ctx
.sql(
"UPDATE t1 SET value = 1 FROM t2 \
WHERE t1.id = t2.id AND t2.src_only = 'active' AND t1.value > 10",
)
.await;
.await?;

// Verify UPDATE ... FROM is rejected with appropriate error
// Verify execution-layer rejection to preserve planner/executor boundary
// TODO fix https://github.com/apache/datafusion/issues/19950
assert!(result.is_err());
let err = result.unwrap_err();
let err = df.collect().await.unwrap_err();
assert!(
err.to_string().contains("UPDATE ... FROM is not supported"),
"Expected 'UPDATE ... FROM is not supported' error, got: {err}"
err.to_string()
.contains("UPDATE ... FROM execution is not yet supported"),
"Expected 'UPDATE ... FROM execution is not yet supported' error, got: {err}"
);
Ok(())
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/logical_plan/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ use datafusion_common::{DFSchemaRef, TableReference};

use crate::{LogicalPlan, TableSource};

/// Prefix used for hidden columns carrying the original target-row values in
/// `UPDATE ... FROM` plans.
pub const UPDATE_FROM_OLD_COLUMN_PREFIX: &str = "__df_update_old_";

/// Returns the hidden `UPDATE ... FROM` column name used to carry the original
/// value of `column_name` from the target table.
pub fn update_from_old_column_name(column_name: &str) -> String {
format!("{UPDATE_FROM_OLD_COLUMN_PREFIX}{column_name}")
}

/// Operator that copies the contents of a database to file(s)
#[derive(Clone)]
pub struct CopyTo {
Expand Down
58 changes: 38 additions & 20 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::parser::{
LexOrdering, ResetStatement, Statement as DFStatement,
};
use crate::planner::{
ContextProvider, PlannerContext, SqlToRel, object_name_to_qualifier,
ContextProvider, IdentNormalizer, PlannerContext, SqlToRel, object_name_to_qualifier,
};
use crate::utils::normalize_ident;

Expand All @@ -38,7 +38,7 @@ use datafusion_common::{
internal_err, not_impl_err, plan_datafusion_err, plan_err, schema_err,
unqualified_field_not_found,
};
use datafusion_expr::dml::{CopyTo, InsertOp};
use datafusion_expr::dml::{CopyTo, InsertOp, update_from_old_column_name};
use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check;
use datafusion_expr::logical_plan::DdlStatement;
use datafusion_expr::logical_plan::builder::project;
Expand Down Expand Up @@ -1084,12 +1084,6 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
let update_from = from_clauses.and_then(|mut f| f.pop());

// UPDATE ... FROM is currently not working
// TODO fix https://github.com/apache/datafusion/issues/19950
if update_from.is_some() {
return not_impl_err!("UPDATE ... FROM is not supported");
}

if returning.is_some() {
plan_err!("Update-returning clause not yet supported")?;
}
Expand Down Expand Up @@ -2191,6 +2185,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.collect::<Result<HashMap<String, SQLExpr>>>()?;

// Build scan, join with from table if it exists.
let has_update_from = from.is_some();
let mut input_tables = vec![table];
input_tables.extend(from);
let scan = self.plan_from_tables(input_tables, &mut planner_context)?;
Expand All @@ -2216,7 +2211,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
};

// Build updated values for each column, using the previous value if not modified
let exprs = table_schema
let mut exprs = table_schema
.iter()
.map(|(qualifier, field)| {
let expr = match assign_map.remove(field.name()) {
Expand All @@ -2236,22 +2231,29 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// Cast to target column type, if necessary
expr.cast_to(field.data_type(), source.schema())?
}
None => {
// If the target table has an alias, use it to qualify the column name
if let Some(alias) = &table_alias {
Expr::Column(Column::new(
Some(self.ident_normalizer.normalize(alias.name.clone())),
field.name(),
))
} else {
Expr::Column(Column::from((qualifier, field)))
}
}
None => Self::update_target_column_expr(
&self.ident_normalizer,
&table_alias,
qualifier,
field,
),
};
Ok(expr.alias(field.name()))
})
.collect::<Result<Vec<_>>>()?;

if has_update_from {
exprs.extend(table_schema.iter().map(|(qualifier, field)| {
Self::update_target_column_expr(
&self.ident_normalizer,
&table_alias,
qualifier,
field,
)
.alias(update_from_old_column_name(field.name()))
}));
}

let source = project(source, exprs)?;

let plan = LogicalPlan::Dml(DmlStatement::new(
Expand All @@ -2263,6 +2265,22 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
Ok(plan)
}

fn update_target_column_expr(
ident_normalizer: &IdentNormalizer,
table_alias: &Option<ast::TableAlias>,
qualifier: Option<&TableReference>,
field: &FieldRef,
) -> Expr {
if let Some(alias) = table_alias {
Expr::Column(Column::new(
Some(ident_normalizer.normalize(alias.name.clone())),
field.name(),
))
} else {
Expr::Column(Column::from((qualifier, field)))
}
}

fn insert_to_plan(
&self,
table_name: ObjectName,
Expand Down
62 changes: 62 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,68 @@ fn plan_update() {
);
}

#[test]
fn plan_update_from() {
let sql = "update person \
set last_name = src.last_name, age = src.age \
from person as src \
where person.id = src.id";
let plan = logical_plan(sql).unwrap();
let expected = [
"Dml: op=[Update] table=[person]",
" Projection: person.id AS id, person.first_name AS first_name, src.last_name AS last_name, src.age AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀, person.id AS __df_update_old_id, person.first_name AS __df_update_old_first_name, person.last_name AS __df_update_old_last_name, person.age AS __df_update_old_age, person.state AS __df_update_old_state, person.salary AS __df_update_old_salary, person.birth_date AS __df_update_old_birth_date, person.😀 AS __df_update_old_😀",
" Filter: person.id = src.id",
" Cross Join:",
" TableScan: person",
" SubqueryAlias: src",
" TableScan: person",
]
.join("\n");
assert_eq!(format!("{plan}"), expected);
}

#[test]
fn plan_update_from_before_set() {
let sql = "update person \
from person as src \
set last_name = src.last_name, age = src.age \
where person.id = src.id";
let plan = logical_plan(sql).unwrap();
let expected = [
"Dml: op=[Update] table=[person]",
" Projection: person.id AS id, person.first_name AS first_name, src.last_name AS last_name, src.age AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀, person.id AS __df_update_old_id, person.first_name AS __df_update_old_first_name, person.last_name AS __df_update_old_last_name, person.age AS __df_update_old_age, person.state AS __df_update_old_state, person.salary AS __df_update_old_salary, person.birth_date AS __df_update_old_birth_date, person.😀 AS __df_update_old_😀",
" Filter: person.id = src.id",
" Cross Join:",
" TableScan: person",
" SubqueryAlias: src",
" TableScan: person",
]
.join("\n");
assert_eq!(format!("{plan}"), expected);
}

#[test]
fn plan_update_from_with_aliases_projects_original_target_row() {
let sql = "update person as dst \
set last_name = src.last_name, age = src.age \
from person as src \
where dst.id = src.id";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r#"
Dml: op=[Update] table=[person]
Projection: dst.id AS id, dst.first_name AS first_name, src.last_name AS last_name, src.age AS age, dst.state AS state, dst.salary AS salary, dst.birth_date AS birth_date, dst.😀 AS 😀, dst.id AS __df_update_old_id, dst.first_name AS __df_update_old_first_name, dst.last_name AS __df_update_old_last_name, dst.age AS __df_update_old_age, dst.state AS __df_update_old_state, dst.salary AS __df_update_old_salary, dst.birth_date AS __df_update_old_birth_date, dst.😀 AS __df_update_old_😀
Filter: dst.id = src.id
Cross Join:
SubqueryAlias: dst
TableScan: person
SubqueryAlias: src
TableScan: person
"#
);
}

#[rstest]
#[case::missing_assignment_target("UPDATE person SET doesnotexist = true")]
#[case::missing_assignment_expression("UPDATE person SET age = doesnotexist + 42")]
Expand Down
Loading
Loading