Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion python/restate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import asyncio
import logging
from typing import Dict, TypedDict, Literal
import signal
from typing import Dict, Set, TypedDict, Literal

from restate.discovery import compute_discovery_json
from restate.endpoint import Endpoint
Expand Down Expand Up @@ -213,7 +214,23 @@ def asgi_app(endpoint: Endpoint) -> RestateAppT:
# Prepare request signer
identity_verifier = PyIdentityVerifier(endpoint.identity_keys)

active_channels: Set[ReceiveChannel] = set()
sigterm_installed = False

def _on_sigterm() -> None:
"""Notify all active receive channels of graceful shutdown."""
for ch in active_channels:
ch.notify_shutdown()

async def app(scope: Scope, receive: Receive, send: Send):
nonlocal sigterm_installed
if not sigterm_installed:
loop = asyncio.get_running_loop()
try:
loop.add_signal_handler(signal.SIGTERM, _on_sigterm)
except (NotImplementedError, RuntimeError):
pass # Windows or non-main thread
sigterm_installed = True
try:
if scope["type"] == "lifespan":
raise LifeSpanNotImplemented()
Expand Down Expand Up @@ -265,11 +282,13 @@ async def app(scope: Scope, receive: Receive, send: Send):
# Let us set up restate's execution context for this invocation and handler.
#
receive_channel = ReceiveChannel(receive)
active_channels.add(receive_channel)
try:
await process_invocation_to_completion(
VMWrapper(request_headers), handler, dict(request_headers), receive_channel, send
)
finally:
active_channels.discard(receive_channel)
await receive_channel.close()
except LifeSpanNotImplemented as e:
raise e
Expand Down
8 changes: 5 additions & 3 deletions python/restate/server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ async def leave(self):
# {'type': 'http.request', 'body': b'', 'more_body': True}
# {'type': 'http.request', 'body': b'', 'more_body': False}
# {'type': 'http.disconnect'}
# Wait for the runtime to explicitly close its side of the input.
# On SIGTERM, the shutdown event unblocks this instead of an arbitrary timeout.
await self.receive.block_until_http_input_closed()
# finally, we close our side
# it is important to do it, after the other side has closed his side,
Expand Down Expand Up @@ -545,9 +547,9 @@ async def wrapper(f):
continue
if chunk.get("type") == "http.disconnect":
raise DisconnectedException()
if chunk.get("body", None) is not None:
body = chunk.get("body")
assert isinstance(body, bytes)
# Skip empty body frames to avoid hot loop (see #175)
body: bytes | None = chunk.get("body", None) # type: ignore[assignment]
if body is not None and len(body) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good one!

self.vm.notify_input(body)
if not chunk.get("more_body", False):
self.vm.notify_input_closed()
Expand Down
16 changes: 14 additions & 2 deletions python/restate/server_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class HTTPRequestEvent(TypedDict):
more_body: bool


class HTTPDisconnectEvent(TypedDict):
"""ASGI Disconnect event"""

type: Literal["http.disconnect"]


class HTTPResponseStartEvent(TypedDict):
"""ASGI Response start event"""

Expand All @@ -75,7 +81,7 @@ class HTTPResponseBodyEvent(TypedDict):
more_body: bool


ASGIReceiveEvent = HTTPRequestEvent
ASGIReceiveEvent = Union[HTTPRequestEvent, HTTPDisconnectEvent]


ASGISendEvent = Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]
Expand Down Expand Up @@ -158,12 +164,18 @@ async def loop():

async def __call__(self) -> ASGIReceiveEvent | RestateEvent:
"""Get the next message."""
if self._disconnected.is_set() and self._queue.empty():
return {"type": "http.disconnect"}
what = await self._queue.get()
self._queue.task_done()
return what

def notify_shutdown(self) -> None:
"""Signal that a graceful shutdown has been requested (e.g. SIGTERM)."""
self._http_input_closed.set()

async def block_until_http_input_closed(self) -> None:
"""Wait until the HTTP input is closed"""
"""Wait until the HTTP input is closed or a shutdown signal is received."""
await self._http_input_closed.wait()

async def enqueue_restate_event(self, what: RestateEvent):
Expand Down
256 changes: 256 additions & 0 deletions tests/disconnect_hotloop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
#
# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""
Regression tests for disconnect and SIGTERM shutdown handling.

Covers:
- Hot-loop bug when BidiStream disconnects (empty queue, empty body frames)
- Graceful shutdown via notify_shutdown() unblocking block_until_http_input_closed()
"""

import asyncio
from typing import cast
from unittest.mock import MagicMock

import pytest

from restate.server_types import ASGIReceiveEvent, ReceiveChannel


@pytest.fixture(scope="session")
def anyio_backend():
return "asyncio"


pytestmark = [
pytest.mark.anyio,
]


async def test_receive_channel_returns_disconnect_when_drained():
"""After disconnect, an empty queue should return http.disconnect immediately."""
events = [
{"type": "http.request", "body": b"hello", "more_body": True},
{"type": "http.request", "body": b"", "more_body": False},
{"type": "http.disconnect"},
]
event_iter = iter(events)

async def mock_receive() -> ASGIReceiveEvent:
try:
return cast(ASGIReceiveEvent, next(event_iter))
except StopIteration:
# Block forever — simulates the real ASGI receive after disconnect
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)

# Drain all queued events
try:
result1 = await asyncio.wait_for(channel(), timeout=1.0)
assert result1["type"] == "http.request"

result2 = await asyncio.wait_for(channel(), timeout=1.0)
assert result2["type"] == "http.request"

result3 = await asyncio.wait_for(channel(), timeout=1.0)
assert result3["type"] == "http.disconnect"

# Now the queue is drained and _disconnected is set.
# This call should return immediately with a synthetic disconnect,
# NOT block forever.
result4 = await asyncio.wait_for(channel(), timeout=1.0)
assert result4["type"] == "http.disconnect"
finally:
await channel.close()


async def test_receive_channel_does_not_block_after_disconnect():
"""Repeated calls after disconnect should all return promptly."""
events = [
{"type": "http.disconnect"},
]
event_iter = iter(events)

async def mock_receive() -> ASGIReceiveEvent:
try:
return cast(ASGIReceiveEvent, next(event_iter))
except StopIteration:
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)

try:
# Consume the real disconnect
result = await asyncio.wait_for(channel(), timeout=1.0)
assert result["type"] == "http.disconnect"

# Subsequent calls should not block
for _ in range(5):
result = await asyncio.wait_for(channel(), timeout=0.5)
assert result["type"] == "http.disconnect"
finally:
await channel.close()


async def test_empty_body_frames_do_not_cause_hotloop():
"""
When the VM returns DoProgressReadFromInput and the chunk has body=b'',
notify_input should NOT be called (it would cause a tight loop).
The loop should exit via DisconnectedException when http.disconnect arrives.
"""
from restate.server_context import ServerInvocationContext, DisconnectedException
from restate.vm import DoProgressReadFromInput

# Build a minimal mock context
vm = MagicMock()
vm.take_output.return_value = None
vm.do_progress.return_value = DoProgressReadFromInput()

handler = MagicMock()
invocation = MagicMock()
send = MagicMock()

events = [
{"type": "http.request", "body": b"", "more_body": True},
{"type": "http.request", "body": b"", "more_body": False},
{"type": "http.disconnect"},
]
event_iter = iter(events)

async def mock_receive() -> ASGIReceiveEvent:
try:
return cast(ASGIReceiveEvent, next(event_iter))
except StopIteration:
await asyncio.Event().wait()
raise RuntimeError("unreachable")

receive_channel = ReceiveChannel(mock_receive)

ctx = ServerInvocationContext.__new__(ServerInvocationContext)
ctx.vm = vm
ctx.handler = handler
ctx.invocation = invocation
ctx.send = send
ctx.receive = receive_channel
ctx.run_coros_to_execute = {}
ctx.tasks = MagicMock()

try:
with pytest.raises(DisconnectedException):
await asyncio.wait_for(
ctx.create_poll_or_cancel_coroutine([0]),
timeout=2.0,
)

# notify_input should never have been called with empty bytes
for call in vm.notify_input.call_args_list:
arg = call[0][0]
assert len(arg) > 0, f"notify_input called with empty bytes: {arg!r}"
finally:
await receive_channel.close()


# ---- Shutdown / SIGTERM tests ----


async def test_block_until_http_input_closed_returns_on_normal_close():
"""block_until_http_input_closed returns when the runtime closes its input."""
events = [
{"type": "http.request", "body": b"data", "more_body": True},
{"type": "http.request", "body": b"", "more_body": False},
]
event_iter = iter(events)

async def mock_receive() -> ASGIReceiveEvent:
try:
return cast(ASGIReceiveEvent, next(event_iter))
except StopIteration:
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)
try:
# Should return promptly once more_body=False is received
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0)
finally:
await channel.close()


async def test_block_until_http_input_closed_returns_on_shutdown():
"""block_until_http_input_closed returns when notify_shutdown() is called,
even if the runtime never closes its input."""

async def mock_receive() -> ASGIReceiveEvent:
# Never sends any events — simulates the runtime not closing its side
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)
try:
# Schedule shutdown after a short delay
async def trigger_shutdown():
await asyncio.sleep(0.05)
channel.notify_shutdown()

asyncio.create_task(trigger_shutdown())

# Should return promptly due to shutdown, NOT block forever
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=1.0)
finally:
await channel.close()


async def test_notify_shutdown_is_idempotent():
"""Calling notify_shutdown() multiple times does not raise."""

async def mock_receive() -> ASGIReceiveEvent:
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)
try:
channel.notify_shutdown()
channel.notify_shutdown() # should not raise

# Should return immediately since shutdown is already set
await asyncio.wait_for(channel.block_until_http_input_closed(), timeout=0.5)
finally:
await channel.close()


async def test_shutdown_unblocks_concurrent_waiters():
"""Multiple concurrent waiters on block_until_http_input_closed
should all be unblocked by a single notify_shutdown()."""

async def mock_receive() -> ASGIReceiveEvent:
await asyncio.Event().wait()
raise RuntimeError("unreachable")

channel = ReceiveChannel(mock_receive)
try:
results = []

async def waiter(idx: int):
await channel.block_until_http_input_closed()
results.append(idx)

tasks = [asyncio.create_task(waiter(i)) for i in range(3)]

await asyncio.sleep(0.05)
channel.notify_shutdown()

await asyncio.wait_for(asyncio.gather(*tasks), timeout=1.0)
assert sorted(results) == [0, 1, 2]
finally:
await channel.close()