From 4aedab8de241aab6de350eaafb258c2fb4d977c5 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Wed, 10 Jun 2026 13:59:47 -0700 Subject: [PATCH] # Description This PR fixes a `ValueError: can only convert an array of size 1 to a Python scalar` that occurs in `RemoteIteratorWrapper` during state save/restore on multi-device topologies (size > 1). It also adds validation to ensure colocated Python data input is only used with Pathways (single controller) enabled, and replaces incorrect usages of `jax.local_devices()` with `global_mesh.devices`. # Root Cause 1. **ValueError in save/restore**: `RemoteIteratorWrapper.save_state` and `restore_state` were attempting to shape the step value array using `self.dummy_array.shape` and shard it across devices. On topologies with more than 1 device, this resulted in a partitioned array. When this partitioned array was passed to the local iterator, attempting to unpack it to a Python scalar (e.g. via `.item()` or direct conversion) failed because JAX does not allow converting partitioned arrays of size > 1 to Python scalars. 2. **Incorrect Device Resolution**: `RemoteIteratorWrapper` was using `jax.local_devices()` to determine CPU/TPU devices. Under Pathways (single-controller), all devices in the cluster are virtualized as local to the JAX client, meaning `jax.local_devices()` returns all devices (including inactive ones during elastic scale-down), which is incorrect for sharding and shape calculations. 3. **Missing Validation**: `colocated_python_data_input` relies on Pathways single-controller mode, but there was no validation enforcing this constraint, which could lead to cryptic failures if misconfigured. # Solution 1. **Replicated Scalar for Step**: Modified `RemoteIteratorWrapper.save_state` and `restore_state` in `multihost_dataloading.py` to pass the training step as a replicated 0D JAX scalar array (global shape `()`) with replicated sharding (`NamedSharding` with `PartitionSpec()`). This ensures the array has size 1 on all devices and can be safely converted to a Python scalar by the local iterator. 2. **Use Global Mesh Devices**: Replaced `jax.local_devices()` with `global_mesh.devices` (via `tuple(global_mesh.devices.flat)`) in `RemoteIteratorWrapper.__init__` to ensure it only uses the active devices defined by the global mesh, handling elastic scaling correctly. 3. **Config Validation**: Added a check in `types.py` to raise a `ValueError` if `colocated_python_data_input` is enabled but `enable_single_controller` is false. # Tests Added new unit tests in `third_party/py/maxtext/tests/unit/multihost_dataloading_test.py` to verify the fixes: 1. `test_remote_iterator_wrapper_save_state`: Parameterized over different mesh shapes (1, 2, and 4 devices). Instantiates `RemoteIteratorWrapper` and verifies that calling `save_state` successfully writes the state to a JSON file without raising `ValueError`. 2. `test_remote_iterator_wrapper_restore_state`: Parameterized over different mesh shapes. Verifies that `restore_state` successfully restores the state from a JSON file and resumes iteration correctly. These tests are configured to run with `XLA_FLAGS="--xla_force_host_platform_device_count=4"` via the `BUILD` target to simulate multi-device environments. # Checklist Before submitting this PR, please make sure (put X in square brackets): - [X] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [X] I have necessary comments in my code, particularly in hard-to-understand areas. - [X] I have run end-to-end tests tests and provided workload links above if applicable. - [X] I have made or will make corresponding changes to the doc if needed. PiperOrigin-RevId: 930062734 --- src/maxtext/configs/types.py | 5 + .../input_pipeline/multihost_dataloading.py | 13 +- tests/unit/multihost_dataloading_test.py | 144 ++++++++++++++++-- 3 files changed, 145 insertions(+), 17 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cb1987eb77..36165e8045 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2881,6 +2881,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.") if self.elastic_enabled and not self.enable_single_controller: raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).") + if self.colocated_python_data_input and not self.enable_single_controller: + raise ValueError( + "Colocated python data input is only supported with Pathways (single" + " controller) enabled (`enable_single_controller=True`)." + ) if self.grain_use_elastic_iterator and self.grain_file_type != "arrayrecord": raise ValueError( "`grain_use_elastic_iterator=True` only supports `grain_file_type=arrayrecord`. " diff --git a/src/maxtext/input_pipeline/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py index 221a6ed338..9ae3d9ca34 100644 --- a/src/maxtext/input_pipeline/multihost_dataloading.py +++ b/src/maxtext/input_pipeline/multihost_dataloading.py @@ -264,8 +264,7 @@ class RemoteIteratorWrapper: """Wrapper for RemoteIterator that handles device placement.""" def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path="", elastic=False): - self.cpu_devices = _colocated_cpu_devices(jax.local_devices()) - self.tpu_devices = jax.local_devices() + self.cpu_devices = _colocated_cpu_devices(tuple(global_mesh.devices.flat)) self.cpu_mesh = _colocated_cpu_mesh(global_mesh) self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names)) @@ -288,11 +287,13 @@ def __next__(self): return jax.device_put(out, self.tpu_sharding) def save_state(self, step): - step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) - step_array = jax.device_put(step_array, self.cpu_sharding) + replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec()) + step_array = jnp.array(step, dtype=jnp.int32) + step_array = jax.device_put(step_array, replicated_cpu_sharding) self.local_iterator.save_state(step_array) def restore_state(self, step): - step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) - step_array = jax.device_put(step_array, self.cpu_sharding) + replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec()) + step_array = jnp.array(step, dtype=jnp.int32) + step_array = jax.device_put(step_array, replicated_cpu_sharding) self.local_iterator.restore_state(step_array) diff --git a/tests/unit/multihost_dataloading_test.py b/tests/unit/multihost_dataloading_test.py index d4a6172141..64c0c87c09 100644 --- a/tests/unit/multihost_dataloading_test.py +++ b/tests/unit/multihost_dataloading_test.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,27 +14,67 @@ # pylint: disable=missing-module-docstring, missing-function-docstring import itertools +import json + +import pathlib import sys -import unittest +import tempfile -import pytest -import numpy as np +from absl.testing import absltest +from absl.testing import parameterized import jax -from jax.sharding import Mesh from jax.experimental import mesh_utils - +from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.input_pipeline import multihost_dataloading -from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory +from tests.utils.test_helpers import get_test_base_output_directory +from tests.utils.test_helpers import get_test_config_path +from tests.utils.test_helpers import get_test_dataset_path +import numpy as np +import pytest + + +class MockIterator: + + def __init__(self, mesh_size): + self.state = 0 + self.mesh_size = mesh_size + + def __next__(self): + self.state += 1 + return np.full((self.mesh_size, 1), self.state, dtype=np.int32) + + def get_state(self) -> dict[str, int]: + return {"state": self.state} + + def set_state(self, state: dict[str, int]): + self.state = state["state"] + + +class MockDataloader: + + def __init__(self, mesh_size): + self.mesh_size = mesh_size + + def __iter__(self) -> MockIterator: + return MockIterator(self.mesh_size) -class MultihostDataloadingTest(unittest.TestCase): +def _get_test_mesh_shapes_named(): + return [ + ("1_device", (1, 1)), + ("2_devices", (2, 1)), + ("4_devices", (2, 2)), + ] + + +class MultihostDataloadingTest(parameterized.TestCase): def setUp(self): super().setUp() - # Note: this test uses gs://max-experiments/ (not gs://runner-maxtext-logs) in cloud mode + # Note: this test uses gs://max-experiments/ (not runner logs) in cloud mode base_output_directory = get_test_base_output_directory(cloud_path="gs://max-experiments/") dataset_path = get_test_dataset_path(cloud_path="gs://maxtext-dataset/") batch_size = len(jax.devices()) @@ -62,8 +102,90 @@ def setUp(self): def test_batch_sharded_data_pipeline(self): first_batch = next(self.multihost_gen) sec_batch = next(self.multihost_gen) - self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True)) + self.assertFalse(np.array_equal(first_batch, sec_batch, equal_nan=True)) + + @parameterized.named_parameters(*_get_test_mesh_shapes_named()) + def test_remote_iterator_wrapper_save_state(self, mesh_shape): + mesh_size = mesh_shape[0] * mesh_shape[1] + if mesh_size > len(jax.devices()): + self.skipTest( + f"Skipping test because available devices ({len(jax.devices())}) is" + f" less than required mesh size ({mesh_size}) for shape {mesh_shape}." + ) + + devs = jax.devices()[:mesh_size] + devices = mesh_utils.create_device_mesh(mesh_shape, devs) + mesh = Mesh(devices, ("x", "y")) + + def get_ds_fn(dataloading_host_index, dataloading_host_count): + del dataloading_host_index, dataloading_host_count + return MockDataloader(mesh_size) + + preprocessing_fn = lambda dataset: dataset + global_shape = (mesh_size, 1) + + with tempfile.TemporaryDirectory() as tmpdir: + wrapper = multihost_dataloading.RemoteIteratorWrapper( + get_ds_fn=get_ds_fn, + preprocessing_fn=preprocessing_fn, + global_mesh=mesh, + global_shape=global_shape, + checkpoint_path=tmpdir, + elastic=False, + ) + # Advance state once so the value is 1 + next(wrapper) + + wrapper.save_state(step=5) + + # Verify that a file was written in the tempdir containing {"state": 1} + json_files = list(pathlib.Path(tmpdir).glob("**/*.json")) + self.assertEqual(len(json_files), 1, f"Expected 1 JSON file, found: {json_files}") + written_content = json_files[0].read_text() + self.assertEqual(json.loads(written_content), {"state": 1}) + + @parameterized.named_parameters(*_get_test_mesh_shapes_named()) + def test_remote_iterator_wrapper_restore_state(self, mesh_shape): + mesh_size = mesh_shape[0] * mesh_shape[1] + if mesh_size > len(jax.devices()): + self.skipTest( + f"Skipping test because available devices ({len(jax.devices())}) is" + f" less than required mesh size ({mesh_size}) for shape {mesh_shape}." + ) + + devs = jax.devices()[:mesh_size] + devices = mesh_utils.create_device_mesh(mesh_shape, devs) + mesh = Mesh(devices, ("x", "y")) + + def get_ds_fn(dataloading_host_index, dataloading_host_count): + del dataloading_host_index, dataloading_host_count + return MockDataloader(mesh_size) + + preprocessing_fn = lambda dataset: dataset + global_shape = (mesh_size, 1) + + with tempfile.TemporaryDirectory() as tmpdir: + step = 5 + state_dir = pathlib.Path(tmpdir) / str(step) / "iter" + state_dir.mkdir(parents=True, exist_ok=True) + state_file = state_dir / "process_0-of-1.json" + state_file.write_text('{"state": 10}') + + wrapper = multihost_dataloading.RemoteIteratorWrapper( + get_ds_fn=get_ds_fn, + preprocessing_fn=preprocessing_fn, + global_mesh=mesh, + global_shape=global_shape, + checkpoint_path=tmpdir, + elastic=False, + ) + + wrapper.restore_state(step=5) + val = next(wrapper) + + # Next value should be 11 (state 10 + 1) + self.assertEqual(val.addressable_data(0)[0], 11) if __name__ == "__main__": - unittest.main() + absltest.main()