Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 123 additions & 23 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1672,34 +1672,68 @@ impl LogicalPlan {
let mut param_types: HashMap<String, Option<FieldRef>> = HashMap::new();

self.apply_with_subqueries(|plan| {
let row_count_field = Arc::new(Field::new("", DataType::Int64, true));
plan.apply_expressions(|expr| {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, field }) = expr {
let prev = param_types.get(id);
match (prev, field) {
(Some(Some(prev)), Some(field)) => {
check_metadata_with_storage_equal(
(field.data_type(), Some(field.metadata())),
(prev.data_type(), Some(prev.metadata())),
"parameter",
&format!(": Conflicting types for id {id}"),
)?;
}
(_, Some(field)) => {
param_types.insert(id.clone(), Some(Arc::clone(field)));
}
_ => {
param_types.insert(id.clone(), None);
}
}
}
Ok(TreeNodeRecursion::Continue)
})
if matches!(plan, LogicalPlan::Limit(_)) {
let expr = Self::infer_limit_row_count_parameter_fields(
expr,
&row_count_field,
)?;
Self::collect_parameter_fields(&expr, &mut param_types)
} else {
Self::collect_parameter_fields(expr, &mut param_types)
}
})
})
.map(|_| param_types)
}

fn infer_limit_row_count_parameter_fields(
expr: &Expr,
row_count_field: &FieldRef,
) -> Result<Expr> {
let empty_schema = DFSchema::empty();
let (expr, _) = expr.clone().infer_placeholder_types(&empty_schema)?;

if let Expr::Placeholder(Placeholder { id, field: None }) = expr {
return Ok(Expr::Placeholder(Placeholder::new_with_field(
id,
Some(Arc::clone(row_count_field)),
)));
}

Ok(expr)
}

fn collect_parameter_fields(
expr: &Expr,
param_types: &mut HashMap<String, Option<FieldRef>>,
) -> Result<TreeNodeRecursion> {
expr.apply(|expr| {
if let Expr::Placeholder(Placeholder { id, field }) = expr {
let field = field.as_ref().map(Arc::clone);
let prev = param_types.get(id);
match (prev, &field) {
(Some(Some(prev)), Some(field)) => {
check_metadata_with_storage_equal(
(field.data_type(), Some(field.metadata())),
(prev.data_type(), Some(prev.metadata())),
"parameter",
&format!(": Conflicting types for id {id}"),
)?;
}
(_, Some(field)) => {
param_types.insert(id.clone(), Some(Arc::clone(field)));
}
_ => {
param_types.insert(id.clone(), None);
}
}
}
Ok(TreeNodeRecursion::Continue)
})
}

// ------------
// Various implementations for printing out LogicalPlans
// ------------
Expand Down Expand Up @@ -4887,7 +4921,7 @@ mod tests {
use crate::select_expr::SelectExpr;
use crate::test::function_stub::{count, count_udaf};
use crate::{
GroupingSet, binary_expr, col, exists, in_subquery, lit, placeholder,
GroupingSet, binary_expr, cast, col, exists, in_subquery, lit, placeholder,
scalar_subquery,
};
use datafusion_common::metadata::ScalarAndMetadata;
Expand Down Expand Up @@ -6204,6 +6238,72 @@ mod tests {
assert_eq!(parameter_type, None);
}

#[test]
fn test_limit_parameter_fields_infer_row_counts() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let input = table_scan(TableReference::none(), &schema, None)
.unwrap()
.build()
.unwrap();
let plan = LogicalPlan::Limit(Limit {
skip: Some(Box::new(placeholder("$1"))),
fetch: Some(Box::new(placeholder("$2"))),
input: Arc::new(input),
});

let params = plan.get_parameter_types().unwrap();
assert_eq!(params.len(), 2);
assert_eq!(params.get("$1"), Some(&Some(DataType::Int64)));
assert_eq!(params.get("$2"), Some(&Some(DataType::Int64)));

let fields = plan.get_parameter_fields().unwrap();
assert_eq!(fields.len(), 2);
assert_eq!(
fields.get("$1").and_then(|field| field.as_ref()),
Some(&Arc::new(Field::new("", DataType::Int64, true)))
);
assert_eq!(
fields.get("$2").and_then(|field| field.as_ref()),
Some(&Arc::new(Field::new("", DataType::Int64, true)))
);
}

#[test]
fn test_limit_nested_parameter_fields_infer_from_expression() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let input = table_scan(TableReference::none(), &schema, None)
.unwrap()
.build()
.unwrap();
let plan = LogicalPlan::Limit(Limit {
skip: None,
fetch: Some(Box::new(placeholder("$1") * lit(2i64))),
input: Arc::new(input),
});

let params = plan.get_parameter_types().unwrap();
assert_eq!(params.len(), 1);
assert_eq!(params.get("$1"), Some(&Some(DataType::Int64)));
}

#[test]
fn test_limit_cast_parameter_field_stays_local_to_cast() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let input = table_scan(TableReference::none(), &schema, None)
.unwrap()
.build()
.unwrap();
let plan = LogicalPlan::Limit(Limit {
skip: None,
fetch: Some(Box::new(cast(placeholder("$1"), DataType::Int32))),
input: Arc::new(input),
});

let params = plan.get_parameter_types().unwrap();
assert_eq!(params.len(), 1);
assert_eq!(params.get("$1"), Some(&None));
}

#[test]
fn test_join_with_new_exprs() -> Result<()> {
fn create_test_join(
Expand Down
64 changes: 64 additions & 0 deletions datafusion/sql/tests/cases/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,70 @@ fn test_prepare_statement_to_plan_having() {
);
}

#[test]
fn test_infer_types_from_limit() {
let test = ParameterTest {
sql: "SELECT id FROM person OFFSET $1 LIMIT $2",
expected_types: vec![
("$1", Some(DataType::Int64)),
("$2", Some(DataType::Int64)),
],
param_values: vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))],
};

assert_snapshot!(
test.run(),
@r"
** Initial Plan:
Limit: skip=$1, fetch=$2
Projection: person.id
TableScan: person
** Final Plan:
Limit: skip=10, fetch=200
Projection: person.id
TableScan: person
"
);
}

#[test]
fn test_infer_types_from_limit_conflicting_parameter() {
let plan =
logical_plan("SELECT id FROM person WHERE first_name = $1 OFFSET $1").unwrap();

let err = plan.get_parameter_types().unwrap_err();
assert_contains!(err.to_string(), "Conflicting types for id $1");
}

#[test]
fn test_prepare_statement_infer_types_from_limit() {
let sql = "PREPARE my_plan AS
SELECT id FROM person \
OFFSET $1 LIMIT $2";
let (plan, dt) = generate_prepare_stmt_and_data_types(sql);
assert_snapshot!(
plan,
@r#"
Prepare: "my_plan" [Int64, Int64]
Limit: skip=$1, fetch=$2
Projection: person.id
TableScan: person
"#
);
assert_snapshot!(dt, @"Int64, Int64");

let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))];
let plan_with_params = plan.with_param_values(param_values).unwrap();
assert_snapshot!(
plan_with_params,
@r"
Limit: skip=10, fetch=200
Projection: person.id
TableScan: person
"
);
}

#[test]
fn test_prepare_statement_to_plan_limit() {
let sql = "PREPARE my_plan(BIGINT, BIGINT) AS
Expand Down