Skip to content

Commit 7428a39

Browse files
fix(optimizer)!: query schema directly when type annotation fails for processing UNNEST source
1 parent f7458a4 commit 7428a39

File tree

2 files changed

+114
-3
lines changed

2 files changed

+114
-3
lines changed

sqlglot/optimizer/resolver.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,25 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
144144
# in bigquery, unnest structs are automatically scoped as tables, so you can
145145
# directly select a struct field in a query.
146146
# this handles the case where the unnest is statically defined.
147-
if self.dialect.UNNEST_COLUMN_ONLY:
148-
if source.expression.is_type(exp.DataType.Type.STRUCT):
149-
for k in source.expression.type.expressions: # type: ignore
147+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
148+
unnest = source.expression
149+
150+
# if type is not annotated yet, try to get it from the schema
151+
if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
152+
unnest_expr = seq_get(unnest.expressions, 0)
153+
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
154+
col_type = self._get_unnest_column_type(unnest_expr)
155+
# extract element type if it's an ARRAY
156+
if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
157+
element_types = col_type.expressions
158+
if element_types:
159+
unnest.type = element_types[0].copy()
160+
else:
161+
if col_type:
162+
unnest.type = col_type.copy()
163+
# check if the result type is a STRUCT - extract struct field names
164+
if unnest.is_type(exp.DataType.Type.STRUCT):
165+
for k in unnest.type.expressions: # type: ignore
150166
columns.append(k.name)
151167
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
152168
columns = self.get_source_columns_from_set_op(source.expression)
@@ -299,3 +315,56 @@ def _get_unambiguous_columns(
299315
unambiguous_columns[column] = table
300316

301317
return unambiguous_columns
318+
319+
def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
320+
"""
321+
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
322+
323+
Args:
324+
column: The column expression being unnested.
325+
326+
Returns:
327+
The DataType of the column, or None if not found.
328+
"""
329+
scope = self.scope.parent
330+
331+
# if column is qualified, use that table, otherwise disambiguate using the resolver
332+
if column.table:
333+
table_name = column.table
334+
else:
335+
# use the parent scope's resolver to disambiguate the column
336+
parent_resolver = Resolver(scope, self.schema, self._infer_schema)
337+
table_identifier = parent_resolver.get_table(column)
338+
if not table_identifier:
339+
return None
340+
table_name = table_identifier.name
341+
342+
source = scope.sources.get(table_name)
343+
return self._get_column_type_from_scope(source, column) if source else None
344+
345+
def _get_column_type_from_scope(
346+
self, source: t.Union[Scope, exp.Table], column: exp.Column
347+
) -> t.Optional[exp.DataType]:
348+
"""
349+
Get a column's type by tracing through scopes/tables to find the base table.
350+
351+
Args:
352+
source: The source to search - can be a Scope (to iterate its sources) or a Table.
353+
column: The column to find the type for.
354+
355+
Returns:
356+
The DataType of the column, or None if not found.
357+
"""
358+
if isinstance(source, exp.Table):
359+
# base table - get the column type from schema
360+
col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
361+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
362+
return col_type
363+
elif isinstance(source, Scope):
364+
# iterate over all sources in the scope
365+
for source_name, nested_source in source.sources.items():
366+
col_type = self._get_column_type_from_scope(nested_source, column)
367+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
368+
return col_type
369+
370+
return None

tests/test_optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,48 @@ def test_qualify_columns(self, logger):
516516
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
517517
)
518518

519+
self.assertEqual(
520+
optimizer.qualify.qualify(
521+
parse_one(
522+
"""
523+
SELECT
524+
(SELECT SUM(c.amount)
525+
FROM UNNEST(credits) AS c
526+
WHERE type != 'promotion') as total
527+
FROM billing
528+
""",
529+
read="bigquery",
530+
),
531+
schema={"billing": {"credits": "ARRAY<STRUCT<amount FLOAT64, type STRING>>"}},
532+
dialect="bigquery",
533+
).sql(dialect="bigquery"),
534+
"SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`",
535+
)
536+
537+
self.assertEqual(
538+
optimizer.qualify.qualify(
539+
parse_one(
540+
"""
541+
WITH cte AS (SELECT * FROM base_table)
542+
SELECT
543+
(SELECT SUM(item.price)
544+
FROM UNNEST(items) AS item
545+
WHERE category = 'electronics') as electronics_total
546+
FROM cte
547+
""",
548+
read="bigquery",
549+
),
550+
schema={
551+
"base_table": {
552+
"id": "INT64",
553+
"items": "ARRAY<STRUCT<price FLOAT64, category STRING>>",
554+
}
555+
},
556+
dialect="bigquery",
557+
).sql(dialect="bigquery"),
558+
"WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`",
559+
)
560+
519561
self.check_file(
520562
"qualify_columns",
521563
qualify_columns,

0 commit comments

Comments
 (0)