Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion getstream/video/rtc/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,18 @@ def __init__(
create: bool = True,
subscription_config: Optional[SubscriptionConfig] = None,
max_join_retries: int = 3,
drain_video_frames: bool = False,
drain_video_frames: bool = True,
**kwargs: Any,
):
"""
Args:
drain_video_frames: When True, attaches a MediaBlackhole to each
incoming video track so unconsumed frames are drained
automatically. This prevents unbounded queue growth in
RTCRtpReceiver when no subscriber is consuming the track.
The drain is stopped once a real subscriber is added via
add_track_subscriber.
"""
super().__init__()

# Public attributes
Expand Down
21 changes: 13 additions & 8 deletions getstream/video/rtc/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self,
connection,
configuration: aiortc.RTCConfiguration,
drain_video_frames: bool = False,
drain_video_frames: bool = True,
) -> None:
logger.info(
f"creating subscriber peer connection with configuration: {configuration}"
Expand All @@ -142,8 +142,8 @@ def __init__(

self.track_map = {} # track_id -> (MediaRelay, original_track)
self.video_frame_trackers = {} # track_id -> VideoFrameTracker
self._video_blackholes: dict[str, MediaBlackhole] = {}
self._video_drain_tasks: dict[str, asyncio.Task] = {}
self._video_blackholes: dict[str, tuple[MediaBlackhole, asyncio.Task]] = {}
self._background_tasks: set[asyncio.Task] = set()

@self.on("track")
async def on_track(track: aiortc.mediastreams.MediaStreamTrack):
Expand Down Expand Up @@ -189,11 +189,8 @@ def _emit_pcm(pcm: PcmData):
drain_proxy = relay.subscribe(tracked_track)
blackhole = MediaBlackhole()
blackhole.addTrack(drain_proxy)
self._video_blackholes[track.id] = blackhole
self._video_drain_tasks[track.id] = asyncio.create_task(
blackhole.start()
)

drain_task = asyncio.create_task(blackhole.start())
self._video_blackholes[track.id] = (blackhole, drain_task)
self.emit("track_added", proxy, user)

@self.on("icegatheringstatechange")
Expand All @@ -208,6 +205,14 @@ def add_track_subscriber(
"""Add a new subscriber to an existing track's MediaRelay."""
track_data = self.track_map.get(track_id)

blackhole, drain_task = self._video_blackholes.pop(track_id, (None, None))

if blackhole and drain_task:
task = asyncio.create_task(blackhole.stop())
drain_task.cancel() # safety net if start() becomes long-lived in future aiortc
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)

if track_data:
relay, original_track = track_data
return relay.subscribe(original_track, buffered=False)
Expand Down
2 changes: 1 addition & 1 deletion getstream/video/rtc/peer_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class PeerConnectionManager:
"""Manages WebRTC peer connections for publishing and subscribing."""

def __init__(self, connection_manager, drain_video_frames: bool = False):
def __init__(self, connection_manager, drain_video_frames: bool = True):
self.connection_manager = connection_manager
self._drain_video_frames = drain_video_frames
self.publisher_pc: Optional[PublisherPeerConnection] = None
Expand Down
49 changes: 49 additions & 0 deletions tests/rtc/test_subscriber_drain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for SubscriberPeerConnection video drain behavior."""

from unittest.mock import AsyncMock, Mock

import pytest
from aiortc.contrib.media import MediaRelay

from getstream.video.rtc.pc import SubscriberPeerConnection


@pytest.fixture
def subscriber_pc():
"""Create a SubscriberPeerConnection bypassing heavy parent inits."""
pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection)
pc.connection = Mock()
pc._drain_video_frames = True
pc.track_map = {}
pc.video_frame_trackers = {}
pc._video_blackholes = {}
pc._background_tasks = set()
pc._listeners = {}
return pc


class TestAddTrackSubscriberStopsDrain:
@pytest.mark.asyncio
async def test_blackhole_stopped_when_subscriber_added(self, subscriber_pc):
track_id = "user123:video:0"
relay = MediaRelay()
original_track = Mock()
subscriber_pc.track_map[track_id] = (relay, original_track)

blackhole = Mock()
blackhole.stop = AsyncMock()
subscriber_pc._video_blackholes[track_id] = (blackhole, Mock())

subscriber_pc.add_track_subscriber(track_id)

blackhole.stop.assert_called_once()
assert track_id not in subscriber_pc._video_blackholes

def test_no_error_when_no_drain_exists(self, subscriber_pc):
track_id = "user123:video:0"
relay = MediaRelay()
original_track = Mock()
subscriber_pc.track_map[track_id] = (relay, original_track)

result = subscriber_pc.add_track_subscriber(track_id)
assert result is not None