Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from bigframes import series, session
from bigframes.bigquery._operations import utils as bq_utils
from bigframes.core import convert
from bigframes.core.compile.sqlglot import sql as sg_sql
from bigframes.core.logging import log_adapter
import bigframes.core.sql.literals
from bigframes.ml import base as ml_base
from bigframes.ml import core as ml_core
from bigframes.operations import ai_ops, output_schemas

Expand Down Expand Up @@ -392,7 +393,7 @@ def generate_double(

@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_embedding(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
model: Union[ml_base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
output_dimensionality: Optional[int] = None,
Expand All @@ -416,7 +417,7 @@ def generate_embedding(
... ) # doctest: +SKIP

Args:
model (bigframes.ml.base.BaseEstimator or str):
model (ml_base.BaseEstimator or str):
The model to use for text embedding.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate embeddings for. If a Series is provided, it is
Expand Down Expand Up @@ -458,7 +459,7 @@ def generate_embedding(
model_name, session = bq_utils.get_model_name_and_session(model, data)
table_sql = bq_utils.to_sql(data)

struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
struct_fields: Dict[str, Any] = {}
if output_dimensionality is not None:
struct_fields["OUTPUT_DIMENSIONALITY"] = output_dimensionality
if task_type is not None:
Expand All @@ -478,7 +479,7 @@ def generate_embedding(
FROM AI.GENERATE_EMBEDDING(
MODEL `{model_name}`,
({table_sql}),
{bigframes.core.sql.literals.struct_literal(struct_fields)}
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
)
"""

Expand All @@ -490,7 +491,7 @@ def generate_embedding(

@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_text(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
model: Union[ml_base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
temperature: Optional[float] = None,
Expand Down Expand Up @@ -519,7 +520,7 @@ def generate_text(
... ) # doctest: +SKIP

Args:
model (bigframes.ml.base.BaseEstimator or str):
model (ml_base.BaseEstimator or str):
The model to use for text generation.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate text for. If a Series is provided, it is
Expand Down Expand Up @@ -591,7 +592,7 @@ def generate_text(
FROM AI.GENERATE_TEXT(
MODEL `{model_name}`,
({table_sql}),
{bigframes.core.sql.literals.struct_literal(struct_fields)}
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
)
"""

Expand All @@ -603,7 +604,7 @@ def generate_text(

@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_table(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
model: Union[ml_base.BaseEstimator, str, pd.Series],
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
*,
output_schema: Union[str, Mapping[str, str]],
Expand Down Expand Up @@ -635,7 +636,7 @@ def generate_table(
... ) # doctest: +SKIP

Args:
model (bigframes.ml.base.BaseEstimator or str):
model (ml_base.BaseEstimator or str):
The model to use for table generation.
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
The data to generate table for. If a Series is provided, it is
Expand Down Expand Up @@ -677,9 +678,7 @@ def generate_table(
else:
output_schema_str = output_schema

struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
"output_schema": output_schema_str
}
struct_fields_bq: Dict[str, Any] = {"output_schema": output_schema_str}
if temperature is not None:
struct_fields_bq["temperature"] = temperature
if top_p is not None:
Expand All @@ -691,7 +690,7 @@ def generate_table(
if request_type is not None:
struct_fields_bq["request_type"] = request_type

struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
struct_sql = sg_sql.to_sql(sg_sql.literal(struct_fields_bq))
query = f"""
SELECT *
FROM AI.GENERATE_TABLE(
Expand Down
19 changes: 14 additions & 5 deletions bigframes/core/compile/sqlglot/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def identifier(id: str) -> sge.Identifier:
return sge.to_identifier(id, quoted=QUOTED)


def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
def literal(value: typing.Any, dtype: dtypes.Dtype | None = None) -> sge.Expression:
"""Return a string representing column reference in a SQL."""
if dtype is None:
dtype = dtypes.infer_literal_type(value)

sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
if sqlglot_type is None:
if not pd.isna(value):
Expand All @@ -81,6 +84,14 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
expressions=[literal(value=v, dtype=value_type) for v in value]
)
return values if len(value) > 0 else cast(values, sqlglot_type)
elif dtype == dtypes.FLOAT_DTYPE:
if pd.isna(value):
if isinstance(value, (float, np.floating)) and np.isnan(value):
return constants._NAN
return cast(sge.Null(), sqlglot_type)
if np.isinf(value):
return constants._INF if value > 0 else constants._NEG_INF
return sge.convert(value)
elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid):
return cast(sge.Null(), sqlglot_type)
elif dtype == dtypes.JSON_DTYPE:
Expand All @@ -100,13 +111,11 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
elif dtype == dtypes.TIMEDELTA_DTYPE:
return sge.convert(utils.timedelta_to_micros(value))
elif dtype == dtypes.FLOAT_DTYPE:
if np.isinf(value):
return constants._INF if value > 0 else constants._NEG_INF
return sge.convert(value)
else:
if isinstance(value, np.generic):
value = value.item()
if isinstance(value, pa.Scalar):
value = value.as_py()
return sge.convert(value)


Expand Down
8 changes: 4 additions & 4 deletions bigframes/core/pyformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _field_to_template_value(
dry_run: bool = False,
) -> str:
"""Convert value to something embeddable in a SQL string."""
import bigframes.core.sql # Avoid circular imports
import bigframes.core.compile.sqlglot.sql as sql # Avoid circular imports
import bigframes.dataframe # Avoid circular imports

_validate_type(name, value)
Expand All @@ -107,20 +107,20 @@ def _field_to_template_value(
if isinstance(value, str):
return value

return bigframes.core.sql.simple_literal(value)
return sql.to_sql(sql.literal(value))


def _validate_type(name: str, value: Any):
"""Raises TypeError if value is unsupported."""
import bigframes.core.sql # Avoid circular imports
import bigframes.dataframe # Avoid circular imports
import bigframes.dtypes # Avoid circular imports

if value is None:
return # None can't be used in isinstance, but is a valid literal.

supported_types = (
typing.get_args(_BQ_TABLE_TYPES)
+ typing.get_args(bigframes.core.sql.SIMPLE_LITERAL_TYPES)
+ bigframes.dtypes.SUPPORTED_LITERAL_TYPES
+ (bigframes.dataframe.DataFrame,)
+ (pandas.DataFrame,)
)
Expand Down
93 changes: 20 additions & 73 deletions bigframes/core/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@
Utility functions for SQL construction.
"""

import datetime
import decimal
import json
import math
from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union
from typing import (
Any,
cast,
Collection,
Iterable,
Mapping,
Optional,
TYPE_CHECKING,
Union,
)

import bigframes_vendored.sqlglot.expressions as sge
import shapely.geometry.base # type: ignore

from bigframes.core.compile.sqlglot import sql

Expand All @@ -43,68 +48,8 @@
to_wkt = dumps


SIMPLE_LITERAL_TYPES = Union[
bytes,
str,
int,
bool,
float,
datetime.datetime,
datetime.date,
datetime.time,
decimal.Decimal,
list,
]


### Writing SQL Values (literals, column references, table references, etc.)
def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str:
"""Return quoted input string."""

# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
if value is None:
return "NULL"
elif isinstance(value, str):
# Single quoting seems to work nicer with ibis than double quoting
return f"'{sql.escape_chars(value)}'"
elif isinstance(value, bytes):
return repr(value)
elif isinstance(value, (bool, int)):
return str(value)
elif isinstance(value, float):
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals
if math.isnan(value):
return 'CAST("nan" as FLOAT)'
if value == math.inf:
return 'CAST("+inf" as FLOAT)'
if value == -math.inf:
return 'CAST("-inf" as FLOAT)'
return str(value)
# Check datetime first as it is a subclass of date
elif isinstance(value, datetime.datetime):
if value.tzinfo is None:
return f"DATETIME('{value.isoformat()}')"
else:
return f"TIMESTAMP('{value.isoformat()}')"
elif isinstance(value, datetime.date):
return f"DATE('{value.isoformat()}')"
elif isinstance(value, datetime.time):
return f"TIME(DATETIME('1970-01-01 {value.isoformat()}'))"
elif isinstance(value, shapely.geometry.base.BaseGeometry):
return f"ST_GEOGFROMTEXT({simple_literal(to_wkt(value))})"
elif isinstance(value, decimal.Decimal):
# TODO: disambiguate BIGNUMERIC based on scale and/or precision
return f"CAST('{str(value)}' AS NUMERIC)"
elif isinstance(value, list):
simple_literals = [simple_literal(i) for i in value]
return f"[{', '.join(simple_literals)}]"

else:
raise ValueError(f"Cannot produce literal for {value}")


def multi_literal(*values: str):
literal_strings = [simple_literal(i) for i in values]
def multi_literal(*values: Any):
literal_strings = [sql.to_sql(sql.literal(i)) for i in values]
return "(" + ", ".join(literal_strings) + ")"


Expand Down Expand Up @@ -210,7 +155,7 @@ def create_vector_index_ddl(

rendered_options = ", ".join(
[
f"{option_name} = {simple_literal(option_value)}"
f"{option_name} = {sql.to_sql(sql.literal(option_value))}"
for option_name, option_value in options.items()
]
)
Expand All @@ -237,24 +182,26 @@ def create_vector_search_sql(

vector_search_args = [
f"TABLE {sql.to_sql(sql.identifier(cast(str, base_table)))}",
f"{simple_literal(column_to_search)}",
f"{sql.to_sql(sql.literal(column_to_search))}",
f"({sql_string})",
]

if query_column_to_search is not None:
vector_search_args.append(
f"query_column_to_search => {simple_literal(query_column_to_search)}"
f"query_column_to_search => {sql.to_sql(sql.literal(query_column_to_search))}"
)

if top_k is not None:
vector_search_args.append(f"top_k=> {simple_literal(top_k)}")
vector_search_args.append(f"top_k=> {sql.to_sql(sql.literal(top_k))}")

if distance_type is not None:
vector_search_args.append(f"distance_type => {simple_literal(distance_type)}")
vector_search_args.append(
f"distance_type => {sql.to_sql(sql.literal(distance_type))}"
)

if options is not None:
vector_search_args.append(
f"options => {simple_literal(json.dumps(options, indent=None))}"
f"options => {sql.to_sql(sql.literal(json.dumps(options, indent=None)))}"
)

args_str = ",\n".join(vector_search_args)
Expand Down
58 changes: 0 additions & 58 deletions bigframes/core/sql/literals.py

This file was deleted.

Loading
Loading