From ee87dafab396a36bc718b47f450ef2477e3596dc Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 5 Mar 2026 17:27:16 -0800 Subject: [PATCH 1/2] add guard at bisected jax version where lower is segfault Signed-off-by: tdophung --- transformer_engine/jax/triton_extensions/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2627a08929..ff23a263e9 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -162,6 +162,21 @@ def _check_triton_compatibility(): "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." ) from e +# Minimum jaxlib version required for Triton kernel dispatch to work correctly. +# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution +# (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24). +_JAXLIB_MIN_VERSION = "0.8.0.dev20250924" +import jaxlib # noqa: E402 + +if version.parse(jaxlib.__version__) < version.parse(_JAXLIB_MIN_VERSION): + raise RuntimeError( + f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" + f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" + f"Please upgrade: pip install --upgrade jax jaxlib\n" + f"Or use a JAX nightly container dated 2025-09-24 or later.\n" + f"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) + __all__ = ["triton_call_lowering", "get_triton_info"] From 14666c021fecfa5c89baccbba46ea25c9f99e46c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Mar 2026 01:35:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/triton_extensions/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index ff23a263e9..18aa09c6cd 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -172,9 +172,9 @@ def _check_triton_compatibility(): raise RuntimeError( f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" - f"Please upgrade: pip install --upgrade jax jaxlib\n" - f"Or use a JAX nightly container dated 2025-09-24 or later.\n" - f"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + "Please upgrade: pip install --upgrade jax jaxlib\n" + "Or use a JAX nightly container dated 2025-09-24 or later.\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." )