diff --git a/bigframes/bigquery/_operations/ml.py b/bigframes/bigquery/_operations/ml.py index e5a5c5dfb6..29ab19550b 100644 --- a/bigframes/bigquery/_operations/ml.py +++ b/bigframes/bigquery/_operations/ml.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import cast, Mapping, Optional, Union +from typing import cast, List, Mapping, Optional, Union import bigframes_vendored.constants import google.cloud.bigquery @@ -431,3 +431,92 @@ def transform( return bpd.read_gbq_query(sql) else: return session.read_gbq_query(sql) + + +@log_adapter.method_logger(custom_base_name="bigquery_ml") +def generate_text( + model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + input_: Union[pd.DataFrame, dataframe.DataFrame, str], + *, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + flatten_json_output: Optional[bool] = None, + stop_sequences: Optional[List[str]] = None, + ground_with_google_search: Optional[bool] = None, + request_type: Optional[str] = None, +) -> dataframe.DataFrame: + """ + Generates text using a BigQuery ML model. + + See the `BigQuery ML GENERATE_TEXT function syntax + `_ + for additional reference. + + Args: + model (bigframes.ml.base.BaseEstimator or str): + The model to use for text generation. + input_ (Union[bigframes.pandas.DataFrame, str]): + The DataFrame or query to use for text generation. + temperature (float, optional): + A FLOAT64 value that is used for sampling promiscuity. The value + must be in the range ``[0.0, 1.0]``. A lower temperature works well + for prompts that expect a more deterministic and less open-ended + or creative response, while a higher temperature can lead to more + diverse or creative results. A temperature of ``0`` is + deterministic, meaning that the highest probability response is + always selected. + max_output_tokens (int, optional): + An INT64 value that sets the maximum number of tokens in the + generated text. + top_k (int, optional): + An INT64 value that changes how the model selects tokens for + output. A ``top_k`` of ``1`` means the next selected token is the + most probable among all tokens in the model's vocabulary. A + ``top_k`` of ``3`` means that the next token is selected from + among the three most probable tokens by using temperature. The + default value is ``40``. + top_p (float, optional): + A FLOAT64 value that changes how the model selects tokens for + output. Tokens are selected from most probable to least probable + until the sum of their probabilities equals the ``top_p`` value. + For example, if tokens A, B, and C have a probability of 0.3, 0.2, + and 0.1 and the ``top_p`` value is ``0.5``, then the model will + select either A or B as the next token by using temperature. The + default value is ``0.95``. + flatten_json_output (bool, optional): + A BOOL value that determines the content of the generated JSON column. + stop_sequences (List[str], optional): + An ARRAY value that contains the stop sequences for the model. + ground_with_google_search (bool, optional): + A BOOL value that determines whether to ground the model with Google Search. + request_type (str, optional): + A STRING value that contains the request type for the model. + + Returns: + bigframes.pandas.DataFrame: + The generated text. + """ + import bigframes.pandas as bpd + + model_name, session = _get_model_name_and_session(model, input_) + table_sql = _to_sql(input_) + + sql = bigframes.core.sql.ml.generate_text( + model_name=model_name, + table=table_sql, + temperature=temperature, + max_output_tokens=max_output_tokens, + top_k=top_k, + top_p=top_p, + flatten_json_output=flatten_json_output, + stop_sequences=stop_sequences, + ground_with_google_search=ground_with_google_search, + request_type=request_type, + ) + + if session is None: + return bpd.read_gbq_query(sql) + else: + return session.read_gbq_query(sql) diff --git a/bigframes/bigquery/ml.py b/bigframes/bigquery/ml.py index 6ceadb324d..ef9aa3288b 100644 --- a/bigframes/bigquery/ml.py +++ b/bigframes/bigquery/ml.py @@ -23,6 +23,7 @@ create_model, evaluate, explain_predict, + generate_text, global_explain, predict, transform, @@ -35,4 +36,5 @@ "explain_predict", "global_explain", "transform", + "generate_text", ] diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 1749315925..0b9427b938 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -14,7 +14,9 @@ from __future__ import annotations -from typing import Dict, Mapping, Optional, Union +import collections.abc +import json +from typing import Any, Dict, List, Mapping, Optional, Union import bigframes.core.compile.googlesql as googlesql import bigframes.core.sql @@ -100,14 +102,41 @@ def create_model_ddl( def _build_struct_sql( - struct_options: Mapping[str, Union[str, int, float, bool]] + struct_options: Mapping[ + str, + Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]], + ] ) -> str: if not struct_options: return "" rendered_options = [] for option_name, option_value in struct_options.items(): - rendered_val = bigframes.core.sql.simple_literal(option_value) + if option_name == "model_params": + json_str = json.dumps(option_value) + # Escape single quotes for SQL string literal + sql_json_str = json_str.replace("'", "''") + rendered_val = f"JSON'{sql_json_str}'" + elif isinstance(option_value, collections.abc.Mapping): + struct_body = ", ".join( + [ + f"{bigframes.core.sql.simple_literal(v)} AS {k}" + for k, v in option_value.items() + ] + ) + rendered_val = f"STRUCT({struct_body})" + elif isinstance(option_value, list): + rendered_val = ( + "[" + + ", ".join( + [bigframes.core.sql.simple_literal(v) for v in option_value] + ) + + "]" + ) + elif isinstance(option_value, bool): + rendered_val = str(option_value).lower() + else: + rendered_val = bigframes.core.sql.simple_literal(option_value) rendered_options.append(f"{rendered_val} AS {option_name}") return f", STRUCT({', '.join(rendered_options)})" @@ -151,7 +180,7 @@ def predict( """Encode the ML.PREDICT statement. See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference. """ - struct_options = {} + struct_options: Dict[str, Union[str, int, float, bool]] = {} if threshold is not None: struct_options["threshold"] = threshold if keep_original_columns is not None: @@ -205,7 +234,7 @@ def global_explain( """Encode the ML.GLOBAL_EXPLAIN statement. See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference. """ - struct_options = {} + struct_options: Dict[str, Union[str, int, float, bool]] = {} if class_level_explain is not None: struct_options["class_level_explain"] = class_level_explain @@ -224,3 +253,46 @@ def transform( """ sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n" return sql + + +def generate_text( + model_name: str, + table: str, + *, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + flatten_json_output: Optional[bool] = None, + stop_sequences: Optional[List[str]] = None, + ground_with_google_search: Optional[bool] = None, + request_type: Optional[str] = None, +) -> str: + """Encode the ML.GENERATE_TEXT statement. + See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference. + """ + struct_options: Dict[ + str, + Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]], + ] = {} + if temperature is not None: + struct_options["temperature"] = temperature + if max_output_tokens is not None: + struct_options["max_output_tokens"] = max_output_tokens + if top_k is not None: + struct_options["top_k"] = top_k + if top_p is not None: + struct_options["top_p"] = top_p + if flatten_json_output is not None: + struct_options["flatten_json_output"] = flatten_json_output + if stop_sequences is not None: + struct_options["stop_sequences"] = stop_sequences + if ground_with_google_search is not None: + struct_options["ground_with_google_search"] = ground_with_google_search + if request_type is not None: + struct_options["request_type"] = request_type + + sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {googlesql.identifier(model_name)}, ({table})" + sql += _build_struct_sql(struct_options) + sql += ")\n" + return sql diff --git a/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb index 501bfc88d3..3dc0eabf5a 100644 --- a/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb +++ b/notebooks/ml/bq_dataframes_ml_cross_validation.ipynb @@ -991,7 +991,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "venv (3.10.14)", "language": "python", "name": "python3" }, @@ -1005,7 +1005,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.15" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tests/unit/bigquery/test_ml.py b/tests/unit/bigquery/test_ml.py index 96b97d68fe..e52820f88a 100644 --- a/tests/unit/bigquery/test_ml.py +++ b/tests/unit/bigquery/test_ml.py @@ -163,3 +163,40 @@ def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock): assert "ML.TRANSFORM" in generated_sql assert f"MODEL `{MODEL_NAME}`" in generated_sql assert "(SELECT * FROM `pandas_df`)" in generated_sql + + +@mock.patch("bigframes.pandas.read_gbq_query") +@mock.patch("bigframes.pandas.read_pandas") +def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock): + df = pd.DataFrame({"col1": [1, 2, 3]}) + read_pandas_mock.return_value._to_sql_query.return_value = ( + "SELECT * FROM `pandas_df`", + [], + [], + ) + ml_ops.generate_text( + MODEL_SERIES, + input_=df, + temperature=0.5, + max_output_tokens=128, + top_k=20, + top_p=0.9, + flatten_json_output=True, + stop_sequences=["a", "b"], + ground_with_google_search=True, + request_type="TYPE", + ) + read_pandas_mock.assert_called_once() + read_gbq_query_mock.assert_called_once() + generated_sql = read_gbq_query_mock.call_args[0][0] + assert "ML.GENERATE_TEXT" in generated_sql + assert f"MODEL `{MODEL_NAME}`" in generated_sql + assert "(SELECT * FROM `pandas_df`)" in generated_sql + assert "STRUCT(0.5 AS temperature" in generated_sql + assert "128 AS max_output_tokens" in generated_sql + assert "20 AS top_k" in generated_sql + assert "0.9 AS top_p" in generated_sql + assert "true AS flatten_json_output" in generated_sql + assert "['a', 'b'] AS stop_sequences" in generated_sql + assert "true AS ground_with_google_search" in generated_sql + assert "'TYPE' AS request_type" in generated_sql diff --git a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql index 01eb4d3781..848c36907b 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level)) +SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql new file mode 100644 index 0000000000..9d98687644 --- /dev/null +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_basic/generate_text_model_basic.sql @@ -0,0 +1 @@ +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql new file mode 100644 index 0000000000..7839ff3fbd --- /dev/null +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql @@ -0,0 +1 @@ +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql index 1a3baa0c13..b8d158acfc 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain)) +SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql index 96c8074e4c..f320d47fcf 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(True AS keep_original_columns)) +SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns)) diff --git a/tests/unit/core/sql/test_ml.py b/tests/unit/core/sql/test_ml.py index 9721f42fee..15e9ef0aa1 100644 --- a/tests/unit/core/sql/test_ml.py +++ b/tests/unit/core/sql/test_ml.py @@ -177,3 +177,27 @@ def test_transform_model_basic(snapshot): table="SELECT * FROM new_data", ) snapshot.assert_match(sql, "transform_model_basic.sql") + + +def test_generate_text_model_basic(snapshot): + sql = bigframes.core.sql.ml.generate_text( + model_name="my_project.my_dataset.my_model", + table="SELECT * FROM new_data", + ) + snapshot.assert_match(sql, "generate_text_model_basic.sql") + + +def test_generate_text_model_with_options(snapshot): + sql = bigframes.core.sql.ml.generate_text( + model_name="my_project.my_dataset.my_model", + table="SELECT * FROM new_data", + temperature=0.5, + max_output_tokens=128, + top_k=20, + top_p=0.9, + flatten_json_output=True, + stop_sequences=["a", "b"], + ground_with_google_search=True, + request_type="TYPE", + ) + snapshot.assert_match(sql, "generate_text_model_with_options.sql")