Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 23 additions & 25 deletions src/torchada/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,12 +1137,7 @@ def _patch_backends_cuda():
This patches:
- is_built() to return True when MUSA is available (since we're using
torch.cuda APIs that are redirected to MUSA)
- torch.backends.cuda.matmul.fp32_precision to use torch.get/set_float32_matmul_precision
(this attribute is missing in some torch_musa versions)

Note: Other torch.backends.cuda.matmul properties (allow_tf32, etc.) work
as-is because they are settings that apply to the internal PyTorch
operations regardless of backend.
- torch.backends.cuda.matmul attribute access to MUSA matmul semantics
"""
if not hasattr(torch, "backends") or not hasattr(torch.backends, "cuda"):
return
Expand All @@ -1167,44 +1162,47 @@ def patched_is_built():

torch.backends.cuda.is_built = patched_is_built

if (
if not (
is_musa_platform()
and hasattr(torch.backends, "musa")
and hasattr(torch.backends.musa, "matmul")
and hasattr(torch.backends.cuda, "matmul")
):
torch.backends.cuda.matmul = torch.backends.musa.matmul

# Patch cuBLASModule to support fp32_precision attribute
# This attribute is in newer PyTorch but may be missing in torch_musa's version
matmul = torch.backends.cuda.matmul
matmul_class = matmul.__class__

# Check if fp32_precision is already supported
try:
_ = matmul.fp32_precision
# Already supported, no need to patch
return
except AttributeError:
pass

# Store original methods
cuda_matmul = torch.backends.cuda.matmul
musa_matmul = torch.backends.musa.matmul
matmul_class = cuda_matmul.__class__
original_getattr = matmul_class.__getattr__
original_setattr = matmul_class.__setattr__

try:
_ = cuda_matmul.fp32_precision
has_native_fp32_precision = True
except AttributeError:
Copy link
Copy Markdown

@augmentcode augmentcode Bot May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/torchada/_patch.py:1182: the probe cuda_matmul.fp32_precision may raise AssertionError (not just AttributeError) for unknown attributes (as noted in the tests), which would make _patch_backends_cuda() crash during import on some versions; consider catching AssertionError here as well.

Severity: high

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

has_native_fp32_precision = False

def patched_getattr(self, name):
if name == "fp32_precision":
if name == "fp32_precision" and not has_native_fp32_precision:
return torch.get_float32_matmul_precision()
return original_getattr(self, name)
try:
return getattr(musa_matmul, name)
except (AttributeError, AssertionError):
return original_getattr(self, name)

def patched_setattr(self, name, value):
if name == "fp32_precision":
if name == "fp32_precision" and not has_native_fp32_precision:
return torch.set_float32_matmul_precision(value)
return original_setattr(self, name, value)
try:
return setattr(musa_matmul, name, value)
except (AttributeError, AssertionError):
return original_setattr(self, name, value)

matmul_class.__getattr__ = patched_getattr
matmul_class.__setattr__ = patched_setattr



@patch_function
@requires_import("torchada.utils.cpp_extension", "torch.utils.cpp_extension")
def _patch_cpp_extension():
Expand Down
24 changes: 19 additions & 5 deletions tests/test_cuda_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,23 @@ def test_torch_backends_cuda_matmul_allow_tf32(self):
# Restore original
torch.backends.cuda.matmul.allow_tf32 = original

def test_torch_backends_cuda_matmul_forwards_to_musa(self):
"""Test torch.backends.cuda.matmul forwards settings to MUSA on MUSA."""
import torch

import torchada

if not torchada.is_musa_platform() or not hasattr(torch.backends, "musa"):
Copy link
Copy Markdown

@augmentcode augmentcode Bot May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/test_cuda_patching.py:1251: this skip guard checks hasattr(torch.backends, "musa") but the test immediately dereferences torch.backends.musa.matmul; if musa exists without matmul, this will error instead of skipping.

Severity: medium

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

pytest.skip("MUSA matmul backend not available")

original = torch.backends.musa.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = not original
assert torch.backends.musa.matmul.allow_tf32 is (not original)
assert torch.backends.cuda.matmul.allow_tf32 is (not original)
finally:
torch.backends.cuda.matmul.allow_tf32 = original

def test_torch_backends_cuda_matmul_fp32_precision(self):
"""Test torch.backends.cuda.matmul.fp32_precision is accessible."""
import torch
Expand All @@ -1258,12 +1275,9 @@ def test_torch_backends_cuda_matmul_fp32_precision(self):
except (AttributeError, AssertionError):
pytest.skip("fp32_precision not available (torchada MUSA-specific attribute)")

if (
not torchada.is_musa_platform()
and torch.__version__ >= torch.torch_version.TorchVersion("2.9.0")
):
if torch.__version__ >= torch.torch_version.TorchVersion("2.9.0"):
# PyTorch 2.9+: Only use the new API. Do NOT call torch.get_float32_matmul_precision()
valid_precisions = ("ieee", "tf32")
valid_precisions = ("none", "ieee", "tf32")
test_values = ["ieee", "tf32"]
check_underlying_api = False # Critical: avoid mixing APIs
else:
Expand Down