Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import time
import weakref
from collections.abc import AsyncGenerator, AsyncIterable, Callable
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Coroutine
from dataclasses import dataclass
from datetime import timedelta
from typing import cast, get_args
from typing import Any, cast, get_args

from grpc.aio import StreamStreamCall

import google.auth
from google.api_core.client_options import ClientOptions
Expand Down Expand Up @@ -748,14 +751,28 @@ async def input_generator(
None,
]:
nonlocal audio_pushed
stop_task = asyncio.create_task(should_stop.wait())
try:
yield self._build_init_request(client)

async for frame in self._input_ch:
# when the stream is aborted due to reconnect, this input_generator
# needs to stop consuming frames
# when the generator stops, the previous gRPC stream will close
if should_stop.is_set():
input_iter = aiter(self._input_ch)
while True:
# Race the next-frame await against should_stop so this generator
# can exit even when no audio is flowing. Without this, on reconnect
# the generator stays parked on _input_ch and pins the previous
# gRPC streaming call, leaking it across iterations.
frame_task = asyncio.create_task(
cast(Coroutine[rtc.AudioFrame, Any, Any], anext(input_iter))
)
done, _ = await asyncio.wait(
[frame_task, stop_task], return_when=asyncio.FIRST_COMPLETED
)
if stop_task in done:
frame_task.cancel()
return
try:
frame = frame_task.result()
except StopAsyncIteration:
return

if isinstance(frame, rtc.AudioFrame):
Expand All @@ -765,6 +782,8 @@ async def input_generator(

except Exception:
logger.exception("an error occurred while streaming input to google STT")
finally:
stop_task.cancel()

async def process_stream(
client: SpeechAsyncClientV2 | SpeechAsyncClientV1,
Expand Down Expand Up @@ -882,6 +901,12 @@ async def process_stream(
self._reconnect_event.clear()
finally:
should_stop.set()
# Cancel the streaming RPC so its underlying call object releases
# its read/write tasks and request iterator. Without this the
# call (and the input_generator that yielded into it) stays
# pinned across reconnects and leaks ~0.4 MB per cycle.
with contextlib.suppress(Exception):
cast(StreamStreamCall, stream).cancel()
if not process_stream_task.done() and not wait_reconnect_task.done():
# try to gracefully stop the process_stream_task
try:
Expand All @@ -896,6 +921,10 @@ async def process_stream(
if e.code == 409:
if audio_pushed:
logger.debug("stream timed out, restarting.")
else:
# No audio has been pushed — back off so we don't tight-loop
# against Google's server-side 5-minute idle timeout.
await asyncio.sleep(5.0)
else:
raise APIStatusError(
f"{e.message} {e.details}", status_code=e.code or -1
Expand Down