Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions codex_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@
"""
from __future__ import annotations

import asyncio
import json
from contextlib import AsyncExitStack
from typing import Any, Iterable

from llm_router_host import _cached_tokens
from provider_adapters.common import (
before_first_output,
first_token_timeout_err,
first_token_timeout_s,
)

CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"

Expand Down Expand Up @@ -267,6 +274,20 @@ def aggregate_codex_sse(lines: Iterable[str], latency_ms: int) -> dict:
}


def _codex_line_has_output_delta(line: str) -> bool:
line = line.strip()
if not line or not line.startswith("data:"):
return False
payload = line[len("data:"):].strip()
if payload == "[DONE]":
return False
try:
ev = json.loads(payload)
except ValueError:
return False
return ev.get("type") == "response.output_text.delta" and bool(ev.get("delta"))


def make_codex_async_call_provider(
auth,
base_url: str = CODEX_BASE_URL,
Expand Down Expand Up @@ -309,25 +330,62 @@ async def call(request: dict) -> dict:
timeout = (request.get("timeout_ms") or int(timeout_s * 1000)) / 1000.0

t0 = time.monotonic()
saw_output = False
status_seen = False
first_timeout_s = first_token_timeout_s(request)

def _latency() -> int:
return int((time.monotonic() - t0) * 1000)

def _saw_output() -> bool:
return saw_output

def _timeout_err() -> dict:
return first_token_timeout_err(first_timeout_s, _latency())

try:
async with httpx.AsyncClient(timeout=timeout) as c:
async with c.stream("POST", url, json=body, headers=headers) as resp:
async with AsyncExitStack() as stack:
try:
resp = await before_first_output(stack.enter_async_context(
c.stream("POST", url, json=body, headers=headers)),
first_timeout_s, t0, _saw_output)
except (asyncio.TimeoutError, TimeoutError):
if not status_seen:
_notify(0)
return _timeout_err()

_notify(resp.status_code, resp.headers)
latency = int((time.monotonic() - t0) * 1000)
status_seen = True
latency = _latency()
if resp.status_code == 401:
return _err("auth_error", 401, latency, "codex token rejected")
if resp.status_code == 429:
return _err("rate_limit", 429, latency, "codex rate limited")
if resp.status_code >= 400:
detail = (await resp.aread()).decode("utf-8", "replace")[:500]
return _err("server_error", resp.status_code, latency, detail)
lines = [line async for line in resp.aiter_lines()]
return aggregate_codex_sse(lines, int((time.monotonic() - t0) * 1000))
lines = []
stream_lines = resp.aiter_lines().__aiter__()
while True:
try:
line = await before_first_output(
stream_lines.__anext__(), first_timeout_s, t0, _saw_output)
except StopAsyncIteration:
break
except (asyncio.TimeoutError, TimeoutError):
if not saw_output:
return _timeout_err()
raise
lines.append(line)
if _codex_line_has_output_delta(line):
saw_output = True
return aggregate_codex_sse(lines, _latency())
except httpx.TimeoutException:
_notify(0)
return _err("timeout", 0, int((time.monotonic() - t0) * 1000), "codex request timed out")
return _err("timeout", 0, _latency(), "codex request timed out")
except (httpx.NetworkError, httpx.RequestError) as e:
_notify(0)
return _err("network_error", 0, int((time.monotonic() - t0) * 1000), str(e))
return _err("network_error", 0, _latency(), str(e))

return call
32 changes: 32 additions & 0 deletions provider_adapters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import json
import asyncio
import time
from typing import Awaitable, Callable, Any

Expand Down Expand Up @@ -159,10 +160,41 @@ def elapsed_ms(t0: float) -> int:
return int((time.monotonic() - t0) * 1000)


def first_token_timeout_s(request: dict) -> "float | None":
raw = request.get("first_token_timeout_ms")
try:
seconds = float(raw) / 1000.0
except (TypeError, ValueError):
return None
return seconds if seconds > 0 else None


async def before_first_output(
awaitable,
timeout_s: "float | None",
t0: float,
saw_output: Callable[[], bool],
):
if timeout_s is None or saw_output():
return await awaitable
remaining = timeout_s - (time.monotonic() - t0)
if remaining <= 0:
raise asyncio.TimeoutError
return await asyncio.wait_for(awaitable, timeout=remaining)


def first_token_timeout_err(timeout_s: float, latency_ms: int) -> dict:
return err("timeout", 0, latency_ms,
f"first token timed out after {int(timeout_s * 1000)}ms")


# Back-compat names used by older tests/scripts that import through
# llm_router_host's re-export layer.
_cached_tokens = cached_tokens
_classify_status = classify_status
_err = err
_elapsed_ms = elapsed_ms
_first_token_timeout_s = first_token_timeout_s
_before_first_output = before_first_output
_first_token_timeout_err = first_token_timeout_err
_provider_error_message = provider_error_message
42 changes: 20 additions & 22 deletions provider_adapters/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import json
import os
import time
from contextlib import AsyncExitStack
from typing import Awaitable, Callable, Any

from provider_adapters.common import (
CallProviderHook,
AsyncCallProviderHook,
TokenProvider,
before_first_output,
first_token_timeout_err,
first_token_timeout_s,
_cached_tokens,
_classify_status,
_elapsed_ms,
Expand Down Expand Up @@ -308,26 +312,25 @@ async def stream_openai_compatible(
usage: dict = {}
raw_model = None
saw_output = False
first_token_timeout_ms = request.get("first_token_timeout_ms")
try:
first_token_timeout_s = (
float(first_token_timeout_ms) / 1000.0
if first_token_timeout_ms is not None and float(first_token_timeout_ms) > 0
else None
)
except (TypeError, ValueError):
first_token_timeout_s = None
first_timeout_s = first_token_timeout_s(request)

def _latency() -> int:
return int((time.monotonic() - t0) * 1000)

def _first_token_timeout_error() -> dict:
return _err("timeout", 0, _latency(),
f"first token timed out after {int(first_token_timeout_s * 1000)}ms")
def _saw_output() -> bool:
return saw_output

def _timeout_err() -> dict:
return first_token_timeout_err(first_timeout_s, _latency())

try:
async with client.stream("POST", url, json=body, headers=headers,
timeout=timeout) as resp:
async with AsyncExitStack() as stack:
try:
resp = await before_first_output(stack.enter_async_context(
client.stream("POST", url, json=body, headers=headers,
timeout=timeout)), first_timeout_s, t0, _saw_output)
except (asyncio.TimeoutError, TimeoutError):
return _timeout_err()
if not (200 <= resp.status_code < 300):
raw = (await resp.aread()).decode("utf-8", "replace")[:500]
kind = _classify_from_map(raw, rules.get("error_map")) \
Expand All @@ -337,18 +340,13 @@ def _first_token_timeout_error() -> dict:
lines = resp.aiter_lines().__aiter__()
while True:
try:
if first_token_timeout_s is not None and not saw_output:
remaining = first_token_timeout_s - (time.monotonic() - t0)
if remaining <= 0:
return _first_token_timeout_error()
line = await asyncio.wait_for(lines.__anext__(), timeout=remaining)
else:
line = await lines.__anext__()
line = await before_first_output(
lines.__anext__(), first_timeout_s, t0, _saw_output)
except StopAsyncIteration:
break
except (asyncio.TimeoutError, TimeoutError):
if not saw_output:
return _first_token_timeout_error()
return _timeout_err()
raise
if not line or not line.startswith("data:"):
continue
Expand Down
37 changes: 33 additions & 4 deletions streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@

import json
import re
import asyncio
import time
import uuid
from contextlib import AsyncExitStack
from typing import Any, Awaitable, Callable

from provider_adapters.common import _err
from provider_adapters.common import (
_err,
before_first_output,
first_token_timeout_err,
first_token_timeout_s,
)
# The openai-compatible STREAM backend lives with its non-stream sibling in the
# adapter leaf (it shares _prepare_openai_call); re-exported here so the dispatcher
# and shim keep their import site. This module → provider_adapters (the allowed
Expand Down Expand Up @@ -76,13 +83,25 @@ def _notify(status: int, headers=None) -> None:
t0 = time.monotonic()
emitted = False
lines: list[str] = []
first_timeout_s = first_token_timeout_s(request)

def _latency() -> int:
return int((time.monotonic() - t0) * 1000)

def _emitted() -> bool:
return emitted

def _timeout_err() -> dict:
return first_token_timeout_err(first_timeout_s, _latency())

try:
async with client.stream("POST", url, json=body, headers=headers,
timeout=timeout) as resp:
async with AsyncExitStack() as stack:
try:
resp = await before_first_output(stack.enter_async_context(
client.stream("POST", url, json=body, headers=headers,
timeout=timeout)), first_timeout_s, t0, _emitted)
except (asyncio.TimeoutError, TimeoutError):
return _timeout_err()
_notify(resp.status_code, resp.headers)
if resp.status_code == 401:
return _err("auth_error", 401, _latency(), "codex token rejected")
Expand All @@ -91,7 +110,17 @@ def _latency() -> int:
if resp.status_code >= 400:
detail = (await resp.aread()).decode("utf-8", "replace")[:500]
return _err("server_error", resp.status_code, _latency(), detail)
async for line in resp.aiter_lines():
stream_lines = resp.aiter_lines().__aiter__()
while True:
try:
line = await before_first_output(
stream_lines.__anext__(), first_timeout_s, t0, _emitted)
except StopAsyncIteration:
break
except (asyncio.TimeoutError, TimeoutError):
if not emitted:
return _timeout_err()
raise
lines.append(line)
if line.startswith("data:"):
data = line[len("data:"):].strip()
Expand Down
Loading