Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
ext = CUDAExtension("my_ext", sources=["kernel.cu"])
```

To skip CUDA→MUSA source porting for specified directories, set
`TORCHADA_EXCLUDE_DIRS` to an `os.pathsep`-delimited directory list before
building the extension. On Linux this separator is `:`.

```bash
export TORCHADA_EXCLUDE_DIRS=/path/to/dir0:/path/to/dir1
```

Projects can also extend `BuildExtension` for code-level configuration:

```python
class MyBuildExt(BuildExtension):
def get_exclude_dirs(self):
return super().get_exclude_dirs() + ["/path/to/dir"]
```

### Custom Ops

```python
Expand Down
15 changes: 15 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,21 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
ext = CUDAExtension("my_ext", sources=["kernel.cu"])
```

如果特定的目录不需要 CUDA→MUSA 源码转换,可以在构建前设置
`TORCHADA_EXCLUDE_DIRS`。多个目录使用 `os.pathsep` 分隔;Linux 上分隔符是 `:`。

```bash
export TORCHADA_EXCLUDE_DIRS=/path/to/dir0:/path/to/dir1
```

项目也可以继承 `BuildExtension`,在代码中追加排除目录:

```python
class MyBuildExt(BuildExtension):
def get_exclude_dirs(self):
return super().get_exclude_dirs() + ["/path/to/dir"]
```

### 自定义算子

```python
Expand Down
139 changes: 129 additions & 10 deletions src/torchada/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,78 @@ def _is_musa_file(path: str) -> bool:
return ext in [".cu", ".cuh", ".mu", ".muh"]


def _normalize_exclude_dirs(exclude_dirs: Optional[List[str]]) -> List[str]:
"""Normalize exclude directory paths to the form used by SimplePorting."""
if not exclude_dirs:
return []

normalized = []
seen = set()
for exclude_dir in exclude_dirs:
exclude_dir = str(exclude_dir).strip() if exclude_dir else ""
if not exclude_dir:
continue
real_dir = os.path.realpath(os.path.abspath(exclude_dir))
if real_dir not in seen:
normalized.append(real_dir)
seen.add(real_dir)
return normalized


def _get_env_exclude_dirs() -> List[str]:
"""Read TORCHADA_EXCLUDE_DIRS as an os.pathsep-delimited directory list."""
env_value = os.environ.get("TORCHADA_EXCLUDE_DIRS", "")
if not env_value:
return []
return _normalize_exclude_dirs(env_value.split(os.pathsep))


def _is_path_in_dir(path: str, directory: str) -> bool:
"""Return True when path is directory itself or lives below it."""
path_real = os.path.normcase(os.path.realpath(os.path.abspath(path)))
dir_real = os.path.normcase(os.path.realpath(os.path.abspath(directory)))
try:
return os.path.commonpath([path_real, dir_real]) == dir_real
except ValueError:
return False


def _same_real_path(left: str, right: str) -> bool:
"""Return True when two paths resolve to the same filesystem location."""
left_real = os.path.normcase(os.path.realpath(os.path.abspath(left)))
right_real = os.path.normcase(os.path.realpath(os.path.abspath(right)))
return left_real == right_real


def _collect_simple_porting_ignore_dirs(source_dir: str, exclude_dirs: List[str]) -> List[str]:
"""
Return ignore_dir_paths compatible with torch_musa SimplePorting.

SimplePorting checks ignore dirs by exact root equality while walking the source
tree. Include every existing descendant directory of an excluded subtree so
nested files are skipped too.
"""
source_dir = os.path.realpath(os.path.abspath(source_dir))
ignore_dirs = []
seen = set()

for exclude_dir in _normalize_exclude_dirs(exclude_dirs):
if not _is_path_in_dir(exclude_dir, source_dir):
continue

candidates = [exclude_dir]
if os.path.isdir(exclude_dir):
candidates = [root for root, _, _ in os.walk(exclude_dir)]

for candidate in candidates:
real_candidate = os.path.realpath(os.path.abspath(candidate))
if real_candidate not in seen:
ignore_dirs.append(real_candidate)
seen.add(real_candidate)

return ignore_dirs


def _patch_simple_porting_load_replaced_mapping(musa_sp):
"""
Patch SimplePorting.load_replaced_mapping to suppress unwanted print output.
Expand Down Expand Up @@ -505,12 +577,30 @@ def get_mapping_rule(self):
"""
return _MAPPING_RULE.copy()

def get_exclude_dirs(self):
"""
Get directories that should be excluded from CUDA->MUSA porting.

Override this in subclasses to add project-specific exclusions.
Environment entries come from TORCHADA_EXCLUDE_DIRS and are
separated with os.pathsep (":" on Unix, ";" on Windows).
"""
return _get_env_exclude_dirs()

def build_extensions(self):
# Register .cu, .cuh as valid source extensions
self.compiler.src_extensions += [".cu", ".cuh"]
super().build_extensions()

def _port_directory(self, source_dir, mapping_rule=None):
def _is_excluded_dir(self, dir_path, exclude_dirs=None):
"""Check whether dir_path is exactly excluded or under an excluded dir."""
if exclude_dirs is None:
exclude_dirs = self.get_exclude_dirs()
return any(
_is_path_in_dir(dir_path, exclude_dir) for exclude_dir in exclude_dirs
)

def _port_directory(self, source_dir, mapping_rule=None, exclude_dirs=None):
"""
Port a directory containing CUDA sources to MUSA.

Expand All @@ -520,20 +610,34 @@ def _port_directory(self, source_dir, mapping_rule=None):
Args:
source_dir: Path to directory containing CUDA sources
mapping_rule: Optional custom mapping rules (uses get_mapping_rule() if None)
exclude_dirs: Optional directory paths to pass to SimplePorting as
ignore_dir_paths when they are inside source_dir

Returns:
str: Path to the ported directory (source_dir + "_musa")
"""

if mapping_rule is None:
mapping_rule = self.get_mapping_rule()
if exclude_dirs is None:
exclude_dirs = self.get_exclude_dirs()
exclude_dirs = _normalize_exclude_dirs(exclude_dirs)

source_dir = os.path.realpath(os.path.abspath(source_dir))
if self._is_excluded_dir(source_dir, exclude_dirs):
return source_dir

source_dir = os.path.abspath(source_dir)
musa_dir = source_dir + "_musa"

if source_dir not in self._ported_dirs:
musa_sp.LOGGER.setLevel(logging.ERROR)
ignore_dir_paths = _collect_simple_porting_ignore_dirs(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_collect_simple_porting_ignore_dirs() returns realpath-canonicalized entries, but _port_directory() passes cuda_dir_path=source_dir where source_dir is only abspath-canonicalized. If source_dir is reached via a symlink, SimplePorting’s exact-path ignore matching may not exclude the intended subtree.

Severity: medium

Generating Fix in Augment link...

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

source_dir, exclude_dirs
)
musa_sp.SimplePorting(
cuda_dir_path=source_dir, mapping_rule=mapping_rule
cuda_dir_path=source_dir,
ignore_dir_paths=ignore_dir_paths,
mapping_rule=mapping_rule,
).run()
self._ported_dirs.add(source_dir)

Expand All @@ -556,24 +660,32 @@ def _port_directory(self, source_dir, mapping_rule=None):

return musa_dir

def _convert_source_path(self, source):
def _convert_source_path(self, source, exclude_dirs=None):
"""
Convert a CUDA source path to its ported MUSA equivalent.

Args:
source: Original source file path (e.g., "csrc/kernel.cu")
exclude_dirs: Optional directory paths excluded from porting

Returns:
tuple: (converted_path, needs_porting)
- converted_path: Path to ported file (e.g., "csrc_musa/kernel.mu")
- needs_porting: True if the source directory needs porting
"""
source_path = os.path.abspath(source)
if exclude_dirs is None:
exclude_dirs = self.get_exclude_dirs()
exclude_dirs = _normalize_exclude_dirs(exclude_dirs)

source_path = os.path.realpath(os.path.abspath(source))
source_dir = os.path.dirname(source_path)
source_file = os.path.basename(source_path)
base_name, ext_name = os.path.splitext(source_file)
ext_name_lower = ext_name.lower()

if self._is_excluded_dir(source_dir, exclude_dirs):
return source, False

# Port all source files that may contain CUDA references:
# - .cu/.cuh: CUDA source/header files
# - .cc/.cpp/.cxx: C++ files that may reference CUDA symbols
Expand Down Expand Up @@ -609,6 +721,7 @@ def run(self):
5. Calls parent run() to perform actual compilation
"""
mapping_rule = self.get_mapping_rule()
exclude_dirs = _normalize_exclude_dirs(self.get_exclude_dirs())

for ext in self.extensions:
new_sources = []
Expand All @@ -619,18 +732,20 @@ def run(self):
(
new_source,
needs_porting,
) = self._convert_source_path(source)
) = self._convert_source_path(source, exclude_dirs)
new_sources.append(new_source)
if needs_porting:
source_dir = os.path.dirname(os.path.abspath(source))
source_dir = os.path.dirname(
os.path.realpath(os.path.abspath(source))
)
dirs_to_port.add(source_dir)
# Sort directories by depth (deepest first) to ensure proper porting order
dirs_to_port = sorted(
dirs_to_port, key=lambda p: p.count("/"), reverse=True
)
# Port each unique directory
for cuda_dir in dirs_to_port:
self._port_directory(cuda_dir, mapping_rule)
self._port_directory(cuda_dir, mapping_rule, exclude_dirs)

# Update extension sources to point to ported files
ext.sources = new_sources
Expand Down Expand Up @@ -671,9 +786,13 @@ def run(self):
pass

if has_cuda_headers:
ported_dir = self._port_directory(inc_dir_abs, mapping_rule)
ported_dir = self._port_directory(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If inc_dir_abs is excluded, _port_directory() returns inc_dir_abs, and this loop will append the same include directory twice (ported_dir and then inc_dir_abs). This can unexpectedly duplicate include paths and potentially affect include precedence assumptions for excluded dirs.

Severity: low

Generating Fix in Augment link...

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

inc_dir_abs, mapping_rule, exclude_dirs
)
# Add ported dir first so ported headers take precedence
if os.path.isdir(ported_dir):
if os.path.isdir(ported_dir) and not _same_real_path(
ported_dir, inc_dir_abs
):
new_include_dirs.append(ported_dir)
new_include_dirs.append(inc_dir_abs)
ext.include_dirs = new_include_dirs
Expand Down
94 changes: 94 additions & 0 deletions tests/test_cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,97 @@ def test_mapping_rule_has_expected_entries(self):
assert rules.get("cudaStream_t") == "musaStream_t"
assert rules.get("at::cuda") == "at::musa"
assert rules.get("c10::cuda") == "c10::musa"


class TestSimplePortingExcludeDirs:
"""Test directory exclusion helpers used with SimplePorting."""

def test_get_env_exclude_dirs_uses_path_separator(self, monkeypatch, tmp_path):
"""TORCHADA_EXCLUDE_DIRS should accept multiple pathsep-delimited paths."""
from torchada.utils.cpp_extension import _get_env_exclude_dirs

first = tmp_path / "vendor"
second = tmp_path / "third_party"
monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", f"{first}{os.pathsep}{second}")

result = _get_env_exclude_dirs()

assert os.path.realpath(str(first)) in result
assert os.path.realpath(str(second)) in result

def test_get_env_exclude_dirs_ignores_whitespace_entries(self, monkeypatch):
"""Whitespace-only env entries must not resolve to the current directory."""
from torchada.utils.cpp_extension import _get_env_exclude_dirs

monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", f" {os.pathsep}\t")

assert _get_env_exclude_dirs() == []

def test_is_path_in_dir_does_not_match_prefix_siblings(self, tmp_path):
"""A sibling with a shared prefix must not be treated as excluded."""
from torchada.utils.cpp_extension import _is_path_in_dir

excluded = tmp_path / "vendor"
sibling = tmp_path / "vendor_extra"

assert _is_path_in_dir(str(excluded), str(excluded))
assert not _is_path_in_dir(str(sibling), str(excluded))

def test_same_real_path_matches_equivalent_paths(self, tmp_path):
"""Equivalent absolute paths should be recognized before adding includes."""
from torchada.utils.cpp_extension import _same_real_path

include_dir = tmp_path / "include"
include_dir.mkdir()

assert _same_real_path(str(include_dir), str(include_dir / ".." / "include"))

def test_collect_simple_porting_ignore_dirs_includes_nested_dirs(self, tmp_path):
"""SimplePorting needs exact ignore entries for nested excluded dirs."""
from torchada.utils.cpp_extension import _collect_simple_porting_ignore_dirs

source_dir = tmp_path / "csrc"
excluded = source_dir / "vendor"
nested = excluded / "cub"
nested.mkdir(parents=True)

result = _collect_simple_porting_ignore_dirs(str(source_dir), [str(excluded)])

assert os.path.realpath(str(excluded)) in result
assert os.path.realpath(str(nested)) in result

def test_collect_simple_porting_ignore_dirs_ignores_outside_dirs(self, tmp_path):
"""Only excludes inside the ported source root should be passed down."""
from torchada.utils.cpp_extension import _collect_simple_porting_ignore_dirs

source_dir = tmp_path / "csrc"
outside = tmp_path / "other_vendor"
source_dir.mkdir()
outside.mkdir()

result = _collect_simple_porting_ignore_dirs(str(source_dir), [str(outside)])

assert result == []

def test_subclass_can_override_get_exclude_dirs(self, monkeypatch, tmp_path):
"""BuildExtension subclasses can merge env and project-specific excludes."""
if not torchada.is_musa_platform():
return

from torchada.utils.cpp_extension import _get_build_extension_class

env_exclude = tmp_path / "env_vendor"
custom_exclude = tmp_path / "custom_vendor"
monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", str(env_exclude))

BaseClass = _get_build_extension_class()

class CustomBuildExt(BaseClass):
def get_exclude_dirs(self):
return super().get_exclude_dirs() + [str(custom_exclude)]

instance = CustomBuildExt.__new__(CustomBuildExt)
result = instance.get_exclude_dirs()

assert os.path.realpath(str(env_exclude)) in result
assert str(custom_exclude) in result