Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,20 @@ fn recursive_cte_with_nested_subquery() -> Result<()> {

assert_snapshot!(
format!("{plan}"),
@r"
@"
SubqueryAlias: numbers
Projection: sub.id AS id, sub.level AS level
RecursiveQuery: is_distinct=false
RecursiveQuery: is_distinct=false
Projection: sub.id AS id, sub.level AS level
SubqueryAlias: sub
Projection: test.col_int32 AS id, Int64(1) AS level
TableScan: test projection=[col_int32]
Projection: t.col_int32, numbers.level + Int64(1)
Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1)
SubqueryAlias: t
Filter: CAST(test.col_int32 AS Int64) IS NOT NULL
TableScan: test projection=[col_int32]
Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL
TableScan: numbers projection=[id, level]
Projection: t.col_int32, numbers.level + Int64(1)
Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1)
SubqueryAlias: t
Filter: CAST(test.col_int32 AS Int64) IS NOT NULL
TableScan: test projection=[col_int32]
Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL
TableScan: numbers projection=[id, level]
"
);

Expand Down
33 changes: 25 additions & 8 deletions datafusion/sql/src/cte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::{
Result, not_impl_err, plan_err,
Result, TableReference, not_impl_err, plan_err,
tree_node::{TreeNode, TreeNodeRecursion},
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
use sqlparser::ast::{Query, SetExpr, SetOperator, With};
use sqlparser::ast::{Ident, Query, SetExpr, SetOperator, With};

impl<S: ContextProvider> SqlToRel<'_, S> {
pub(super) fn plan_with_clause(
Expand All @@ -46,14 +46,24 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

// Create a logical plan for the CTE
let cte_plan = if is_recursive {
self.recursive_cte(&cte_name, *cte.query, planner_context)?
let columns = cte.alias.columns.iter().map(|c| c.name.clone()).collect();
self.recursive_cte(&cte_name, columns, *cte.query, planner_context)?
} else {
self.non_recursive_cte(*cte.query, planner_context)?
};

// Each `WITH` block can change the column names in the last
// projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
// Each `WITH` block can change the column names in the last projection
// (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). Recursive CTEs apply those
// to the static term in recursive_cte(), so only the relation name here.
let final_plan = if is_recursive {
LogicalPlanBuilder::from(cte_plan)
.alias(TableReference::bare(
self.ident_normalizer.normalize(cte.alias.name),
))?
.build()?
} else {
self.apply_table_alias(cte_plan, cte.alias)?
};
// Export the CTE to the outer query
planner_context.insert_cte(cte_name, final_plan);
}
Expand All @@ -71,6 +81,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
fn recursive_cte(
&self,
cte_name: &str,
columns: Vec<Ident>,
mut cte_query: Query,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
Expand All @@ -91,9 +102,11 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
set_quantifier,
} => (left, right, set_quantifier),
other => {
// If the query is not a UNION, then it is not a recursive CTE
// Not a UNION, so not actually a recursive CTE. The caller adds only
// the relation name for recursive CTEs, so apply the column aliases here.
*cte_query.body = other;
return self.non_recursive_cte(cte_query, planner_context);
let plan = self.non_recursive_cte(cte_query, planner_context)?;
return self.apply_expr_alias(plan, columns);
}
};

Expand All @@ -111,6 +124,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// ---------- Step 1: Compile the static term ------------------
let static_plan = self.set_expr_to_plan(*left_expr, planner_context)?;

// Apply the declared column-list aliases (e.g. `t(n)`) to the static term, so
// the work table built from its schema below exposes the declared names.
let static_plan = self.apply_expr_alias(static_plan, columns)?;

// Since the recursive CTEs include a component that references a
// table with its name, like the example below:
//
Expand Down
105 changes: 105 additions & 0 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,111 @@ physical_plan
07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)----------WorkTableExec: name=nodes

# recursive CTE with a column-list alias (e.g. `t(n)`): the declared names must be
# applied to the static term so the recursive self-reference can resolve them
query I rowsort
WITH RECURSIVE t(n) AS (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement! One small suggestion: could we add a regression test with a quoted recursive CTE column-list alias, for example WITH RECURSIVE t("N") AS (...) SELECT "N" FROM t? I think it would be helpful to document that quoted and case-sensitive aliases are preserved in the recursive work table as well. This is not blocking since the implementation already goes through the existing alias normalization path.

SELECT 1
UNION ALL
SELECT n + 1 FROM t WHERE n < 10
)
SELECT n FROM t
----
1
10
2
3
4
5
6
7
8
9

# recursive CTE with a multi-column column-list alias
query II rowsort
WITH RECURSIVE t(a, b) AS (
SELECT 1, 2
UNION ALL
SELECT a + 1, b * 2 FROM t WHERE a < 5
)
SELECT a, b FROM t
----
1 2
2 4
3 8
4 16
5 32

# recursive CTE with a column-list alias and UNION (DISTINCT)
query I rowsort
WITH RECURSIVE t(n) AS (
SELECT 1
UNION
SELECT n + 1 FROM t WHERE n < 5
)
SELECT n FROM t
----
1
2
3
4
5

# recursive CTE column-list alias arity mismatch is rejected cleanly (raised at
# the static term, rather than the old confusing "No field named ...")
query error DataFusion error: Error during planning: Source table contains 1 columns but only 2 names given as column alias
WITH RECURSIVE t(a, b) AS (
SELECT 1
UNION ALL
SELECT a + 1 FROM t WHERE a < 3
)
SELECT * FROM t

# explain a column-list-aliased recursive CTE: the declared name is applied to
# the static term, so there is no extra projection on top of RecursiveQuery
query TT
EXPLAIN WITH RECURSIVE t(n) AS (
SELECT 1
UNION ALL
SELECT n + 1 FROM t WHERE n < 10
)
SELECT * FROM t
----
logical_plan
01)SubqueryAlias: t
02)--RecursiveQuery: is_distinct=false
03)----Projection: Int64(1) AS n
04)------EmptyRelation: rows=1
05)----Projection: t.n + Int64(1)
06)------Filter: t.n < Int64(10)
07)--------TableScan: t projection=[n]
physical_plan
01)RecursiveQueryExec: name=t, is_distinct=false
02)--ProjectionExec: expr=[CAST(1 AS Int64) as n]
03)----PlaceholderRowExec
04)--CoalescePartitionsExec
05)----ProjectionExec: expr=[n@0 + 1 as n]
06)------FilterExec: n@0 < 10
07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)----------WorkTableExec: name=t

# recursive CTE with a quoted, case-sensitive column-list alias: `"N"` must be
# preserved (not lowercased) so the recursive self-reference resolves it
query I rowsort
WITH RECURSIVE t("N") AS (
SELECT 1
UNION ALL
SELECT "N" + 1 FROM t WHERE "N" < 5
)
SELECT "N" FROM t
----
1
2
3
4
5

# simple deduplicating recursive CTE works
query I
WITH RECURSIVE nodes AS (
Expand Down
Loading