Skip to content

Commit 3b0a522

Browse files
fix(optimizer)!: query schema directly when type annotation fails for processing UNNEST source
1 parent f7458a4 commit 3b0a522

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

sqlglot/optimizer/resolver.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,22 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
144144
# in bigquery, unnest structs are automatically scoped as tables, so you can
145145
# directly select a struct field in a query.
146146
# this handles the case where the unnest is statically defined.
147-
if self.dialect.UNNEST_COLUMN_ONLY:
148-
if source.expression.is_type(exp.DataType.Type.STRUCT):
147+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
148+
unnest_type = source.expression.type
149+
150+
# if type is not annotated yet, try to get it from the schema
151+
if not unnest_type or unnest_type.is_type(exp.DataType.Type.UNKNOWN):
152+
unnest_expr = seq_get(source.expression.expressions, 0)
153+
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
154+
unnest_type = self._get_unnest_column_type(unnest_expr)
155+
156+
# check if unnesting an ARRAY of STRUCTs - extract struct field names
157+
if unnest_type and unnest_type.is_type(exp.DataType.Type.ARRAY):
158+
element_types = unnest_type.expressions
159+
if element_types and element_types[0].is_type(exp.DataType.Type.STRUCT):
160+
for field in element_types[0].expressions: # type: ignore
161+
columns.append(field.name)
162+
elif source.expression.is_type(exp.DataType.Type.STRUCT):
149163
for k in source.expression.type.expressions: # type: ignore
150164
columns.append(k.name)
151165
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
@@ -299,3 +313,66 @@ def _get_unambiguous_columns(
299313
unambiguous_columns[column] = table
300314

301315
return unambiguous_columns
316+
317+
def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
318+
"""
319+
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
320+
321+
Args:
322+
column: The column expression being unnested.
323+
324+
Returns:
325+
The DataType of the column, or None if not found.
326+
"""
327+
# start from parent scope and trace through sources to find the actual table
328+
scope = self.scope.parent
329+
if not scope:
330+
return None
331+
332+
# try each source in the parent scope to find which one contains this column
333+
for source_name in scope.sources:
334+
source = scope.sources[source_name]
335+
col_type: t.Optional[exp.DataType]
336+
337+
if isinstance(source, exp.Table):
338+
# found a base table - get the column type from schema
339+
col_type = self.schema.get_column_type(source, column)
340+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
341+
return col_type
342+
elif isinstance(source, Scope):
343+
# CTE or subquery - recursively check its sources
344+
col_type = self._get_column_type_from_scope(source, column.name)
345+
if col_type:
346+
return col_type
347+
348+
return None
349+
350+
def _get_column_type_from_scope(self, scope: Scope, col_name: str) -> t.Optional[exp.DataType]:
351+
"""
352+
Recursively find a column's type by tracing through nested scopes to the base table.
353+
354+
Args:
355+
scope: The scope to search.
356+
col_name: The column name to find.
357+
358+
Returns:
359+
The DataType of the column, or None if not found.
360+
"""
361+
for source_name in scope.sources:
362+
source = scope.sources[source_name]
363+
col_type: t.Optional[exp.DataType]
364+
365+
if isinstance(source, exp.Table):
366+
# found a base table - try to get the column type
367+
col_type = self.schema.get_column_type(
368+
source, exp.Column(this=exp.to_identifier(col_name))
369+
)
370+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
371+
return col_type
372+
elif isinstance(source, Scope):
373+
# nested scope - recurse
374+
col_type = self._get_column_type_from_scope(source, col_name)
375+
if col_type:
376+
return col_type
377+
378+
return None

tests/test_optimizer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def setUp(self):
143143
"t_bool": {
144144
"a": "BOOLEAN",
145145
},
146+
"table1": {
147+
"credits": "ARRAY<STRUCT<name STRING, amount FLOAT64, type STRING>>",
148+
},
146149
}
147150

148151
def check_file(
@@ -516,6 +519,48 @@ def test_qualify_columns(self, logger):
516519
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
517520
)
518521

522+
self.assertEqual(
523+
optimizer.qualify.qualify(
524+
parse_one(
525+
"""
526+
SELECT
527+
(SELECT SUM(c.amount)
528+
FROM UNNEST(credits) AS c
529+
WHERE type != 'promotion') as total
530+
FROM billing
531+
""",
532+
read="bigquery",
533+
),
534+
schema={"billing": {"credits": "ARRAY<STRUCT<amount FLOAT64, type STRING>>"}},
535+
dialect="bigquery",
536+
).sql(dialect="bigquery"),
537+
"SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`",
538+
)
539+
540+
self.assertEqual(
541+
optimizer.qualify.qualify(
542+
parse_one(
543+
"""
544+
WITH cte AS (SELECT * FROM base_table)
545+
SELECT
546+
(SELECT SUM(item.price)
547+
FROM UNNEST(items) AS item
548+
WHERE category = 'electronics') as electronics_total
549+
FROM cte
550+
""",
551+
read="bigquery",
552+
),
553+
schema={
554+
"base_table": {
555+
"id": "INT64",
556+
"items": "ARRAY<STRUCT<price FLOAT64, category STRING>>",
557+
}
558+
},
559+
dialect="bigquery",
560+
).sql(dialect="bigquery"),
561+
"WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`",
562+
)
563+
519564
self.check_file(
520565
"qualify_columns",
521566
qualify_columns,

0 commit comments

Comments
 (0)