diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 8f8d74255a87..8c3fa5de98f8 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -412,9 +412,7 @@ def emit_line() -> None: emitter.emit_line() if generate_full: - generate_setup_for_class( - cl, defaults_fn, vtable_name, shadow_vtable_name, coroutine_setup_name, emitter - ) + generate_setup_for_class(cl, defaults_fn, vtable_name, shadow_vtable_name, emitter) emitter.emit_line() generate_constructor_for_class(cl, cl.ctor, init_fn, setup_name, vtable_name, emitter) emitter.emit_line() @@ -606,7 +604,6 @@ def generate_setup_for_class( defaults_fn: FuncIR | None, vtable_name: str, shadow_vtable_name: str | None, - coroutine_setup_name: str, emitter: Emitter, ) -> None: """Generate a native function that allocates an instance of a class.""" @@ -662,13 +659,6 @@ def generate_setup_for_class( if defaults_fn is not None: emit_attr_defaults_func_call(defaults_fn, "self", emitter) - # Initialize function wrapper for callable classes. As opposed to regular functions, - # each instance of a callable class needs its own wrapper because they might be instantiated - # inside other functions. - if cl.coroutine_name: - emitter.emit_line(f"if ({NATIVE_PREFIX}{coroutine_setup_name}((PyObject *)self) != 1)") - emitter.emit_line(" return NULL;") - emitter.emit_line("return (PyObject *)self;") emitter.emit_line("}") diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 88161f45bf03..5099f7053b92 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -69,6 +69,7 @@ Assign, BasicBlock, Branch, + Call, ComparisonOp, GetAttr, InitStatic, @@ -91,6 +92,7 @@ RType, RUnion, bitmap_rprimitive, + bool_rprimitive, bytes_rprimitive, c_pyssize_t_rprimitive, dict_rprimitive, @@ -1461,6 +1463,20 @@ def get_current_class_ir(self) -> ClassIR | None: type_info = self.fn_info.fitem.info return self.mapper.type_to_ir.get(type_info) + def add_coroutine_setup_call(self, class_name: str, obj: Value) -> Value: + return self.add( + Call( + FuncDecl( + class_name + "_coroutine_setup", + None, + self.module_name, + FuncSignature([RuntimeArg("type", object_rprimitive)], bool_rprimitive), + ), + [obj], + -1, + ) + ) + def gen_arg_defaults(builder: IRBuilder) -> None: """Generate blocks for arguments that have default values. diff --git a/mypyc/irbuild/callable_class.py b/mypyc/irbuild/callable_class.py index 59645d2597a7..22784ca12dfa 100644 --- a/mypyc/irbuild/callable_class.py +++ b/mypyc/irbuild/callable_class.py @@ -232,4 +232,9 @@ def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value: curr_env_reg = builder.fn_info.curr_env_reg if curr_env_reg: builder.add(SetAttr(func_reg, ENV_ATTR_NAME, curr_env_reg, fitem.line)) + # Initialize function wrapper for callable classes. As opposed to regular functions, + # each instance of a callable class needs its own wrapper because they might be instantiated + # inside other functions. + if not fn_info.in_non_ext and fn_info.is_coroutine: + builder.add_coroutine_setup_call(fn_info.callable_class.ir.name, func_reg) return func_reg diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 2e67d7aa785e..2fd7357ec4f0 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -31,7 +31,7 @@ from mypy.types import Instance, UnboundType, get_proper_type from mypyc.common import PROPSET_PREFIX from mypyc.ir.class_ir import ClassIR, NonExtClassInfo -from mypyc.ir.func_ir import FuncDecl, FuncSignature, RuntimeArg +from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.ops import ( NAMESPACE_TYPE, BasicBlock, @@ -473,19 +473,7 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value: -1, ) ) - - builder.add( - Call( - FuncDecl( - cdef.name + "_coroutine_setup", - None, - builder.module_name, - FuncSignature([RuntimeArg("type", object_rprimitive)], bool_rprimitive), - ), - [tp], - -1, - ) - ) + builder.add_coroutine_setup_call(cdef.name, tp) # Populate a '__mypyc_attrs__' field containing the list of attrs builder.primitive_op( diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 361dcffbbe73..f0320e60ee1a 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -1739,3 +1739,63 @@ async def m_future_with_reraised_exception(first_exc: Exception, second_exc: Exc return await make_future(first_exc) except type(first_exc): raise second_exc + +[case testCPyFunctionWithFreedInstance] +import asyncio + +from functools import wraps +from typing import Any + +class ctx_man: + async def __aenter__(self) -> None: + pass + + async def __aexit__(self, *args: Any) -> None: + pass + +def with_ctx_man(): + def decorator(f): + @wraps(f) + async def inner(): + async with ctx_man(): + return await f() + + return inner + + return decorator + +async def func() -> int: + return 33 + +async def run_wrapped(): + wrapped = with_ctx_man()(func) + return await wrapped() + +def test_native(): + assert asyncio.run(run_wrapped()) == 33 + +[file driver.py] +import asyncio + +from native import test_native, with_ctx_man + +async def func() -> int: + return 42 + +async def run_wrapped(): + wrapped = with_ctx_man()(func) + return await wrapped() + +def test_interpreted(): + assert asyncio.run(run_wrapped()) == 42 + +# Run multiple times to test that the CPyFunction attribute is still +# set correctly after reusing a freed instance of the inner callable class object. +for i in range(10): + test_interpreted() + test_native() + +[file asyncio/__init__.pyi] +from typing import Any, Generator + +def run(x: object) -> object: ...