diff --git a/src/maxtext/common/managed_mldiagnostics.py b/src/maxtext/common/managed_mldiagnostics.py index 62b39051b0..91ca91b4d0 100644 --- a/src/maxtext/common/managed_mldiagnostics.py +++ b/src/maxtext/common/managed_mldiagnostics.py @@ -13,6 +13,7 @@ # limitations under the License. """Create the managed mldiagnostics run.""" + import json from typing import Any @@ -24,16 +25,16 @@ class ManagedMLDiagnostics: - """ - ML Diagnostics Run, implemented with the Singleton pattern. + """ML Diagnostics Run, implemented with the Singleton pattern. + Ensures that only one instance of the class can exist. """ _instance = None # Class attribute to hold the single instance def __new__(cls, *args: Any, **kwargs: Any): - """ - Overrides the instance creation method. + """Overrides the instance creation method. + If an instance already exists, it is returned instead of creating a new one. """ if cls._instance is None: @@ -42,9 +43,7 @@ def __new__(cls, *args: Any, **kwargs: Any): return cls._instance def __init__(self, config): - """ - Initializes the ManagedMLDiagnostics, ensuring this method runs only once. - """ + """Initializes the ManagedMLDiagnostics, ensuring this method runs only once.""" # We need a flag to ensure __init__ only runs once, # as the object is returned multiple times by __new__. if hasattr(self, "_initialized"): @@ -67,11 +66,11 @@ def should_log_key(key, value): config_dict = {key: value for key, value in config.get_keys().items() if should_log_key(key, value)} # Create a run for the managed mldiagnostics, and upload the configuration. + region = config.managed_mldiagnostics_region if config.managed_mldiagnostics_region else None mldiag.machinelearning_run( name=f"{config.run_name}", run_group=config.managed_mldiagnostics_run_group, configs=config_dict, gcs_path=config.managed_mldiagnostics_dir, - # TODO: b/455623960 - Remove the following once multi-region and prod support are enabled. - region="us-central1", + region=region, ) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index c973087a09..7048ee8e4b 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -839,6 +839,7 @@ tpu_num_sparse_cores_to_trace: 2 # - upload training metrics, at the defined log_period interval. managed_mldiagnostics: false # Whether to enable the managed diagnostics managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. +managed_mldiagnostics_region: "" # Optional. GCP region for managed mldiagnostics. If empty, it will be auto-detected by the SDK. # Dump HLO and jaxpr options dump_hlo: false diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b70b7238d3..b8f5e1dc86 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1786,6 +1786,7 @@ class ManagedMLDiagnostics(BaseModel): managed_mldiagnostics: bool = Field(False, description="Enable managed mldiagnostics.") managed_mldiagnostics_run_group: str = Field("", description="Name used to group multiple runs.") + managed_mldiagnostics_region: str = Field("", description="GCP region for managed mldiagnostics.") class Goodput(BaseModel): diff --git a/tests/unit/managed_mldiagnostics_test.py b/tests/unit/managed_mldiagnostics_test.py new file mode 100644 index 0000000000..f3c4695e68 --- /dev/null +++ b/tests/unit/managed_mldiagnostics_test.py @@ -0,0 +1,81 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ManagedMLDiagnostics.""" + +import unittest +from unittest import mock + +from maxtext.common.managed_mldiagnostics import ManagedMLDiagnostics +import pytest + + +@pytest.mark.cpu_only +class ManagedMLDiagnosticsTest(unittest.TestCase): + # pylint: disable=protected-access + + def setUp(self): + super().setUp() + # Reset singleton instance between tests + ManagedMLDiagnostics._instance = None + + def test_not_enabled_noop(self): + mock_config = mock.MagicMock() + mock_config.managed_mldiagnostics = False + + with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run: + ManagedMLDiagnostics(mock_config) + mock_run.assert_not_called() + + def test_enabled_empty_region_passes_none(self): + mock_config = mock.MagicMock() + mock_config.managed_mldiagnostics = True + mock_config.managed_mldiagnostics_region = "" + mock_config.run_name = "test_run" + mock_config.managed_mldiagnostics_run_group = "test_group" + mock_config.managed_mldiagnostics_dir = "gs://test_dir" + mock_config.get_keys.return_value = {"key1": "val1"} + + with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run: + ManagedMLDiagnostics(mock_config) + mock_run.assert_called_once_with( + name="test_run", + run_group="test_group", + configs={"key1": "val1"}, + gcs_path="gs://test_dir", + region=None, + ) + + def test_enabled_populated_region_passes_region(self): + mock_config = mock.MagicMock() + mock_config.managed_mldiagnostics = True + mock_config.managed_mldiagnostics_region = "us-east1" + mock_config.run_name = "test_run" + mock_config.managed_mldiagnostics_run_group = "test_group" + mock_config.managed_mldiagnostics_dir = "gs://test_dir" + mock_config.get_keys.return_value = {"key1": "val1"} + + with mock.patch.object(ManagedMLDiagnostics.mldiag, "machinelearning_run") as mock_run: + ManagedMLDiagnostics(mock_config) + mock_run.assert_called_once_with( + name="test_run", + run_group="test_group", + configs={"key1": "val1"}, + gcs_path="gs://test_dir", + region="us-east1", + ) + + +if __name__ == "__main__": + unittest.main()