Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
Support for OpenAI Realtime API, LLM, TTS, and STT APIs.

Also includes support for a large number of OpenAI-compatible APIs including Azure OpenAI, Cerebras,
Fireworks, Perplexity, Telnyx, xAI, Ollama, DeepSeek, OpenRouter, and OVHcloud AI Endpoints.
Fireworks, Perplexity, Telnyx, xAI, Ollama, DeepSeek, OpenRouter, Cloudflare AI Gateway, and
OVHcloud AI Endpoints.

See https://docs.livekit.io/agents/integrations/openai/ and
https://docs.livekit.io/agents/integrations/llm/ for more information.
Expand All @@ -27,6 +28,7 @@
from .embeddings import EmbeddingData, create_embeddings
from .llm import LLM, LLMStream
from .models import (
CloudflareGatewayOptions,
OpenRouterProviderPreferences,
OpenRouterWebPlugin,
STTModels,
Expand All @@ -42,6 +44,7 @@
"TTS",
"LLM",
"LLMStream",
"CloudflareGatewayOptions",
"OpenRouterProviderPreferences",
"OpenRouterWebPlugin",
"STTModels",
Expand Down
119 changes: 119 additions & 0 deletions livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import json
import os
from dataclasses import asdict, dataclass
from typing import Any, Literal
Expand Down Expand Up @@ -42,6 +43,7 @@
from .models import (
CerebrasChatModels,
ChatModels,
CloudflareGatewayOptions,
CometAPIChatModels,
DeepSeekChatModels,
NebiusChatModels,
Expand Down Expand Up @@ -938,6 +940,123 @@ def with_letta(
tool_choice=NOT_GIVEN,
)

@staticmethod
def with_cloudflare(
*,
model: str,
account_id: str | None = None,
api_key: str | None = None,
gateway_id: str | None = None,
base_url: str | None = None,
gateway_options: CloudflareGatewayOptions | None = None,
client: openai.AsyncClient | None = None,
user: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: ToolChoice = "auto",
reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN,
safety_identifier: NotGivenOr[str] = NOT_GIVEN,
prompt_cache_key: NotGivenOr[str] = NOT_GIVEN,
top_p: NotGivenOr[float] = NOT_GIVEN,
timeout: httpx.Timeout | None = None,
) -> LLM:
"""
Create a new instance of an LLM backed by the Cloudflare AI Gateway.

Uses the gateway's OpenAI-compatible REST API
(``https://api.cloudflare.com/client/v4/accounts/<account_id>/ai/v1``). The endpoint URL
is built from ``account_id`` unless an explicit ``base_url`` is given, and the model is a
``provider/model`` string.

Args:
model (str): Model in ``provider/model`` form, e.g.
``"workers-ai/@cf/meta/llama-3.3-70b-instruct-fp8-fast"`` or ``"openai/gpt-4o"``.
account_id (str | None, optional): Cloudflare account ID used to build the endpoint
URL. Falls back to ``CLOUDFLARE_ACCOUNT_ID``. Required unless ``base_url`` is set.
api_key (str | None, optional): Cloudflare API token with the ``AI Gateway``
permission, sent as the ``Authorization: Bearer`` header. Falls back to
``CLOUDFLARE_API_KEY``.
gateway_id (str | None, optional): Route through a specific gateway via the
``cf-aig-gateway-id`` header. Defaults to the account's default gateway.
base_url (str | None, optional): Full endpoint URL, e.g.
``"https://api.cloudflare.com/client/v4/accounts/<account_id>/ai/v1"``.
Overrides ``account_id`` when provided.
gateway_options (CloudflareGatewayOptions | None, optional): Per-request gateway
options (caching, retries, timeout, metadata), translated into ``cf-aig-*``
request headers.

Returns:
LLM: A configured LLM instance routed through the Cloudflare AI Gateway.
"""

api_key = api_key or os.environ.get("CLOUDFLARE_API_KEY")
if not api_key:
raise ValueError(
"Cloudflare API token is required, either as argument or set "
"CLOUDFLARE_API_KEY environment variable"
)

if base_url is None:
account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
if account_id is None:
raise ValueError(
"Cloudflare account_id is required, either as argument or set "
"CLOUDFLARE_ACCOUNT_ID environment variable (or pass base_url directly)"
)
base_url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/v1"

parsed = urlparse(base_url)
if parsed.scheme not in {"http", "https"}:
raise ValueError(f"Invalid URL scheme: '{parsed.scheme}'. Must be 'http' or 'https'.")
if not parsed.netloc:
raise ValueError(f"URL '{base_url}' is missing a network location (e.g., domain name).")

default_headers: dict[str, str] = {}
if gateway_id:
default_headers["cf-aig-gateway-id"] = gateway_id

if gateway_options:
if "cache_ttl" in gateway_options:
default_headers["cf-aig-cache-ttl"] = str(gateway_options["cache_ttl"])
if gateway_options.get("skip_cache"):
default_headers["cf-aig-skip-cache"] = "true"
if "cache_key" in gateway_options:
default_headers["cf-aig-cache-key"] = gateway_options["cache_key"]
if "request_timeout" in gateway_options:
default_headers["cf-aig-request-timeout"] = str(gateway_options["request_timeout"])
if "max_attempts" in gateway_options:
default_headers["cf-aig-max-attempts"] = str(gateway_options["max_attempts"])
if "retry_delay" in gateway_options:
default_headers["cf-aig-retry-delay"] = str(gateway_options["retry_delay"])
if "backoff" in gateway_options:
default_headers["cf-aig-backoff"] = gateway_options["backoff"]
if "collect_log" in gateway_options:
default_headers["cf-aig-collect-log"] = (
"true" if gateway_options["collect_log"] else "false"
)
if "metadata" in gateway_options:
metadata = gateway_options["metadata"]
default_headers["cf-aig-metadata"] = (
metadata if isinstance(metadata, str) else json.dumps(metadata)
)

return LLM(
model=model,
api_key=api_key,
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
reasoning_effort=reasoning_effort,
safety_identifier=safety_identifier,
prompt_cache_key=prompt_cache_key,
top_p=top_p,
extra_headers=default_headers,
timeout=timeout,
)

def chat(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,29 @@ class OpenRouterProviderPreferences(TypedDict, total=False):
quantizations: list[str]
sort: Literal["price", "throughput", "latency"]
max_price: dict[str, float]


class CloudflareGatewayOptions(TypedDict, total=False):
"""Per-request Cloudflare AI Gateway options, mapped to ``cf-aig-*`` request headers.

See https://developers.cloudflare.com/ai-gateway/configuration/ for details.
"""

cache_ttl: int
"""Cache duration in seconds (``cf-aig-cache-ttl``)."""
skip_cache: bool
"""Bypass the cache for this request (``cf-aig-skip-cache``)."""
cache_key: str
"""Override the default cache key (``cf-aig-cache-key``)."""
request_timeout: int
"""Per-request timeout in milliseconds (``cf-aig-request-timeout``)."""
max_attempts: int
"""Maximum number of request attempts (``cf-aig-max-attempts``)."""
retry_delay: int
"""Delay between retries in milliseconds (``cf-aig-retry-delay``)."""
backoff: Literal["constant", "linear", "exponential"]
"""Retry backoff strategy (``cf-aig-backoff``)."""
collect_log: bool
"""Enable or disable logging for this request (``cf-aig-collect-log``)."""
metadata: dict[str, str | int | bool] | str
"""Custom metadata attached to the request (``cf-aig-metadata``); a dict or a JSON string."""
149 changes: 149 additions & 0 deletions tests/test_openai_with_cloudflare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

import json

import pytest

from livekit.plugins import openai

pytestmark = pytest.mark.unit

_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/acct/ai/v1"


@pytest.fixture(autouse=True)
def _clear_cloudflare_env(monkeypatch: pytest.MonkeyPatch) -> None:
# keep construction deterministic regardless of the host environment
monkeypatch.delenv("CLOUDFLARE_API_KEY", raising=False)
monkeypatch.delenv("CLOUDFLARE_ACCOUNT_ID", raising=False)


def test_builds_rest_api_url_from_account() -> None:
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct", api_key="cf-tok")
assert str(llm._client.base_url).rstrip("/") == _BASE_URL


def test_account_id_falls_back_to_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CLOUDFLARE_ACCOUNT_ID", "env-acct")
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", api_key="cf-tok")
assert "/accounts/env-acct/ai/v1" in str(llm._client.base_url)


def test_api_key_falls_back_to_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CLOUDFLARE_API_KEY", "env-tok")
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct")
assert llm._client.api_key == "env-tok"


def test_token_sent_as_bearer_authorization() -> None:
# the Cloudflare API token is the OpenAI client's api_key -> Authorization: Bearer <token>
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct", api_key="cf-tok")
assert llm._client.api_key == "cf-tok"
# the deprecated cf-aig-authorization header is no longer used
assert "cf-aig-authorization" not in (llm._opts.extra_headers or {})


def test_base_url_overrides_account_id() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", account_id="ignored", base_url=_BASE_URL, api_key="cf-tok"
)
assert str(llm._client.base_url).rstrip("/") == _BASE_URL


def test_gateway_id_sets_header() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", account_id="acct", api_key="cf-tok", gateway_id="prod"
)
assert llm._opts.extra_headers["cf-aig-gateway-id"] == "prod"


def test_no_gateway_id_omits_header() -> None:
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct", api_key="cf-tok")
assert "cf-aig-gateway-id" not in (llm._opts.extra_headers or {})


def test_gateway_options_map_to_headers() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={
"cache_ttl": 3600,
"cache_key": "k1",
"request_timeout": 2000,
"max_attempts": 3,
"retry_delay": 500,
"backoff": "exponential",
"metadata": {"room": "r1", "turn": 4, "live": True},
},
)
headers = llm._opts.extra_headers
assert headers["cf-aig-cache-ttl"] == "3600"
assert headers["cf-aig-cache-key"] == "k1"
assert headers["cf-aig-request-timeout"] == "2000"
assert headers["cf-aig-max-attempts"] == "3"
assert headers["cf-aig-retry-delay"] == "500"
assert headers["cf-aig-backoff"] == "exponential"
assert json.loads(headers["cf-aig-metadata"]) == {"room": "r1", "turn": 4, "live": True}


def test_metadata_accepts_json_string() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={"metadata": '{"room":"r1"}'},
)
# a pre-serialized JSON string is passed through unchanged
assert llm._opts.extra_headers["cf-aig-metadata"] == '{"room":"r1"}'


def test_collect_log_emitted_true_or_false() -> None:
on = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={"collect_log": True},
)
assert on._opts.extra_headers["cf-aig-collect-log"] == "true"

off = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={"collect_log": False},
)
assert off._opts.extra_headers["cf-aig-collect-log"] == "false"


def test_skip_cache_header_only_emitted_when_true() -> None:
enabled = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={"skip_cache": True},
)
assert enabled._opts.extra_headers["cf-aig-skip-cache"] == "true"

disabled = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
account_id="acct",
api_key="cf-tok",
gateway_options={"skip_cache": False},
)
assert "cf-aig-skip-cache" not in disabled._opts.extra_headers


def test_invalid_base_url_raises() -> None:
with pytest.raises(ValueError):
openai.LLM.with_cloudflare(model="openai/gpt-4o", base_url="not-a-url", api_key="cf-tok")


def test_missing_api_key_raises() -> None:
with pytest.raises(ValueError, match=r"API token"):
openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct")


def test_missing_account_id_raises() -> None:
with pytest.raises(ValueError, match=r"account_id"):
openai.LLM.with_cloudflare(model="openai/gpt-4o", api_key="cf-tok")