Add guard at lowest JAX version that still supports triton kernel calling#2741
Add guard at lowest JAX version that still supports triton kernel calling#2741tdophung wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci jax |
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Awesome, LGTM pending CI, thanks!
Greptile SummaryThis PR adds a backward-compatibility guard in
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[import transformer_engine.jax.triton_extensions.utils] --> B[_check_triton_compatibility]
B --> C{Triton installed\nand valid?}
C -- No --> D[raise ImportError]
C -- Yes --> E[Import gpu_triton,\ntriton.compiler, etc.]
E --> F{ImportError?}
F -- Yes --> G[raise ImportError]
F -- No --> H[import jaxlib]
H --> I{jaxlib.__version__ <\n0.8.0.dev20250924?}
I -- Yes --> J[raise RuntimeError\nversion too old - segfault risk]
I -- No --> K[Module ready:\ntriton_call_lowering,\nget_triton_info exported]
Last reviewed commit: 14666c0 |
| # jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution | ||
| # (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24). | ||
| _JAXLIB_MIN_VERSION = "0.8.0.dev20250924" | ||
| import jaxlib # noqa: E402 |
There was a problem hiding this comment.
Import placed mid-module when jax is already available
jax is already imported at line 42, and jax.__version__ always matches jaxlib.__version__ (they are co-released together). The mid-file import jaxlib requires a # noqa: E402 suppression and adds an extra package import for information that's already available via the existing jax import.
Alternatively, if you specifically want jaxlib.__version__ for precision, this import should be hoisted to the top-of-file import block alongside the other jax/jaxlib imports rather than sitting between executable module-level code.
| import jaxlib # noqa: E402 | |
| import jaxlib # jaxlib version guard below |
(move this line to the top of the file alongside the other imports at lines 41–43)
Or, to avoid the extra import entirely:
if version.parse(jax.__version__) < version.parse(_JAXLIB_MIN_VERSION):since jax.__version__ == jaxlib.__version__ for any properly-installed JAX environment.
There was a problem hiding this comment.
I was also about to comment about jax vs jaxlib.
Besides, we have an existing helper here as well https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/helper.py#L72-L78, perhaps we should find a way to repurpose this helper, especially since we will need to have jax guard all over places in the future.
On the side note, should we check for a release version, i.e., >= 0.8.0 instead?
There was a problem hiding this comment.
Great points! Yes, reusing _jax_version_meet_requirement would be much cleaner. You could move it to a shared utility module if multiple files need it.
Regarding dev vs release version:
>= 0.8.0.dev20250924allows: 0.8.0.dev20250924, later dev versions, and all 0.8.x releases>= 0.8.0only allows official releases (excludes all dev versions)
Since dev versions sort before releases in PEP 440, using >= 0.8.0 would actually block 0.8.0.dev20250924 even though your bisection shows it works.
Recommendation: If most users will use official releases (not nightlies), >= 0.8.0 is cleaner and safer. Dev/nightly users typically expect to handle version quirks. But if you want to explicitly support those working dev versions, keep >= 0.8.0.dev20250924.
For the code:
# At top of file with other imports
from .quantize.helper import _jax_version_meet_requirement
# Later in the file
_JAXLIB_MIN_VERSION = "0.8.0" # or keep dev version if needed
if not _jax_version_meet_requirement(_JAXLIB_MIN_VERSION):
raise RuntimeError(f"JAX >= {_JAXLIB_MIN_VERSION} required for Triton kernel support")
This eliminates the mid-file import and reuses existing infrastructure.
There was a problem hiding this comment.
@phu0ngng I think greptile has analyzed it pretty thoroughly on checking for > dev_version or >release_version. So the decision is between supporting dev versions after 0924 or not. I think we should support these dev versions, just in case other teams have not updated their pipeline to use a newer/release version. The ugly code to have the dev version instead of release is not that ugly and does not need any maintenance anyways
There was a problem hiding this comment.
regarding the other poitns (jax vs jaxlib) and using the _jax_version_meets_requirement : I agree and will fix
| ) from e | ||
|
|
||
| # Minimum jaxlib version required for Triton kernel dispatch to work correctly. | ||
| # jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution |
There was a problem hiding this comment.
Great work! I think we should expose a function here, something like is_triton_extension_supported() so us or users can check as well. Then in our Triton tests we can add something like the following block at the top:
def test_some_triton_extension(...):
if not is_triton_extension_supported():
pytest.skip("... same info message about triton jax compatibility in older versions ...")
return
main test codeSimilar to how we guard tests based on compute arch
Without this change, I think the tests on old containers would still fail, though the new error message is much more informative then before!
Description
To provide backward compatibility with older jax versions, we need to have a safeguard in place for jax versions too old to work with triton kernel calling. Using Claude Code to automate bisecting through JAX toolbox nightly containers between Sep 1, 2025 and Oct 1, 2025 (*), I have found that the first passing version of the container starts on Sep 24th, 2025, corresponding to jax 0.8.0.dev20250924 hence the guard is put there.
(*) the date range is determined by having a data point that the officially released jax toolbox (nvcr.io/nvidia/jax:25.10-py3 fails while the nightly jax container on Oct 1st passed.
Fixes # (issue)
Type of change
Changes
Handles jax < 0.8.0.dev20250924 segfault error when calling triton kernels frfom JAX side
Checklist: