-
Notifications
You must be signed in to change notification settings - Fork 653
Add guard at lowest JAX version that still supports triton kernel calling #2741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -162,6 +162,21 @@ def _check_triton_compatibility(): | |||||
| "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." | ||||||
| ) from e | ||||||
|
|
||||||
| # Minimum jaxlib version required for Triton kernel dispatch to work correctly. | ||||||
| # jaxlib < 0.8.0.dev20250924 segfaults in pxla.py during Triton kernel execution | ||||||
| # (bisected: last known segfault = jax-2025-09-23, first known pass = jax-2025-09-24). | ||||||
| _JAXLIB_MIN_VERSION = "0.8.0.dev20250924" | ||||||
| import jaxlib # noqa: E402 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import placed mid-module when
Alternatively, if you specifically want
Suggested change
(move this line to the top of the file alongside the other imports at lines 41–43) Or, to avoid the extra import entirely: if version.parse(jax.__version__) < version.parse(_JAXLIB_MIN_VERSION):since
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was also about to comment about jax vs jaxlib. Besides, we have an existing helper here as well https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/helper.py#L72-L78, perhaps we should find a way to repurpose this helper, especially since we will need to have jax guard all over places in the future. On the side note, should we check for a release version, i.e., >= 0.8.0 instead?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great points! Yes, reusing Regarding dev vs release version:
Since dev versions sort before releases in PEP 440, using Recommendation: If most users will use official releases (not nightlies), For the code: This eliminates the mid-file import and reuses existing infrastructure.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @phu0ngng I think greptile has analyzed it pretty thoroughly on checking for > dev_version or >release_version. So the decision is between supporting dev versions after 0924 or not. I think we should support these dev versions, just in case other teams have not updated their pipeline to use a newer/release version. The ugly code to have the dev version instead of release is not that ugly and does not need any maintenance anyways
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regarding the other poitns (jax vs jaxlib) and using the |
||||||
|
|
||||||
| if version.parse(jaxlib.__version__) < version.parse(_JAXLIB_MIN_VERSION): | ||||||
| raise RuntimeError( | ||||||
| f"jaxlib {jaxlib.__version__} is too old for transformer_engine.jax.triton_extensions.\n" | ||||||
| f"Triton kernel dispatch segfaults with jaxlib < {_JAXLIB_MIN_VERSION}.\n" | ||||||
| "Please upgrade: pip install --upgrade jax jaxlib\n" | ||||||
| "Or use a JAX nightly container dated 2025-09-24 or later.\n" | ||||||
| "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| __all__ = ["triton_call_lowering", "get_triton_info"] | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! I think we should expose a function here, something like
is_triton_extension_supported()so us or users can check as well. Then in our Triton tests we can add something like the following block at the top:Similar to how we guard tests based on compute arch
Without this change, I think the tests on old containers would still fail, though the new error message is much more informative then before!