diff --git a/code_to_optimize/sample_code.py b/code_to_optimize/sample_code.py index d356ce807..704bda3cb 100644 --- a/code_to_optimize/sample_code.py +++ b/code_to_optimize/sample_code.py @@ -1,12 +1,24 @@ from functools import partial +from typing import Any import jax.numpy as jnp import numpy as np import tensorflow as tf import torch from jax import lax +from torch import nn +class AlexNet(nn.Module): + def __init__(self, num_classes=10, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + self.layer = nn.Linear(5,10) + + def forward(self, x): + x = self.layer(x) + return x + def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: n = len(b) diff --git a/code_to_optimize/tests/pytest/test_alexnet.py b/code_to_optimize/tests/pytest/test_alexnet.py new file mode 100644 index 000000000..1a0c9b6e6 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_alexnet.py @@ -0,0 +1,63 @@ +import torch + +from code_to_optimize.sample_code import AlexNet + +def test_models(): + torch.manual_seed(42) + model = AlexNet(num_classes=10) + input_data = torch.randn(2,5) + assert torch.allclose(model(input_data), torch.Tensor([ + [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, + 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, + 0.3680166304, 0.3558489084], + [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, + -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, + 0.2874411345, -0.4801278412]])) + +def test_models1(): + torch.manual_seed(42) + model = AlexNet(num_classes=10) + input_data = torch.randn(2,5) + assert torch.allclose(model(input_data), torch.Tensor([ + [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, + 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, + 0.3680166304, 0.3558489084], + [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, + -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, + 0.2874411345, -0.4801278412]])) + +def test_models2(): + torch.manual_seed(42) + model = AlexNet(num_classes=10) + input_data = torch.randn(2,5) + assert torch.allclose(model(input_data), torch.Tensor([ + [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, + 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, + 0.3680166304, 0.3558489084], + [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, + -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, + 0.2874411345, -0.4801278412]])) + +def test_models3(): + torch.manual_seed(42) + model = AlexNet(num_classes=10) + input_data = torch.randn(2,5) + assert torch.allclose(model(input_data), torch.Tensor([ + [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, + 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, + 0.3680166304, 0.3558489084], + [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, + -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, + 0.2874411345, -0.4801278412]])) + +def test_models4(): + torch.manual_seed(42) + model = AlexNet(num_classes=10) + input_data = torch.randn(2,5) + assert torch.allclose(model(input_data), torch.Tensor([ + [0.2655223608, 0.3765228391, -0.4080065191, 0.3314782381, + 0.6830080152, 0.5442206264, 0.1187968627, 0.2742837071, + 0.3680166304, 0.3558489084], + [-0.9252133369, -0.8182569146, -0.5546661019, 0.6546985507, + -0.1227166206, -0.0484373420, -0.5192810893, -0.4771555662, + 0.2874411345, -0.4801278412]])) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..d86a695ab 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -79,9 +79,53 @@ def __init__( self.only_function_name = function.function_name self.module_path = module_path self.call_positions = call_positions + # Track instance variables when optimizing forward methods (PyTorch nn.Module pattern) + self.instance_variable_names: set[str] = set() if len(function.parents) == 1 and function.parents[0].type == "ClassDef": self.class_name = function.top_level_parent_name + def collect_instance_variables(self, func_node: ast.FunctionDef) -> None: + """Collect variable names that are instances of the target class. + + This handles the PyTorch nn.Module pattern where: + model = AlexNet(...) + model(input_data) # calls __call__ which invokes forward() + + When optimizing ClassName.forward, we need to track variables assigned + from ClassName(...) so we can instrument calls to those variables. + """ + if self.class_name is None or self.only_function_name != "forward": + return + + class_name = self.class_name + instance_vars = self.instance_variable_names + + # Manually traverse only assignment nodes instead of walking entire tree + nodes_to_check = list(func_node.body) + while nodes_to_check: + node = nodes_to_check.pop() + + # Look for assignments like: model = ClassName(...) + if isinstance(node, ast.Assign): + value = node.value + if isinstance(value, ast.Call): + func = value.func + if isinstance(func, ast.Name) and func.id == class_name: + for target in node.targets: + if isinstance(target, ast.Name): + instance_vars.add(target.id) + + # Add nested statements to check + if hasattr(node, "body"): + nodes_to_check.extend(node.body) + if hasattr(node, "orelse"): + nodes_to_check.extend(node.orelse) + if hasattr(node, "finalbody"): + nodes_to_check.extend(node.finalbody) + if hasattr(node, "handlers"): + for handler in node.handlers: + nodes_to_check.extend(handler.body) + def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: @@ -122,7 +166,16 @@ def iter_ast_calls(node): codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) for node in iter_ast_calls(test_node): - if not node_in_call_position(node, self.call_positions): + # Check if this call is at a known position OR is an instance variable call + # for forward methods (PyTorch nn.Module pattern) + is_at_call_position = node_in_call_position(node, self.call_positions) + is_instance_call = ( + isinstance(node.func, ast.Name) + and node.func.id in self.instance_variable_names + and self.only_function_name == "forward" + ) + + if not is_at_call_position and not is_instance_call: continue call_node = node @@ -134,7 +187,8 @@ def iter_ast_calls(node): function_name = node_func.id # Check if this is the function we want to instrument - if function_name != fn_obj.function_name: + # Also match instance variable calls for forward methods + if function_name != fn_obj.function_name and function_name not in self.instance_variable_names: continue if fn_obj.is_async: @@ -325,6 +379,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: if node.name.startswith("test_"): + # Collect instance variables for forward method instrumentation (PyTorch pattern) + self.collect_instance_variables(node) + did_update = False i = len(node.body) - 1 while i >= 0: diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index d1ef28a8d..b4a4b007c 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -265,27 +265,22 @@ def visit_Import(self, node: ast.Import) -> None: def visit_Assign(self, node: ast.Assign) -> None: """Track variable assignments, especially class instantiations.""" - if self.found_any_target_function: - return - - # Check if the assignment is a class instantiation + # Always track instance assignments, even if we've found a target function + # This is needed for the PyTorch nn.Module pattern where model(x) calls forward(x) value = node.value if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): class_name = value.func.id if class_name in self.imported_modules: # Map the variable to the actual class name (handling aliases) original_class = self.alias_mapping.get(class_name, class_name) - # Use list comprehension for direct assignment to instance_mapping, reducing loop overhead targets = node.targets - instance_mapping = self.instance_mapping - # since ast.Name nodes are heavily used, avoid local lookup for isinstance - # and reuse locals for faster attribute access for target in targets: if isinstance(target, ast.Name): - instance_mapping[target.id] = original_class + self.instance_mapping[target.id] = original_class - # Continue visiting child nodes - self.generic_visit(node) + # Continue visiting child nodes if we haven't found a target function yet + if not self.found_any_target_function: + self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Handle 'from module import name' statements.""" @@ -405,7 +400,7 @@ def visit_Attribute(self, node: ast.Attribute) -> None: ast.NodeVisitor.generic_visit(self, node) def visit_Call(self, node: ast.Call) -> None: - """Handle function calls, particularly __import__.""" + """Handle function calls, particularly __import__ and instance calls for nn.Module.forward.""" if self.found_any_target_function: return @@ -415,6 +410,19 @@ def visit_Call(self, node: ast.Call) -> None: # When __import__ is used, any target function could potentially be imported # Be conservative and assume it might import target functions + # Check if this is a call on an instance variable (PyTorch nn.Module pattern) + # When model = AlexNet(...) and we call model(input_data), this invokes forward() + if isinstance(node.func, ast.Name): + instance_name = node.func.id + if instance_name in self.instance_mapping: + class_name = self.instance_mapping[instance_name] + # Check if ClassName.forward is in our target functions + roots_possible = self._dot_methods.get("forward") + if roots_possible and class_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(class_name, "forward")] + return + self.generic_visit(node) def visit_Name(self, node: ast.Name) -> None: @@ -495,6 +503,68 @@ def _fast_generic_visit(self, node: ast.AST) -> None: append((value._fields, value)) +class InstanceMappingExtractor(ast.NodeVisitor): + """Simple visitor to extract instance-to-class mappings from a file. + + This is needed for detecting PyTorch nn.Module.forward calls where model(x) calls forward(x). + """ + + def __init__(self) -> None: + self.imported_modules: set[str] = set() + self.alias_mapping: dict[str, str] = {} + self.instance_mapping: dict[str, str] = {} + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + module_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(module_name) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if not node.module: + return + for alias in node.names: + if alias.name == "*": + continue + imported_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(imported_name) + if alias.asname: + self.alias_mapping[imported_name] = alias.name + self.generic_visit(node) + + def visit_Assign(self, node: ast.Assign) -> None: + value = node.value + if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): + class_name = value.func.id + if class_name in self.imported_modules: + original_class = self.alias_mapping.get(class_name, class_name) + for target in node.targets: + if isinstance(target, ast.Name): + self.instance_mapping[target.id] = original_class + self.generic_visit(node) + + +def extract_instance_mapping(test_file_path: Path) -> dict[str, str]: + """Extract instance-to-class mappings from a test file. + + Args: + test_file_path: Path to the test file. + + Returns: + Dictionary mapping instance variable names to class names. + + """ + try: + with test_file_path.open("r", encoding="utf-8") as f: + source_code = f.read() + tree = ast.parse(source_code, filename=str(test_file_path)) + extractor = InstanceMappingExtractor() + extractor.visit(tree) + return extractor.instance_mapping + except (SyntaxError, FileNotFoundError): + return {} + + def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: """Analyze a test file to see if it imports any of the target functions.""" try: @@ -879,6 +949,10 @@ def process_test_files( top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + # Get instance-to-class mappings for PyTorch nn.Module.forward detection + # When model = AlexNet(...) and model(x) is called, it invokes forward(x) + instance_to_class_mapping = extract_instance_mapping(test_file) if functions_to_optimize else {} + except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") progress.advance(task_id) @@ -1017,6 +1091,61 @@ def process_test_files( num_discovered_replay_tests += 1 num_discovered_tests += 1 + + # Also check for PyTorch nn.Module pattern: model(x) -> forward(x) + # When an instance variable is called, it invokes the forward method + if name.name in instance_to_class_mapping: + class_name = instance_to_class_mapping[name.name] + for func_to_opt in functions_to_optimize: + # Check if the target is ClassName.forward + if ( + func_to_opt.function_name == "forward" + and func_to_opt.top_level_parent_name == class_name + ): + qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root( + project_root_path + ) + + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: + if test_framework == "pytest": + scope_test_function = ( + f"{test_func.function_name}[{test_func.parameters}]" + ) + else: # unittest + scope_test_function = ( + f"{test_func.function_name}_{test_func.parameters}" + ) + else: + scope_test_function = test_func.function_name + + function_to_test_map[qualified_name_with_modules].add( + FunctionCalledInTest( + tests_in_file=TestsInFile( + test_file=test_file, + test_class=test_func.test_class, + test_function=scope_test_function, + test_type=test_func.test_type, + ), + position=CodePosition(line_no=name.line, col_no=name.column), + ) + ) + tests_cache.insert_test( + file_path=str(test_file), + file_hash=file_hash, + qualified_name_with_modules_from_root=qualified_name_with_modules, + function_name=scope, + test_class=test_func.test_class or "", + test_function=scope_test_function, + test_type=test_func.test_type, + line_number=name.line, + col_number=name.column, + ) + + if test_func.test_type == TestType.REPLAY_TEST: + num_discovered_replay_tests += 1 + + num_discovered_tests += 1 continue definition_obj = definition[0] definition_path = str(definition_obj.module_path) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index c80a287e5..4c2c809eb 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import os import re import sqlite3 @@ -22,6 +21,9 @@ ) from codeflash.discovery.discover_unit_tests import discover_parameters_unittest from codeflash.languages import is_javascript + +# Import Jest-specific parsing from the JavaScript language module +from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml from codeflash.models.models import ( ConcurrencyMetrics, FunctionTestInvocation, @@ -32,10 +34,6 @@ ) from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils -# Import Jest-specific parsing from the JavaScript language module -from codeflash.languages.javascript.parse import jest_end_pattern, jest_start_pattern -from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml - if TYPE_CHECKING: import subprocess diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a8cd75b70..c5a6ab19f 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -3306,3 +3306,77 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): finally: test_path.unlink(missing_ok=True) + + +def test_pytorch_forward_method_instrumentation() -> None: + """Test instrumentation of PyTorch nn.Module forward method when called via instance(). + + This tests the pattern: + model = MyModule(...) + model(input_data) # calls __call__ which invokes forward() + + The instrumentation should wrap the instance call even though the position + recorded is where the class is referenced, not where the instance is called. + """ + code = """ +class MockModule: + def __init__(self, num_classes=10): + self.num_classes = num_classes + + def forward(self, x): + return x * 2 + +def test_module(): + model = MockModule(num_classes=10) + input_data = 5 + result = model(input_data) + assert result == 10 +""" + code_path = Path(tempfile.gettempdir()) / "mock_module.py" + test_path = Path(tempfile.gettempdir()) / "test_mock_module.py" + + try: + with code_path.open("w") as f: + f.write(code) + + with test_path.open("w") as f: + f.write(code) + + func = FunctionToOptimize( + function_name="forward", + parents=[FunctionParent("MockModule", "ClassDef")], + file_path=code_path, + starting_line=6, + ending_line=7, + is_async=False, + ) + + # Position where MockModule is called (line 10 in 1-indexed: model = MockModule(...)) + call_positions = [CodePosition(line_no=10, col_no=12)] + + success, new_test = inject_profiling_into_existing_test( + test_path, + call_positions, + func, + test_path.parent, + mode=TestingMode.PERFORMANCE, + ) + + assert success + assert new_test is not None + + # The key assertion: model(input_data) should be wrapped with codeflash_wrap + # The wrap should be around 'model', passing the instance as the callable + assert "codeflash_wrap(model," in new_test, ( + "Expected model(input_data) to be wrapped as codeflash_wrap(model, ..., input_data), " + f"but got:\n{new_test}" + ) + + # Verify the function name in the wrap is the qualified name (MockModule.forward) + assert "MockModule.forward" in new_test, ( + f"Expected 'MockModule.forward' to appear in the instrumented code, but got:\n{new_test}" + ) + + finally: + code_path.unlink(missing_ok=True) + test_path.unlink(missing_ok=True)