Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/unit/sharding_compare_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na


@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
@pytest.mark.tpu_only
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
"""
Test sharding configurations from train_compile.get_shaped_inputs.
Expand Down Expand Up @@ -193,7 +194,14 @@ def abstract_state_and_shardings(request):
config = pyconfig.initialize(params)
validate_config(config)

topology_mesh = get_topology_mesh(config)
try:
topology_mesh = get_topology_mesh(config)
except RuntimeError as e:
if "JAX TPU support not installed" in str(e):
pytest.skip("Skipping test because JAX TPU support is not installed.")
else:
raise e

quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh=topology_mesh, quant=quant)

Expand Down
12 changes: 11 additions & 1 deletion tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,24 @@
pytestmark = [pytest.mark.external_training]


def run_train_compile_main(args):
try:
train_compile_main(args)
except RuntimeError as e:
if "JAX TPU support not installed" in str(e):
pytest.skip("Skipping test because JAX TPU support is not installed.")
else:
raise e


class TrainCompile(unittest.TestCase):
"""Tests for the Ahead of Time Compilation functionality, train_compile.py"""

@pytest.mark.cpu_only
def test_save_compiled_v4(self):
temp_dir = gettempdir()
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_v4.pickle")
train_compile_main(
run_train_compile_main(
(
"",
get_test_config_path(),
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/sharding_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,10 @@ def main(argv: Sequence[str]) -> None:
optimizers.get_optimizer(config, learning_rate_schedule)
shaped_train_args, _, state_mesh_shardings, logical_annotations, _ = get_shaped_inputs(topology_mesh, config)
except Exception as e: # pylint: disable=broad-except
print(f"Error generating inputs: {e}")
if "JAX TPU support not installed" in str(e):
print(f"Error generating inputs: {e} (You may need to install JAX with TPU support)")
else:
print(f"Error generating inputs: {e}")
return

if not state_mesh_shardings:
Expand Down
Loading