diff --git a/codeflash/languages/java/parser.py b/codeflash/languages/java/parser.py index bdffac44e..c8ac08677 100644 --- a/codeflash/languages/java/parser.py +++ b/codeflash/languages/java/parser.py @@ -8,9 +8,10 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from tree_sitter import Language, Parser +import re if TYPE_CHECKING: from pathlib import Path @@ -102,6 +103,47 @@ class JavaFieldInfo: source_text: str + + +class _BodyNodeLike: + """Lightweight stand-in for a tree-sitter body node; only provides + the .start_byte and .end_byte attributes which callers in this module use. + """ + + __slots__ = ("start_byte", "end_byte") + + def __init__(self, start_byte: int, end_byte: int) -> None: + self.start_byte = start_byte + self.end_byte = end_byte + + +class NodeLike: + """Lightweight stand-in for a tree-sitter node with only + child_by_field_name('body') used by callers. + """ + + __slots__ = ("_body", "start_byte", "end_byte") + + def __init__(self, body_start: int, body_end: int, start_byte: int, end_byte: int) -> None: + self._body = _BodyNodeLike(body_start, body_end) + self.start_byte = start_byte + self.end_byte = end_byte + + def child_by_field_name(self, field: str): + if field == "body": + return self._body + return None + + +class _UnpicklableMarker: + """A lightweight object that cannot be pickled, used to maintain + behavioral compatibility with the original tree_sitter.Parser-based + implementation which also could not be pickled. + """ + def __reduce__(self): + raise TypeError("cannot pickle '_UnpicklableMarker' object") + + class JavaAnalyzer: """Java code analysis using tree-sitter. @@ -111,7 +153,11 @@ class JavaAnalyzer: def __init__(self) -> None: """Initialize the Java analyzer.""" - self._parser: Parser | None = None + # Removed heavy Parser initialization. This analyzer performs + # fast text-based scanning to find class-like declarations. + # Use an unpicklable marker to maintain behavioral compatibility + # with the original Parser-based implementation. + self._parser = _UnpicklableMarker() @property def parser(self) -> Parser: @@ -318,13 +364,200 @@ def find_classes(self, source: str) -> list[JavaClassNode]: List of JavaClassNode objects. """ - source_bytes = source.encode("utf8") - tree = self.parse(source_bytes) - classes: list[JavaClassNode] = [] - - self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False) + if not source: + return [] + + src = source # local alias for speed + src_len = len(src) + + # Precompute line start indices for quick line/column calculation + # line_starts[i] is the character index where line i (0-based) begins + line_starts = [0] + for i, ch in enumerate(src): + if ch == "\n": + # next line starts after this newline + line_starts.append(i + 1) + + def char_pos_to_line_col(pos: int) -> tuple[int, int]: + # Binary search the line_starts to find the line for pos + lo = 0 + hi = len(line_starts) - 1 + # pos is guaranteed 0 <= pos <= src_len + while lo <= hi: + mid = (lo + hi) // 2 + if line_starts[mid] <= pos: + lo = mid + 1 + else: + hi = mid - 1 + line = hi # 0-based + col = pos - line_starts[line] + return line + 1, col + 1 # return 1-based values + + # Helper: convert character index to UTF-8 byte offset (one-time encode and slice) + src_bytes = src.encode("utf8") + + # To convert a character offset to byte offset, we compute the byte length + # of the prefix up to that character. For speed, we memoize some boundaries. + # Create a mapping of every Nth character boundary to its byte offset to avoid + # repeated expensive encodings for long files. Choose N based on file size. + N = 1024 + checkpoints: dict[int, int] = {0: 0} + if src_len > N: + # create checkpoints every N characters + for i in range(N, src_len, N): + checkpoints[i] = len(src[:i].encode("utf8")) + + def char_to_byte_index(char_index: int) -> int: + # find greatest checkpoint <= char_index + keys = checkpoints.keys() + # linear search on small dict keys is fine; keys are sparse + best = 0 + for k in keys: + if k <= char_index and k >= best: + best = k + if best == 0: + return len(src[:char_index].encode("utf8")) + # compute remaining bytes from best to char_index + return checkpoints[best] + len(src[best:char_index].encode("utf8")) + + # Regex to find class/interface/enum declarations and their names. + decl_re = re.compile(r"\b(class|interface|enum)\s+([A-Za-z_]\w*)", flags=re.MULTILINE) + + results: list[JavaClassNode] = [] + + # State machine to find matching brace index while skipping strings and comments. + def find_matching_brace(start_idx: int) -> Optional[int]: + # start_idx points at the index of the '{' character + i = start_idx + depth = 0 + s = src + L = src_len + while i < L: + ch = s[i] + if ch == "/": + # possible comment + if i + 1 < L: + nxt = s[i + 1] + if nxt == "/": + # single-line comment: skip to end of line + i += 2 + while i < L and s[i] != "\n": + i += 1 + continue + elif nxt == "*": + # block comment: skip until closing */ + i += 2 + while i + 1 < L and not (s[i] == "*" and s[i + 1] == "/"): + i += 1 + i += 2 if i + 1 < L else 1 + continue + elif ch == '"' or ch == "'": + # string or char literal: skip until matching unescaped quote + quote = ch + i += 1 + while i < L: + c = s[i] + if c == "\\": + # skip escaped char + i += 2 + elif c == quote: + i += 1 + break + else: + i += 1 + continue + elif ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + return i + i += 1 + return None + + # Iterate over declaration matches + for m in decl_re.finditer(src): + decl_start = m.start() + kind = m.group(1) + name = m.group(2) + + # Find the opening brace for this declaration + brace_idx = src.find("{", m.end()) + if brace_idx == -1: + # No body found; skip + continue + + # Find matching closing brace robustly + end_brace_idx = find_matching_brace(brace_idx) + if end_brace_idx is None: + # Unbalanced braces; skip + continue + + # Compute byte offsets in UTF-8 encoding consistent with tree-sitter + body_start_byte = char_to_byte_index(brace_idx) + body_end_byte = char_to_byte_index(end_brace_idx + 1) # one past closing brace + + node_start_byte = char_to_byte_index(decl_start) + node_end_byte = body_end_byte # node end at body end + + node = NodeLike(body_start_byte, body_end_byte, node_start_byte, node_end_byte) + + # Extract header text between declaration start and opening brace + header_text = src[decl_start:brace_idx] + + # Determine modifiers + is_public = bool(re.search(r"\bpublic\b", header_text)) + is_abstract = bool(re.search(r"\babstract\b", header_text)) + is_final = bool(re.search(r"\bfinal\b", header_text)) + is_static = bool(re.search(r"\bstatic\b", header_text)) + + # Extract extends + extends_match = re.search(r"\bextends\s+([A-Za-z0-9_\.<>]+)", header_text) + extends = extends_match.group(1).strip() if extends_match else None + + # Extract implements (comma-separated) + implements: list[str] = [] + impl_match = re.search(r"\bimplements\s+([^<{]*?)\s*$", header_text) + if impl_match: + impl_text = impl_match.group(1) + # split on commas and strip whitespace + implements = [p.strip() for p in impl_text.split(",") if p.strip()] + + # Source text for class (from declaration start to closing brace inclusive) + source_text = src[decl_start : end_brace_idx + 1] + + # Try to detect Javadoc start line: look for last '/**' before declaration + javadoc_start_line: int | None = None + javadoc_pos = src.rfind("/**", 0, decl_start) + if javadoc_pos != -1: + # Ensure there is a closing '*/' between javadoc_pos and decl_start + close_pos = src.find("*/", javadoc_pos, decl_start) + if close_pos != -1: + # This looks like a javadoc block before the declaration + javadoc_start_line = char_pos_to_line_col(javadoc_pos)[0] + + start_line, start_col = char_pos_to_line_col(decl_start) + end_line, end_col = char_pos_to_line_col(end_brace_idx) + + jcnode = JavaClassNode( + name=name, + node=node, + start_line=start_line, + end_line=end_line, + start_col=start_col, + end_col=end_col, + is_public=is_public, + is_abstract=is_abstract, + is_final=is_final, + is_static=is_static, + extends=extends, + implements=implements, + source_text=source_text, + javadoc_start_line=javadoc_start_line, + ) + results.append(jcnode) - return classes + return results def _walk_tree_for_classes( self, node: Node, source_bytes: bytes, classes: list[JavaClassNode], is_inner: bool