Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import cast, Mapping, Optional, Union
from typing import cast, List, Mapping, Optional, Union

import bigframes_vendored.constants
import google.cloud.bigquery
Expand Down Expand Up @@ -431,3 +431,92 @@ def transform(
return bpd.read_gbq_query(sql)
else:
return session.read_gbq_query(sql)


@log_adapter.method_logger(custom_base_name="bigquery_ml")
def generate_text(
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
*,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
flatten_json_output: Optional[bool] = None,
stop_sequences: Optional[List[str]] = None,
ground_with_google_search: Optional[bool] = None,
request_type: Optional[str] = None,
) -> dataframe.DataFrame:
"""
Generates text using a BigQuery ML model.

See the `BigQuery ML GENERATE_TEXT function syntax
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
for additional reference.

Args:
model (bigframes.ml.base.BaseEstimator or str):
The model to use for text generation.
input_ (Union[bigframes.pandas.DataFrame, str]):
The DataFrame or query to use for text generation.
temperature (float, optional):
A FLOAT64 value that is used for sampling promiscuity. The value
must be in the range ``[0.0, 1.0]``. A lower temperature works well
for prompts that expect a more deterministic and less open-ended
or creative response, while a higher temperature can lead to more
diverse or creative results. A temperature of ``0`` is
deterministic, meaning that the highest probability response is
always selected.
max_output_tokens (int, optional):
An INT64 value that sets the maximum number of tokens in the
generated text.
top_k (int, optional):
An INT64 value that changes how the model selects tokens for
output. A ``top_k`` of ``1`` means the next selected token is the
most probable among all tokens in the model's vocabulary. A
``top_k`` of ``3`` means that the next token is selected from
among the three most probable tokens by using temperature. The
default value is ``40``.
top_p (float, optional):
A FLOAT64 value that changes how the model selects tokens for
output. Tokens are selected from most probable to least probable
until the sum of their probabilities equals the ``top_p`` value.
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
select either A or B as the next token by using temperature. The
default value is ``0.95``.
flatten_json_output (bool, optional):
A BOOL value that determines the content of the generated JSON column.
stop_sequences (List[str], optional):
An ARRAY<STRING> value that contains the stop sequences for the model.
ground_with_google_search (bool, optional):
A BOOL value that determines whether to ground the model with Google Search.
request_type (str, optional):
A STRING value that contains the request type for the model.

Returns:
bigframes.pandas.DataFrame:
The generated text.
"""
import bigframes.pandas as bpd

model_name, session = _get_model_name_and_session(model, input_)
table_sql = _to_sql(input_)

sql = bigframes.core.sql.ml.generate_text(
model_name=model_name,
table=table_sql,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_k=top_k,
top_p=top_p,
flatten_json_output=flatten_json_output,
stop_sequences=stop_sequences,
ground_with_google_search=ground_with_google_search,
request_type=request_type,
)

if session is None:
return bpd.read_gbq_query(sql)
else:
return session.read_gbq_query(sql)
2 changes: 2 additions & 0 deletions bigframes/bigquery/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
create_model,
evaluate,
explain_predict,
generate_text,
global_explain,
predict,
transform,
Expand All @@ -35,4 +36,5 @@
"explain_predict",
"global_explain",
"transform",
"generate_text",
]
82 changes: 77 additions & 5 deletions bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from __future__ import annotations

from typing import Dict, Mapping, Optional, Union
import collections.abc
import json
from typing import Any, Dict, List, Mapping, Optional, Union

import bigframes.core.compile.googlesql as googlesql
import bigframes.core.sql
Expand Down Expand Up @@ -100,14 +102,41 @@ def create_model_ddl(


def _build_struct_sql(
struct_options: Mapping[str, Union[str, int, float, bool]]
struct_options: Mapping[
str,
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
]
) -> str:
if not struct_options:
return ""

rendered_options = []
for option_name, option_value in struct_options.items():
rendered_val = bigframes.core.sql.simple_literal(option_value)
if option_name == "model_params":
json_str = json.dumps(option_value)
# Escape single quotes for SQL string literal
sql_json_str = json_str.replace("'", "''")
rendered_val = f"JSON'{sql_json_str}'"
elif isinstance(option_value, collections.abc.Mapping):
struct_body = ", ".join(
[
f"{bigframes.core.sql.simple_literal(v)} AS {k}"
for k, v in option_value.items()
]
)
rendered_val = f"STRUCT({struct_body})"
elif isinstance(option_value, list):
rendered_val = (
"["
+ ", ".join(
[bigframes.core.sql.simple_literal(v) for v in option_value]
)
+ "]"
)
elif isinstance(option_value, bool):
rendered_val = str(option_value).lower()
else:
rendered_val = bigframes.core.sql.simple_literal(option_value)
rendered_options.append(f"{rendered_val} AS {option_name}")
return f", STRUCT({', '.join(rendered_options)})"

Expand Down Expand Up @@ -151,7 +180,7 @@ def predict(
"""Encode the ML.PREDICT statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
"""
struct_options = {}
struct_options: Dict[str, Union[str, int, float, bool]] = {}
if threshold is not None:
struct_options["threshold"] = threshold
if keep_original_columns is not None:
Expand Down Expand Up @@ -205,7 +234,7 @@ def global_explain(
"""Encode the ML.GLOBAL_EXPLAIN statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
"""
struct_options = {}
struct_options: Dict[str, Union[str, int, float, bool]] = {}
if class_level_explain is not None:
struct_options["class_level_explain"] = class_level_explain

Expand All @@ -224,3 +253,46 @@ def transform(
"""
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
return sql


def generate_text(
model_name: str,
table: str,
*,
temperature: Optional[float] = None,
max_output_tokens: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
flatten_json_output: Optional[bool] = None,
stop_sequences: Optional[List[str]] = None,
ground_with_google_search: Optional[bool] = None,
request_type: Optional[str] = None,
) -> str:
"""Encode the ML.GENERATE_TEXT statement.
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference.
"""
struct_options: Dict[
str,
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
] = {}
if temperature is not None:
struct_options["temperature"] = temperature
if max_output_tokens is not None:
struct_options["max_output_tokens"] = max_output_tokens
if top_k is not None:
struct_options["top_k"] = top_k
if top_p is not None:
struct_options["top_p"] = top_p
if flatten_json_output is not None:
struct_options["flatten_json_output"] = flatten_json_output
if stop_sequences is not None:
struct_options["stop_sequences"] = stop_sequences
if ground_with_google_search is not None:
struct_options["ground_with_google_search"] = ground_with_google_search
if request_type is not None:
struct_options["request_type"] = request_type

sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {googlesql.identifier(model_name)}, ({table})"
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
4 changes: 2 additions & 2 deletions notebooks/ml/bq_dataframes_ml_cross_validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "venv (3.10.14)",
"language": "python",
"name": "python3"
},
Expand All @@ -1005,7 +1005,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/bigquery/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,40 @@ def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
assert "ML.TRANSFORM" in generated_sql
assert f"MODEL `{MODEL_NAME}`" in generated_sql
assert "(SELECT * FROM `pandas_df`)" in generated_sql


@mock.patch("bigframes.pandas.read_gbq_query")
@mock.patch("bigframes.pandas.read_pandas")
def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
df = pd.DataFrame({"col1": [1, 2, 3]})
read_pandas_mock.return_value._to_sql_query.return_value = (
"SELECT * FROM `pandas_df`",
[],
[],
)
ml_ops.generate_text(
MODEL_SERIES,
input_=df,
temperature=0.5,
max_output_tokens=128,
top_k=20,
top_p=0.9,
flatten_json_output=True,
stop_sequences=["a", "b"],
ground_with_google_search=True,
request_type="TYPE",
)
read_pandas_mock.assert_called_once()
read_gbq_query_mock.assert_called_once()
generated_sql = read_gbq_query_mock.call_args[0][0]
assert "ML.GENERATE_TEXT" in generated_sql
assert f"MODEL `{MODEL_NAME}`" in generated_sql
assert "(SELECT * FROM `pandas_df`)" in generated_sql
assert "STRUCT(0.5 AS temperature" in generated_sql
assert "128 AS max_output_tokens" in generated_sql
assert "20 AS top_k" in generated_sql
assert "0.9 AS top_p" in generated_sql
assert "true AS flatten_json_output" in generated_sql
assert "['a', 'b'] AS stop_sequences" in generated_sql
assert "true AS ground_with_google_search" in generated_sql
assert "'TYPE' AS request_type" in generated_sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type))
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain))
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain))
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(True AS keep_original_columns))
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns))
24 changes: 24 additions & 0 deletions tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,27 @@ def test_transform_model_basic(snapshot):
table="SELECT * FROM new_data",
)
snapshot.assert_match(sql, "transform_model_basic.sql")


def test_generate_text_model_basic(snapshot):
sql = bigframes.core.sql.ml.generate_text(
model_name="my_project.my_dataset.my_model",
table="SELECT * FROM new_data",
)
snapshot.assert_match(sql, "generate_text_model_basic.sql")


def test_generate_text_model_with_options(snapshot):
sql = bigframes.core.sql.ml.generate_text(
model_name="my_project.my_dataset.my_model",
table="SELECT * FROM new_data",
temperature=0.5,
max_output_tokens=128,
top_k=20,
top_p=0.9,
flatten_json_output=True,
stop_sequences=["a", "b"],
ground_with_google_search=True,
request_type="TYPE",
)
snapshot.assert_match(sql, "generate_text_model_with_options.sql")
Loading