Skip to content

Commit 28a07c1

Browse files
fix test skipping
Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent b9ba7c5 commit 28a07c1

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,7 +1925,7 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
19251925
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
19261926
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
19271927

1928-
1928+
@pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
19291929
class TestFFICompatibility:
19301930

19311931
HLO_DIR = os.path.join(os.path.dirname(__file__), "ffi_hlo")
@@ -2077,7 +2077,6 @@ def _make_args_based_on_input_tensor_shape_and_dtype(self, stablehlo_text: str):
20772077
parsed_args.append(jnp.ones(shape, dtype=dtype))
20782078
return parsed_args
20792079

2080-
@pytest.mark.skipif(is_fp4_supported, reason=fp4_unsupported_reason)
20812080
def test_ffi_compatibility(self, ffi_hlo_name):
20822081
"""Tests that the current FFI bindings are compatible with the provided HLO and there are no API mismatches."""
20832082
from jax.extend.backend import get_backend

0 commit comments

Comments
 (0)