Skip to content

Add guard at lowest JAX version that still supports triton kernel calling#2741

Open
tdophung wants to merge 2 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat
Open

Add guard at lowest JAX version that still supports triton kernel calling#2741
tdophung wants to merge 2 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Mar 6, 2026

Description

To provide backward compatibility with older jax versions, we need to have a safeguard in place for jax versions too old to work with triton kernel calling. Using Claude Code to automate bisecting through JAX toolbox nightly containers between Sep 1, 2025 and Oct 1, 2025 (*), I have found that the first passing version of the container starts on Sep 24th, 2025, corresponding to jax 0.8.0.dev20250924 hence the guard is put there.

(*) the date range is determined by having a data point that the officially released jax toolbox (nvcr.io/nvidia/jax:25.10-py3 fails while the nightly jax container on Oct 1st passed.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Handles jax < 0.8.0.dev20250924 segfault error when calling triton kernels frfom JAX side

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

tdophung commented Mar 6, 2026

/te-ci jax

@tdophung tdophung changed the title add guard at bisected jax version where lower is segfault Add guard at lowest JAX version that still supports triton kernel calling Mar 6, 2026
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM pending CI, thanks!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR adds a backward-compatibility guard in transformer_engine/jax/triton_extensions/utils.py that raises a clear RuntimeError at import time when jaxlib < 0.8.0.dev20250924 is detected — versions known to segfault in pxla.py during Triton kernel dispatch, as determined via automated bisection through JAX nightly containers.

  • Guard logic is correct: packaging.version.parse handles PEP 440 dev/nightly version strings properly; 0.8.0.dev20250924 is treated as a pre-release of 0.8.0, so any released 0.8.0 or later passes, and any earlier nightly or stable release is blocked as intended.
  • Error message is user-friendly: The RuntimeError message includes the installed version, the minimum required version, an upgrade command, and the cpp_extensions escape hatch.
  • Minor style concern: import jaxlib is inserted mid-module (requiring # noqa: E402) even though jax — whose __version__ is always identical to jaxlib.__version__ — is already imported at the top of the file.

Confidence Score: 4/5

  • Safe to merge; the guard correctly prevents segfaults on old jaxlib versions with a clear error message.
  • The change is small, focused, and logically sound — version comparison using packaging.version.parse is correct for PEP 440 dev versions. The only concern is a minor style issue (mid-module import), which does not affect correctness or runtime behaviour.
  • No files require special attention beyond the minor import placement noted in transformer_engine/jax/triton_extensions/utils.py.

Important Files Changed

Filename Overview
transformer_engine/jax/triton_extensions/utils.py Adds a jaxlib minimum version guard (0.8.0.dev20250924) that raises RuntimeError on import for versions known to segfault during Triton kernel dispatch; logic and version comparison are correct, but the new import jaxlib is placed mid-module and could use the already-imported jax.__version__ instead.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[import transformer_engine.jax.triton_extensions.utils] --> B[_check_triton_compatibility]
    B --> C{Triton installed\nand valid?}
    C -- No --> D[raise ImportError]
    C -- Yes --> E[Import gpu_triton,\ntriton.compiler, etc.]
    E --> F{ImportError?}
    F -- Yes --> G[raise ImportError]
    F -- No --> H[import jaxlib]
    H --> I{jaxlib.__version__ <\n0.8.0.dev20250924?}
    I -- Yes --> J[raise RuntimeError\nversion too old - segfault risk]
    I -- No --> K[Module ready:\ntriton_call_lowering,\nget_triton_info exported]
Loading

Last reviewed commit: 14666c0

# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution
# (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24).
_JAXLIB_MIN_VERSION = "0.8.0.dev20250924"
import jaxlib # noqa: E402
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import placed mid-module when jax is already available

jax is already imported at line 42, and jax.__version__ always matches jaxlib.__version__ (they are co-released together). The mid-file import jaxlib requires a # noqa: E402 suppression and adds an extra package import for information that's already available via the existing jax import.

Alternatively, if you specifically want jaxlib.__version__ for precision, this import should be hoisted to the top-of-file import block alongside the other jax/jaxlib imports rather than sitting between executable module-level code.

Suggested change
import jaxlib # noqa: E402
import jaxlib # jaxlib version guard below

(move this line to the top of the file alongside the other imports at lines 41–43)

Or, to avoid the extra import entirely:

if version.parse(jax.__version__) < version.parse(_JAXLIB_MIN_VERSION):

since jax.__version__ == jaxlib.__version__ for any properly-installed JAX environment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also about to comment about jax vs jaxlib.

Besides, we have an existing helper here as well https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/helper.py#L72-L78, perhaps we should find a way to repurpose this helper, especially since we will need to have jax guard all over places in the future.

On the side note, should we check for a release version, i.e., >= 0.8.0 instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great points! Yes, reusing _jax_version_meet_requirement would be much cleaner. You could move it to a shared utility module if multiple files need it.

Regarding dev vs release version:

  • >= 0.8.0.dev20250924 allows: 0.8.0.dev20250924, later dev versions, and all 0.8.x releases
  • >= 0.8.0 only allows official releases (excludes all dev versions)

Since dev versions sort before releases in PEP 440, using >= 0.8.0 would actually block 0.8.0.dev20250924 even though your bisection shows it works.

Recommendation: If most users will use official releases (not nightlies), >= 0.8.0 is cleaner and safer. Dev/nightly users typically expect to handle version quirks. But if you want to explicitly support those working dev versions, keep >= 0.8.0.dev20250924.

For the code:

# At top of file with other imports
from .quantize.helper import _jax_version_meet_requirement

# Later in the file
_JAXLIB_MIN_VERSION = "0.8.0"  # or keep dev version if needed
if not _jax_version_meet_requirement(_JAXLIB_MIN_VERSION):
    raise RuntimeError(f"JAX >= {_JAXLIB_MIN_VERSION} required for Triton kernel support")

This eliminates the mid-file import and reuses existing infrastructure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phu0ngng I think greptile has analyzed it pretty thoroughly on checking for > dev_version or >release_version. So the decision is between supporting dev versions after 0924 or not. I think we should support these dev versions, just in case other teams have not updated their pipeline to use a newer/release version. The ugly code to have the dev version instead of release is not that ugly and does not need any maintenance anyways

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding the other poitns (jax vs jaxlib) and using the _jax_version_meets_requirement : I agree and will fix

) from e

# Minimum jaxlib version required for Triton kernel dispatch to work correctly.
# jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! I think we should expose a function here, something like is_triton_extension_supported() so us or users can check as well. Then in our Triton tests we can add something like the following block at the top:

def test_some_triton_extension(...):
    if not is_triton_extension_supported():
        pytest.skip("... same info message about triton jax compatibility in older versions ...")
       return
    main test code

Similar to how we guard tests based on compute arch

Without this change, I think the tests on old containers would still fail, though the new error message is much more informative then before!

@jberchtold-nvidia jberchtold-nvidia self-requested a review March 6, 2026 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants