diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c154bc7c92fa5..5f80cfd4d6e0f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1672,34 +1672,68 @@ impl LogicalPlan { let mut param_types: HashMap> = HashMap::new(); self.apply_with_subqueries(|plan| { + let row_count_field = Arc::new(Field::new("", DataType::Int64, true)); plan.apply_expressions(|expr| { - expr.apply(|expr| { - if let Expr::Placeholder(Placeholder { id, field }) = expr { - let prev = param_types.get(id); - match (prev, field) { - (Some(Some(prev)), Some(field)) => { - check_metadata_with_storage_equal( - (field.data_type(), Some(field.metadata())), - (prev.data_type(), Some(prev.metadata())), - "parameter", - &format!(": Conflicting types for id {id}"), - )?; - } - (_, Some(field)) => { - param_types.insert(id.clone(), Some(Arc::clone(field))); - } - _ => { - param_types.insert(id.clone(), None); - } - } - } - Ok(TreeNodeRecursion::Continue) - }) + if matches!(plan, LogicalPlan::Limit(_)) { + let expr = Self::infer_limit_row_count_parameter_fields( + expr, + &row_count_field, + )?; + Self::collect_parameter_fields(&expr, &mut param_types) + } else { + Self::collect_parameter_fields(expr, &mut param_types) + } }) }) .map(|_| param_types) } + fn infer_limit_row_count_parameter_fields( + expr: &Expr, + row_count_field: &FieldRef, + ) -> Result { + let empty_schema = DFSchema::empty(); + let (expr, _) = expr.clone().infer_placeholder_types(&empty_schema)?; + + if let Expr::Placeholder(Placeholder { id, field: None }) = expr { + return Ok(Expr::Placeholder(Placeholder::new_with_field( + id, + Some(Arc::clone(row_count_field)), + ))); + } + + Ok(expr) + } + + fn collect_parameter_fields( + expr: &Expr, + param_types: &mut HashMap>, + ) -> Result { + expr.apply(|expr| { + if let Expr::Placeholder(Placeholder { id, field }) = expr { + let field = field.as_ref().map(Arc::clone); + let prev = param_types.get(id); + match (prev, &field) { + (Some(Some(prev)), Some(field)) => { + check_metadata_with_storage_equal( + (field.data_type(), Some(field.metadata())), + (prev.data_type(), Some(prev.metadata())), + "parameter", + &format!(": Conflicting types for id {id}"), + )?; + } + (_, Some(field)) => { + param_types.insert(id.clone(), Some(Arc::clone(field))); + } + _ => { + param_types.insert(id.clone(), None); + } + } + } + Ok(TreeNodeRecursion::Continue) + }) + } + // ------------ // Various implementations for printing out LogicalPlans // ------------ @@ -4887,7 +4921,7 @@ mod tests { use crate::select_expr::SelectExpr; use crate::test::function_stub::{count, count_udaf}; use crate::{ - GroupingSet, binary_expr, col, exists, in_subquery, lit, placeholder, + GroupingSet, binary_expr, cast, col, exists, in_subquery, lit, placeholder, scalar_subquery, }; use datafusion_common::metadata::ScalarAndMetadata; @@ -6204,6 +6238,72 @@ mod tests { assert_eq!(parameter_type, None); } + #[test] + fn test_limit_parameter_fields_infer_row_counts() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let input = table_scan(TableReference::none(), &schema, None) + .unwrap() + .build() + .unwrap(); + let plan = LogicalPlan::Limit(Limit { + skip: Some(Box::new(placeholder("$1"))), + fetch: Some(Box::new(placeholder("$2"))), + input: Arc::new(input), + }); + + let params = plan.get_parameter_types().unwrap(); + assert_eq!(params.len(), 2); + assert_eq!(params.get("$1"), Some(&Some(DataType::Int64))); + assert_eq!(params.get("$2"), Some(&Some(DataType::Int64))); + + let fields = plan.get_parameter_fields().unwrap(); + assert_eq!(fields.len(), 2); + assert_eq!( + fields.get("$1").and_then(|field| field.as_ref()), + Some(&Arc::new(Field::new("", DataType::Int64, true))) + ); + assert_eq!( + fields.get("$2").and_then(|field| field.as_ref()), + Some(&Arc::new(Field::new("", DataType::Int64, true))) + ); + } + + #[test] + fn test_limit_nested_parameter_fields_infer_from_expression() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let input = table_scan(TableReference::none(), &schema, None) + .unwrap() + .build() + .unwrap(); + let plan = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(placeholder("$1") * lit(2i64))), + input: Arc::new(input), + }); + + let params = plan.get_parameter_types().unwrap(); + assert_eq!(params.len(), 1); + assert_eq!(params.get("$1"), Some(&Some(DataType::Int64))); + } + + #[test] + fn test_limit_cast_parameter_field_stays_local_to_cast() { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let input = table_scan(TableReference::none(), &schema, None) + .unwrap() + .build() + .unwrap(); + let plan = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(cast(placeholder("$1"), DataType::Int32))), + input: Arc::new(input), + }); + + let params = plan.get_parameter_types().unwrap(); + assert_eq!(params.len(), 1); + assert_eq!(params.get("$1"), Some(&None)); + } + #[test] fn test_join_with_new_exprs() -> Result<()> { fn create_test_join( diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs index 68c560ead68cd..bac452cc59376 100644 --- a/datafusion/sql/tests/cases/params.rs +++ b/datafusion/sql/tests/cases/params.rs @@ -987,6 +987,70 @@ fn test_prepare_statement_to_plan_having() { ); } +#[test] +fn test_infer_types_from_limit() { + let test = ParameterTest { + sql: "SELECT id FROM person OFFSET $1 LIMIT $2", + expected_types: vec![ + ("$1", Some(DataType::Int64)), + ("$2", Some(DataType::Int64)), + ], + param_values: vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + ** Final Plan: + Limit: skip=10, fetch=200 + Projection: person.id + TableScan: person + " + ); +} + +#[test] +fn test_infer_types_from_limit_conflicting_parameter() { + let plan = + logical_plan("SELECT id FROM person WHERE first_name = $1 OFFSET $1").unwrap(); + + let err = plan.get_parameter_types().unwrap_err(); + assert_contains!(err.to_string(), "Conflicting types for id $1"); +} + +#[test] +fn test_prepare_statement_infer_types_from_limit() { + let sql = "PREPARE my_plan AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int64, Int64] + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + "# + ); + assert_snapshot!(dt, @"Int64, Int64"); + + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Limit: skip=10, fetch=200 + Projection: person.id + TableScan: person + " + ); +} + #[test] fn test_prepare_statement_to_plan_limit() { let sql = "PREPARE my_plan(BIGINT, BIGINT) AS