-
Notifications
You must be signed in to change notification settings - Fork 10
feat(cpp_extension): add exclude dirs for MUSA source porting #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,78 @@ def _is_musa_file(path: str) -> bool: | |
| return ext in [".cu", ".cuh", ".mu", ".muh"] | ||
|
|
||
|
|
||
| def _normalize_exclude_dirs(exclude_dirs: Optional[List[str]]) -> List[str]: | ||
| """Normalize exclude directory paths to the form used by SimplePorting.""" | ||
| if not exclude_dirs: | ||
| return [] | ||
|
|
||
| normalized = [] | ||
| seen = set() | ||
| for exclude_dir in exclude_dirs: | ||
| exclude_dir = str(exclude_dir).strip() if exclude_dir else "" | ||
| if not exclude_dir: | ||
| continue | ||
| real_dir = os.path.realpath(os.path.abspath(exclude_dir)) | ||
| if real_dir not in seen: | ||
| normalized.append(real_dir) | ||
| seen.add(real_dir) | ||
| return normalized | ||
|
|
||
|
|
||
| def _get_env_exclude_dirs() -> List[str]: | ||
| """Read TORCHADA_EXCLUDE_DIRS as an os.pathsep-delimited directory list.""" | ||
| env_value = os.environ.get("TORCHADA_EXCLUDE_DIRS", "") | ||
| if not env_value: | ||
| return [] | ||
| return _normalize_exclude_dirs(env_value.split(os.pathsep)) | ||
|
|
||
|
|
||
| def _is_path_in_dir(path: str, directory: str) -> bool: | ||
| """Return True when path is directory itself or lives below it.""" | ||
| path_real = os.path.normcase(os.path.realpath(os.path.abspath(path))) | ||
| dir_real = os.path.normcase(os.path.realpath(os.path.abspath(directory))) | ||
| try: | ||
| return os.path.commonpath([path_real, dir_real]) == dir_real | ||
| except ValueError: | ||
| return False | ||
|
|
||
|
|
||
| def _same_real_path(left: str, right: str) -> bool: | ||
| """Return True when two paths resolve to the same filesystem location.""" | ||
| left_real = os.path.normcase(os.path.realpath(os.path.abspath(left))) | ||
| right_real = os.path.normcase(os.path.realpath(os.path.abspath(right))) | ||
| return left_real == right_real | ||
|
|
||
|
|
||
| def _collect_simple_porting_ignore_dirs(source_dir: str, exclude_dirs: List[str]) -> List[str]: | ||
| """ | ||
| Return ignore_dir_paths compatible with torch_musa SimplePorting. | ||
|
|
||
| SimplePorting checks ignore dirs by exact root equality while walking the source | ||
| tree. Include every existing descendant directory of an excluded subtree so | ||
| nested files are skipped too. | ||
| """ | ||
| source_dir = os.path.realpath(os.path.abspath(source_dir)) | ||
| ignore_dirs = [] | ||
| seen = set() | ||
|
|
||
| for exclude_dir in _normalize_exclude_dirs(exclude_dirs): | ||
| if not _is_path_in_dir(exclude_dir, source_dir): | ||
| continue | ||
|
|
||
| candidates = [exclude_dir] | ||
| if os.path.isdir(exclude_dir): | ||
| candidates = [root for root, _, _ in os.walk(exclude_dir)] | ||
|
|
||
| for candidate in candidates: | ||
| real_candidate = os.path.realpath(os.path.abspath(candidate)) | ||
| if real_candidate not in seen: | ||
| ignore_dirs.append(real_candidate) | ||
| seen.add(real_candidate) | ||
|
|
||
| return ignore_dirs | ||
|
|
||
|
|
||
| def _patch_simple_porting_load_replaced_mapping(musa_sp): | ||
| """ | ||
| Patch SimplePorting.load_replaced_mapping to suppress unwanted print output. | ||
|
|
@@ -505,12 +577,30 @@ def get_mapping_rule(self): | |
| """ | ||
| return _MAPPING_RULE.copy() | ||
|
|
||
| def get_exclude_dirs(self): | ||
| """ | ||
| Get directories that should be excluded from CUDA->MUSA porting. | ||
|
|
||
| Override this in subclasses to add project-specific exclusions. | ||
| Environment entries come from TORCHADA_EXCLUDE_DIRS and are | ||
| separated with os.pathsep (":" on Unix, ";" on Windows). | ||
| """ | ||
| return _get_env_exclude_dirs() | ||
|
|
||
| def build_extensions(self): | ||
| # Register .cu, .cuh as valid source extensions | ||
| self.compiler.src_extensions += [".cu", ".cuh"] | ||
| super().build_extensions() | ||
|
|
||
| def _port_directory(self, source_dir, mapping_rule=None): | ||
| def _is_excluded_dir(self, dir_path, exclude_dirs=None): | ||
| """Check whether dir_path is exactly excluded or under an excluded dir.""" | ||
| if exclude_dirs is None: | ||
| exclude_dirs = self.get_exclude_dirs() | ||
| return any( | ||
| _is_path_in_dir(dir_path, exclude_dir) for exclude_dir in exclude_dirs | ||
| ) | ||
|
|
||
| def _port_directory(self, source_dir, mapping_rule=None, exclude_dirs=None): | ||
| """ | ||
| Port a directory containing CUDA sources to MUSA. | ||
|
|
||
|
|
@@ -520,20 +610,34 @@ def _port_directory(self, source_dir, mapping_rule=None): | |
| Args: | ||
| source_dir: Path to directory containing CUDA sources | ||
| mapping_rule: Optional custom mapping rules (uses get_mapping_rule() if None) | ||
| exclude_dirs: Optional directory paths to pass to SimplePorting as | ||
| ignore_dir_paths when they are inside source_dir | ||
|
|
||
| Returns: | ||
| str: Path to the ported directory (source_dir + "_musa") | ||
| """ | ||
|
|
||
| if mapping_rule is None: | ||
| mapping_rule = self.get_mapping_rule() | ||
| if exclude_dirs is None: | ||
| exclude_dirs = self.get_exclude_dirs() | ||
| exclude_dirs = _normalize_exclude_dirs(exclude_dirs) | ||
|
|
||
| source_dir = os.path.realpath(os.path.abspath(source_dir)) | ||
| if self._is_excluded_dir(source_dir, exclude_dirs): | ||
| return source_dir | ||
|
|
||
| source_dir = os.path.abspath(source_dir) | ||
| musa_dir = source_dir + "_musa" | ||
|
|
||
| if source_dir not in self._ported_dirs: | ||
| musa_sp.LOGGER.setLevel(logging.ERROR) | ||
| ignore_dir_paths = _collect_simple_porting_ignore_dirs( | ||
| source_dir, exclude_dirs | ||
| ) | ||
| musa_sp.SimplePorting( | ||
| cuda_dir_path=source_dir, mapping_rule=mapping_rule | ||
| cuda_dir_path=source_dir, | ||
| ignore_dir_paths=ignore_dir_paths, | ||
| mapping_rule=mapping_rule, | ||
| ).run() | ||
| self._ported_dirs.add(source_dir) | ||
|
|
||
|
|
@@ -556,24 +660,32 @@ def _port_directory(self, source_dir, mapping_rule=None): | |
|
|
||
| return musa_dir | ||
|
|
||
| def _convert_source_path(self, source): | ||
| def _convert_source_path(self, source, exclude_dirs=None): | ||
| """ | ||
| Convert a CUDA source path to its ported MUSA equivalent. | ||
|
|
||
| Args: | ||
| source: Original source file path (e.g., "csrc/kernel.cu") | ||
| exclude_dirs: Optional directory paths excluded from porting | ||
|
|
||
| Returns: | ||
| tuple: (converted_path, needs_porting) | ||
| - converted_path: Path to ported file (e.g., "csrc_musa/kernel.mu") | ||
| - needs_porting: True if the source directory needs porting | ||
| """ | ||
| source_path = os.path.abspath(source) | ||
| if exclude_dirs is None: | ||
| exclude_dirs = self.get_exclude_dirs() | ||
| exclude_dirs = _normalize_exclude_dirs(exclude_dirs) | ||
|
|
||
| source_path = os.path.realpath(os.path.abspath(source)) | ||
| source_dir = os.path.dirname(source_path) | ||
| source_file = os.path.basename(source_path) | ||
| base_name, ext_name = os.path.splitext(source_file) | ||
| ext_name_lower = ext_name.lower() | ||
|
|
||
| if self._is_excluded_dir(source_dir, exclude_dirs): | ||
| return source, False | ||
|
|
||
| # Port all source files that may contain CUDA references: | ||
| # - .cu/.cuh: CUDA source/header files | ||
| # - .cc/.cpp/.cxx: C++ files that may reference CUDA symbols | ||
|
|
@@ -609,6 +721,7 @@ def run(self): | |
| 5. Calls parent run() to perform actual compilation | ||
| """ | ||
| mapping_rule = self.get_mapping_rule() | ||
| exclude_dirs = _normalize_exclude_dirs(self.get_exclude_dirs()) | ||
|
|
||
| for ext in self.extensions: | ||
| new_sources = [] | ||
|
|
@@ -619,18 +732,20 @@ def run(self): | |
| ( | ||
| new_source, | ||
| needs_porting, | ||
| ) = self._convert_source_path(source) | ||
| ) = self._convert_source_path(source, exclude_dirs) | ||
| new_sources.append(new_source) | ||
| if needs_porting: | ||
| source_dir = os.path.dirname(os.path.abspath(source)) | ||
| source_dir = os.path.dirname( | ||
| os.path.realpath(os.path.abspath(source)) | ||
| ) | ||
| dirs_to_port.add(source_dir) | ||
| # Sort directories by depth (deepest first) to ensure proper porting order | ||
| dirs_to_port = sorted( | ||
| dirs_to_port, key=lambda p: p.count("/"), reverse=True | ||
| ) | ||
| # Port each unique directory | ||
| for cuda_dir in dirs_to_port: | ||
| self._port_directory(cuda_dir, mapping_rule) | ||
| self._port_directory(cuda_dir, mapping_rule, exclude_dirs) | ||
|
|
||
| # Update extension sources to point to ported files | ||
| ext.sources = new_sources | ||
|
|
@@ -671,9 +786,13 @@ def run(self): | |
| pass | ||
|
|
||
| if has_cuda_headers: | ||
| ported_dir = self._port_directory(inc_dir_abs, mapping_rule) | ||
| ported_dir = self._port_directory( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Severity: low ⏳ Generating Fix in Augment link... 🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage. |
||
| inc_dir_abs, mapping_rule, exclude_dirs | ||
| ) | ||
| # Add ported dir first so ported headers take precedence | ||
| if os.path.isdir(ported_dir): | ||
| if os.path.isdir(ported_dir) and not _same_real_path( | ||
| ported_dir, inc_dir_abs | ||
| ): | ||
| new_include_dirs.append(ported_dir) | ||
| new_include_dirs.append(inc_dir_abs) | ||
| ext.include_dirs = new_include_dirs | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_collect_simple_porting_ignore_dirs()returnsrealpath-canonicalized entries, but_port_directory()passescuda_dir_path=source_dirwheresource_diris onlyabspath-canonicalized. Ifsource_diris reached via a symlink, SimplePorting’s exact-path ignore matching may not exclude the intended subtree.Severity: medium
⏳ Generating Fix in Augment link...
🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.