-
Notifications
You must be signed in to change notification settings - Fork 10
fix(patch): preserve native matmul backend on torch 2.9 #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1242,6 +1242,23 @@ def test_torch_backends_cuda_matmul_allow_tf32(self): | |
| # Restore original | ||
| torch.backends.cuda.matmul.allow_tf32 = original | ||
|
|
||
| def test_torch_backends_cuda_matmul_forwards_to_musa(self): | ||
| """Test torch.backends.cuda.matmul forwards settings to MUSA on MUSA.""" | ||
| import torch | ||
|
|
||
| import torchada | ||
|
|
||
| if not torchada.is_musa_platform() or not hasattr(torch.backends, "musa"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Severity: medium 🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage. |
||
| pytest.skip("MUSA matmul backend not available") | ||
|
|
||
| original = torch.backends.musa.matmul.allow_tf32 | ||
| try: | ||
| torch.backends.cuda.matmul.allow_tf32 = not original | ||
| assert torch.backends.musa.matmul.allow_tf32 is (not original) | ||
| assert torch.backends.cuda.matmul.allow_tf32 is (not original) | ||
| finally: | ||
| torch.backends.cuda.matmul.allow_tf32 = original | ||
|
|
||
| def test_torch_backends_cuda_matmul_fp32_precision(self): | ||
| """Test torch.backends.cuda.matmul.fp32_precision is accessible.""" | ||
| import torch | ||
|
|
@@ -1258,12 +1275,9 @@ def test_torch_backends_cuda_matmul_fp32_precision(self): | |
| except (AttributeError, AssertionError): | ||
| pytest.skip("fp32_precision not available (torchada MUSA-specific attribute)") | ||
|
|
||
| if ( | ||
| not torchada.is_musa_platform() | ||
| and torch.__version__ >= torch.torch_version.TorchVersion("2.9.0") | ||
| ): | ||
| if torch.__version__ >= torch.torch_version.TorchVersion("2.9.0"): | ||
| # PyTorch 2.9+: Only use the new API. Do NOT call torch.get_float32_matmul_precision() | ||
| valid_precisions = ("ieee", "tf32") | ||
| valid_precisions = ("none", "ieee", "tf32") | ||
| test_values = ["ieee", "tf32"] | ||
| check_underlying_api = False # Critical: avoid mixing APIs | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/torchada/_patch.py:1182: the probecuda_matmul.fp32_precisionmay raiseAssertionError(not justAttributeError) for unknown attributes (as noted in the tests), which would make_patch_backends_cuda()crash during import on some versions; consider catchingAssertionErrorhere as well.Severity: high
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.