diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index d306124d8c0f..9f5cbbea5d11 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -57,6 +57,7 @@ BasicBlock, Branch, Call, + ComparisonOp, InitStatic, Integer, LoadAddress, @@ -82,6 +83,7 @@ none_rprimitive, object_pointer_rprimitive, object_rprimitive, + pointer_rprimitive, ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional from mypyc.irbuild.builder import IRBuilder, create_type_params, int_borrow_friendly_op @@ -102,7 +104,9 @@ AssignmentTargetTuple, ) from mypyc.primitives.exc_ops import ( + err_occurred_op, error_catch_op, + error_clear_op, exc_matches_op, get_exc_info_op, get_exc_value_op, @@ -110,10 +114,11 @@ no_err_occurred_op, propagate_if_error_op, raise_exception_op, + raise_exception_with_tb_op, reraise_exception_op, restore_exc_info_op, ) -from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op +from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op, py_setattr_op from mypyc.primitives.misc_ops import ( check_stop_op, coro_op, @@ -940,6 +945,19 @@ def transform_with( is_async: bool, line: int, ) -> None: + + if ( + not is_async + and isinstance(expr, mypy.nodes.CallExpr) + and isinstance(expr.callee, mypy.nodes.RefExpr) + and isinstance(dec := expr.callee.node, mypy.nodes.Decorator) + and len(dec.decorators) == 1 + and isinstance(dec1 := dec.decorators[0], mypy.nodes.RefExpr) + and dec1.node + and dec1.node.fullname == "contextlib.contextmanager" + ): + return _transform_with_contextmanager(builder, expr, target, body, line) + # This is basically a straight transcription of the Python code in PEP 343. # I don't actually understand why a bunch of it is the way it is. # We could probably optimize the case where the manager is compiled by us, @@ -1017,6 +1035,285 @@ def finally_body() -> None: ) +def _transform_with_contextmanager( + builder: IRBuilder, + expr: mypy.nodes.CallExpr, + target: Lvalue | None, + with_body: GenFunc, + line: int, +) -> None: + assert isinstance(expr.callee, mypy.nodes.RefExpr) + dec = expr.callee.node + assert isinstance(dec, mypy.nodes.Decorator) + + # mgrv = ctx.__wrapped__(*args, **kwargs) + wrapped_call = mypy.nodes.CallExpr( + mypy.nodes.MemberExpr(expr.callee, "__wrapped__"), + expr.args, + expr.arg_kinds, + expr.arg_names, + ) + wrapped_call.line = line + gen = builder.maybe_spill(builder.accept(wrapped_call)) + + def raise_runtime_error_from_none(msg: str) -> None: + runtime_error = builder.load_module_attr_by_fullname("builtins.RuntimeError", line) + exc = builder.py_call(runtime_error, [builder.load_str(msg)], line) + builder.primitive_op( + py_setattr_op, [exc, builder.load_str("__cause__"), builder.none_object()], line + ) + builder.primitive_op( + py_setattr_op, + [ + exc, + builder.load_str("__suppress_context__"), + builder.coerce(builder.true(), object_rprimitive, line), + ], + line, + ) + builder.call_c(raise_exception_op, [exc], line) + builder.add(Unreachable()) + + # try: + # target = next(gen) + # except StopIteration: + # raise RuntimeError("generator didn't yield") from None + mgr_target = builder.call_c(next_raw_op, [builder.read(gen)], line) + + runtime_block, main_block = BasicBlock(), BasicBlock() + builder.add(Branch(mgr_target, runtime_block, main_block, Branch.IS_ERROR)) + + builder.activate_block(runtime_block) + err_occurred = builder.call_c(err_occurred_op, [], line) + null = Integer(0, pointer_rprimitive, line) + has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line)) + implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock() + builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL)) + + builder.activate_block(error_exc_block) + old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) + stop_block, propagate_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_stop_iteration, stop_block, propagate_block, Branch.BOOL)) + + builder.activate_block(propagate_block) + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) + builder.add(Unreachable()) + + builder.activate_block(stop_block) + builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line) + raise_runtime_error_from_none("generator didn't yield") + + builder.activate_block(implicit_stop_block) + raise_runtime_error_from_none("generator didn't yield") + + builder.activate_block(main_block) + + exc = builder.maybe_spill_assignable(builder.true()) + + # try: + # {body} + + def try_body() -> None: + if target: + builder.assign(builder.get_assignment_target(target), mgr_target, line) + with_body() + + # except BaseException as e: + # try: + # gen.throw(type, value, traceback) + # except StopIteration as e2: + # if e2 is not e: + # raise + # return + # except RuntimeError: + # raise + # except BaseException: + # # approximately + # raise + + def except_body() -> None: + builder.assign(exc, builder.false(), line) + exc_info = builder.call_c(get_exc_info_op, [], line) + exc_type = builder.add(TupleGet(exc_info, 0, line)) + exc_value = builder.add(TupleGet(exc_info, 1, line)) + exc_tb = builder.add(TupleGet(exc_info, 2, line)) + exc_value_target = builder.maybe_spill_assignable(exc_value) + + def reraise_original() -> None: + builder.call_c( + raise_exception_with_tb_op, + [exc_type, builder.read(exc_value_target), exc_tb], + line, + ) + builder.add(Unreachable()) + + # Make sure we have an exception instance so identity comparisons are reliable. + none = builder.none_object() + is_none = builder.binary_op(builder.read(exc_value_target), none, "is", line) + value_block, value_done = BasicBlock(), BasicBlock() + builder.add(Branch(is_none, value_block, value_done, Branch.BOOL)) + builder.activate_block(value_block) + new_value = builder.py_call(exc_type, [], line) + builder.assign(exc_value_target, new_value, line) + builder.goto(value_done) + builder.activate_block(value_done) + + error_block, no_error_block = BasicBlock(), BasicBlock() + builder.builder.push_error_handler(error_block) + builder.goto_and_activate(BasicBlock()) + builder.py_call( + builder.py_get_attr(builder.read(gen), "throw", line), + [exc_type, builder.read(exc_value_target), exc_tb], + line, + ) + builder.goto(no_error_block) + builder.builder.pop_error_handler() + + builder.activate_block(no_error_block) + builder.add( + RaiseStandardError( + RaiseStandardError.RUNTIME_ERROR, "generator didn't stop after throw()", line + ) + ) + builder.add(Unreachable()) + + builder.activate_block(error_block) + throw_old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) + stop_block, runtime_check_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_stop_iteration, stop_block, runtime_check_block, Branch.BOOL)) + + suppress_block = BasicBlock() + + builder.activate_block(stop_block) + stop_exc = builder.call_c(get_exc_value_op, [], line) + is_same_exc = builder.binary_op(stop_exc, builder.read(exc_value_target), "is", line) + propagate_block = BasicBlock() + builder.add(Branch(is_same_exc, propagate_block, suppress_block, Branch.BOOL)) + + builder.activate_block(propagate_block) + reraise_original() + + builder.activate_block(runtime_check_block) + runtime_error = builder.load_module_attr_by_fullname("builtins.RuntimeError", line) + is_runtime_error = builder.call_c(exc_matches_op, [runtime_error], line) + runtime_block, other_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_runtime_error, runtime_block, other_block, Branch.BOOL)) + + builder.activate_block(runtime_block) + runtime_exc = builder.call_c(get_exc_value_op, [], line) + is_same_runtime = builder.binary_op( + runtime_exc, builder.read(exc_value_target), "is", line + ) + runtime_same_block, runtime_cause_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_same_runtime, runtime_same_block, runtime_cause_block, Branch.BOOL)) + + builder.activate_block(runtime_same_block) + reraise_original() + + builder.activate_block(runtime_cause_block) + is_stop = builder.binary_op(exc_type, stop_iteration, "is", line) + cause_block, reraise_runtime_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_stop, cause_block, reraise_runtime_block, Branch.BOOL)) + + builder.activate_block(cause_block) + cause = builder.py_get_attr(runtime_exc, "__cause__", line) + is_cause = builder.binary_op(cause, builder.read(exc_value_target), "is", line) + cause_match_block, cause_miss_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_cause, cause_match_block, cause_miss_block, Branch.BOOL)) + + builder.activate_block(cause_match_block) + reraise_original() + + builder.activate_block(cause_miss_block) + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) + builder.add(Unreachable()) + + builder.activate_block(reraise_runtime_block) + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) + builder.add(Unreachable()) + + builder.activate_block(other_block) + other_exc = builder.call_c(get_exc_value_op, [], line) + is_same_other = builder.binary_op(other_exc, builder.read(exc_value_target), "is", line) + other_same_block, other_reraise_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_same_other, other_same_block, other_reraise_block, Branch.BOOL)) + + builder.activate_block(other_same_block) + reraise_original() + + builder.activate_block(other_reraise_block) + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) + builder.add(Unreachable()) + + builder.activate_block(suppress_block) + builder.call_c(restore_exc_info_op, [builder.read(throw_old_exc)], line) + builder.call_c(error_clear_op, [], -1) + + handlers = [(None, None, except_body)] + + # finally (normal exit path): + # try: + # next(gen) + # except StopIteration: + # pass + # else: + # raise RuntimeError("generator didn't stop") + + def normal_exit_body() -> None: + value = builder.call_c(next_raw_op, [builder.read(gen)], line) + stop_block, error_block = BasicBlock(), BasicBlock() + builder.add(Branch(value, stop_block, error_block, Branch.IS_ERROR)) + + builder.activate_block(error_block) + builder.add( + RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "generator didn't stop", line) + ) + builder.add(Unreachable()) + + builder.activate_block(stop_block) + err_occurred = builder.call_c(err_occurred_op, [], line) + null = Integer(0, pointer_rprimitive, line) + has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line)) + implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock() + builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL)) + + builder.activate_block(error_exc_block) + old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) + explicit_stop_block, propagate_block = BasicBlock(), BasicBlock() + builder.add(Branch(is_stop_iteration, explicit_stop_block, propagate_block, Branch.BOOL)) + + builder.activate_block(propagate_block) + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) + builder.add(Unreachable()) + + builder.activate_block(explicit_stop_block) + builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line) + builder.goto(implicit_stop_block) + + builder.activate_block(implicit_stop_block) + builder.call_c(error_clear_op, [], -1) + + def finally_body() -> None: + out_block, exit_block = BasicBlock(), BasicBlock() + builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL)) + builder.activate_block(exit_block) + normal_exit_body() + builder.goto_and_activate(out_block) + + transform_try_finally_stmt( + builder, + lambda: transform_try_except(builder, try_body, handlers, None, line), + finally_body, + line, + ) + + def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None: # Generate separate logic for each expr in it, left to right def generate(i: int) -> None: diff --git a/mypyc/primitives/exc_ops.py b/mypyc/primitives/exc_ops.py index e1234f807afa..01258712d3ad 100644 --- a/mypyc/primitives/exc_ops.py +++ b/mypyc/primitives/exc_ops.py @@ -80,6 +80,11 @@ arg_types=[], return_type=exc_rtuple, c_function_name="CPy_CatchError", error_kind=ERR_NEVER ) +# Clear the current exception. +error_clear_op = custom_op( + arg_types=[], return_type=void_rtype, c_function_name="PyErr_Clear", error_kind=ERR_NEVER +) + # Restore an old "currently handled exception" returned from. # error_catch (by sticking it into sys.exc_info()) restore_exc_info_op = custom_op( diff --git a/mypyc/test-data/run-functions.test b/mypyc/test-data/run-functions.test index 9bc5bb05c8d6..ebf77a187e96 100644 --- a/mypyc/test-data/run-functions.test +++ b/mypyc/test-data/run-functions.test @@ -1240,6 +1240,200 @@ def test_special_case() -> None: with f(): a.pop() +[case testContextManagerSpecialCaseSemantics] +from contextlib import contextmanager +from typing import Any, Generator, Iterator, cast +import traceback + +@contextmanager +def cm_no_yield() -> Iterator[None]: + if False: + yield + return + +@contextmanager +def cm_no_stop() -> Iterator[str]: + print("enter cm_no_stop") + yield "first" + print("after first") + yield "second" + +@contextmanager +def cm_value_error_suppress() -> Iterator[None]: + try: + yield + except ValueError as e: + print("suppress", str(e)) + +@contextmanager +def cm_throw_twice() -> Iterator[None]: + try: + yield + except ValueError: + print("throw yield") + yield + +@contextmanager +def cm_stop_passthrough() -> Iterator[None]: + yield + +@contextmanager +def cm_stop_suppress() -> Iterator[None]: + try: + yield + except StopIteration as e: + print("suppress stop", str(e)) + +@contextmanager +def cm_value() -> Iterator[int]: + print("enter cm_value") + try: + yield 1 + finally: + print("exit cm_value") + +@contextmanager +def cm_gen() -> Iterator[int]: + print("cm enter") + try: + yield 1 + finally: + print("cm exit") + +@contextmanager +def cm_log(tag: str) -> Iterator[None]: + print(tag, "enter") + try: + yield + finally: + print(tag, "exit") + +def test_no_yield() -> None: + try: + with cm_no_yield(): + print("body no yield") + except RuntimeError as e: + e_any = cast(Any, e) + print( + "no_yield", + str(e), + e_any.__cause__ is None, + e_any.__suppress_context__, + ) + +def test_no_stop() -> None: + try: + with cm_no_stop() as v: + print("body", v) + except RuntimeError as e: + print("no_stop", str(e)) + +def test_suppress_value_error() -> None: + with cm_value_error_suppress(): + raise ValueError("boom") + print("after suppress value") + +def test_throw_twice() -> None: + try: + with cm_throw_twice(): + raise ValueError("oops") + except RuntimeError as e: + print("throw_twice", str(e)) + +def test_stop_iteration_passthrough() -> None: + try: + with cm_stop_passthrough(): + raise StopIteration("stop") + except Exception as e: + print("stop passthrough", type(e).__name__, str(e)) + +def test_stop_iteration_suppress() -> None: + with cm_stop_suppress(): + raise StopIteration("stop2") + print("after suppress stop") + +def test_traceback() -> None: + try: + with cm_value(): + raise ValueError("trace") + except Exception as e: + e_any = cast(Any, e) + tb = e_any.__traceback__ + assert tb is not None + tb_mod = cast(Any, traceback) + names = [frame.name for frame in tb_mod.extract_tb(tb)] + print("traceback_has_cm_value", "cm_value" in names) + +def test_generator_with_yield() -> None: + def gen() -> Generator[int, None, None]: + with cm_gen() as v: + yield v + print("after yield in with") + + g = gen() + print(next(g)) + try: + next(g) + except StopIteration: + print("gen done") + +def test_return_in_with() -> None: + def inner() -> int: + with cm_log("ret"): + print("ret body") + return 5 + return 0 + + print("ret result", inner()) + +[file driver.py] +from native import ( + test_no_yield, + test_no_stop, + test_suppress_value_error, + test_throw_twice, + test_stop_iteration_passthrough, + test_stop_iteration_suppress, + test_traceback, + test_generator_with_yield, + test_return_in_with, +) + +test_no_yield() +test_no_stop() +test_suppress_value_error() +test_throw_twice() +test_stop_iteration_passthrough() +test_stop_iteration_suppress() +test_traceback() +test_generator_with_yield() +test_return_in_with() +[out] +no_yield generator didn't yield True True +enter cm_no_stop +body first +after first +no_stop generator didn't stop +suppress boom +after suppress value +throw yield +throw_twice generator didn't stop after throw() +stop passthrough StopIteration stop +suppress stop stop2 +after suppress stop +enter cm_value +exit cm_value +traceback_has_cm_value False +cm enter +1 +after yield in with +cm exit +gen done +ret enter +ret body +ret exit +ret result 5 + [case testUnpackKwargsCompiled] from typing import TypedDict from typing_extensions import Unpack