Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 298 additions & 1 deletion mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
BasicBlock,
Branch,
Call,
ComparisonOp,
InitStatic,
Integer,
LoadAddress,
Expand All @@ -82,6 +83,7 @@
none_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
pointer_rprimitive,
)
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
from mypyc.irbuild.builder import IRBuilder, create_type_params, int_borrow_friendly_op
Expand All @@ -102,18 +104,21 @@
AssignmentTargetTuple,
)
from mypyc.primitives.exc_ops import (
err_occurred_op,
error_catch_op,
error_clear_op,
exc_matches_op,
get_exc_info_op,
get_exc_value_op,
keep_propagating_op,
no_err_occurred_op,
propagate_if_error_op,
raise_exception_op,
raise_exception_with_tb_op,
reraise_exception_op,
restore_exc_info_op,
)
from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op
from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op, py_setattr_op
from mypyc.primitives.misc_ops import (
check_stop_op,
coro_op,
Expand Down Expand Up @@ -940,6 +945,19 @@ def transform_with(
is_async: bool,
line: int,
) -> None:

if (
not is_async
and isinstance(expr, mypy.nodes.CallExpr)
and isinstance(expr.callee, mypy.nodes.RefExpr)
and isinstance(dec := expr.callee.node, mypy.nodes.Decorator)
and len(dec.decorators) == 1
and isinstance(dec1 := dec.decorators[0], mypy.nodes.RefExpr)
and dec1.node
and dec1.node.fullname == "contextlib.contextmanager"
):
return _transform_with_contextmanager(builder, expr, target, body, line)

# This is basically a straight transcription of the Python code in PEP 343.
# I don't actually understand why a bunch of it is the way it is.
# We could probably optimize the case where the manager is compiled by us,
Expand Down Expand Up @@ -1017,6 +1035,285 @@ def finally_body() -> None:
)


def _transform_with_contextmanager(
builder: IRBuilder,
expr: mypy.nodes.CallExpr,
target: Lvalue | None,
with_body: GenFunc,
line: int,
) -> None:
assert isinstance(expr.callee, mypy.nodes.RefExpr)
dec = expr.callee.node
assert isinstance(dec, mypy.nodes.Decorator)

# mgrv = ctx.__wrapped__(*args, **kwargs)
wrapped_call = mypy.nodes.CallExpr(
mypy.nodes.MemberExpr(expr.callee, "__wrapped__"),
expr.args,
expr.arg_kinds,
expr.arg_names,
)
wrapped_call.line = line
gen = builder.maybe_spill(builder.accept(wrapped_call))

def raise_runtime_error_from_none(msg: str) -> None:
runtime_error = builder.load_module_attr_by_fullname("builtins.RuntimeError", line)
exc = builder.py_call(runtime_error, [builder.load_str(msg)], line)
builder.primitive_op(
py_setattr_op, [exc, builder.load_str("__cause__"), builder.none_object()], line
)
builder.primitive_op(
py_setattr_op,
[
exc,
builder.load_str("__suppress_context__"),
builder.coerce(builder.true(), object_rprimitive, line),
],
line,
)
builder.call_c(raise_exception_op, [exc], line)
builder.add(Unreachable())

# try:
# target = next(gen)
# except StopIteration:
# raise RuntimeError("generator didn't yield") from None
mgr_target = builder.call_c(next_raw_op, [builder.read(gen)], line)

runtime_block, main_block = BasicBlock(), BasicBlock()
builder.add(Branch(mgr_target, runtime_block, main_block, Branch.IS_ERROR))

builder.activate_block(runtime_block)
err_occurred = builder.call_c(err_occurred_op, [], line)
null = Integer(0, pointer_rprimitive, line)
has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line))
implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock()
builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL))

builder.activate_block(error_exc_block)
old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line))
stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line)
is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line)
stop_block, propagate_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_stop_iteration, stop_block, propagate_block, Branch.BOOL))

builder.activate_block(propagate_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

builder.activate_block(stop_block)
builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line)
raise_runtime_error_from_none("generator didn't yield")

builder.activate_block(implicit_stop_block)
raise_runtime_error_from_none("generator didn't yield")

builder.activate_block(main_block)

exc = builder.maybe_spill_assignable(builder.true())

# try:
# {body}

def try_body() -> None:
if target:
builder.assign(builder.get_assignment_target(target), mgr_target, line)
with_body()

# except BaseException as e:
# try:
# gen.throw(type, value, traceback)
# except StopIteration as e2:
# if e2 is not e:
# raise
# return
# except RuntimeError:
# raise
# except BaseException:
# # approximately
# raise

def except_body() -> None:
builder.assign(exc, builder.false(), line)
exc_info = builder.call_c(get_exc_info_op, [], line)
exc_type = builder.add(TupleGet(exc_info, 0, line))
exc_value = builder.add(TupleGet(exc_info, 1, line))
exc_tb = builder.add(TupleGet(exc_info, 2, line))
exc_value_target = builder.maybe_spill_assignable(exc_value)

def reraise_original() -> None:
builder.call_c(
raise_exception_with_tb_op,
[exc_type, builder.read(exc_value_target), exc_tb],
line,
)
builder.add(Unreachable())

# Make sure we have an exception instance so identity comparisons are reliable.
none = builder.none_object()
is_none = builder.binary_op(builder.read(exc_value_target), none, "is", line)
value_block, value_done = BasicBlock(), BasicBlock()
builder.add(Branch(is_none, value_block, value_done, Branch.BOOL))
builder.activate_block(value_block)
new_value = builder.py_call(exc_type, [], line)
builder.assign(exc_value_target, new_value, line)
builder.goto(value_done)
builder.activate_block(value_done)

error_block, no_error_block = BasicBlock(), BasicBlock()
builder.builder.push_error_handler(error_block)
builder.goto_and_activate(BasicBlock())
builder.py_call(
builder.py_get_attr(builder.read(gen), "throw", line),
[exc_type, builder.read(exc_value_target), exc_tb],
line,
)
builder.goto(no_error_block)
builder.builder.pop_error_handler()

builder.activate_block(no_error_block)
builder.add(
RaiseStandardError(
RaiseStandardError.RUNTIME_ERROR, "generator didn't stop after throw()", line
)
)
builder.add(Unreachable())

builder.activate_block(error_block)
throw_old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line))
stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line)
is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line)
stop_block, runtime_check_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_stop_iteration, stop_block, runtime_check_block, Branch.BOOL))

suppress_block = BasicBlock()

builder.activate_block(stop_block)
stop_exc = builder.call_c(get_exc_value_op, [], line)
is_same_exc = builder.binary_op(stop_exc, builder.read(exc_value_target), "is", line)
propagate_block = BasicBlock()
builder.add(Branch(is_same_exc, propagate_block, suppress_block, Branch.BOOL))

builder.activate_block(propagate_block)
reraise_original()

builder.activate_block(runtime_check_block)
runtime_error = builder.load_module_attr_by_fullname("builtins.RuntimeError", line)
is_runtime_error = builder.call_c(exc_matches_op, [runtime_error], line)
runtime_block, other_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_runtime_error, runtime_block, other_block, Branch.BOOL))

builder.activate_block(runtime_block)
runtime_exc = builder.call_c(get_exc_value_op, [], line)
is_same_runtime = builder.binary_op(
runtime_exc, builder.read(exc_value_target), "is", line
)
runtime_same_block, runtime_cause_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_same_runtime, runtime_same_block, runtime_cause_block, Branch.BOOL))

builder.activate_block(runtime_same_block)
reraise_original()

builder.activate_block(runtime_cause_block)
is_stop = builder.binary_op(exc_type, stop_iteration, "is", line)
cause_block, reraise_runtime_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_stop, cause_block, reraise_runtime_block, Branch.BOOL))

builder.activate_block(cause_block)
cause = builder.py_get_attr(runtime_exc, "__cause__", line)
is_cause = builder.binary_op(cause, builder.read(exc_value_target), "is", line)
cause_match_block, cause_miss_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_cause, cause_match_block, cause_miss_block, Branch.BOOL))

builder.activate_block(cause_match_block)
reraise_original()

builder.activate_block(cause_miss_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

builder.activate_block(reraise_runtime_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

builder.activate_block(other_block)
other_exc = builder.call_c(get_exc_value_op, [], line)
is_same_other = builder.binary_op(other_exc, builder.read(exc_value_target), "is", line)
other_same_block, other_reraise_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_same_other, other_same_block, other_reraise_block, Branch.BOOL))

builder.activate_block(other_same_block)
reraise_original()

builder.activate_block(other_reraise_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

builder.activate_block(suppress_block)
builder.call_c(restore_exc_info_op, [builder.read(throw_old_exc)], line)
builder.call_c(error_clear_op, [], -1)

handlers = [(None, None, except_body)]

# finally (normal exit path):
# try:
# next(gen)
# except StopIteration:
# pass
# else:
# raise RuntimeError("generator didn't stop")

def normal_exit_body() -> None:
value = builder.call_c(next_raw_op, [builder.read(gen)], line)
stop_block, error_block = BasicBlock(), BasicBlock()
builder.add(Branch(value, stop_block, error_block, Branch.IS_ERROR))

builder.activate_block(error_block)
builder.add(
RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "generator didn't stop", line)
)
builder.add(Unreachable())

builder.activate_block(stop_block)
err_occurred = builder.call_c(err_occurred_op, [], line)
null = Integer(0, pointer_rprimitive, line)
has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line))
implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock()
builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL))

builder.activate_block(error_exc_block)
old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line))
stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line)
is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line)
explicit_stop_block, propagate_block = BasicBlock(), BasicBlock()
builder.add(Branch(is_stop_iteration, explicit_stop_block, propagate_block, Branch.BOOL))

builder.activate_block(propagate_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

builder.activate_block(explicit_stop_block)
builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line)
builder.goto(implicit_stop_block)

builder.activate_block(implicit_stop_block)
builder.call_c(error_clear_op, [], -1)

def finally_body() -> None:
out_block, exit_block = BasicBlock(), BasicBlock()
builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL))
builder.activate_block(exit_block)
normal_exit_body()
builder.goto_and_activate(out_block)

transform_try_finally_stmt(
builder,
lambda: transform_try_except(builder, try_body, handlers, None, line),
finally_body,
line,
)


def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None:
# Generate separate logic for each expr in it, left to right
def generate(i: int) -> None:
Expand Down
5 changes: 5 additions & 0 deletions mypyc/primitives/exc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
arg_types=[], return_type=exc_rtuple, c_function_name="CPy_CatchError", error_kind=ERR_NEVER
)

# Clear the current exception.
error_clear_op = custom_op(
arg_types=[], return_type=void_rtype, c_function_name="PyErr_Clear", error_kind=ERR_NEVER
)

# Restore an old "currently handled exception" returned from.
# error_catch (by sticking it into sys.exc_info())
restore_exc_info_op = custom_op(
Expand Down
Loading
Loading