Skip to content
7 changes: 5 additions & 2 deletions packages/bigframes/bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import bigframes.dataframe as dataframe
import bigframes.ml.base
import bigframes.session
import bigframes.core.col as col
from bigframes.bigquery._operations import utils


Expand All @@ -50,7 +51,9 @@ def create_model(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
] = None,
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
session: Optional[bigframes.session.Session] = None,
Expand Down Expand Up @@ -78,7 +81,7 @@ def create_model(
The OUTPUT clause, which specifies the schema of the output data.
connection_name (str, optional):
The connection to use for the model.
options (Mapping[str, Union[str, int, float, bool, list]], optional):
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.col.Expression]], optional):
The OPTIONS clause, which specifies the model options.
training_data (Union[bigframes.pandas.DataFrame, str], optional):
The query or DataFrame to use for training the model.
Expand Down
11 changes: 9 additions & 2 deletions packages/bigframes/bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from typing import Any, Dict, List, Mapping, Optional, Union

import bigframes.core.col as col
from bigframes.core.compile.sqlglot import sql as sg_sql
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler


def create_model_ddl(
Expand All @@ -28,7 +30,9 @@ def create_model_ddl(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "col.Expression"]]
] = None,
training_data: Optional[str] = None,
custom_holiday: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -70,7 +74,10 @@ def create_model_ddl(
if options:
rendered_options = []
for option_name, option_value in options.items():
if isinstance(option_value, (list, tuple)):
if isinstance(option_value, col.Expression):
sg_expr = expression_compiler.compile_expression(option_value._value)
rendered_val = sg_sql.to_sql(sg_expr)
Comment on lines +77 to +79
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just directly compile here fine because there are no columns in scope. For other functions, with dataframe inputs, may need to resolve labels to column ids.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll make sure to include this in the policy doc I'm putting together.

elif isinstance(option_value, (list, tuple)):
# Handle list options like model_registry="vertex_ai"
# wait, usually options are key=value.
# if value is list, it is [val1, val2]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_model`
OPTIONS(l2_reg = 0.1 * 10, booster_type = 'gbtree')
AS SELECT * FROM t
24 changes: 24 additions & 0 deletions packages/bigframes/tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

import pytest

import bigframes.core.col as col
import bigframes.core.expression as ex
import bigframes.core.sql.ml
import bigframes.dtypes as dtypes
import bigframes.operations.numeric_ops as numeric_ops

pytest.importorskip("pytest_snapshot")

Expand Down Expand Up @@ -97,6 +101,26 @@ def test_create_model_list_option(snapshot):
snapshot.assert_match(sql, "create_model_list_option.sql")


def test_create_model_expression_option(snapshot):
# An expression that calls a function on a literal value
# e.g. 0.1 * 10
literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE)
multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE)
math_expr = col.Expression(
ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr))
)

sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
options={
"l2_reg": math_expr,
"booster_type": "gbtree",
},
training_data="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_expression_option.sql")


def test_evaluate_model_basic(snapshot):
sql = bigframes.core.sql.ml.evaluate(
model_name="my_project.my_dataset.my_model",
Expand Down
Loading