diff --git a/getstream/video/rtc/connection_manager.py b/getstream/video/rtc/connection_manager.py index 434a7a4e..b8d992f1 100644 --- a/getstream/video/rtc/connection_manager.py +++ b/getstream/video/rtc/connection_manager.py @@ -58,9 +58,18 @@ def __init__( create: bool = True, subscription_config: Optional[SubscriptionConfig] = None, max_join_retries: int = 3, - drain_video_frames: bool = False, + drain_video_frames: bool = True, **kwargs: Any, ): + """ + Args: + drain_video_frames: When True, attaches a MediaBlackhole to each + incoming video track so unconsumed frames are drained + automatically. This prevents unbounded queue growth in + RTCRtpReceiver when no subscriber is consuming the track. + The drain is stopped once a real subscriber is added via + add_track_subscriber. + """ super().__init__() # Public attributes diff --git a/getstream/video/rtc/pc.py b/getstream/video/rtc/pc.py index a07d5840..ab7302eb 100644 --- a/getstream/video/rtc/pc.py +++ b/getstream/video/rtc/pc.py @@ -131,7 +131,7 @@ def __init__( self, connection, configuration: aiortc.RTCConfiguration, - drain_video_frames: bool = False, + drain_video_frames: bool = True, ) -> None: logger.info( f"creating subscriber peer connection with configuration: {configuration}" @@ -142,8 +142,8 @@ def __init__( self.track_map = {} # track_id -> (MediaRelay, original_track) self.video_frame_trackers = {} # track_id -> VideoFrameTracker - self._video_blackholes: dict[str, MediaBlackhole] = {} - self._video_drain_tasks: dict[str, asyncio.Task] = {} + self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {} + self._background_tasks: set[asyncio.Task] = set() @self.on("track") async def on_track(track: aiortc.mediastreams.MediaStreamTrack): @@ -189,11 +189,8 @@ def _emit_pcm(pcm: PcmData): drain_proxy = relay.subscribe(tracked_track) blackhole = MediaBlackhole() blackhole.addTrack(drain_proxy) - self._video_blackholes[track.id] = blackhole - self._video_drain_tasks[track.id] = asyncio.create_task( - blackhole.start() - ) - + drain_task = asyncio.create_task(blackhole.start()) + self._video_blackholes[track.id] = (blackhole, drain_task) self.emit("track_added", proxy, user) @self.on("icegatheringstatechange") @@ -208,6 +205,14 @@ def add_track_subscriber( """Add a new subscriber to an existing track's MediaRelay.""" track_data = self.track_map.get(track_id) + blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None)) + + if blackhole and drain_task: + task = asyncio.create_task(blackhole.stop()) + drain_task.cancel() # safety net if start() becomes long-lived in future aiortc + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + if track_data: relay, original_track = track_data return relay.subscribe(original_track, buffered=False) diff --git a/getstream/video/rtc/peer_connection.py b/getstream/video/rtc/peer_connection.py index 26e89df0..685de269 100644 --- a/getstream/video/rtc/peer_connection.py +++ b/getstream/video/rtc/peer_connection.py @@ -28,7 +28,7 @@ class PeerConnectionManager: """Manages WebRTC peer connections for publishing and subscribing.""" - def __init__(self, connection_manager, drain_video_frames: bool = False): + def __init__(self, connection_manager, drain_video_frames: bool = True): self.connection_manager = connection_manager self._drain_video_frames = drain_video_frames self.publisher_pc: Optional[PublisherPeerConnection] = None diff --git a/tests/rtc/test_subscriber_drain.py b/tests/rtc/test_subscriber_drain.py new file mode 100644 index 00000000..08b86f51 --- /dev/null +++ b/tests/rtc/test_subscriber_drain.py @@ -0,0 +1,49 @@ +"""Tests for SubscriberPeerConnection video drain behavior.""" + +from unittest.mock import AsyncMock, Mock + +import pytest +from aiortc.contrib.media import MediaRelay + +from getstream.video.rtc.pc import SubscriberPeerConnection + + +@pytest.fixture +def subscriber_pc(): + """Create a SubscriberPeerConnection bypassing heavy parent inits.""" + pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection) + pc.connection = Mock() + pc._drain_video_frames = True + pc.track_map = {} + pc.video_frame_trackers = {} + pc._video_blackholes = {} + pc._background_tasks = set() + pc._listeners = {} + return pc + + +class TestAddTrackSubscriberStopsDrain: + @pytest.mark.asyncio + async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc): + track_id = "user123:video:0" + relay = MediaRelay() + original_track = Mock() + subscriber_pc.track_map[track_id] = (relay, original_track) + + blackhole = Mock() + blackhole.stop = AsyncMock() + subscriber_pc._video_blackholes[track_id] = (blackhole, Mock()) + + subscriber_pc.add_track_subscriber(track_id) + + blackhole.stop.assert_called_once() + assert track_id not in subscriber_pc._video_blackholes + + def test_no_error_when_no_drain_exists(self, subscriber_pc): + track_id = "user123:video:0" + relay = MediaRelay() + original_track = Mock() + subscriber_pc.track_map[track_id] = (relay, original_track) + + result = subscriber_pc.add_track_subscriber(track_id) + assert result is not None