Skip to content
Closed
53 changes: 53 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,25 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal
return False


class ReturnTypeFinder(TraverserVisitor):
"""Visitor to collect return types from return statements in a function body.

This is used to infer return types for functions without explicit return type annotations.
"""

def __init__(self, typemap: dict[Expression, Type]) -> None:
self.typemap = typemap
self.return_types: list[Type] = []

def visit_return_stmt(self, o: ReturnStmt) -> None:
if o.expr is not None and o.expr in self.typemap:
self.return_types.append(self.typemap[o.expr])

def visit_func_def(self, o: FuncDef) -> None:
# Skip nested functions
pass


class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
"""Mypy type checker.

Expand Down Expand Up @@ -1600,6 +1619,40 @@ def check_func_def(
):
self.note(message_registry.EMPTY_BODY_ABSTRACT, defn)

# Infer return type from return statements if function has no explicit return type annotation
if isinstance(item, FuncDef) and isinstance(typ, CallableType):

def is_unannotated_any(t: Type) -> bool:
if not isinstance(t, ProperType):
return False
return isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated

ret_type_proper = get_proper_type(typ.ret_type)
# Only infer for functions without explicit return type annotations
# Skip generators and coroutines as they have special return type handling
if (
is_unannotated_any(ret_type_proper)
and not defn.is_generator
and not defn.is_coroutine
and not self.dynamic_funcs[-1]
and item.body is not None
):
# Collect return types from return statements
# Use the master type map (first in stack) where final types are stored
# At this point in type checking, return statement types should be in the master map
finder = ReturnTypeFinder(self._type_maps[0])
item.body.accept(finder)
return_types_list = finder.return_types

if return_types_list:
# Create union of all return types
inferred_ret_type = make_simplified_union(return_types_list)
# Update the function's return type
typ = typ.copy_modified(ret_type=inferred_ret_type)
item.type = typ
# Update the return_types stack as well
self.return_types[-1] = inferred_ret_type

self.return_types.pop()

self.binder = old_binder
Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,29 @@ def f() -> Iterator[int]:
return "foo" # E: No return value expected
[out]

[case testInferReturnTypeFromReturnStatements]
# Test that mypy infers return type from return statements when function has no explicit return type annotation
def f(x: int):
if x > 0:
return "positive"
else:
return 0

reveal_type(f(1)) # N: Revealed type is "builtins.str | builtins.int"

def g(x: bool):
return x

reveal_type(g(True)) # N: Revealed type is "builtins.bool"

def h(x: int):
if x > 0:
return "positive"
return None

reveal_type(h(1)) # N: Revealed type is "builtins.str | None"

[out]

-- If statement
-- ------------
Expand Down
Loading