Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,53 @@ def __init__(
self.only_function_name = function.function_name
self.module_path = module_path
self.call_positions = call_positions
# Track instance variables when optimizing forward methods (PyTorch nn.Module pattern)
self.instance_variable_names: set[str] = set()
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name

def collect_instance_variables(self, func_node: ast.FunctionDef) -> None:
"""Collect variable names that are instances of the target class.

This handles the PyTorch nn.Module pattern where:
model = AlexNet(...)
model(input_data) # calls __call__ which invokes forward()

When optimizing ClassName.forward, we need to track variables assigned
from ClassName(...) so we can instrument calls to those variables.
"""
if self.class_name is None or self.only_function_name != "forward":
return

class_name = self.class_name
instance_vars = self.instance_variable_names

# Manually traverse only assignment nodes instead of walking entire tree
nodes_to_check = list(func_node.body)
while nodes_to_check:
node = nodes_to_check.pop()

# Look for assignments like: model = ClassName(...)
if isinstance(node, ast.Assign):
value = node.value
if isinstance(value, ast.Call):
func = value.func
if isinstance(func, ast.Name) and func.id == class_name:
for target in node.targets:
if isinstance(target, ast.Name):
instance_vars.add(target.id)

# Add nested statements to check
if hasattr(node, "body"):
nodes_to_check.extend(node.body)
if hasattr(node, "orelse"):
nodes_to_check.extend(node.orelse)
if hasattr(node, "finalbody"):
nodes_to_check.extend(node.finalbody)
if hasattr(node, "handlers"):
for handler in node.handlers:
nodes_to_check.extend(handler.body)

def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
Expand Down Expand Up @@ -122,7 +166,16 @@ def iter_ast_calls(node):
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())

for node in iter_ast_calls(test_node):
if not node_in_call_position(node, self.call_positions):
# Check if this call is at a known position OR is an instance variable call
# for forward methods (PyTorch nn.Module pattern)
is_at_call_position = node_in_call_position(node, self.call_positions)
is_instance_call = (
isinstance(node.func, ast.Name)
and node.func.id in self.instance_variable_names
and self.only_function_name == "forward"
)

if not is_at_call_position and not is_instance_call:
continue

call_node = node
Expand All @@ -134,7 +187,8 @@ def iter_ast_calls(node):
function_name = node_func.id

# Check if this is the function we want to instrument
if function_name != fn_obj.function_name:
# Also match instance variable calls for forward methods
if function_name != fn_obj.function_name and function_name not in self.instance_variable_names:
continue

if fn_obj.is_async:
Expand Down Expand Up @@ -325,6 +379,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:

def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
# Collect instance variables for forward method instrumentation (PyTorch pattern)
self.collect_instance_variables(node)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.

Consider clearing the set at the start of each test function:

Suggested change
self.collect_instance_variables(node)
self.instance_variable_names.clear()
self.collect_instance_variables(node)


did_update = False
i = len(node.body) - 1
while i >= 0:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_instrument_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3306,3 +3306,77 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time):

finally:
test_path.unlink(missing_ok=True)


def test_pytorch_forward_method_instrumentation() -> None:
"""Test instrumentation of PyTorch nn.Module forward method when called via instance().

This tests the pattern:
model = MyModule(...)
model(input_data) # calls __call__ which invokes forward()

The instrumentation should wrap the instance call even though the position
recorded is where the class is referenced, not where the instance is called.
"""
code = """
class MockModule:
def __init__(self, num_classes=10):
self.num_classes = num_classes

def forward(self, x):
return x * 2

def test_module():
model = MockModule(num_classes=10)
input_data = 5
result = model(input_data)
assert result == 10
"""
code_path = Path(tempfile.gettempdir()) / "mock_module.py"
test_path = Path(tempfile.gettempdir()) / "test_mock_module.py"

try:
with code_path.open("w") as f:
f.write(code)

with test_path.open("w") as f:
f.write(code)

func = FunctionToOptimize(
function_name="forward",
parents=[FunctionParent("MockModule", "ClassDef")],
file_path=code_path,
starting_line=6,
ending_line=7,
is_async=False,
)

# Position where MockModule is called (line 10 in 1-indexed: model = MockModule(...))
call_positions = [CodePosition(line_no=10, col_no=12)]

success, new_test = inject_profiling_into_existing_test(
test_path,
call_positions,
func,
test_path.parent,
mode=TestingMode.PERFORMANCE,
)

assert success
assert new_test is not None

# The key assertion: model(input_data) should be wrapped with codeflash_wrap
# The wrap should be around 'model', passing the instance as the callable
assert "codeflash_wrap(model," in new_test, (
"Expected model(input_data) to be wrapped as codeflash_wrap(model, ..., input_data), "
f"but got:\n{new_test}"
)

# Verify the function name in the wrap is the qualified name (MockModule.forward)
assert "MockModule.forward" in new_test, (
f"Expected 'MockModule.forward' to appear in the instrumented code, but got:\n{new_test}"
)

finally:
code_path.unlink(missing_ok=True)
test_path.unlink(missing_ok=True)
Loading