diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 8d2d7bc7fb..5bdb30fd31 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -111,6 +111,7 @@ def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_na @pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES) +@pytest.mark.tpu_only def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None: """ Test sharding configurations from train_compile.get_shaped_inputs. @@ -193,7 +194,14 @@ def abstract_state_and_shardings(request): config = pyconfig.initialize(params) validate_config(config) - topology_mesh = get_topology_mesh(config) + try: + topology_mesh = get_topology_mesh(config) + except RuntimeError as e: + if "JAX TPU support not installed" in str(e): + pytest.skip("Skipping test because JAX TPU support is not installed.") + else: + raise e + quant = quantizations.configure_quantization(config) model = Transformer(config, mesh=topology_mesh, quant=quant) diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index e6c44533d6..0a5f6fc40f 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -31,6 +31,16 @@ pytestmark = [pytest.mark.external_training] +def run_train_compile_main(args): + try: + train_compile_main(args) + except RuntimeError as e: + if "JAX TPU support not installed" in str(e): + pytest.skip("Skipping test because JAX TPU support is not installed.") + else: + raise e + + class TrainCompile(unittest.TestCase): """Tests for the Ahead of Time Compilation functionality, train_compile.py""" @@ -38,7 +48,7 @@ class TrainCompile(unittest.TestCase): def test_save_compiled_v4(self): temp_dir = gettempdir() compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_v4.pickle") - train_compile_main( + run_train_compile_main( ( "", get_test_config_path(), diff --git a/tests/utils/sharding_dump.py b/tests/utils/sharding_dump.py index ec7b98b752..9defd5a79e 100644 --- a/tests/utils/sharding_dump.py +++ b/tests/utils/sharding_dump.py @@ -420,7 +420,10 @@ def main(argv: Sequence[str]) -> None: optimizers.get_optimizer(config, learning_rate_schedule) shaped_train_args, _, state_mesh_shardings, logical_annotations, _ = get_shaped_inputs(topology_mesh, config) except Exception as e: # pylint: disable=broad-except - print(f"Error generating inputs: {e}") + if "JAX TPU support not installed" in str(e): + print(f"Error generating inputs: {e} (You may need to install JAX with TPU support)") + else: + print(f"Error generating inputs: {e}") return if not state_mesh_shardings: