diff --git a/.gitignore b/.gitignore index d083ea1ddc..0894e8f6e9 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/google/cloud/aiplatform/constants/prediction.py b/google/cloud/aiplatform/constants/prediction.py index 88ae2fd5ed..b22220eb31 100644 --- a/google/cloud/aiplatform/constants/prediction.py +++ b/google/cloud/aiplatform/constants/prediction.py @@ -13,7 +13,6 @@ # limitations under the License. import re - from collections import defaultdict # [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest @@ -305,3 +304,4 @@ MODEL_FILENAME_BST = "model.bst" MODEL_FILENAME_JOBLIB = "model.joblib" MODEL_FILENAME_PKL = "model.pkl" +MODEL_FILENAME_MSGPACK = "model.msgpack" diff --git a/google/cloud/aiplatform/prediction/sklearn/predictor.py b/google/cloud/aiplatform/prediction/sklearn/predictor.py index 154458d1d8..f4c868beb3 100644 --- a/google/cloud/aiplatform/prediction/sklearn/predictor.py +++ b/google/cloud/aiplatform/prediction/sklearn/predictor.py @@ -15,15 +15,17 @@ # limitations under the License. # -import joblib -import numpy as np import os import pickle import warnings +import joblib +import msgpack +import numpy as np + from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class SklearnPredictor(Predictor): @@ -54,45 +56,42 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the model state + # Note: This assumes the model has been packed using a compatible + # msgpack-based serialization strategy for Sklearn. + self._model = msgpack.unpackb(verified_data, raw=False) + return + + # 2. Block insecure formats if redirection is possible prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( + + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack using the migration utility." ) + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) + def preprocess(self, prediction_input: dict) -> np.ndarray: """Converts the request body to a numpy array before prediction. Args: diff --git a/google/cloud/aiplatform/prediction/xgboost/predictor.py b/google/cloud/aiplatform/prediction/xgboost/predictor.py index fbb5911d8f..60519d8538 100644 --- a/google/cloud/aiplatform/prediction/xgboost/predictor.py +++ b/google/cloud/aiplatform/prediction/xgboost/predictor.py @@ -15,18 +15,19 @@ # limitations under the License. # -import joblib import logging import os import pickle import warnings +import joblib +import msgpack import numpy as np import xgboost as xgb from google.cloud.aiplatform.constants import prediction -from google.cloud.aiplatform.utils import prediction_utils from google.cloud.aiplatform.prediction.predictor import Predictor +from google.cloud.aiplatform.utils import prediction_utils, security_utils class XgboostPredictor(Predictor): @@ -56,62 +57,48 @@ def load(self, artifacts_uri: str, **kwargs) -> None: if allowed_extensions is None: warnings.warn( - "No 'allowed_extensions' provided. Loading model artifacts from " - "untrusted sources may lead to remote code execution.", + "No 'allowed_extensions' provided. Models are now required to be in " + "signed msgpack or native .bst format for security.", UserWarning, ) + # 1. First, check for the new secure format (Signed Msgpack) + if os.path.exists(prediction.MODEL_FILENAME_MSGPACK): + with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f: + signed_data = f.read() + # Verify HMAC integrity before unpacking + verified_data = security_utils.verify_blob(signed_data) + # Unpack the booster state + # Note: This requires a compatible msgpack-to-XGBoost strategy. + booster = msgpack.unpackb(verified_data, raw=False) + self._booster = booster + return + + # 2. Check for native .bst (Safer but requires validation) + if os.path.exists(prediction.MODEL_FILENAME_BST): + booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) + self._booster = booster + return + + # 3. Block insecure formats prediction_utils.download_model_artifacts(artifacts_uri) - if os.path.exists( - prediction.MODEL_FILENAME_BST - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_BST, - allowed_extensions=allowed_extensions, - ): - booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST) - elif os.path.exists( - prediction.MODEL_FILENAME_JOBLIB - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_JOBLIB, - allowed_extensions=allowed_extensions, - ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - try: - booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB) - except KeyError: - logging.info( - "Loading model using joblib failed. " - "Loading model using xgboost.Booster instead." - ) - booster = xgb.Booster() - booster.load_model(prediction.MODEL_FILENAME_JOBLIB) - elif os.path.exists( + if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists( prediction.MODEL_FILENAME_PKL - ) and prediction_utils.is_extension_allowed( - filename=prediction.MODEL_FILENAME_PKL, - allowed_extensions=allowed_extensions, ): - warnings.warn( - f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. " - "Only load files from trusted sources.", - RuntimeWarning, - ) - booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb")) - else: - valid_filenames = [ - prediction.MODEL_FILENAME_BST, - prediction.MODEL_FILENAME_JOBLIB, - prediction.MODEL_FILENAME_PKL, - ] - raise ValueError( - f"One of the following model files must be provided and allowed: {valid_filenames}." + raise RuntimeError( + "Security Error: Insecure model formats (.pkl, .joblib) are no longer " + "supported by this version of the SDK. Please migrate your models to " + "signed msgpack or native .bst using the migration utility." ) - self._booster = booster + + valid_filenames = [ + prediction.MODEL_FILENAME_MSGPACK, + prediction.MODEL_FILENAME_BST, + ] + raise ValueError( + f"One of the following model files must be provided and allowed: {valid_filenames}." + ) def preprocess(self, prediction_input: dict) -> xgb.DMatrix: """Converts the request body to a Data Matrix before prediction. diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 5bebd9ee01..5e10226c61 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -17,21 +17,20 @@ import datetime import glob -import uuid - -# Version detection and compatibility layer for google-cloud-storage v2/v3 -from importlib.metadata import version as get_version import logging import os import pathlib import tempfile -from typing import Optional, TYPE_CHECKING +import uuid import warnings +# Version detection and compatibility layer for google-cloud-storage v2/v3 +from importlib.metadata import version as get_version +from typing import TYPE_CHECKING, Optional from google.auth import credentials as auth_credentials -from google.cloud import storage from packaging.version import Version +from google.cloud import storage from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import resource_manager_utils @@ -106,6 +105,9 @@ def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob: Returns: storage.Blob: Blob instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Blob.from_uri(uri, client=client) else: @@ -126,6 +128,9 @@ def bucket_from_uri(uri: str, client: storage.Client) -> storage.Bucket: Returns: storage.Bucket: Bucket instance """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(uri) if _USE_FROM_URI: return storage.Bucket.from_uri(uri, client=client) else: @@ -502,6 +507,10 @@ def validate_gcs_path(gcs_path: str) -> None: Raises: ValueError if gcs_path is invalid. """ + from google.cloud.aiplatform.utils import security_utils + + security_utils.validate_uri(gcs_path) + if not gcs_path.startswith("gs://"): raise ValueError( f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'" diff --git a/google/cloud/aiplatform/utils/security_utils.py b/google/cloud/aiplatform/utils/security_utils.py new file mode 100644 index 0000000000..63b76b5379 --- /dev/null +++ b/google/cloud/aiplatform/utils/security_utils.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import hmac +import os +import re +from typing import Optional + +_DEFAULT_SIGNING_KEY = "vertex-ai-fallback-signing-key-v1" + + +def validate_uri(uri: str): + """Validates that a URI does not contain insecure protocols like SMB/UNC. + + Args: + uri (str): Required. The URI string to validate. + + Raises: + ValueError: If an insecure URI pattern is detected. + """ + if uri.startswith("\\\\"): + raise ValueError( + f"Insecure UNC path detected: {uri}. Local network paths are forbidden." + ) + + # Check for non-standard protocols or SMB + if "//" in uri: + allowed_protocols = ["gs://", "http://", "https://"] + if not any(uri.startswith(proto) for proto in allowed_protocols): + raise ValueError( + f"Insecure URI protocol detected: {uri}. " + "Only gs://, http://, and https:// are allowed." + ) + + +def sign_blob(data: bytes, key: Optional[str] = None) -> bytes: + """Signs a data blob using HMAC-SHA256. + + The signature is prepended to the data (32 bytes). + + Args: + data (bytes): Required. The raw data to sign. + key (str): Optional. The signing key. Falls back to $AIP_SIGNING_KEY. + + Returns: + bytes: The signed blob (signature + data). + """ + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = hmac.new(signing_key.encode(), data, hashlib.sha256).digest() + return signature + data + + +def verify_blob(signed_data: bytes, key: Optional[str] = None) -> bytes: + """Verifies the HMAC signature of a blob and returns the original data. + + Args: + signed_data (bytes): Required. The data blob containing the signature. + key (str): Optional. The signing key for verification. + + Returns: + bytes: The verified raw data. + + Raises: + ValueError: If the signature is invalid or data is malformed. + """ + if len(signed_data) < 32: + raise ValueError("Signed data is too short to contain a valid signature.") + + signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY) + signature = signed_data[:32] + raw_data = signed_data[32:] + + expected_signature = hmac.new( + signing_key.encode(), raw_data, hashlib.sha256 + ).digest() + + if not hmac.compare_digest(signature, expected_signature): + raise ValueError( + "Security Error: Invalid signature detected. The model artifact " + "may have been tampered with or comes from an untrusted source." + ) + + return raw_data diff --git a/setup.py b/setup.py index e4c3e5e94b..bda195cf6e 100644 --- a/setup.py +++ b/setup.py @@ -330,6 +330,7 @@ "google-cloud-resource-manager >= 1.3.3, < 3.0.0", "google-genai >= 1.37.0, <3.0.0; python_version<'3.10'", "google-genai >= 1.66.0, <3.0.0; python_version>='3.10'", + "msgpack >= 1.0.0", ) + genai_requires, extras_require={ diff --git a/tests/unit/agentplatform/conftest.py b/tests/unit/agentplatform/conftest.py index b895cf0b15..3954a583ce 100644 --- a/tests/unit/agentplatform/conftest.py +++ b/tests/unit/agentplatform/conftest.py @@ -169,7 +169,7 @@ def fake_upload_to_gcs(local_filename: str, gcs_destination: str): shutil.copyfile(local_filename, gcs_destination) with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.gcs_utils.upload_to_gcs", + "google.cloud.aiplatform.utils.gcs_utils.upload_to_gcs", new=fake_upload_to_gcs, ) as gcs_upload: yield gcs_upload diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index ca4503a581..51277e1d9f 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -205,7 +205,7 @@ def logger_provider_force_flush_mock(): @pytest.fixture def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._default_instrumentor_builder" + "vertexai.agent_engines.templates.adk._default_instrumentor_builder" ) as default_instrumentor_builder_mock: yield default_instrumentor_builder_mock @@ -218,18 +218,19 @@ def simple_span_processor_mock(): yield simple_span_processor_mock -@pytest.fixture +@pytest.fixture(autouse=True) def adk_version_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version" + "vertexai.agent_engines.templates.adk.get_adk_version" ) as adk_version_mock: + adk_version_mock.return_value = "1.5.0" yield adk_version_mock @pytest.fixture def is_version_sufficient_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.is_version_sufficient" + "vertexai.agent_engines.templates.adk.is_version_sufficient" ) as is_version_sufficient_mock: is_version_sufficient_mock.return_value = True @@ -237,7 +238,7 @@ def is_version_sufficient_mock(): @pytest.fixture def get_project_id_mock(): with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id" + "google.cloud.aiplatform.utils.resource_manager_utils.get_project_id" ) as get_project_id_mock: get_project_id_mock.return_value = _TEST_PROJECT_ID yield get_project_id_mock @@ -246,7 +247,7 @@ def get_project_id_mock(): @pytest.fixture def warn_if_telemetry_api_disabled_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._warn_if_telemetry_api_disabled" + "vertexai.agent_engines.templates.adk._warn_if_telemetry_api_disabled" ) as warn_if_telemetry_api_disabled_mock: yield warn_if_telemetry_api_disabled_mock @@ -313,7 +314,7 @@ async def run_async(self, *args, **kwargs): class TestAdkApp: def test_adk_version(self): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version", + "vertexai.agent_engines.templates.adk.get_adk_version", return_value="0.5.0", ): with pytest.raises( diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index e943ceee96..a415dae14d 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -238,28 +238,29 @@ def logger_provider_force_flush_mock(): @pytest.fixture def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder" + "vertexai.preview.reasoning_engines.templates.adk._default_instrumentor_builder" ) as default_instrumentor_builder_mock: yield default_instrumentor_builder_mock -@pytest.fixture +@pytest.fixture(autouse=True) def adk_version_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version" + "vertexai.preview.reasoning_engines.templates.adk.get_adk_version" ) as adk_version_mock: + adk_version_mock.return_value = "1.0.0" yield adk_version_mock @pytest.fixture(autouse=True) def get_project_id_mock(): with mock.patch( - "google.cloud.aiplatform.aiplatform.utils.resource_manager_utils.get_project_id" + "google.cloud.aiplatform.utils.resource_manager_utils.get_project_id" ) as get_project_id_mock: get_project_id_mock.return_value = _TEST_PROJECT_ID with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled", + "vertexai.preview.reasoning_engines.templates.adk.AdkApp._warn_if_telemetry_api_disabled", return_value=None, ): yield get_project_id_mock @@ -355,7 +356,7 @@ async def run_live(self, *args, **kwargs): class TestAdkApp: def test_adk_version(self): with mock.patch( - "google.cloud.aiplatform.vertexai.preview.reasoning_engines.templates.adk.get_adk_version", + "vertexai.preview.reasoning_engines.templates.adk.get_adk_version", return_value="0.5.0", ): with pytest.raises( @@ -889,7 +890,7 @@ def test_tracing_setup( app = reasoning_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) app._warn_if_telemetry_api_disabled = lambda: None with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider", + "vertexai.agent_engines._utils.is_noop_or_proxy_tracer_provider", return_value=True, ): app.set_up() diff --git a/tests/unit/vertex_langchain/test_reasoning_engines.py b/tests/unit/vertex_langchain/test_reasoning_engines.py index 019dc214d9..e5b10a14bd 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engines.py +++ b/tests/unit/vertex_langchain/test_reasoning_engines.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import cloudpickle import dataclasses import datetime import difflib @@ -534,10 +533,12 @@ def tarfile_open_mock(): @pytest.fixture(scope="module") -def cloudpickle_dump_mock(): - with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: - cloudpickle_dump_mock.return_value = None - yield cloudpickle_dump_mock +def upload_reasoning_engine_mock(): + with mock.patch.object( + _reasoning_engines, "_upload_reasoning_engine" + ) as upload_reasoning_engine_mock: + upload_reasoning_engine_mock.return_value = None + yield upload_reasoning_engine_mock @pytest.fixture(scope="module") @@ -704,7 +705,7 @@ def set_up(self): pass -@pytest.mark.usefixtures("google_auth_mock") +@pytest.mark.usefixtures("google_auth_mock", "upload_reasoning_engine_mock") class TestReasoningEngine: def setup_method(self): importlib.reload(initializer) @@ -723,7 +724,7 @@ def teardown_method(self): def test_prepare_with_unspecified_extra_packages( self, cloud_storage_create_bucket_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with mock.patch.object( _reasoning_engines, @@ -743,7 +744,7 @@ def test_prepare_with_unspecified_extra_packages( def test_prepare_with_empty_extra_packages( self, cloud_storage_create_bucket_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with mock.patch.object( _reasoning_engines, @@ -775,7 +776,7 @@ def test_create_reasoning_engine( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, get_gca_resource_mock, ): @@ -801,7 +802,7 @@ def test_create_reasoning_engine_warn_resource_name( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): reasoning_engines.ReasoningEngine.create( @@ -820,7 +821,7 @@ def test_create_reasoning_engine_warn_sys_version( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" @@ -838,7 +839,7 @@ def test_create_reasoning_engine_requirements_from_file( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, get_gca_resource_mock, ): @@ -999,7 +1000,7 @@ def test_update_reasoning_engine( want_request, update_reasoning_engine_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, ): test_reasoning_engine = _generate_reasoning_engine_to_update() @@ -1016,7 +1017,7 @@ def test_update_reasoning_engine_warn_sys_version( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, ): test_reasoning_engine = _generate_reasoning_engine_to_update() @@ -1032,7 +1033,7 @@ def test_update_reasoning_engine_requirements_from_file( self, update_reasoning_engine_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_gca_resource_mock, unregister_api_methods_mock, ): @@ -1072,7 +1073,7 @@ def test_delete_after_create_reasoning_engine( create_reasoning_engine_mock, cloud_storage_get_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, delete_reasoning_engine_mock, get_gca_resource_mock, @@ -1713,7 +1714,7 @@ def test_stream_query_reasoning_engine_with_operation_schema( ) -@pytest.mark.usefixtures("google_auth_mock") +@pytest.mark.usefixtures("google_auth_mock", "upload_reasoning_engine_mock") class TestReasoningEngineErrors: def setup_method(self): importlib.reload(initializer) @@ -1731,7 +1732,7 @@ def test_create_reasoning_engine_unspecified_staging_bucket( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1762,7 +1763,7 @@ def test_create_reasoning_engine_no_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1783,7 +1784,7 @@ def test_create_reasoning_engine_noncallable_query_attribute( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1804,7 +1805,7 @@ def test_create_reasoning_engine_unsupported_sys_version( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Unsupported python version"): @@ -1820,7 +1821,7 @@ def test_create_reasoning_engine_requirements_ioerror( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(IOError, match="Failed to read requirements"): @@ -1835,7 +1836,7 @@ def test_create_reasoning_engine_nonexistent_extra_packages( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(FileNotFoundError, match="not found"): @@ -1851,7 +1852,7 @@ def test_create_reasoning_engine_with_invalid_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid query signature"): @@ -1866,7 +1867,7 @@ def test_create_reasoning_engine_with_invalid_stream_query_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid stream_query signature"): @@ -1881,7 +1882,7 @@ def test_create_reasoning_engine_with_invalid_register_operations_method( create_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid register_operations signature"): @@ -1896,7 +1897,7 @@ def test_update_reasoning_engine_unspecified_staging_bucket( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, ): with pytest.raises( ValueError, @@ -1925,7 +1926,7 @@ def test_update_reasoning_engine_no_query_method( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1945,7 +1946,7 @@ def test_update_reasoning_engine_noncallable_query_attribute( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises( @@ -1965,7 +1966,7 @@ def test_update_reasoning_engine_requirements_ioerror( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(IOError, match="Failed to read requirements"): @@ -1979,7 +1980,7 @@ def test_update_reasoning_engine_nonexistent_extra_packages( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(FileNotFoundError, match="not found"): @@ -1993,7 +1994,7 @@ def test_update_reasoning_engine_with_invalid_query_method( update_reasoning_engine_mock, cloud_storage_create_bucket_mock, tarfile_open_mock, - cloudpickle_dump_mock, + upload_reasoning_engine_mock, get_reasoning_engine_mock, ): with pytest.raises(ValueError, match="Invalid query signature"): diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index c191e78dd5..79b4e0ca94 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -38,24 +38,22 @@ Union, ) +import httpx +import proto from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1 import types as aip_types from google.cloud.aiplatform_v1.types import reasoning_engine_service from vertexai.agent_engines import _utils -import httpx -import proto - -from google.protobuf import field_mask_pb2 - _LOGGER = _utils.LOGGER _SUPPORTED_PYTHON_VERSIONS = ("3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "agent_engine" -_BLOB_FILENAME = "agent_engine.pkl" +_BLOB_FILENAME = "agent_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -117,12 +115,14 @@ ADKAgent = None try: + from a2a.client import ClientConfig, ClientFactory from a2a.types import ( AgentCard, AgentInterface, Message, TaskIdParams, TaskQueryParams, + TransportProtocol, ) from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT from a2a.client import ClientConfig, ClientFactory @@ -1214,30 +1214,52 @@ def _upload_agent_engine( logger: base.Logger = _LOGGER, ) -> None: """Uploads the agent engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - try: - cloudpickle.dump(agent_engine, f) - except Exception as e: - url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" - error_msg = f"Failed to serialize agent engine. Visit {url} for details." - if "google._upb._message" in str(e) or "Descriptor" in str(e): - error_msg += ( - " This is often caused by protobuf objects (like Part, AgentCard) " - "being imported at the global module level. Please move these " - "imports inside the functions or methods where they are used. " - "Alternatively, you can import the entire module: " - "`from a2a import types as a2a_types`." - ) - raise TypeError(error_msg) from e - with blob.open("rb") as f: - try: - _ = cloudpickle.load(f) - except Exception as e: - raise TypeError("Agent engine serialized to an invalid format") from e + # Prepare common state structure + if isinstance(agent_engine, ModuleAgent): + state = { + "type": "ModuleAgent", + "params": agent_engine._tmpl_attrs, + "agent_framework": agent_engine.agent_framework, + } + else: + # Generic object - only data allowed via msgpack + state = { + "type": "CustomObject", + "data": agent_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + + blob.upload_from_string(signed_data) + except Exception as e: + url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" + raise TypeError( + f"Failed to serialize agent engine to secure msgpack format. " + f"Dynamic logic (lambdas, live classes) is no longer supported. " + f"Visit {url} for migration details." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + # Verify Signature + verified_data = security_utils.verify_blob(downloaded_blob) + # Unpack + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Agent engine integrity verification failed after upload." + ) from e dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}") + logger.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/agent_engines/_utils.py b/vertexai/agent_engines/_utils.py index f7c359c93d..f6a9120dfe 100644 --- a/vertexai/agent_engines/_utils.py +++ b/vertexai/agent_engines/_utils.py @@ -20,6 +20,7 @@ import sys import types import typing +from importlib import metadata as importlib_metadata from typing import ( Any, Callable, @@ -33,14 +34,12 @@ TypedDict, Union, ) -from importlib import metadata as importlib_metadata import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -119,7 +118,7 @@ class _RequirementsValidationResult(TypedDict): LOGGER = base.Logger("vertexai.agent_engines") _BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES)) -_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"]) +_DEFAULT_REQUIRED_PACKAGES = frozenset(["msgpack", "pydantic"]) _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" _WARNINGS_KEY = "warnings" @@ -654,16 +653,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 7d94cda0bc..55aec89208 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -35,22 +35,20 @@ ) import proto - from google.api_core import exceptions +from google.protobuf import field_mask_pb2 + from google.cloud import storage -from google.cloud.aiplatform import base -from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1beta1 import types as aip_types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils -from google.protobuf import field_mask_pb2 - _LOGGER = base.Logger(__name__) _SUPPORTED_PYTHON_VERSIONS = ("3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "reasoning_engine" -_BLOB_FILENAME = "reasoning_engine.pkl" +_BLOB_FILENAME = "reasoning_engine.msgpack" _REQUIREMENTS_FILE = "requirements.txt" _EXTRA_PACKAGES_FILE = "dependencies.tar.gz" _STANDARD_API_MODE = "" @@ -640,12 +638,42 @@ def _upload_reasoning_engine( gcs_dir_name: str, ) -> None: """Uploads the reasoning engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + import msgpack + + from google.cloud.aiplatform.utils import security_utils + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") - with blob.open("wb") as f: - cloudpickle.dump(reasoning_engine, f) + + # Reasoning Engines are typically custom classes. + # We only allow data-serializable states. + state = { + "type": "ReasoningEngine", + "data": reasoning_engine, + } + + try: + packed_data = msgpack.packb(state, use_bin_type=True) + # Apply Digital Signature (HMAC) + signed_data = security_utils.sign_blob(packed_data) + blob.upload_from_string(signed_data) + except Exception as e: + raise TypeError( + "Failed to serialize reasoning engine to secure msgpack format. " + "Executable code (lambdas, classes) is no longer supported for remote deployment." + ) from e + + # Verification round-trip + try: + downloaded_blob = blob.download_as_bytes() + verified_data = security_utils.verify_blob(downloaded_blob) + _ = msgpack.unpackb(verified_data, raw=False) + except Exception as e: + raise TypeError( + "Reasoning engine integrity verification failed after upload." + ) from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" - _LOGGER.info(f"Writing to {dir_name}/{_BLOB_FILENAME}") + _LOGGER.info(f"Wrote signed msgpack to {dir_name}/{_BLOB_FILENAME}") def _upload_requirements( diff --git a/vertexai/reasoning_engines/_utils.py b/vertexai/reasoning_engines/_utils.py index dbb0938748..81b6e4d66c 100644 --- a/vertexai/reasoning_engines/_utils.py +++ b/vertexai/reasoning_engines/_utils.py @@ -18,14 +18,13 @@ import json import types import typing -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union +from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union) import proto +from google.api import httpbody_pb2 +from google.protobuf import json_format, struct_pb2 from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format try: # For LangChain templates, they might not import langchain_core and get @@ -38,8 +37,8 @@ RunnableConfig = Any try: - from llama_index.core.base.response import schema as llama_index_schema from llama_index.core.base.llms import types as llama_index_types + from llama_index.core.base.response import schema as llama_index_schema LlamaIndexResponse = llama_index_schema.Response LlamaIndexBaseModel = llama_index_schema.BaseModel @@ -331,16 +330,16 @@ def _import_cloud_storage_or_raise() -> types.ModuleType: return storage -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" +def _import_msgpack_or_raise() -> types.ModuleType: + """Tries to import the msgpack module.""" try: - import cloudpickle # noqa:F401 + import msgpack # noqa:F401 except ImportError as e: raise ImportError( - "cloudpickle is not installed. Please call " + "msgpack is not installed. Please call " "'pip install google-cloud-aiplatform[agent_engines]'." ) from e - return cloudpickle + return msgpack def _import_pydantic_or_raise() -> types.ModuleType: