Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 83 additions & 10 deletions livekit-agents/livekit/agents/telemetry/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import threading
import time
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -82,6 +82,76 @@ def start_as_current_span(self, *args: Any, **kwargs: Any) -> Iterator[Span]:
tracer: _DynamicTracer = _DynamicTracer("livekit-agents")


class _UploadGate:
"""Process-wide gate that stops observability uploads once LiveKit Cloud reports data
recording is disabled for the project. Reset per session from JobContext.init_recording().
"""

# substrings identifying the 401 "data recording is disabled by owner" rejection. Other
# 401s ("missing project id", "operation requires observability write grant") share the
# same status/grpc code, so we match the message text rather than the code.
_DISABLED_MARKERS = ("data recording is disabled", "disabled by owner")

def __init__(self) -> None:
self._disabled = False

def reset(self) -> None:
self._disabled = False

@property
def disabled(self) -> bool:
return self._disabled

def disable(self) -> None:
if self._disabled:
return
self._disabled = True
logger.warning(
"LiveKit Cloud data recording is disabled for this project; "
"skipping telemetry and recording uploads for this session"
)

@staticmethod
def is_disabled_response(status_code: int, body: bytes) -> bool:
"""Return True if an upload response means recording is disabled by the project owner."""
if status_code not in (401, 403):
return False
text = body.decode("utf-8", "ignore").lower()
return any(marker in text for marker in _UploadGate._DISABLED_MARKERS)


_upload_gate = _UploadGate()


class _AuthRefreshingSession(requests.Session):
"""requests.Session shared by the OTLP exporters. Injects a fresh auth header on every
request and, once the project reports recording is disabled, stops uploading and reports
success so the exporters don't keep logging errors."""

def __init__(self, header_provider: Callable[[], dict[str, str]]) -> None:
super().__init__()
self._header_provider = header_provider

@staticmethod
def _make_ok_response() -> requests.Response:
"""A synthetic 200 response so OTLP exporters treat the export as successful."""
resp = requests.Response()
resp.status_code = 200
resp._content = b""
return resp

def request(self, *args: Any, **kwargs: Any) -> requests.Response:
if _upload_gate.disabled:
return self._make_ok_response()

self.headers.update(self._header_provider())
resp = super().request(*args, **kwargs)
if _upload_gate.is_disabled_response(resp.status_code, resp.content):
_upload_gate.disable()
return self._make_ok_response()
return resp


class _MetadataSpanProcessor(SpanProcessor):
def __init__(self, metadata: dict[str, AttributeValue]) -> None:
self._metadata = metadata
Expand Down Expand Up @@ -162,18 +232,13 @@ def _setup_cloud_tracer(
enable_traces: bool = True,
enable_logs: bool = True,
) -> None:
# new session's telemetry begins here; re-arm the upload gate so a prior session's
# "disabled" state (the providers are process-global) doesn't carry over
_upload_gate.reset()

token_ttl = timedelta(hours=6)
refresh_margin = timedelta(minutes=5)

class _AuthRefreshingSession(requests.Session):
def __init__(self, header_provider: _AuthHeaderProvider) -> None:
super().__init__()
self._header_provider = header_provider

def request(self, *args: Any, **kwargs: Any) -> requests.Response:
self.headers.update(self._header_provider())
return super().request(*args, **kwargs)

class _AuthHeaderProvider:
def __init__(self) -> None:
self._lock = threading.Lock()
Expand Down Expand Up @@ -449,6 +514,9 @@ async def _upload_session_report(
tagger: Tagger,
http_session: aiohttp.ClientSession,
) -> None:
if _upload_gate.disabled:
return

def _get_logger(name: str) -> Any:
return get_logger_provider().get_logger(
name=name,
Expand Down Expand Up @@ -633,6 +701,11 @@ def _build_multipart() -> aiohttp.MultipartWriter:
if resp.status < 400:
break

body = await resp.read()
if _upload_gate.is_disabled_response(resp.status, body):
_upload_gate.disable()
return

retry_delay = await _parse_retry_delay(resp)
if retry_delay is None or attempt == max_retries:
resp.raise_for_status()
Expand Down
138 changes: 138 additions & 0 deletions tests/test_telemetry_recording_disabled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Unit tests for the "recording disabled by owner" upload latch in
``livekit.agents.telemetry.traces``.

When a LiveKit Cloud project has data recording disabled, the OTLP exporters (and the
one-shot recording upload) get a 401/403 whose body says recording is disabled. The SDK
detects that signal, warns once per session, and short-circuits further uploads by handing
the exporter a synthetic 200 so it stops logging errors.
"""

from __future__ import annotations

import logging

import pytest

from livekit.agents.telemetry import traces

pytestmark = pytest.mark.unit


@pytest.fixture(autouse=True)
def _reset_latch():
traces._upload_gate.reset()
yield
traces._upload_gate.reset()


def _status_proto(message: str) -> bytes:
from google.rpc import status_pb2 # type: ignore[import-untyped]

status = status_pb2.Status(code=7, message=message)
return status.SerializeToString()


# the exact message LiveKit Cloud returns (cloud-observability gin.go), always inside a
# google.rpc.Status protobuf body
_DISABLED_MSG = "project data recording is disabled by owner"


@pytest.mark.parametrize(
"status_code, body, expected",
[
# the real wire shape: 401 with a protobuf google.rpc.Status body
(401, _status_proto(_DISABLED_MSG), True),
# plain-text body (defensive: if the gateway ever returns text)
(401, _DISABLED_MSG.encode(), True),
(403, b"data recording is disabled", True),
# wrong status code
(200, _status_proto(_DISABLED_MSG), False),
(500, _status_proto(_DISABLED_MSG), False),
# sibling 401s that share the same status/grpc code -> must NOT latch
(401, _status_proto("missing project id"), False),
(401, _status_proto("operation requires observability write grant"), False),
(401, b"", False),
],
)
def test_is_disabled_response(status_code: int, body: bytes, expected: bool):
assert traces._UploadGate.is_disabled_response(status_code, body) is expected


def test_make_ok_response_is_ok():
resp = traces._AuthRefreshingSession._make_ok_response()
assert resp.status_code == 200
assert resp.ok # OTLP exporters treat this as a successful export


def test_disable_uploads_warns_once_per_session(caplog):
assert not traces._upload_gate.disabled

with caplog.at_level(logging.WARNING, logger="livekit.agents"):
traces._upload_gate.disable()
traces._upload_gate.disable()

assert traces._upload_gate.disabled
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
assert len(warnings) == 1

# a new session re-arms the latch and warns again
traces._upload_gate.reset()
assert not traces._upload_gate.disabled
with caplog.at_level(logging.WARNING, logger="livekit.agents"):
traces._upload_gate.disable()
assert len([r for r in caplog.records if r.levelno == logging.WARNING]) == 2


class _FakeResponse:
def __init__(self, status_code: int, content: bytes) -> None:
self.status_code = status_code
self.content = content

@property
def text(self) -> str:
return self.content.decode("utf-8", "ignore")


def test_session_latches_and_warns_then_short_circuits(monkeypatch, caplog):
calls = {"super": 0}

def fake_request(self, *args, **kwargs):
calls["super"] += 1
return _FakeResponse(401, _status_proto(_DISABLED_MSG))

# patch the parent's request so no real network/credentials are needed
monkeypatch.setattr("requests.Session.request", fake_request, raising=True)

session = traces._AuthRefreshingSession(lambda: {"Authorization": "Bearer x"})

with caplog.at_level(logging.WARNING, logger="livekit.agents"):
first = session.request("POST", "https://example/observability/metrics/otlp/v0")
# subsequent exports must not hit the network at all
second = session.request("POST", "https://example/observability/metrics/otlp/v0")
third = session.request("POST", "https://example/observability/metrics/otlp/v0")

# OTel sees success every time -> it never logs the 401
assert first.ok and second.ok and third.ok
# only the detecting request reached the parent; later ones short-circuited
assert calls["super"] == 1
assert traces._upload_gate.disabled
assert len([r for r in caplog.records if r.levelno == logging.WARNING]) == 1


def test_session_passes_through_success_and_unrelated_errors(monkeypatch):
response = _FakeResponse(200, b"ok")

def fake_request(self, *args, **kwargs):
return response

monkeypatch.setattr("requests.Session.request", fake_request, raising=True)

session = traces._AuthRefreshingSession(lambda: {"Authorization": "Bearer x"})
assert session.request("POST", "https://example") is response
assert not traces._upload_gate.disabled

# an unrelated 401 (e.g. bad token) is returned as-is and does NOT latch
bad = _FakeResponse(401, b"invalid token")
monkeypatch.setattr("requests.Session.request", lambda self, *a, **k: bad, raising=True)
assert session.request("POST", "https://example") is bad
assert not traces._upload_gate.disabled