diff --git a/bellows/ash.py b/bellows/ash.py index 48b492ca..c76ab712 100644 --- a/bellows/ash.py +++ b/bellows/ash.py @@ -378,7 +378,12 @@ def connection_lost(self, exc: Exception | None) -> None: self._ezsp_protocol.connection_lost(exc) def eof_received(self): + _LOGGER.warning("EOF received from remote end") self._ezsp_protocol.eof_received() + # Return True to prevent the transport from auto-closing. + # For serial-over-TCP connections (e.g. ser2net), the remote end may + # signal EOF during initialization without intending to close. + return True def _cancel_pending_data_frames( self, exc: BaseException = RuntimeError("Connection has been closed") @@ -445,7 +450,7 @@ def _unstuff_bytes(data: bytes) -> bytes: return out def data_received(self, data: bytes) -> None: - _LOGGER.debug("Received data %s", data.hex()) + _LOGGER.warning("ASH received %d bytes: %s", len(data), data[:32].hex()) self._buffer.extend(data) if len(self._buffer) > MAX_BUFFER_SIZE: @@ -742,5 +747,6 @@ async def send_data(self, data: bytes) -> None: ) def send_reset(self) -> None: + _LOGGER.warning("Sending ASH reset frame") # Some adapters seem to send a NAK immediately but still process the reset frame self._write_frame(RstFrame(), prefix=32 * (Reserved.CANCEL,)) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 7166847e..3b0c339c 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -117,6 +117,9 @@ def is_tcp_serial_port(self) -> bool: async def _startup_reset(self) -> None: """Start EZSP and reset the stack.""" + if self._gw is None: + raise EzspError("Gateway is not connected") + # `zigbeed` resets on startup if self.is_tcp_serial_port: try: @@ -220,8 +223,21 @@ async def get_xncp_features(self) -> xncp.FirmwareFeatures: async def disconnect(self): self.stop_ezsp() - if self._gw: - await self._gw.disconnect() + if self._gw is not None: + try: + await self._gw.disconnect() + except ConnectionError: + # The secondary event loop is dead. Force-close the + # underlying TCP socket so ser2net (or similar) releases + # the serial port for subsequent connection attempts. + try: + ash = self._gw._obj._transport + if ash is not None and ash._transport is not None: + sock = ash._transport.get_extra_info("socket") + if sock is not None: + sock.close() + except Exception: + pass self._gw = None async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any: diff --git a/bellows/thread.py b/bellows/thread.py index 4311768d..270402f6 100644 --- a/bellows/thread.py +++ b/bellows/thread.py @@ -1,6 +1,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor import functools +import inspect import logging LOGGER = logging.getLogger(__name__) @@ -14,7 +15,7 @@ def __init__(self): self.thread_complete = None def run_coroutine_threadsafe(self, coroutine): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) return asyncio.wrap_future(future, loop=current_loop) @@ -30,7 +31,7 @@ def _thread_main(self, init_task): self.loop = None async def start(self): - current_loop = asyncio.get_event_loop() + current_loop = asyncio.get_running_loop() if self.loop is not None and not self.loop.is_closed(): return @@ -95,11 +96,21 @@ def func_wrapper(*args, **kwargs): if loop == curr_loop: return call() if loop.is_closed(): - # Disconnected - LOGGER.warning("Attempted to use a closed event loop") - return - if asyncio.iscoroutinefunction(func): - future = asyncio.run_coroutine_threadsafe(call(), loop) + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) + if inspect.iscoroutinefunction(func): + coro = call() + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + except RuntimeError: + # Loop closed between is_closed() check and dispatch + coro.close() + raise ConnectionError( + "Attempted to use a closed event loop, " + "the connection may have been lost" + ) return asyncio.wrap_future(future, loop=curr_loop) else: diff --git a/bellows/uart.py b/bellows/uart.py index af274dc8..82043e0b 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -33,7 +33,12 @@ def data_received(self, data): def reset_received(self, code: t.NcpResetCode) -> None: """Reset acknowledgement frame receive handler""" - LOGGER.debug("Received reset: %r", code) + LOGGER.warning( + "Received reset: %r (reset_future=%s, startup_reset_future=%s)", + code, + self._reset_future, + self._startup_reset_future, + ) if self._reset_future and not self._reset_future.done(): self._reset_future.set_result(True) @@ -46,14 +51,15 @@ def reset_received(self, code: t.NcpResetCode) -> None: def error_received(self, code: t.NcpResetCode) -> None: """Error frame receive handler.""" if self._reset_future is not None or self._startup_reset_future is not None: - LOGGER.debug("Ignoring spurious error during reset: %r", code) + LOGGER.warning("Ignoring spurious error during reset: %r", code) else: + LOGGER.warning("Error received, entering failed state: %r", code) self._api.enter_failed_state(code) async def wait_for_startup_reset(self) -> None: """Wait for the first reset frame on startup.""" - assert self._startup_reset_future is None - self._startup_reset_future = asyncio.get_running_loop().create_future() + if self._startup_reset_future is None: + self._startup_reset_future = asyncio.get_running_loop().create_future() try: await self._startup_reset_future @@ -68,7 +74,7 @@ def connection_lost(self, exc): """Port was closed unexpectedly.""" super().connection_lost(exc) - LOGGER.debug("Connection lost: %r", exc) + LOGGER.warning("Gateway connection lost: %r", exc) reason = exc or ConnectionResetError("Remote server closed connection") # XXX: The startup reset future must be resolved with an error *before* the @@ -98,7 +104,7 @@ async def reset(self): return await self._reset_future self._transport.send_reset() - self._reset_future = asyncio.get_event_loop().create_future() + self._reset_future = asyncio.get_running_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) async with asyncio_timeout(RESET_TIMEOUT): @@ -106,13 +112,18 @@ async def reset(self): async def _connect(config, api): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() connection_done_future = loop.create_future() gateway = Gateway(api, connection_done_future) protocol = AshProtocol(gateway) + # Pre-create the startup reset future before opening the connection so that + # reset frames arriving immediately after connect are captured by + # reset_received() instead of triggering enter_failed_state(). + gateway._startup_reset_future = loop.create_future() + if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None: xon_xoff, rtscts = True, False else: @@ -135,7 +146,7 @@ async def _connect(config, api): async def connect(config, api, use_thread=True): if use_thread: - api = ThreadsafeProxy(api, asyncio.get_event_loop()) + api = ThreadsafeProxy(api, asyncio.get_running_loop()) thread = EventLoopThread() await thread.start() try: diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index d548309a..483ad085 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -789,6 +789,30 @@ async def wait_forever(*args, **kwargs): assert version_mock.await_count == 1 +async def test_startup_reset_gw_none(): + """Test _startup_reset raises EzspError when gateway is None.""" + ezsp = make_ezsp( + config={ + **DEVICE_CONFIG, + zigpy.config.CONF_DEVICE_PATH: "socket://localhost:1234", + } + ) + ezsp._gw = None + + with pytest.raises(EzspError, match="Gateway is not connected"): + await ezsp._startup_reset() + + +async def test_disconnect_gw_none(): + """Test disconnect doesn't raise when gateway is already None.""" + ezsp = make_ezsp() + ezsp._gw = None + + await ezsp.disconnect() # Should not raise + + assert ezsp._gw is None + + async def test_wait_for_stack_status(ezsp_f): assert not ezsp_f._stack_status_listeners[t.sl_Status.NETWORK_DOWN] diff --git a/tests/test_thread.py b/tests/test_thread.py index 72efa701..e8d35bc2 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -157,7 +157,8 @@ async def test_proxy_loop_closed(): obj = mock.MagicMock() proxy = ThreadsafeProxy(obj, loop) loop.close() - proxy.test() + with pytest.raises(ConnectionError, match="closed event loop"): + proxy.test() assert obj.test.call_count == 0