diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..f3e929688 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -636,6 +636,7 @@ def inject_async_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + gpu: bool = False, ) -> tuple[bool, str | None]: """Inject profiling for async function calls by setting environment variables before each call.""" with test_path.open(encoding="utf8") as f: @@ -708,6 +709,7 @@ def inject_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + gpu: bool = False, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -752,7 +754,7 @@ def inject_profiling_into_existing_test( else: # If there's an alias, use it (e.g., "import torch as th") new_imports.append(ast.Import(names=[ast.alias(name=framework_name, asname=framework_alias)])) - additional_functions = [create_wrapper_function(mode, used_frameworks)] + additional_functions = [create_wrapper_function(mode, used_frameworks, gpu)] tree.body = [*new_imports, *additional_functions, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) @@ -908,6 +910,60 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] | return precompute_statements +def _create_gpu_event_timing_precompute_statements(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements to pre-compute GPU event timing conditions. + + This generates: + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements that pre-compute GPU timer availability + + """ + if not used_frameworks or "torch" not in used_frameworks: + return [] + + torch_alias = used_frameworks["torch"] + + # _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + return [ + ast.Assign( + targets=[ast.Name(id="_codeflash_use_gpu_timer", ctx=ast.Store())], + value=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="is_initialized", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + lineno=1, + ) + ] + + def _create_device_sync_statements( used_frameworks: dict[str, str] | None, for_return_value: bool = False ) -> list[ast.stmt]: @@ -1030,8 +1086,338 @@ def _create_device_sync_statements( return sync_statements +def _create_gpu_timing_try_body(torch_alias: str) -> list[ast.stmt]: + """Create AST statements for the GPU event timing try body. + + Generates: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1_000_000) + + Args: + torch_alias: The import alias for torch (e.g., "torch" or "th") + + Returns: + List of AST statements for GPU event timing + + """ + return [ + # _codeflash_start_event = torch.cuda.Event(enable_timing=True) + ast.Assign( + targets=[ast.Name(id="_codeflash_start_event", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="Event", + ctx=ast.Load(), + ), + args=[], + keywords=[ast.keyword(arg="enable_timing", value=ast.Constant(value=True))], + ), + lineno=1, + ), + # _codeflash_end_event = torch.cuda.Event(enable_timing=True) + ast.Assign( + targets=[ast.Name(id="_codeflash_end_event", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="Event", + ctx=ast.Load(), + ), + args=[], + keywords=[ast.keyword(arg="enable_timing", value=ast.Constant(value=True))], + ), + lineno=1, + ), + # _codeflash_start_event.record() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_start_event", ctx=ast.Load()), attr="record", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + # return_value = codeflash_wrapped(*args, **kwargs) + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=1, + ), + # _codeflash_end_event.record() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_end_event", ctx=ast.Load()), attr="record", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + # torch.cuda.synchronize() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ), + # codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1_000_000) + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="int", ctx=ast.Load()), + args=[ + ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="_codeflash_start_event", ctx=ast.Load()), + attr="elapsed_time", + ctx=ast.Load(), + ), + args=[ast.Name(id="_codeflash_end_event", ctx=ast.Load())], + keywords=[], + ), + op=ast.Mult(), + right=ast.Constant(value=1_000_000), + ) + ], + keywords=[], + ), + lineno=1, + ), + ] + + +def _create_gpu_timing_except_body(torch_alias: str) -> list[ast.stmt]: + """Create AST statements for the GPU event timing exception handler. + + Generates: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + + Args: + torch_alias: The import alias for torch (e.g., "torch" or "th") + + Returns: + List of AST statements for GPU timing exception handling + + """ + return [ + # torch.cuda.synchronize() + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id=torch_alias, ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ), + # codeflash_duration = 0 + ast.Assign(targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], value=ast.Constant(value=0), lineno=1), + # exception = e + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Name(id="e", ctx=ast.Load()), lineno=1 + ), + ] + + +def _create_cpu_timing_try_body(used_frameworks: dict[str, str] | None) -> list[ast.stmt]: + """Create AST statements for the CPU timing try body. + + Generates standard time.perf_counter_ns() timing with device sync. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + + Returns: + List of AST statements for CPU timing + + """ + return [ + # Pre-sync: synchronize device before starting timer + *_create_device_sync_statements(used_frameworks, for_return_value=False), + # counter = time.perf_counter_ns() + ast.Assign( + targets=[ast.Name(id="counter", ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load()), + args=[], + keywords=[], + ), + lineno=1, + ), + # return_value = codeflash_wrapped(*args, **kwargs) + ast.Assign( + targets=[ast.Name(id="return_value", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=1, + ), + # Post-sync: synchronize device after function call + *_create_device_sync_statements(used_frameworks, for_return_value=True), + # codeflash_duration = time.perf_counter_ns() - counter + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=1, + ), + ] + + +def _create_cpu_timing_except_body() -> list[ast.stmt]: + """Create AST statements for the CPU timing exception handler. + + Generates: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + + Returns: + List of AST statements for CPU timing exception handling + + """ + return [ + # codeflash_duration = time.perf_counter_ns() - counter + ast.Assign( + targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], + value=ast.BinOp( + left=ast.Call( + func=ast.Attribute( + value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() + ), + args=[], + keywords=[], + ), + op=ast.Sub(), + right=ast.Name(id="counter", ctx=ast.Load()), + ), + lineno=1, + ), + # exception = e + ast.Assign( + targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Name(id="e", ctx=ast.Load()), lineno=1 + ), + ] + + +def _create_timing_try_block(used_frameworks: dict[str, str] | None, gpu: bool, lineno: int) -> list[ast.stmt]: + """Create the timing try block, handling both GPU and CPU timing modes. + + When gpu=True and torch is available, generates an if/else structure: + if _codeflash_use_gpu_timer: + # GPU event timing path + else: + # CPU timing fallback path + + Otherwise, generates standard CPU timing. + + Args: + used_frameworks: Dict mapping framework names to their import aliases + gpu: Whether to use GPU event timing when possible + lineno: Current line number for AST nodes + + Returns: + List containing the try statement(s) for timing + + """ + use_gpu_timing = gpu and used_frameworks and "torch" in used_frameworks + + if use_gpu_timing: + torch_alias = used_frameworks["torch"] + + # Create GPU timing try block + gpu_try = ast.Try( + body=_create_gpu_timing_try_body(torch_alias), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_gpu_timing_except_body(torch_alias), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + + # Create CPU timing try block (fallback) + cpu_try = ast.Try( + body=_create_cpu_timing_try_body(used_frameworks), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_cpu_timing_except_body(), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + + # Wrap in if/else based on _codeflash_use_gpu_timer + return [ + ast.If( + test=ast.Name(id="_codeflash_use_gpu_timer", ctx=ast.Load()), + body=[gpu_try], + orelse=[cpu_try], + lineno=lineno + 11, + ) + ] + # Standard CPU timing + return [ + ast.Try( + body=_create_cpu_timing_try_body(used_frameworks), + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="Exception", ctx=ast.Load()), + name="e", + body=_create_cpu_timing_except_body(), + lineno=lineno + 14, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno + 11, + ) + ] + + def create_wrapper_function( - mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None + mode: TestingMode = TestingMode.BEHAVIOR, used_frameworks: dict[str, str] | None = None, gpu: bool = False ) -> ast.FunctionDef: lineno = 1 wrapper_body: list[ast.stmt] = [ @@ -1193,8 +1579,14 @@ def create_wrapper_function( ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), - # Pre-compute device sync conditions before profiling to avoid overhead during timing - *_create_device_sync_precompute_statements(used_frameworks), + # Pre-compute conditions before profiling to avoid overhead during timing + *( + # When gpu=True with torch, we need both the GPU timer check AND device sync conditions for the fallback + _create_gpu_event_timing_precompute_statements(used_frameworks) + + _create_device_sync_precompute_statements(used_frameworks) + if gpu and used_frameworks and "torch" in used_frameworks + else _create_device_sync_precompute_statements(used_frameworks) + ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), @@ -1203,83 +1595,7 @@ def create_wrapper_function( ), lineno=lineno + 9, ), - ast.Try( - body=[ - # Pre-sync: synchronize device before starting timer - *_create_device_sync_statements(used_frameworks, for_return_value=False), - ast.Assign( - targets=[ast.Name(id="counter", ctx=ast.Store())], - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - lineno=lineno + 11, - ), - ast.Assign( - targets=[ast.Name(id="return_value", ctx=ast.Store())], - value=ast.Call( - func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ), - lineno=lineno + 12, - ), - # Post-sync: synchronize device after function call to ensure all device work is complete - *_create_device_sync_statements(used_frameworks, for_return_value=True), - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), attr="perf_counter_ns", ctx=ast.Load() - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 13, - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="Exception", ctx=ast.Load()), - name="e", - body=[ - ast.Assign( - targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], - value=ast.BinOp( - left=ast.Call( - func=ast.Attribute( - value=ast.Name(id="time", ctx=ast.Load()), - attr="perf_counter_ns", - ctx=ast.Load(), - ), - args=[], - keywords=[], - ), - op=ast.Sub(), - right=ast.Name(id="counter", ctx=ast.Load()), - ), - lineno=lineno + 15, - ), - ast.Assign( - targets=[ast.Name(id="exception", ctx=ast.Store())], - value=ast.Name(id="e", ctx=ast.Load()), - lineno=lineno + 13, - ), - ], - lineno=lineno + 14, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno + 11, - ), + *_create_timing_try_block(used_frameworks, gpu, lineno), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="enable", ctx=ast.Load()), diff --git a/tests/test_inject_profiling_used_frameworks.py b/tests/test_inject_profiling_used_frameworks.py index 826be09c8..ede5559df 100644 --- a/tests/test_inject_profiling_used_frameworks.py +++ b/tests/test_inject_profiling_used_frameworks.py @@ -1492,3 +1492,435 @@ def test_my_function(): result = normalize_instrumented_code(instrumented_code) expected = EXPECTED_ALL_FRAMEWORKS_PERFORMANCE assert result == expected + + +# ============================================================================ +# Expected instrumented code for GPU timing mode +# ============================================================================ + +EXPECTED_TORCH_GPU_BEHAVIOR = """import gc +import inspect +import os +import sqlite3 +import time + +import dill as pickle +import torch +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}######!') + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_con.commit() + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] + codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}') + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') + _call__bound__arguments = inspect.signature(my_function).bind(1, 2) + _call__bound__arguments.apply_defaults() + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) + assert result == 3 + codeflash_con.close() +""" + +EXPECTED_TORCH_GPU_PERFORMANCE = """import gc +import os +import time + +import torch +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = torch.cuda.Event(enable_timing=True) + _codeflash_end_event = torch.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + torch.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + torch.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + torch.cuda.synchronize() + elif _codeflash_should_sync_mps: + torch.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}:{codeflash_duration}######!') + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2) + assert result == 3 +""" + +EXPECTED_TORCH_ALIASED_GPU_BEHAVIOR = """import gc +import inspect +import os +import sqlite3 +import time + +import dill as pickle +import torch as th +from mymodule import my_function + + +def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs): + test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}' + if not hasattr(codeflash_wrap, 'index'): + codeflash_wrap.index = {} + if test_id in codeflash_wrap.index: + codeflash_wrap.index[test_id] += 1 + else: + codeflash_wrap.index[test_id] = 0 + codeflash_test_index = codeflash_wrap.index[test_id] + invocation_id = f'{codeflash_line_id}_{codeflash_test_index}' + test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}' + print(f'!$######{test_stdout_tag}######$!') + exception = None + _codeflash_use_gpu_timer = th.cuda.is_available() and th.cuda.is_initialized() + _codeflash_should_sync_cuda = th.cuda.is_available() and th.cuda.is_initialized() + _codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(th.backends, 'mps') and th.backends.mps.is_available() and hasattr(th.mps, 'synchronize') + gc.disable() + if _codeflash_use_gpu_timer: + try: + _codeflash_start_event = th.cuda.Event(enable_timing=True) + _codeflash_end_event = th.cuda.Event(enable_timing=True) + _codeflash_start_event.record() + return_value = codeflash_wrapped(*args, **kwargs) + _codeflash_end_event.record() + th.cuda.synchronize() + codeflash_duration = int(_codeflash_start_event.elapsed_time(_codeflash_end_event) * 1000000) + except Exception as e: + th.cuda.synchronize() + codeflash_duration = 0 + exception = e + else: + try: + if _codeflash_should_sync_cuda: + th.cuda.synchronize() + elif _codeflash_should_sync_mps: + th.mps.synchronize() + counter = time.perf_counter_ns() + return_value = codeflash_wrapped(*args, **kwargs) + if _codeflash_should_sync_cuda: + th.cuda.synchronize() + elif _codeflash_should_sync_mps: + th.mps.synchronize() + codeflash_duration = time.perf_counter_ns() - counter + except Exception as e: + codeflash_duration = time.perf_counter_ns() - counter + exception = e + gc.enable() + print(f'!######{test_stdout_tag}######!') + pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) + codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + codeflash_con.commit() + if exception: + raise exception + return return_value + +def test_my_function(): + codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) + codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] + codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}') + codeflash_cur = codeflash_con.cursor() + codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') + _call__bound__arguments = inspect.signature(my_function).bind(1, 2) + _call__bound__arguments.apply_defaults() + result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs) + assert result == 3 + codeflash_con.close() +""" + + +# ============================================================================ +# Tests for GPU timing mode +# ============================================================================ + + +class TestInjectProfilingGpuTimingMode: + """Tests for inject_profiling_into_existing_test with gpu=True.""" + + def test_torch_gpu_behavior_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch and gpu=True in BEHAVIOR mode.""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.BEHAVIOR, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_GPU_BEHAVIOR + assert result == expected + + def test_torch_gpu_performance_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch and gpu=True in PERFORMANCE mode.""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_GPU_PERFORMANCE + assert result == expected + + def test_torch_aliased_gpu_behavior_mode(self, tmp_path: Path) -> None: + """Test instrumentation with PyTorch alias and gpu=True in BEHAVIOR mode.""" + code = """import torch as th +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.BEHAVIOR, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + expected = EXPECTED_TORCH_ALIASED_GPU_BEHAVIOR + assert result == expected + + def test_no_torch_gpu_flag_uses_cpu_timing(self, tmp_path: Path) -> None: + """Test that gpu=True without torch uses standard CPU timing.""" + code = """from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(4, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + result = normalize_instrumented_code(instrumented_code) + # gpu=True without torch should produce the same result as gpu=False + expected = EXPECTED_NO_FRAMEWORKS_PERFORMANCE + assert result == expected + + def test_gpu_false_with_torch_uses_device_sync(self, tmp_path: Path) -> None: + """Test that gpu=False with torch uses device sync (existing behavior).""" + code = """import torch +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=False, + ) + + result = normalize_instrumented_code(instrumented_code) + # gpu=False with torch should produce device sync code + expected = EXPECTED_TORCH_PERFORMANCE + assert result == expected + + def test_torch_submodule_import_gpu_mode(self, tmp_path: Path) -> None: + """Test that gpu=True works with torch submodule imports like 'from torch import nn'.""" + code = """from torch import nn +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + assert success + # Verify GPU timing code is present (torch detected from submodule import) + assert "_codeflash_use_gpu_timer = torch.cuda.is_available()" in instrumented_code + assert "torch.cuda.Event(enable_timing=True)" in instrumented_code + assert "elapsed_time" in instrumented_code + + def test_torch_dotted_import_gpu_mode(self, tmp_path: Path) -> None: + """Test that gpu=True works with torch dotted imports like 'import torch.nn'.""" + code = """import torch.nn +from mymodule import my_function + +def test_my_function(): + result = my_function(1, 2) + assert result == 3 +""" + test_file = tmp_path / "test_example.py" + test_file.write_text(code) + + func = FunctionToOptimize(function_name="my_function", parents=[], file_path=Path("mymodule.py")) + + success, instrumented_code = inject_profiling_into_existing_test( + test_path=test_file, + call_positions=[CodePosition(5, 13)], + function_to_optimize=func, + tests_project_root=tmp_path, + mode=TestingMode.PERFORMANCE, + gpu=True, + ) + + assert success + # Verify GPU timing code is present (torch detected from dotted import) + assert "_codeflash_use_gpu_timer = torch.cuda.is_available()" in instrumented_code + assert "torch.cuda.Event(enable_timing=True)" in instrumented_code + assert "elapsed_time" in instrumented_code