diff --git a/README.md b/README.md index ef2b636..450d08a 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,22 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension ext = CUDAExtension("my_ext", sources=["kernel.cu"]) ``` +To skip CUDA→MUSA source porting for specified directories, set +`TORCHADA_EXCLUDE_DIRS` to an `os.pathsep`-delimited directory list before +building the extension. On Linux this separator is `:`. + +```bash +export TORCHADA_EXCLUDE_DIRS=/path/to/dir0:/path/to/dir1 +``` + +Projects can also extend `BuildExtension` for code-level configuration: + +```python +class MyBuildExt(BuildExtension): + def get_exclude_dirs(self): + return super().get_exclude_dirs() + ["/path/to/dir"] +``` + ### Custom Ops ```python diff --git a/README_CN.md b/README_CN.md index 3207e54..b34a4b1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -125,6 +125,21 @@ from torch.utils.cpp_extension import CUDAExtension, BuildExtension ext = CUDAExtension("my_ext", sources=["kernel.cu"]) ``` +如果特定的目录不需要 CUDA→MUSA 源码转换,可以在构建前设置 +`TORCHADA_EXCLUDE_DIRS`。多个目录使用 `os.pathsep` 分隔;Linux 上分隔符是 `:`。 + +```bash +export TORCHADA_EXCLUDE_DIRS=/path/to/dir0:/path/to/dir1 +``` + +项目也可以继承 `BuildExtension`,在代码中追加排除目录: + +```python +class MyBuildExt(BuildExtension): + def get_exclude_dirs(self): + return super().get_exclude_dirs() + ["/path/to/dir"] +``` + ### 自定义算子 ```python diff --git a/src/torchada/utils/cpp_extension.py b/src/torchada/utils/cpp_extension.py index 353ce9a..2ec699b 100644 --- a/src/torchada/utils/cpp_extension.py +++ b/src/torchada/utils/cpp_extension.py @@ -95,6 +95,78 @@ def _is_musa_file(path: str) -> bool: return ext in [".cu", ".cuh", ".mu", ".muh"] +def _normalize_exclude_dirs(exclude_dirs: Optional[List[str]]) -> List[str]: + """Normalize exclude directory paths to the form used by SimplePorting.""" + if not exclude_dirs: + return [] + + normalized = [] + seen = set() + for exclude_dir in exclude_dirs: + exclude_dir = str(exclude_dir).strip() if exclude_dir else "" + if not exclude_dir: + continue + real_dir = os.path.realpath(os.path.abspath(exclude_dir)) + if real_dir not in seen: + normalized.append(real_dir) + seen.add(real_dir) + return normalized + + +def _get_env_exclude_dirs() -> List[str]: + """Read TORCHADA_EXCLUDE_DIRS as an os.pathsep-delimited directory list.""" + env_value = os.environ.get("TORCHADA_EXCLUDE_DIRS", "") + if not env_value: + return [] + return _normalize_exclude_dirs(env_value.split(os.pathsep)) + + +def _is_path_in_dir(path: str, directory: str) -> bool: + """Return True when path is directory itself or lives below it.""" + path_real = os.path.normcase(os.path.realpath(os.path.abspath(path))) + dir_real = os.path.normcase(os.path.realpath(os.path.abspath(directory))) + try: + return os.path.commonpath([path_real, dir_real]) == dir_real + except ValueError: + return False + + +def _same_real_path(left: str, right: str) -> bool: + """Return True when two paths resolve to the same filesystem location.""" + left_real = os.path.normcase(os.path.realpath(os.path.abspath(left))) + right_real = os.path.normcase(os.path.realpath(os.path.abspath(right))) + return left_real == right_real + + +def _collect_simple_porting_ignore_dirs(source_dir: str, exclude_dirs: List[str]) -> List[str]: + """ + Return ignore_dir_paths compatible with torch_musa SimplePorting. + + SimplePorting checks ignore dirs by exact root equality while walking the source + tree. Include every existing descendant directory of an excluded subtree so + nested files are skipped too. + """ + source_dir = os.path.realpath(os.path.abspath(source_dir)) + ignore_dirs = [] + seen = set() + + for exclude_dir in _normalize_exclude_dirs(exclude_dirs): + if not _is_path_in_dir(exclude_dir, source_dir): + continue + + candidates = [exclude_dir] + if os.path.isdir(exclude_dir): + candidates = [root for root, _, _ in os.walk(exclude_dir)] + + for candidate in candidates: + real_candidate = os.path.realpath(os.path.abspath(candidate)) + if real_candidate not in seen: + ignore_dirs.append(real_candidate) + seen.add(real_candidate) + + return ignore_dirs + + def _patch_simple_porting_load_replaced_mapping(musa_sp): """ Patch SimplePorting.load_replaced_mapping to suppress unwanted print output. @@ -505,12 +577,30 @@ def get_mapping_rule(self): """ return _MAPPING_RULE.copy() + def get_exclude_dirs(self): + """ + Get directories that should be excluded from CUDA->MUSA porting. + + Override this in subclasses to add project-specific exclusions. + Environment entries come from TORCHADA_EXCLUDE_DIRS and are + separated with os.pathsep (":" on Unix, ";" on Windows). + """ + return _get_env_exclude_dirs() + def build_extensions(self): # Register .cu, .cuh as valid source extensions self.compiler.src_extensions += [".cu", ".cuh"] super().build_extensions() - def _port_directory(self, source_dir, mapping_rule=None): + def _is_excluded_dir(self, dir_path, exclude_dirs=None): + """Check whether dir_path is exactly excluded or under an excluded dir.""" + if exclude_dirs is None: + exclude_dirs = self.get_exclude_dirs() + return any( + _is_path_in_dir(dir_path, exclude_dir) for exclude_dir in exclude_dirs + ) + + def _port_directory(self, source_dir, mapping_rule=None, exclude_dirs=None): """ Port a directory containing CUDA sources to MUSA. @@ -520,20 +610,34 @@ def _port_directory(self, source_dir, mapping_rule=None): Args: source_dir: Path to directory containing CUDA sources mapping_rule: Optional custom mapping rules (uses get_mapping_rule() if None) + exclude_dirs: Optional directory paths to pass to SimplePorting as + ignore_dir_paths when they are inside source_dir Returns: str: Path to the ported directory (source_dir + "_musa") """ + if mapping_rule is None: mapping_rule = self.get_mapping_rule() + if exclude_dirs is None: + exclude_dirs = self.get_exclude_dirs() + exclude_dirs = _normalize_exclude_dirs(exclude_dirs) + + source_dir = os.path.realpath(os.path.abspath(source_dir)) + if self._is_excluded_dir(source_dir, exclude_dirs): + return source_dir - source_dir = os.path.abspath(source_dir) musa_dir = source_dir + "_musa" if source_dir not in self._ported_dirs: musa_sp.LOGGER.setLevel(logging.ERROR) + ignore_dir_paths = _collect_simple_porting_ignore_dirs( + source_dir, exclude_dirs + ) musa_sp.SimplePorting( - cuda_dir_path=source_dir, mapping_rule=mapping_rule + cuda_dir_path=source_dir, + ignore_dir_paths=ignore_dir_paths, + mapping_rule=mapping_rule, ).run() self._ported_dirs.add(source_dir) @@ -556,24 +660,32 @@ def _port_directory(self, source_dir, mapping_rule=None): return musa_dir - def _convert_source_path(self, source): + def _convert_source_path(self, source, exclude_dirs=None): """ Convert a CUDA source path to its ported MUSA equivalent. Args: source: Original source file path (e.g., "csrc/kernel.cu") + exclude_dirs: Optional directory paths excluded from porting Returns: tuple: (converted_path, needs_porting) - converted_path: Path to ported file (e.g., "csrc_musa/kernel.mu") - needs_porting: True if the source directory needs porting """ - source_path = os.path.abspath(source) + if exclude_dirs is None: + exclude_dirs = self.get_exclude_dirs() + exclude_dirs = _normalize_exclude_dirs(exclude_dirs) + + source_path = os.path.realpath(os.path.abspath(source)) source_dir = os.path.dirname(source_path) source_file = os.path.basename(source_path) base_name, ext_name = os.path.splitext(source_file) ext_name_lower = ext_name.lower() + if self._is_excluded_dir(source_dir, exclude_dirs): + return source, False + # Port all source files that may contain CUDA references: # - .cu/.cuh: CUDA source/header files # - .cc/.cpp/.cxx: C++ files that may reference CUDA symbols @@ -609,6 +721,7 @@ def run(self): 5. Calls parent run() to perform actual compilation """ mapping_rule = self.get_mapping_rule() + exclude_dirs = _normalize_exclude_dirs(self.get_exclude_dirs()) for ext in self.extensions: new_sources = [] @@ -619,10 +732,12 @@ def run(self): ( new_source, needs_porting, - ) = self._convert_source_path(source) + ) = self._convert_source_path(source, exclude_dirs) new_sources.append(new_source) if needs_porting: - source_dir = os.path.dirname(os.path.abspath(source)) + source_dir = os.path.dirname( + os.path.realpath(os.path.abspath(source)) + ) dirs_to_port.add(source_dir) # Sort directories by depth (deepest first) to ensure proper porting order dirs_to_port = sorted( @@ -630,7 +745,7 @@ def run(self): ) # Port each unique directory for cuda_dir in dirs_to_port: - self._port_directory(cuda_dir, mapping_rule) + self._port_directory(cuda_dir, mapping_rule, exclude_dirs) # Update extension sources to point to ported files ext.sources = new_sources @@ -671,9 +786,13 @@ def run(self): pass if has_cuda_headers: - ported_dir = self._port_directory(inc_dir_abs, mapping_rule) + ported_dir = self._port_directory( + inc_dir_abs, mapping_rule, exclude_dirs + ) # Add ported dir first so ported headers take precedence - if os.path.isdir(ported_dir): + if os.path.isdir(ported_dir) and not _same_real_path( + ported_dir, inc_dir_abs + ): new_include_dirs.append(ported_dir) new_include_dirs.append(inc_dir_abs) ext.include_dirs = new_include_dirs diff --git a/tests/test_cpp_extension.py b/tests/test_cpp_extension.py index 358f4b8..4d2da6c 100644 --- a/tests/test_cpp_extension.py +++ b/tests/test_cpp_extension.py @@ -167,3 +167,97 @@ def test_mapping_rule_has_expected_entries(self): assert rules.get("cudaStream_t") == "musaStream_t" assert rules.get("at::cuda") == "at::musa" assert rules.get("c10::cuda") == "c10::musa" + + +class TestSimplePortingExcludeDirs: + """Test directory exclusion helpers used with SimplePorting.""" + + def test_get_env_exclude_dirs_uses_path_separator(self, monkeypatch, tmp_path): + """TORCHADA_EXCLUDE_DIRS should accept multiple pathsep-delimited paths.""" + from torchada.utils.cpp_extension import _get_env_exclude_dirs + + first = tmp_path / "vendor" + second = tmp_path / "third_party" + monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", f"{first}{os.pathsep}{second}") + + result = _get_env_exclude_dirs() + + assert os.path.realpath(str(first)) in result + assert os.path.realpath(str(second)) in result + + def test_get_env_exclude_dirs_ignores_whitespace_entries(self, monkeypatch): + """Whitespace-only env entries must not resolve to the current directory.""" + from torchada.utils.cpp_extension import _get_env_exclude_dirs + + monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", f" {os.pathsep}\t") + + assert _get_env_exclude_dirs() == [] + + def test_is_path_in_dir_does_not_match_prefix_siblings(self, tmp_path): + """A sibling with a shared prefix must not be treated as excluded.""" + from torchada.utils.cpp_extension import _is_path_in_dir + + excluded = tmp_path / "vendor" + sibling = tmp_path / "vendor_extra" + + assert _is_path_in_dir(str(excluded), str(excluded)) + assert not _is_path_in_dir(str(sibling), str(excluded)) + + def test_same_real_path_matches_equivalent_paths(self, tmp_path): + """Equivalent absolute paths should be recognized before adding includes.""" + from torchada.utils.cpp_extension import _same_real_path + + include_dir = tmp_path / "include" + include_dir.mkdir() + + assert _same_real_path(str(include_dir), str(include_dir / ".." / "include")) + + def test_collect_simple_porting_ignore_dirs_includes_nested_dirs(self, tmp_path): + """SimplePorting needs exact ignore entries for nested excluded dirs.""" + from torchada.utils.cpp_extension import _collect_simple_porting_ignore_dirs + + source_dir = tmp_path / "csrc" + excluded = source_dir / "vendor" + nested = excluded / "cub" + nested.mkdir(parents=True) + + result = _collect_simple_porting_ignore_dirs(str(source_dir), [str(excluded)]) + + assert os.path.realpath(str(excluded)) in result + assert os.path.realpath(str(nested)) in result + + def test_collect_simple_porting_ignore_dirs_ignores_outside_dirs(self, tmp_path): + """Only excludes inside the ported source root should be passed down.""" + from torchada.utils.cpp_extension import _collect_simple_porting_ignore_dirs + + source_dir = tmp_path / "csrc" + outside = tmp_path / "other_vendor" + source_dir.mkdir() + outside.mkdir() + + result = _collect_simple_porting_ignore_dirs(str(source_dir), [str(outside)]) + + assert result == [] + + def test_subclass_can_override_get_exclude_dirs(self, monkeypatch, tmp_path): + """BuildExtension subclasses can merge env and project-specific excludes.""" + if not torchada.is_musa_platform(): + return + + from torchada.utils.cpp_extension import _get_build_extension_class + + env_exclude = tmp_path / "env_vendor" + custom_exclude = tmp_path / "custom_vendor" + monkeypatch.setenv("TORCHADA_EXCLUDE_DIRS", str(env_exclude)) + + BaseClass = _get_build_extension_class() + + class CustomBuildExt(BaseClass): + def get_exclude_dirs(self): + return super().get_exclude_dirs() + [str(custom_exclude)] + + instance = CustomBuildExt.__new__(CustomBuildExt) + result = instance.get_exclude_dirs() + + assert os.path.realpath(str(env_exclude)) in result + assert str(custom_exclude) in result