Skip to content
Open
Binary file modified .gitignore
Binary file not shown.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/constants/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import re

from collections import defaultdict

# [region]-docker.pkg.dev/vertex-ai/prediction/[framework]-[accelerator].[version]:latest
Expand Down Expand Up @@ -305,3 +304,4 @@
MODEL_FILENAME_BST = "model.bst"
MODEL_FILENAME_JOBLIB = "model.joblib"
MODEL_FILENAME_PKL = "model.pkl"
MODEL_FILENAME_MSGPACK = "model.msgpack"
67 changes: 33 additions & 34 deletions google/cloud/aiplatform/prediction/sklearn/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# limitations under the License.
#

import joblib
import numpy as np
import os
import pickle
import warnings

import joblib
import msgpack
import numpy as np

from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform.utils import prediction_utils
from google.cloud.aiplatform.prediction.predictor import Predictor
from google.cloud.aiplatform.utils import prediction_utils, security_utils


class SklearnPredictor(Predictor):
Expand Down Expand Up @@ -54,45 +56,42 @@ def load(self, artifacts_uri: str, **kwargs) -> None:

if allowed_extensions is None:
warnings.warn(
"No 'allowed_extensions' provided. Loading model artifacts from "
"untrusted sources may lead to remote code execution.",
"No 'allowed_extensions' provided. Models are now required to be in "
"signed msgpack format for security.",
UserWarning,
)

# 1. First, check for the new secure format (Signed Msgpack)
if os.path.exists(prediction.MODEL_FILENAME_MSGPACK):
with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f:
signed_data = f.read()
# Verify HMAC integrity before unpacking
verified_data = security_utils.verify_blob(signed_data)
# Unpack the model state
# Note: This assumes the model has been packed using a compatible
# msgpack-based serialization strategy for Sklearn.
self._model = msgpack.unpackb(verified_data, raw=False)
return

# 2. Block insecure formats if redirection is possible
prediction_utils.download_model_artifacts(artifacts_uri)
if os.path.exists(
prediction.MODEL_FILENAME_JOBLIB
) and prediction_utils.is_extension_allowed(
filename=prediction.MODEL_FILENAME_JOBLIB,
allowed_extensions=allowed_extensions,
):
warnings.warn(
f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. "
"Only load files from trusted sources.",
RuntimeWarning,
)
self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
elif os.path.exists(

if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(
prediction.MODEL_FILENAME_PKL
) and prediction_utils.is_extension_allowed(
filename=prediction.MODEL_FILENAME_PKL,
allowed_extensions=allowed_extensions,
):
warnings.warn(
f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. "
"Only load files from trusted sources.",
RuntimeWarning,
)
self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
else:
valid_filenames = [
prediction.MODEL_FILENAME_JOBLIB,
prediction.MODEL_FILENAME_PKL,
]
raise ValueError(
f"One of the following model files must be provided and allowed: {valid_filenames}."
raise RuntimeError(
"Security Error: Insecure model formats (.pkl, .joblib) are no longer "
"supported by this version of the SDK. Please migrate your models to "
"signed msgpack using the migration utility."
)

valid_filenames = [
prediction.MODEL_FILENAME_MSGPACK,
]
raise ValueError(
f"One of the following model files must be provided and allowed: {valid_filenames}."
)

def preprocess(self, prediction_input: dict) -> np.ndarray:
"""Converts the request body to a numpy array before prediction.
Args:
Expand Down
87 changes: 37 additions & 50 deletions google/cloud/aiplatform/prediction/xgboost/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
# limitations under the License.
#

import joblib
import logging
import os
import pickle
import warnings

import joblib
import msgpack
import numpy as np
import xgboost as xgb

from google.cloud.aiplatform.constants import prediction
from google.cloud.aiplatform.utils import prediction_utils
from google.cloud.aiplatform.prediction.predictor import Predictor
from google.cloud.aiplatform.utils import prediction_utils, security_utils


class XgboostPredictor(Predictor):
Expand Down Expand Up @@ -56,62 +57,48 @@ def load(self, artifacts_uri: str, **kwargs) -> None:

if allowed_extensions is None:
warnings.warn(
"No 'allowed_extensions' provided. Loading model artifacts from "
"untrusted sources may lead to remote code execution.",
"No 'allowed_extensions' provided. Models are now required to be in "
"signed msgpack or native .bst format for security.",
UserWarning,
)

# 1. First, check for the new secure format (Signed Msgpack)
if os.path.exists(prediction.MODEL_FILENAME_MSGPACK):
with open(prediction.MODEL_FILENAME_MSGPACK, "rb") as f:
signed_data = f.read()
# Verify HMAC integrity before unpacking
verified_data = security_utils.verify_blob(signed_data)
# Unpack the booster state
# Note: This requires a compatible msgpack-to-XGBoost strategy.
booster = msgpack.unpackb(verified_data, raw=False)
self._booster = booster
return

# 2. Check for native .bst (Safer but requires validation)
if os.path.exists(prediction.MODEL_FILENAME_BST):
booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST)
self._booster = booster
return

# 3. Block insecure formats
prediction_utils.download_model_artifacts(artifacts_uri)

if os.path.exists(
prediction.MODEL_FILENAME_BST
) and prediction_utils.is_extension_allowed(
filename=prediction.MODEL_FILENAME_BST,
allowed_extensions=allowed_extensions,
):
booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST)
elif os.path.exists(
prediction.MODEL_FILENAME_JOBLIB
) and prediction_utils.is_extension_allowed(
filename=prediction.MODEL_FILENAME_JOBLIB,
allowed_extensions=allowed_extensions,
):
warnings.warn(
f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. "
"Only load files from trusted sources.",
RuntimeWarning,
)
try:
booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
except KeyError:
logging.info(
"Loading model using joblib failed. "
"Loading model using xgboost.Booster instead."
)
booster = xgb.Booster()
booster.load_model(prediction.MODEL_FILENAME_JOBLIB)
elif os.path.exists(
if os.path.exists(prediction.MODEL_FILENAME_JOBLIB) or os.path.exists(
prediction.MODEL_FILENAME_PKL
) and prediction_utils.is_extension_allowed(
filename=prediction.MODEL_FILENAME_PKL,
allowed_extensions=allowed_extensions,
):
warnings.warn(
f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. "
"Only load files from trusted sources.",
RuntimeWarning,
)
booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
else:
valid_filenames = [
prediction.MODEL_FILENAME_BST,
prediction.MODEL_FILENAME_JOBLIB,
prediction.MODEL_FILENAME_PKL,
]
raise ValueError(
f"One of the following model files must be provided and allowed: {valid_filenames}."
raise RuntimeError(
"Security Error: Insecure model formats (.pkl, .joblib) are no longer "
"supported by this version of the SDK. Please migrate your models to "
"signed msgpack or native .bst using the migration utility."
)
self._booster = booster

valid_filenames = [
prediction.MODEL_FILENAME_MSGPACK,
prediction.MODEL_FILENAME_BST,
]
raise ValueError(
f"One of the following model files must be provided and allowed: {valid_filenames}."
)

def preprocess(self, prediction_input: dict) -> xgb.DMatrix:
"""Converts the request body to a Data Matrix before prediction.
Expand Down
21 changes: 15 additions & 6 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@

import datetime
import glob
import uuid

# Version detection and compatibility layer for google-cloud-storage v2/v3
from importlib.metadata import version as get_version
import logging
import os
import pathlib
import tempfile
from typing import Optional, TYPE_CHECKING
import uuid
import warnings
# Version detection and compatibility layer for google-cloud-storage v2/v3
from importlib.metadata import version as get_version
from typing import TYPE_CHECKING, Optional

from google.auth import credentials as auth_credentials
from google.cloud import storage
from packaging.version import Version

from google.cloud import storage
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import resource_manager_utils

Expand Down Expand Up @@ -106,6 +105,9 @@ def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob:
Returns:
storage.Blob: Blob instance
"""
from google.cloud.aiplatform.utils import security_utils

security_utils.validate_uri(uri)
if _USE_FROM_URI:
return storage.Blob.from_uri(uri, client=client)
else:
Expand All @@ -126,6 +128,9 @@ def bucket_from_uri(uri: str, client: storage.Client) -> storage.Bucket:
Returns:
storage.Bucket: Bucket instance
"""
from google.cloud.aiplatform.utils import security_utils

security_utils.validate_uri(uri)
if _USE_FROM_URI:
return storage.Bucket.from_uri(uri, client=client)
else:
Expand Down Expand Up @@ -502,6 +507,10 @@ def validate_gcs_path(gcs_path: str) -> None:
Raises:
ValueError if gcs_path is invalid.
"""
from google.cloud.aiplatform.utils import security_utils

security_utils.validate_uri(gcs_path)

if not gcs_path.startswith("gs://"):
raise ValueError(
f"Invalid GCS path {gcs_path}. Please provide a valid GCS path starting with 'gs://'"
Expand Down
97 changes: 97 additions & 0 deletions google/cloud/aiplatform/utils/security_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import hmac
import os
import re
from typing import Optional

_DEFAULT_SIGNING_KEY = "vertex-ai-fallback-signing-key-v1"


def validate_uri(uri: str):
"""Validates that a URI does not contain insecure protocols like SMB/UNC.

Args:
uri (str): Required. The URI string to validate.

Raises:
ValueError: If an insecure URI pattern is detected.
"""
if uri.startswith("\\\\"):
raise ValueError(
f"Insecure UNC path detected: {uri}. Local network paths are forbidden."
)

# Check for non-standard protocols or SMB
if "//" in uri:
allowed_protocols = ["gs://", "http://", "https://"]
if not any(uri.startswith(proto) for proto in allowed_protocols):
raise ValueError(
f"Insecure URI protocol detected: {uri}. "
"Only gs://, http://, and https:// are allowed."
)


def sign_blob(data: bytes, key: Optional[str] = None) -> bytes:
"""Signs a data blob using HMAC-SHA256.

The signature is prepended to the data (32 bytes).

Args:
data (bytes): Required. The raw data to sign.
key (str): Optional. The signing key. Falls back to $AIP_SIGNING_KEY.

Returns:
bytes: The signed blob (signature + data).
"""
signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY)
signature = hmac.new(signing_key.encode(), data, hashlib.sha256).digest()
return signature + data


def verify_blob(signed_data: bytes, key: Optional[str] = None) -> bytes:
"""Verifies the HMAC signature of a blob and returns the original data.

Args:
signed_data (bytes): Required. The data blob containing the signature.
key (str): Optional. The signing key for verification.

Returns:
bytes: The verified raw data.

Raises:
ValueError: If the signature is invalid or data is malformed.
"""
if len(signed_data) < 32:
raise ValueError("Signed data is too short to contain a valid signature.")

signing_key = key or os.environ.get("AIP_SIGNING_KEY", _DEFAULT_SIGNING_KEY)
signature = signed_data[:32]
raw_data = signed_data[32:]

expected_signature = hmac.new(
signing_key.encode(), raw_data, hashlib.sha256
).digest()

if not hmac.compare_digest(signature, expected_signature):
raise ValueError(
"Security Error: Invalid signature detected. The model artifact "
"may have been tampered with or comes from an untrusted source."
)

return raw_data
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@
"google-cloud-resource-manager >= 1.3.3, < 3.0.0",
"google-genai >= 1.37.0, <3.0.0; python_version<'3.10'",
"google-genai >= 1.66.0, <3.0.0; python_version>='3.10'",
"msgpack >= 1.0.0",
)
+ genai_requires,
extras_require={
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/agentplatform/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def fake_upload_to_gcs(local_filename: str, gcs_destination: str):
shutil.copyfile(local_filename, gcs_destination)

with mock.patch(
"google.cloud.aiplatform.aiplatform.utils.gcs_utils.upload_to_gcs",
"google.cloud.aiplatform.utils.gcs_utils.upload_to_gcs",
new=fake_upload_to_gcs,
) as gcs_upload:
yield gcs_upload
Expand Down
Loading