Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
if self.elastic_enabled and not self.enable_single_controller:
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
if self.colocated_python_data_input and not self.enable_single_controller:
raise ValueError(
"Colocated python data input is only supported with Pathways (single"
" controller) enabled (`enable_single_controller=True`)."
)
if self.grain_use_elastic_iterator and self.grain_file_type != "arrayrecord":
raise ValueError(
"`grain_use_elastic_iterator=True` only supports `grain_file_type=arrayrecord`. "
Expand Down
13 changes: 7 additions & 6 deletions src/maxtext/input_pipeline/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,7 @@ class RemoteIteratorWrapper:
"""Wrapper for RemoteIterator that handles device placement."""

def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path="", elastic=False):
self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
self.tpu_devices = jax.local_devices()
self.cpu_devices = _colocated_cpu_devices(tuple(global_mesh.devices.flat))
self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
Expand All @@ -288,11 +287,13 @@ def __next__(self):
return jax.device_put(out, self.tpu_sharding)

def save_state(self, step):
step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32)
step_array = jax.device_put(step_array, self.cpu_sharding)
replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec())
step_array = jnp.array(step, dtype=jnp.int32)
step_array = jax.device_put(step_array, replicated_cpu_sharding)
self.local_iterator.save_state(step_array)

def restore_state(self, step):
step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32)
step_array = jax.device_put(step_array, self.cpu_sharding)
replicated_cpu_sharding = NamedSharding(self.cpu_mesh, PartitionSpec())
step_array = jnp.array(step, dtype=jnp.int32)
step_array = jax.device_put(step_array, replicated_cpu_sharding)
self.local_iterator.restore_state(step_array)
166 changes: 148 additions & 18 deletions tests/unit/multihost_dataloading_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,28 +14,70 @@

# pylint: disable=missing-module-docstring, missing-function-docstring
import itertools
import json
import os
import pathlib
import sys
import unittest
import tempfile
from unittest import mock

import pytest

import numpy as np
from absl.testing import absltest
from absl.testing import parameterized

import jax
from jax.sharding import Mesh
from jax.experimental import mesh_utils

from jax.sharding import Mesh
from maxtext.configs import pyconfig
from maxtext.input_pipeline import multihost_dataloading
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory
from tests.utils.test_helpers import get_test_base_output_directory
from tests.utils.test_helpers import get_test_config_path
from tests.utils.test_helpers import get_test_dataset_path
import numpy as np
import pytest


class MockIterator:

def __init__(self, mesh_size):
self.state = 0
self.mesh_size = mesh_size

def __next__(self):
self.state += 1
return np.full((self.mesh_size, 1), self.state, dtype=np.int32)

def get_state(self) -> dict[str, int]:
return {"state": self.state}

def set_state(self, state: dict[str, int]):
self.state = state["state"]


class MockDataloader:

def __init__(self, mesh_size):
self.mesh_size = mesh_size

def __iter__(self) -> MockIterator:
return MockIterator(self.mesh_size)


def _get_test_mesh_shapes_named():
return [
("1_device", (1, 1)),
("2_devices", (2, 1)),
("4_devices", (2, 2)),
]


class MultihostDataloadingTest(unittest.TestCase):
class MultihostDataloadingTest(parameterized.TestCase):

def setUp(self):
super().setUp()
# Note: this test uses gs://max-experiments/ (not gs://runner-maxtext-logs) in cloud mode
base_output_directory = get_test_base_output_directory(cloud_path="gs://max-experiments/")
# Note: this test uses gs://max-experiments/ (not runner logs) in cloud mode
base_output_directory = get_test_base_output_directory(
cloud_path="gs://max-experiments/"
)
dataset_path = get_test_dataset_path(cloud_path="gs://maxtext-dataset/")
batch_size = len(jax.devices())
config = pyconfig.initialize(
Expand All @@ -50,20 +92,108 @@ def setUp(self):
enable_checkpointing=False,
)
mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes)
# Create 2 distinct batches and cycle through them infinitely.
global_data = np.arange(batch_size * 2 * config.max_target_length, dtype=np.int32).reshape(
(batch_size * 2, config.max_target_length)
self.mesh = Mesh(
mesh_utils.create_device_mesh(mesh_shape_1d), config.mesh_axes
)
# Create 2 distinct batches and cycle through them infinitely.
global_data = np.arange(
batch_size * 2 * config.max_target_length, dtype=np.int32
).reshape((batch_size * 2, config.max_target_length))
data_batches = [global_data[:batch_size], global_data[batch_size:]]
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(itertools.cycle(data_batches), self.mesh)
self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(
itertools.cycle(data_batches), self.mesh
)

@pytest.mark.tpu_only
def test_batch_sharded_data_pipeline(self):
first_batch = next(self.multihost_gen)
sec_batch = next(self.multihost_gen)
self.assertTrue(not np.array_equal(first_batch, sec_batch, equal_nan=True))
self.assertFalse(np.array_equal(first_batch, sec_batch, equal_nan=True))

@parameterized.named_parameters(*_get_test_mesh_shapes_named())
def test_remote_iterator_wrapper_save_state(self, mesh_shape):
mesh_size = mesh_shape[0] * mesh_shape[1]
if mesh_size > len(jax.devices()):
self.skipTest(
f"Skipping test because available devices ({len(jax.devices())}) is"
f" less than required mesh size ({mesh_size}) for shape {mesh_shape}."
)

devs = jax.devices()[:mesh_size]
devices = mesh_utils.create_device_mesh(mesh_shape, devs)
mesh = Mesh(devices, ("x", "y"))

def get_ds_fn(dataloading_host_index, dataloading_host_count):
del dataloading_host_index, dataloading_host_count
return MockDataloader(mesh_size)

preprocessing_fn = lambda dataset: dataset
global_shape = (mesh_size, 1)

with tempfile.TemporaryDirectory() as tmpdir:
wrapper = multihost_dataloading.RemoteIteratorWrapper(
get_ds_fn=get_ds_fn,
preprocessing_fn=preprocessing_fn,
global_mesh=mesh,
global_shape=global_shape,
checkpoint_path=tmpdir,
elastic=False,
)
# Advance state once so the value is 1
next(wrapper)

wrapper.save_state(step=5)

# Verify that a file was written in the tempdir containing {"state": 1}
json_files = list(pathlib.Path(tmpdir).glob("**/*.json"))
self.assertEqual(
len(json_files), 1, f"Expected 1 JSON file, found: {json_files}"
)
written_content = json_files[0].read_text()
self.assertEqual(json.loads(written_content), {"state": 1})

@parameterized.named_parameters(*_get_test_mesh_shapes_named())
def test_remote_iterator_wrapper_restore_state(self, mesh_shape):
mesh_size = mesh_shape[0] * mesh_shape[1]
if mesh_size > len(jax.devices()):
self.skipTest(
f"Skipping test because available devices ({len(jax.devices())}) is"
f" less than required mesh size ({mesh_size}) for shape {mesh_shape}."
)

devs = jax.devices()[:mesh_size]
devices = mesh_utils.create_device_mesh(mesh_shape, devs)
mesh = Mesh(devices, ("x", "y"))

def get_ds_fn(dataloading_host_index, dataloading_host_count):
del dataloading_host_index, dataloading_host_count
return MockDataloader(mesh_size)

preprocessing_fn = lambda dataset: dataset
global_shape = (mesh_size, 1)

with tempfile.TemporaryDirectory() as tmpdir:
step = 5
state_dir = pathlib.Path(tmpdir) / str(step) / "iter"
state_dir.mkdir(parents=True, exist_ok=True)
state_file = state_dir / "process_0-of-1.json"
state_file.write_text('{"state": 10}')

wrapper = multihost_dataloading.RemoteIteratorWrapper(
get_ds_fn=get_ds_fn,
preprocessing_fn=preprocessing_fn,
global_mesh=mesh,
global_shape=global_shape,
checkpoint_path=tmpdir,
elastic=False,
)

wrapper.restore_state(step=5)
val = next(wrapper)

# Next value should be 11 (state 10 + 1)
self.assertEqual(val.addressable_data(0)[0], 11)


if __name__ == "__main__":
unittest.main()
absltest.main()
Loading