From 61612a8defaf4b5a415c384686f3f4570df44f54 Mon Sep 17 00:00:00 2001 From: Mateusz Kuprowski Date: Fri, 24 Apr 2026 16:08:35 +0200 Subject: [PATCH 1/4] First iteration of changes to SDK handling new client secrets and token-exchange --- .../unit/auth/__init__.py | 1 + .../unit/auth/_mock_transport.py | 95 ++++++ .../auth/test_async_client_credentials.py | 167 +++++++++ .../unit/auth/test_auth_header_hook.py | 212 ++++++++++++ .../unit/auth/test_client_credentials.py | 321 ++++++++++++++++++ .../unit/auth/test_legacy_key_exchange.py | 99 ++++++ .../_hooks/custom/__init__.py | 1 + .../_hooks/custom/auth_header_hook.py | 55 +++ .../_hooks/registration.py | 11 +- src/unstructured_client/auth/__init__.py | 40 +++ src/unstructured_client/auth/_base.py | 193 +++++++++++ src/unstructured_client/auth/_exceptions.py | 24 ++ .../auth/client_credentials.py | 276 +++++++++++++++ .../auth/legacy_api_key.py | 88 +++++ src/unstructured_client/sdk.py | 11 +- 15 files changed, 1590 insertions(+), 4 deletions(-) create mode 100644 _test_unstructured_client/unit/auth/__init__.py create mode 100644 _test_unstructured_client/unit/auth/_mock_transport.py create mode 100644 _test_unstructured_client/unit/auth/test_async_client_credentials.py create mode 100644 _test_unstructured_client/unit/auth/test_auth_header_hook.py create mode 100644 _test_unstructured_client/unit/auth/test_client_credentials.py create mode 100644 _test_unstructured_client/unit/auth/test_legacy_key_exchange.py create mode 100644 src/unstructured_client/_hooks/custom/auth_header_hook.py create mode 100644 src/unstructured_client/auth/__init__.py create mode 100644 src/unstructured_client/auth/_base.py create mode 100644 src/unstructured_client/auth/_exceptions.py create mode 100644 src/unstructured_client/auth/client_credentials.py create mode 100644 src/unstructured_client/auth/legacy_api_key.py diff --git a/_test_unstructured_client/unit/auth/__init__.py b/_test_unstructured_client/unit/auth/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/_test_unstructured_client/unit/auth/__init__.py @@ -0,0 +1 @@ + diff --git a/_test_unstructured_client/unit/auth/_mock_transport.py b/_test_unstructured_client/unit/auth/_mock_transport.py new file mode 100644 index 00000000..857f78b0 --- /dev/null +++ b/_test_unstructured_client/unit/auth/_mock_transport.py @@ -0,0 +1,95 @@ +"""Shared helpers for exercising token-exchange auth callables. + +The mock transports here let tests script a sequence of responses / exceptions +for the ``POST /auth/token-exchange`` endpoint without standing up a real +account-service. +""" + +from __future__ import annotations + +import json +from typing import Any, Callable, Iterable, List, Optional, Union + +import httpx + + +ResponseStep = Union[httpx.Response, Exception, Callable[[httpx.Request], httpx.Response]] + + +class ScriptedTransport(httpx.MockTransport): + """A MockTransport that walks through a scripted sequence of responses. + + Each element can be an :class:`httpx.Response`, an ``Exception`` instance + (raised instead of returned), or a callable that accepts the request and + returns a response. Tests can inspect :attr:`requests` to assert how many + exchanges took place and what bodies were sent. + """ + + def __init__(self, steps: Iterable[ResponseStep]) -> None: + self._steps: List[ResponseStep] = list(steps) + self.requests: List[httpx.Request] = [] + super().__init__(self._handler) + + def _handler(self, request: httpx.Request) -> httpx.Response: + self.requests.append(request) + if not self._steps: + raise AssertionError( + "ScriptedTransport exhausted; unexpected extra request to " + f"{request.url}", + ) + step = self._steps.pop(0) + if isinstance(step, Exception): + raise step + if callable(step): + return step(request) + return step + + +class AsyncScriptedTransport(httpx.MockTransport): + """Async counterpart to :class:`ScriptedTransport`.""" + + def __init__(self, steps: Iterable[ResponseStep]) -> None: + self._steps: List[ResponseStep] = list(steps) + self.requests: List[httpx.Request] = [] + + async def _handler(request: httpx.Request) -> httpx.Response: + self.requests.append(request) + if not self._steps: + raise AssertionError( + "AsyncScriptedTransport exhausted; unexpected extra " + f"request to {request.url}", + ) + step = self._steps.pop(0) + if isinstance(step, Exception): + raise step + if callable(step): + return step(request) + return step + + super().__init__(_handler) + + +def exchange_response( + access_token: Optional[str] = "jwt-1", + *, + expires_in: int = 900, + token_exchange_enabled: bool = True, + token_type: str = "bearer", + status_code: int = 200, + extra: Optional[dict] = None, +) -> httpx.Response: + """Build a canned ``/auth/token-exchange`` response body.""" + body: dict[str, Any] = { + "access_token": access_token, + "token_type": token_type, + "expires_in": expires_in, + "token_exchange_enabled": token_exchange_enabled, + } + if extra: + body.update(extra) + return httpx.Response(status_code, json=body) + + +def body_of(request: httpx.Request) -> dict: + """Decode the JSON body from an outgoing exchange request.""" + return json.loads(request.content.decode("utf-8")) diff --git a/_test_unstructured_client/unit/auth/test_async_client_credentials.py b/_test_unstructured_client/unit/auth/test_async_client_credentials.py new file mode 100644 index 00000000..7f533247 --- /dev/null +++ b/_test_unstructured_client/unit/auth/test_async_client_credentials.py @@ -0,0 +1,167 @@ +"""Unit tests for :class:`unstructured_client.auth.AsyncClientCredentials`.""" + +from __future__ import annotations + +import asyncio +from typing import List + +import httpx +import pytest + +from unstructured_client.auth import ( + AsyncClientCredentials, + InvalidCredentialError, + TokenExchangeError, +) + +from ._mock_transport import AsyncScriptedTransport, body_of, exchange_response + +SERVER_URL = "https://accounts.example.test" +SECRET = "uns_sk_async_example" + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch): + async def _noop(*_args, **_kwargs): + return None + + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.asyncio.sleep", + _noop, + ) + + +@pytest.fixture +def fake_clock(monkeypatch): + state = {"now": 2_000_000.0} + + def _now() -> float: + return state["now"] + + monkeypatch.setattr("unstructured_client.auth._base.time.monotonic", _now) + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.monotonic", _now + ) + return state + + +class DescribeAsyncClientCredentials: + @pytest.mark.asyncio + async def it_exchanges_then_caches(self, fake_clock): + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + first = await acc.acquire() + second = await acc.acquire() + + assert first == second == "jwt-1" + assert len(transport.requests) == 1 + assert body_of(transport.requests[0]) == { + "grant_type": "client_credentials", + "client_secret": SECRET, + } + + @pytest.mark.asyncio + async def it_raises_invalid_credential_on_401(self, fake_clock): + transport = AsyncScriptedTransport( + [httpx.Response(401, json={"detail": "bad"})] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=5, + ) + + with pytest.raises(InvalidCredentialError): + await acc.acquire() + + @pytest.mark.asyncio + async def it_retries_5xx_then_succeeds(self, fake_clock): + transport = AsyncScriptedTransport( + [ + httpx.Response(500), + httpx.Response(502), + exchange_response(access_token="jwt-1", expires_in=900), + ] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=3, + ) + + assert await acc.acquire() == "jwt-1" + assert len(transport.requests) == 3 + + @pytest.mark.asyncio + async def it_serializes_concurrent_acquires(self, fake_clock): + """Ten concurrent ``acquire()`` calls must share one exchange.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + results: List[str] = await asyncio.gather(*(acc.acquire() for _ in range(10))) + + assert results == ["jwt-1"] * 10 + assert len(transport.requests) == 1 + + @pytest.mark.asyncio + async def it_raises_outage_error_without_cached_token(self, fake_clock): + transport = AsyncScriptedTransport([httpx.Response(500)] * 4) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=3, + ) + + with pytest.raises(TokenExchangeError): + await acc.acquire() + + def it_sync_call_works_outside_running_loop(self, fake_clock): + """``__call__`` is the SDK entry point; must work without a loop.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + assert acc() == "jwt-1" + + @pytest.mark.asyncio + async def it_sync_call_works_inside_running_loop(self, fake_clock): + """Driving __call__ from a running loop offloads to a worker thread.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + token = await asyncio.to_thread(acc) + assert token == "jwt-1" diff --git a/_test_unstructured_client/unit/auth/test_auth_header_hook.py b/_test_unstructured_client/unit/auth/test_auth_header_hook.py new file mode 100644 index 00000000..1477cd5a --- /dev/null +++ b/_test_unstructured_client/unit/auth/test_auth_header_hook.py @@ -0,0 +1,212 @@ +"""Tests for :class:`AuthHeaderBeforeRequestHook`. + +These verify the header swap happens for our token-exchange callables and is +a no-op for plain string ``api_key_auth`` or arbitrary user-supplied +callables. +""" + +from __future__ import annotations + +from typing import Optional + +import httpx +import pytest + +from unstructured_client import UnstructuredClient +from unstructured_client._hooks.custom.auth_header_hook import ( + AuthHeaderBeforeRequestHook, +) +from unstructured_client._hooks.types import BeforeRequestContext, HookContext +from unstructured_client.auth import ClientCredentials, LegacyKeyExchange + +from ._mock_transport import ScriptedTransport, exchange_response + +ACCOUNTS_URL = "https://accounts.example.test" +SERVER_URL = "https://api.example.test" +SECRET = "uns_sk_hook_test" +FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + +def _make_request(headers: dict) -> httpx.Request: + return httpx.Request("GET", "https://api.example.test/api/v1/jobs/", headers=headers) + + +def _make_hook_ctx(security_source) -> BeforeRequestContext: + inner = HookContext( + config=None, # type: ignore[arg-type] + base_url=SERVER_URL, + operation_id="cancel_job", + oauth2_scopes=[], + security_source=security_source, + ) + return BeforeRequestContext(inner) + + +class DescribeAuthHeaderHookDirect: + def it_rewrites_header_when_source_is_client_credentials(self): + transport = ScriptedTransport([exchange_response()]) + cc = ClientCredentials( + client_secret=SECRET, + server_url=ACCOUNTS_URL, + http_client=httpx.Client(transport=transport), + ) + + # Simulate what sdk.py builds: a factory with __wrapped_callable__ + def factory(): + return None + + setattr(factory, "__wrapped_callable__", cc) + + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "jwt-value"}) + + result = hook.before_request(_make_hook_ctx(factory), request) + + assert isinstance(result, httpx.Request) + assert result.headers.get("Authorization") == "Bearer jwt-value" + assert "unstructured-api-key" not in result.headers + + def it_rewrites_header_when_source_is_legacy_key_exchange(self): + transport = ScriptedTransport([exchange_response()]) + lke = LegacyKeyExchange( + api_key="legacy", + server_url=ACCOUNTS_URL, + http_client=httpx.Client(transport=transport), + ) + + def factory(): + return None + + setattr(factory, "__wrapped_callable__", lke) + + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "jwt-value"}) + + result = hook.before_request(_make_hook_ctx(factory), request) + + assert isinstance(result, httpx.Request) + assert result.headers.get("Authorization") == "Bearer jwt-value" + assert "unstructured-api-key" not in result.headers + + def it_is_noop_for_plain_string_security_source(self): + # When api_key_auth is a string, sdk.py passes a `shared.Security` + # instance as `security_source`, not a callable. The hook must + # leave the request untouched. + from unstructured_client.models import shared + + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": FAKE_KEY}) + + result = hook.before_request( + _make_hook_ctx(shared.Security(api_key_auth=FAKE_KEY)), + request, + ) + + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == FAKE_KEY + assert "Authorization" not in result.headers + + def it_is_noop_for_arbitrary_user_callable(self): + def user_callable() -> str: + return "whatever" + + def factory(): + return None + + setattr(factory, "__wrapped_callable__", user_callable) + + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "whatever"}) + + result = hook.before_request(_make_hook_ctx(factory), request) + + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == "whatever" + assert "Authorization" not in result.headers + + def it_is_noop_when_security_source_is_none(self): + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": FAKE_KEY}) + + result = hook.before_request(_make_hook_ctx(None), request) + + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == FAKE_KEY + assert "Authorization" not in result.headers + + +class DescribeAuthHeaderHookIntegration: + """End-to-end: instantiate :class:`UnstructuredClient` with a + :class:`ClientCredentials` and assert the outgoing request to the + downstream API carries ``Authorization: Bearer `` and no + ``unstructured-api-key``. + """ + + def it_sends_bearer_header_for_client_credentials(self): + exchange_transport = ScriptedTransport( + [exchange_response(access_token="jwt-abc", expires_in=900)] + ) + exchange_http_client = httpx.Client(transport=exchange_transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=ACCOUNTS_URL, + http_client=exchange_http_client, + ) + + captured: dict = {} + + def _mock(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={}) + + downstream_transport = httpx.MockTransport(_mock) + downstream_client = httpx.Client(transport=downstream_transport) + + session = UnstructuredClient( + api_key_auth=cc, + client=downstream_client, + server_url=SERVER_URL, + ) + + try: + # Any operation triggers a request; cancel_job is lightweight. + from unstructured_client.models import operations + + session.jobs.cancel_job( + request=operations.CancelJobRequest(job_id="test-job-id"), + ) + except Exception: # noqa: BLE001 + # The mocked 200 with empty JSON won't unmarshal correctly, but + # by then the request already fired and the header was captured. + pass + + headers = captured.get("headers", {}) + assert headers.get("authorization") == "Bearer jwt-abc" + assert "unstructured-api-key" not in {k.lower() for k in headers} + + def it_leaves_legacy_path_unchanged_for_plain_string(self): + captured: dict = {} + + def _mock(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={}) + + client = httpx.Client(transport=httpx.MockTransport(_mock)) + session = UnstructuredClient( + api_key_auth=FAKE_KEY, + client=client, + server_url=SERVER_URL, + ) + + try: + from unstructured_client.models import operations + + session.jobs.cancel_job( + request=operations.CancelJobRequest(job_id="test-job-id"), + ) + except Exception: # noqa: BLE001 + pass + + headers = captured.get("headers", {}) + assert headers.get("unstructured-api-key") == FAKE_KEY + assert "authorization" not in {k.lower() for k in headers} diff --git a/_test_unstructured_client/unit/auth/test_client_credentials.py b/_test_unstructured_client/unit/auth/test_client_credentials.py new file mode 100644 index 00000000..7e606527 --- /dev/null +++ b/_test_unstructured_client/unit/auth/test_client_credentials.py @@ -0,0 +1,321 @@ +"""Unit tests for :class:`unstructured_client.auth.ClientCredentials`. + +Uses :class:`httpx.MockTransport` to script the ``/auth/token-exchange`` +endpoint; no real network IO. +""" + +from __future__ import annotations + +import logging +import threading +from typing import List + +import httpx +import pytest + +from unstructured_client.auth import ( + ClientCredentials, + InvalidCredentialError, + TokenExchangeDisabledError, + TokenExchangeError, +) + +from ._mock_transport import ( + ScriptedTransport, + body_of, + exchange_response, +) + +SERVER_URL = "https://accounts.example.test" +SECRET = "uns_sk_example_secret" + + +def _make_client_credentials( + steps: List, + *, + refresh_buffer_seconds: int = 60, + max_retries: int = 3, +) -> tuple[ClientCredentials, ScriptedTransport]: + """Build a :class:`ClientCredentials` wired to a scripted transport.""" + transport = ScriptedTransport(steps) + http_client = httpx.Client(transport=transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + refresh_buffer_seconds=refresh_buffer_seconds, + max_retries=max_retries, + http_client=http_client, + ) + return cc, transport + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch): + """Neutralize exponential-backoff sleeps so 5xx tests run instantly.""" + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.sleep", + lambda *_args, **_kwargs: None, + ) + + +@pytest.fixture +def fake_clock(monkeypatch): + """Controllable ``time.monotonic`` shared by the auth module.""" + state = {"now": 1_000_000.0} + + def _now() -> float: + return state["now"] + + monkeypatch.setattr("unstructured_client.auth._base.time.monotonic", _now) + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.monotonic", _now + ) + return state + + +class DescribeClientCredentialsFirstExchange: + def it_posts_client_credentials_body(self, fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + + token = cc() + + assert token == "jwt-1" + assert len(transport.requests) == 1 + req = transport.requests[0] + assert req.method == "POST" + assert req.url.path == "/auth/token-exchange" + assert req.headers["content-type"] == "application/json" + assert body_of(req) == { + "grant_type": "client_credentials", + "client_secret": SECRET, + } + + def it_strips_trailing_slash_from_server_url(self, fake_clock): + transport = ScriptedTransport([exchange_response()]) + http_client = httpx.Client(transport=transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=f"{SERVER_URL}/", + http_client=http_client, + ) + + cc() + + assert str(transport.requests[0].url).endswith("/auth/token-exchange") + assert "//auth/token-exchange" not in str(transport.requests[0].url) + + +class DescribeClientCredentialsCaching: + def it_returns_cached_jwt_within_ttl(self, fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + + first = cc() + second = cc() + third = cc() + + assert first == second == third == "jwt-1" + assert len(transport.requests) == 1 + + def it_refreshes_when_within_buffer_of_expiry(self, fake_clock): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + exchange_response(access_token="jwt-2", expires_in=900), + ], + refresh_buffer_seconds=60, + ) + + assert cc() == "jwt-1" + fake_clock["now"] += 900 - 59 # within the 60s refresh buffer + assert cc() == "jwt-2" + assert len(transport.requests) == 2 + + def it_does_not_refresh_outside_buffer(self, fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)], + refresh_buffer_seconds=60, + ) + + cc() + fake_clock["now"] += 900 - 120 # still 120s from expiry + cc() + + assert len(transport.requests) == 1 + + +class DescribeClientCredentialsErrors: + def it_raises_invalid_credential_on_401_without_retry(self, fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(401, json={"detail": "invalid"})], + max_retries=5, + ) + + with pytest.raises(InvalidCredentialError): + cc() + assert len(transport.requests) == 1 + + def it_raises_on_400_without_retry(self, fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(400, json={"detail": "bad"})], + max_retries=5, + ) + + with pytest.raises(TokenExchangeError): + cc() + assert len(transport.requests) == 1 + + def it_raises_disabled_when_server_opts_out(self, fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token=None, expires_in=0, token_exchange_enabled=False)] + ) + + with pytest.raises(TokenExchangeDisabledError): + cc() + assert len(transport.requests) == 1 + + +class DescribeClientCredentialsRetry: + def it_retries_5xx_then_succeeds(self, fake_clock): + cc, transport = _make_client_credentials( + [ + httpx.Response(503, json={}), + httpx.Response(500, json={}), + exchange_response(access_token="jwt-1", expires_in=900), + ], + max_retries=3, + ) + + assert cc() == "jwt-1" + assert len(transport.requests) == 3 + + def it_retries_network_errors_then_succeeds(self, fake_clock): + cc, transport = _make_client_credentials( + [ + httpx.ConnectError("refused"), + httpx.ReadTimeout("slow"), + exchange_response(access_token="jwt-1", expires_in=900), + ], + max_retries=3, + ) + + assert cc() == "jwt-1" + assert len(transport.requests) == 3 + + def it_raises_when_retries_exhausted_without_cached_token(self, fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(500, json={})] * 4, + max_retries=3, + ) + + with pytest.raises(TokenExchangeError): + cc() + assert len(transport.requests) == 4 + + def it_serves_cached_jwt_during_outage_when_still_within_ttl( + self, fake_clock, caplog + ): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + httpx.Response(500, json={}), + httpx.Response(502, json={}), + httpx.Response(503, json={}), + httpx.Response(504, json={}), + ], + max_retries=3, + refresh_buffer_seconds=60, + ) + + assert cc() == "jwt-1" + fake_clock["now"] += 900 - 30 # past refresh buffer but before absolute expiry + + caplog.set_level(logging.WARNING, logger="unstructured-client.auth") + assert cc() == "jwt-1" + assert any( + "serving cached JWT" in record.getMessage() for record in caplog.records + ) + + def it_raises_when_cached_token_has_fully_expired(self, fake_clock): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + ], + max_retries=3, + refresh_buffer_seconds=60, + ) + + assert cc() == "jwt-1" + fake_clock["now"] += 1000 # past absolute expiry + + with pytest.raises(TokenExchangeError): + cc() + + +class DescribeClientCredentialsConcurrency: + def it_collapses_concurrent_calls_into_one_exchange(self): + """Ten threads calling `cc()` at once must produce exactly one + exchange HTTP request.""" + barrier = threading.Barrier(10) + transport = ScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.Client(transport=transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + results: List[str] = [] + + def _worker(): + barrier.wait() + results.append(cc()) + + threads = [threading.Thread(target=_worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert results == ["jwt-1"] * 10 + assert len(transport.requests) == 1 + + +class DescribeClientCredentialsConstructionValidation: + @pytest.mark.parametrize("bad_secret", ["", None]) + def it_rejects_empty_secret(self, bad_secret): + with pytest.raises(ValueError): + ClientCredentials( + client_secret=bad_secret, # type: ignore[arg-type] + server_url=SERVER_URL, + ) + + def it_rejects_empty_server_url(self): + with pytest.raises(ValueError): + ClientCredentials(client_secret=SECRET, server_url="") + + @pytest.mark.parametrize("bad_buffer", [-1, -100]) + def it_rejects_negative_refresh_buffer(self, bad_buffer): + with pytest.raises(ValueError): + ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + refresh_buffer_seconds=bad_buffer, + ) + + def it_rejects_negative_max_retries(self): + with pytest.raises(ValueError): + ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + max_retries=-1, + ) diff --git a/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py new file mode 100644 index 00000000..bf52591d --- /dev/null +++ b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py @@ -0,0 +1,99 @@ +"""Unit tests for :class:`unstructured_client.auth.LegacyKeyExchange`. + +These only assert the surface differences from :class:`ClientCredentials` +(grant type, credential field). The full caching / retry / concurrency +behavior is covered in ``test_client_credentials.py`` since +:class:`LegacyKeyExchange` inherits that machinery unchanged. +""" + +from __future__ import annotations + +import httpx +import pytest + +from unstructured_client.auth import ( + AsyncLegacyKeyExchange, + InvalidCredentialError, + LegacyKeyExchange, +) + +from ._mock_transport import ( + AsyncScriptedTransport, + ScriptedTransport, + body_of, + exchange_response, +) + +SERVER_URL = "https://accounts.example.test" +LEGACY_KEY = "uns_ak_legacy_example" + + +@pytest.fixture(autouse=True) +def _no_sleep(monkeypatch): + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.sleep", + lambda *_args, **_kwargs: None, + ) + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.asyncio.sleep", + _noop_async_sleep, + ) + + +async def _noop_async_sleep(*_args, **_kwargs): + return None + + +class DescribeLegacyKeyExchangeBody: + def it_sends_grant_type_api_key_and_api_key_field(self): + transport = ScriptedTransport([exchange_response(access_token="jwt-1")]) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + assert lke() == "jwt-1" + + req = transport.requests[0] + assert body_of(req) == {"grant_type": "api_key", "api_key": LEGACY_KEY} + + def it_propagates_401_as_invalid_credential(self): + transport = ScriptedTransport([httpx.Response(401, json={"detail": "bad"})]) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + with pytest.raises(InvalidCredentialError): + lke() + + +class DescribeAsyncLegacyKeyExchangeBody: + @pytest.mark.asyncio + async def it_sends_grant_type_api_key_and_api_key_field(self): + transport = AsyncScriptedTransport([exchange_response(access_token="jwt-1")]) + http_client = httpx.AsyncClient(transport=transport) + alke = AsyncLegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + token = await alke.acquire() + + assert token == "jwt-1" + assert body_of(transport.requests[0]) == { + "grant_type": "api_key", + "api_key": LEGACY_KEY, + } + + +class DescribeLegacyKeyExchangeConstruction: + @pytest.mark.parametrize("bad", ["", None]) + def it_rejects_empty_api_key(self, bad): + with pytest.raises(ValueError): + LegacyKeyExchange(api_key=bad, server_url=SERVER_URL) # type: ignore[arg-type] diff --git a/src/unstructured_client/_hooks/custom/__init__.py b/src/unstructured_client/_hooks/custom/__init__.py index 8917d508..824ec546 100644 --- a/src/unstructured_client/_hooks/custom/__init__.py +++ b/src/unstructured_client/_hooks/custom/__init__.py @@ -1,3 +1,4 @@ +from .auth_header_hook import AuthHeaderBeforeRequestHook from .clean_server_url_hook import CleanServerUrlSDKInitHook from .logger_hook import LoggerHook from .split_pdf_hook import SplitPdfHook diff --git a/src/unstructured_client/_hooks/custom/auth_header_hook.py b/src/unstructured_client/_hooks/custom/auth_header_hook.py new file mode 100644 index 00000000..7ee02618 --- /dev/null +++ b/src/unstructured_client/_hooks/custom/auth_header_hook.py @@ -0,0 +1,55 @@ +"""Before-request hook that promotes exchanged JWTs to ``Authorization: Bearer``. + +Speakeasy's generated ``Security`` model places ``api_key_auth`` in the +``unstructured-api-key`` header. When the user supplies a token-exchange +callable (``ClientCredentials`` or ``LegacyKeyExchange``) the value is a JWT +and must be sent as ``Authorization: Bearer `` so the service-side +``utic-jwt-auth`` validator picks it up (see ``core-product`` auth_context +and ``platform-api`` public_api/dependencies). + +Plain-string ``api_key_auth`` is untouched. +""" + +from __future__ import annotations + +from typing import Union + +import httpx + +from unstructured_client._hooks.types import BeforeRequestContext, BeforeRequestHook + + +class AuthHeaderBeforeRequestHook(BeforeRequestHook): + """Rewrite ``unstructured-api-key`` -> ``Authorization: Bearer`` when the + active security source is a known token-exchange callable.""" + + def before_request( + self, hook_ctx: BeforeRequestContext, request: httpx.Request + ) -> Union[httpx.Request, Exception]: + if not self._is_exchange_callable(hook_ctx.security_source): + return request + + token = request.headers.get("unstructured-api-key") + if not token: + return request + + del request.headers["unstructured-api-key"] + request.headers["Authorization"] = f"Bearer {token}" + return request + + @staticmethod + def _is_exchange_callable(security_source: object) -> bool: + """Return True when ``security_source`` was built from one of our + token-exchange callables. + + The SDK wraps a user-supplied callable into an internal factory and + attaches ``__wrapped_callable__`` to it (see ``sdk.py``). We import + the base class lazily to avoid any cycle at module load. + """ + from unstructured_client.auth._base import _ExchangeCallableBase + + if security_source is None: + return False + + candidate = getattr(security_source, "__wrapped_callable__", security_source) + return isinstance(candidate, _ExchangeCallableBase) diff --git a/src/unstructured_client/_hooks/registration.py b/src/unstructured_client/_hooks/registration.py index 22cd276d..7554dc7d 100644 --- a/src/unstructured_client/_hooks/registration.py +++ b/src/unstructured_client/_hooks/registration.py @@ -1,6 +1,7 @@ """Registration of custom, human-written hooks.""" from .custom import ( + AuthHeaderBeforeRequestHook, CleanServerUrlSDKInitHook, LoggerHook, SplitPdfHook, @@ -21,6 +22,7 @@ def init_hooks(hooks: Hooks): """ # Initialize custom hooks + auth_header_hook = AuthHeaderBeforeRequestHook() clean_server_url_hook = CleanServerUrlSDKInitHook() logger_hook = LoggerHook() split_pdf_hook = SplitPdfHook() @@ -33,7 +35,11 @@ def init_hooks(hooks: Hooks): hooks.register_sdk_init_hook(logger_hook) hooks.register_sdk_init_hook(split_pdf_hook) - # Register Before Request hooks + # Register Before Request hooks. + # `auth_header_hook` MUST run first so subsequent before-request hooks + # (e.g. `split_pdf_hook`) see the final `Authorization` header when the + # caller is using a ClientCredentials / LegacyKeyExchange callable. + hooks.register_before_request_hook(auth_header_hook) hooks.register_before_request_hook(split_pdf_hook) # Register After Error hooks @@ -42,4 +48,5 @@ def init_hooks(hooks: Hooks): # Register After Error hooks hooks.register_after_error_hook(split_pdf_hook) - hooks.register_after_error_hook(logger_hook) + hooks.register_after_error_hook(logger_hook) + diff --git a/src/unstructured_client/auth/__init__.py b/src/unstructured_client/auth/__init__.py new file mode 100644 index 00000000..0cffd986 --- /dev/null +++ b/src/unstructured_client/auth/__init__.py @@ -0,0 +1,40 @@ +"""Transparent token-exchange auth helpers for ``unstructured-client``. + +Pass one of these instances as ``api_key_auth`` to :class:`~unstructured_client.UnstructuredClient` +and the SDK will automatically exchange your credential for a short-lived +account-service JWT, cache it, refresh before expiry, and send it as +``Authorization: Bearer`` instead of ``unstructured-api-key``:: + + from unstructured_client import UnstructuredClient + from unstructured_client.auth import ClientCredentials + + client = UnstructuredClient( + api_key_auth=ClientCredentials( + client_secret="uns_sk_...", + server_url="https://accounts.unstructuredapp.io", + ), + ) + +Plain-string ``api_key_auth="..."`` continues to work unchanged and is sent +as the ``unstructured-api-key`` header. +""" + +from ._base import _ExchangeCallableBase +from ._exceptions import ( + InvalidCredentialError, + TokenExchangeDisabledError, + TokenExchangeError, +) +from .client_credentials import AsyncClientCredentials, ClientCredentials +from .legacy_api_key import AsyncLegacyKeyExchange, LegacyKeyExchange + +__all__ = [ + "AsyncClientCredentials", + "AsyncLegacyKeyExchange", + "ClientCredentials", + "InvalidCredentialError", + "LegacyKeyExchange", + "TokenExchangeDisabledError", + "TokenExchangeError", + "_ExchangeCallableBase", +] diff --git a/src/unstructured_client/auth/_base.py b/src/unstructured_client/auth/_base.py new file mode 100644 index 00000000..e03db337 --- /dev/null +++ b/src/unstructured_client/auth/_base.py @@ -0,0 +1,193 @@ +"""Shared internals for token-exchange auth callables. + +This module holds the abstract base class used by :class:`ClientCredentials`, +:class:`AsyncClientCredentials`, :class:`LegacyKeyExchange`, and +:class:`AsyncLegacyKeyExchange`. It implements: + +* In-memory caching of the most recent access token with TTL math. +* Lock-guarded refresh (``threading.Lock`` for sync, ``asyncio.Lock`` for async) + so concurrent callers collapse to a single in-flight exchange. +* Exponential-backoff retry on 5xx / network errors. +* Fallback to a still-unexpired cached token when account-service is + unavailable. + +No public API lives here. Users import from :mod:`unstructured_client.auth`. +""" + +from __future__ import annotations + +import asyncio +import logging +import random +import threading +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import httpx + +from ._exceptions import ( + InvalidCredentialError, + TokenExchangeDisabledError, + TokenExchangeError, +) + +logger = logging.getLogger("unstructured-client.auth") + + +TOKEN_EXCHANGE_PATH = "/auth/token-exchange" + +# Exponential-backoff delays (seconds) used for 5xx / network failures. +_BACKOFF_BASE_SECONDS = 0.5 +_BACKOFF_EXPONENT = 2.0 + + +@dataclass +class _CachedToken: + access_token: str + expires_at: float + + +class _ExchangeCallableBase: + """Shared state container for sync/async token-exchange callables. + + Subclasses only need to provide :meth:`_build_request_body`. The detection + logic in :class:`AuthHeaderBeforeRequestHook` uses ``isinstance`` against + this class to decide whether to rewrite the outgoing ``unstructured-api-key`` + header into ``Authorization: Bearer``. + """ + + def __init__( + self, + *, + server_url: str, + refresh_buffer_seconds: int = 60, + request_timeout_seconds: float = 30.0, + max_retries: int = 3, + ) -> None: + if not server_url: + raise ValueError("server_url must be a non-empty account-service base URL") + if refresh_buffer_seconds < 0: + raise ValueError("refresh_buffer_seconds must be >= 0") + if max_retries < 0: + raise ValueError("max_retries must be >= 0") + if request_timeout_seconds <= 0: + raise ValueError("request_timeout_seconds must be > 0") + + self._server_url = server_url.rstrip("/") + self._refresh_buffer_seconds = refresh_buffer_seconds + self._request_timeout_seconds = request_timeout_seconds + self._max_retries = max_retries + + self._cache: Optional[_CachedToken] = None + self._lock = threading.Lock() + + def _build_request_body(self) -> Dict[str, Any]: + """Return the JSON body for the `/auth/token-exchange` POST. + + Subclasses override to select `grant_type` and the credential field. + """ + raise NotImplementedError + + @property + def _exchange_url(self) -> str: + return f"{self._server_url}{TOKEN_EXCHANGE_PATH}" + + def _cached_token_if_fresh(self, now: float) -> Optional[str]: + """Return the cached JWT if it is still valid beyond the refresh buffer.""" + cached = self._cache + if cached is None: + return None + if now >= cached.expires_at - self._refresh_buffer_seconds: + return None + return cached.access_token + + def _cached_token_if_not_expired(self, now: float) -> Optional[str]: + """Return the cached JWT if it has not yet crossed absolute expiry. + + Used as an outage fallback: when account-service is unreachable but a + previously fetched token is still technically valid (past the refresh + buffer but before ``expires_at``), serving it keeps the caller working + until the real expiry lands. + """ + cached = self._cache + if cached is None: + return None + if now >= cached.expires_at: + return None + return cached.access_token + + def _parse_exchange_response(self, response: httpx.Response) -> str: + """Parse a successful ``/auth/token-exchange`` response into a JWT. + + Updates the in-memory cache with the new token and its absolute expiry. + Raises :class:`TokenExchangeDisabledError` if the server reports + ``token_exchange_enabled=False``, and :class:`TokenExchangeError` for + malformed payloads. + """ + try: + payload = response.json() + except ValueError as exc: + raise TokenExchangeError( + f"Account-service returned a non-JSON body on token exchange: {exc}" + ) from exc + + if not payload.get("token_exchange_enabled", True): + raise TokenExchangeDisabledError( + "Account-service reports token_exchange_enabled=False. " + "ClientCredentials / LegacyKeyExchange require a server with " + "DEPLOYMENT_MODE=dedicated (or equivalent) that accepts token " + "exchange. Fall back to plain api_key_auth= if needed.", + ) + + access_token = payload.get("access_token") + expires_in = payload.get("expires_in") + if not access_token or not isinstance(expires_in, (int, float)) or expires_in <= 0: + raise TokenExchangeError( + "Account-service returned a malformed token-exchange response: " + f"access_token={'' if access_token else ''}, " + f"expires_in={expires_in!r}", + ) + + self._cache = _CachedToken( + access_token=access_token, + expires_at=time.monotonic() + float(expires_in), + ) + return access_token + + @staticmethod + def _backoff_delay(attempt: int) -> float: + """Exponential backoff with a small jitter to avoid thundering herds.""" + base = _BACKOFF_BASE_SECONDS * (_BACKOFF_EXPONENT ** attempt) + jitter = random.uniform(0, _BACKOFF_BASE_SECONDS) + return base + jitter + + def _raise_for_status(self, response: httpx.Response) -> None: + """Map HTTP status to auth-specific exceptions before retry decisions.""" + if response.status_code == 401: + raise InvalidCredentialError( + "Account-service rejected the credential (401). Check that the " + "client secret / API key is correct and not revoked.", + ) + if response.status_code == 400: + raise TokenExchangeError( + f"Account-service rejected the token-exchange request (400): " + f"{response.text[:500]}", + ) + + def _handle_outage(self, last_error: Optional[Exception]) -> str: + """Serve a still-unexpired cached JWT after exhausting retries, else raise.""" + now = time.monotonic() + cached = self._cached_token_if_not_expired(now) + if cached is not None: + logger.warning( + "Account-service unavailable during token exchange; " + "serving cached JWT while still within its absolute TTL. " + "Last error: %s", + last_error, + ) + return cached + raise TokenExchangeError( + f"Token exchange failed after {self._max_retries + 1} attempt(s) " + f"and no valid cached token is available: {last_error}", + ) from last_error diff --git a/src/unstructured_client/auth/_exceptions.py b/src/unstructured_client/auth/_exceptions.py new file mode 100644 index 00000000..680f5e17 --- /dev/null +++ b/src/unstructured_client/auth/_exceptions.py @@ -0,0 +1,24 @@ +"""Exceptions raised by the token-exchange auth callables.""" + +from __future__ import annotations + + +class TokenExchangeError(Exception): + """Base error for failures during `/auth/token-exchange` calls.""" + + +class TokenExchangeDisabledError(TokenExchangeError): + """Raised when account-service responds with `token_exchange_enabled=False`. + + The user explicitly opted into `ClientCredentials` / `LegacyKeyExchange`, so + the server not supporting exchange is treated as a misconfiguration rather + than silently returning a null token. + """ + + +class InvalidCredentialError(TokenExchangeError): + """Raised when account-service returns 401 Unauthorized. + + The supplied client secret or legacy API key was not recognized. Retrying + will not help, so the exchange callable raises immediately. + """ diff --git a/src/unstructured_client/auth/client_credentials.py b/src/unstructured_client/auth/client_credentials.py new file mode 100644 index 00000000..5223b6b9 --- /dev/null +++ b/src/unstructured_client/auth/client_credentials.py @@ -0,0 +1,276 @@ +"""``ClientCredentials`` callable - transparent client-secret -> JWT exchange. + +Usage:: + + from unstructured_client import UnstructuredClient + from unstructured_client.auth import ClientCredentials + + client = UnstructuredClient( + api_key_auth=ClientCredentials( + client_secret="uns_sk_...", + server_url="https://accounts.unstructuredapp.io", + ), + ) + +The SDK invokes the callable on each request; this class caches the exchanged +JWT in-memory and refreshes it shortly before expiry. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, Optional + +import httpx + +from ._base import _ExchangeCallableBase +from ._exceptions import InvalidCredentialError, TokenExchangeError + + +class ClientCredentials(_ExchangeCallableBase): + """Synchronous ``client_credentials`` grant callable. + + Exchanges a long-lived client secret for a short-lived account-service JWT + via ``POST /auth/token-exchange``. Thread-safe: concurrent callers collapse + onto a single in-flight exchange via an internal lock. + """ + + def __init__( + self, + client_secret: str, + *, + server_url: str, + refresh_buffer_seconds: int = 60, + request_timeout_seconds: float = 30.0, + max_retries: int = 3, + http_client: Optional[httpx.Client] = None, + ) -> None: + """ + :param client_secret: ``uns_sk_...`` client secret provisioned via + account-service. + :param server_url: Base URL of account-service (e.g. + ``https://accounts.unstructuredapp.io``). + :param refresh_buffer_seconds: Re-exchange when fewer than this many + seconds remain before the token's absolute expiry. + :param request_timeout_seconds: Per-attempt timeout for the exchange + HTTP call. + :param max_retries: Number of additional attempts on 5xx / network + errors before serving a cached JWT or raising. + :param http_client: Optional :class:`httpx.Client` injected for tests + or shared connection pooling. If omitted, a private client is + created lazily. + """ + if not client_secret: + raise ValueError("client_secret must be a non-empty string") + super().__init__( + server_url=server_url, + refresh_buffer_seconds=refresh_buffer_seconds, + request_timeout_seconds=request_timeout_seconds, + max_retries=max_retries, + ) + self._client_secret = client_secret + self._http_client = http_client + self._owns_http_client = http_client is None + + def _build_request_body(self) -> Dict[str, Any]: + return { + "grant_type": "client_credentials", + "client_secret": self._client_secret, + } + + def _get_http_client(self) -> httpx.Client: + if self._http_client is None: + self._http_client = httpx.Client(timeout=self._request_timeout_seconds) + return self._http_client + + def __call__(self) -> str: + """Return a valid JWT, performing an exchange only when necessary.""" + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached + + with self._lock: + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached + return self._exchange() + + def _exchange(self) -> str: + client = self._get_http_client() + body = self._build_request_body() + last_error: Optional[Exception] = None + + for attempt in range(self._max_retries + 1): + try: + response = client.post( + self._exchange_url, + json=body, + headers={"Content-Type": "application/json"}, + timeout=self._request_timeout_seconds, + ) + except httpx.HTTPError as exc: + last_error = exc + if attempt < self._max_retries: + time.sleep(self._backoff_delay(attempt)) + continue + break + + self._raise_for_status(response) + + if 500 <= response.status_code < 600: + last_error = TokenExchangeError( + f"Account-service returned {response.status_code} on token exchange", + ) + if attempt < self._max_retries: + time.sleep(self._backoff_delay(attempt)) + continue + break + + if response.status_code != 200: + raise TokenExchangeError( + f"Unexpected status {response.status_code} from token exchange: " + f"{response.text[:500]}", + ) + + return self._parse_exchange_response(response) + + return self._handle_outage(last_error) + + def close(self) -> None: + """Close the private HTTP client, if one was created internally.""" + if self._owns_http_client and self._http_client is not None: + self._http_client.close() + self._http_client = None + + +class AsyncClientCredentials(_ExchangeCallableBase): + """Asynchronous twin of :class:`ClientCredentials`. + + The synchronous wrapper (:meth:`__call__`) runs the async exchange via + :func:`asyncio.run` when invoked from a non-async context, so it can still + be plugged into the SDK's sync-only ``api_key_auth`` callable hook. When + already inside a running loop, it uses that loop's executor to avoid + deadlocking. + """ + + def __init__( + self, + client_secret: str, + *, + server_url: str, + refresh_buffer_seconds: int = 60, + request_timeout_seconds: float = 30.0, + max_retries: int = 3, + http_client: Optional[httpx.AsyncClient] = None, + ) -> None: + if not client_secret: + raise ValueError("client_secret must be a non-empty string") + super().__init__( + server_url=server_url, + refresh_buffer_seconds=refresh_buffer_seconds, + request_timeout_seconds=request_timeout_seconds, + max_retries=max_retries, + ) + self._client_secret = client_secret + self._http_client = http_client + self._owns_http_client = http_client is None + self._async_lock = asyncio.Lock() + + def _build_request_body(self) -> Dict[str, Any]: + return { + "grant_type": "client_credentials", + "client_secret": self._client_secret, + } + + def _get_http_client(self) -> httpx.AsyncClient: + if self._http_client is None: + self._http_client = httpx.AsyncClient(timeout=self._request_timeout_seconds) + return self._http_client + + async def acquire(self) -> str: + """Async variant of ``__call__``. Returns a valid JWT.""" + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached + + async with self._async_lock: + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached + return await self._exchange() + + async def _exchange(self) -> str: + client = self._get_http_client() + body = self._build_request_body() + last_error: Optional[Exception] = None + + for attempt in range(self._max_retries + 1): + try: + response = await client.post( + self._exchange_url, + json=body, + headers={"Content-Type": "application/json"}, + timeout=self._request_timeout_seconds, + ) + except httpx.HTTPError as exc: + last_error = exc + if attempt < self._max_retries: + await asyncio.sleep(self._backoff_delay(attempt)) + continue + break + + self._raise_for_status(response) + + if 500 <= response.status_code < 600: + last_error = TokenExchangeError( + f"Account-service returned {response.status_code} on token exchange", + ) + if attempt < self._max_retries: + await asyncio.sleep(self._backoff_delay(attempt)) + continue + break + + if response.status_code != 200: + raise TokenExchangeError( + f"Unexpected status {response.status_code} from token exchange: " + f"{response.text[:500]}", + ) + + return self._parse_exchange_response(response) + + return self._handle_outage(last_error) + + def __call__(self) -> str: + """Sync entry point so the SDK's ``api_key_auth`` callable hook works. + + When invoked from inside a running event loop (the usual case for + async SDK methods), the exchange runs in the loop's default executor + so we don't reenter :func:`asyncio.run`. Otherwise we spin up a + temporary loop via :func:`asyncio.run`. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.acquire()) + + # Inside a running loop - offload to a worker thread that drives its + # own event loop so we don't block the caller's loop on httpx IO. + import concurrent.futures + + def _run_in_new_loop() -> str: + return asyncio.run(self.acquire()) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(_run_in_new_loop) + return future.result() + + async def aclose(self) -> None: + """Close the private HTTP client, if one was created internally.""" + if self._owns_http_client and self._http_client is not None: + await self._http_client.aclose() + self._http_client = None diff --git a/src/unstructured_client/auth/legacy_api_key.py b/src/unstructured_client/auth/legacy_api_key.py new file mode 100644 index 00000000..a15b7235 --- /dev/null +++ b/src/unstructured_client/auth/legacy_api_key.py @@ -0,0 +1,88 @@ +"""``LegacyKeyExchange`` - transitional api-key -> JWT exchange. + +This mirrors :class:`ClientCredentials` but uses ``grant_type=api_key`` so +customers still on legacy api-tracking keys can get JWT-backed requests +without re-issuing credentials. The class is intentionally flagged as +transitional and will be removed once legacy keys are decommissioned. +""" + +from __future__ import annotations + +from typing import Any, Dict + +from .client_credentials import AsyncClientCredentials, ClientCredentials + + +class LegacyKeyExchange(ClientCredentials): + """Synchronous ``api_key`` grant callable (transitional). + + Accepts a legacy raw API key (validated by account-service against + api-tracking) and exchanges it for an account-service JWT. Caching, + refresh, and retry behavior are identical to + :class:`~unstructured_client.auth.ClientCredentials`. + + .. deprecated:: + Prefer :class:`~unstructured_client.auth.ClientCredentials` with a + ``uns_sk_...`` client secret. ``LegacyKeyExchange`` exists only to + bridge customers during the API Key Scoping rollout and will be + removed once legacy keys are retired. + """ + + def __init__( + self, + api_key: str, + *, + server_url: str, + refresh_buffer_seconds: int = 60, + request_timeout_seconds: float = 30.0, + max_retries: int = 3, + http_client=None, + ) -> None: + if not api_key: + raise ValueError("api_key must be a non-empty string") + super().__init__( + client_secret=api_key, + server_url=server_url, + refresh_buffer_seconds=refresh_buffer_seconds, + request_timeout_seconds=request_timeout_seconds, + max_retries=max_retries, + http_client=http_client, + ) + self._api_key = api_key + + def _build_request_body(self) -> Dict[str, Any]: + return {"grant_type": "api_key", "api_key": self._api_key} + + +class AsyncLegacyKeyExchange(AsyncClientCredentials): + """Asynchronous twin of :class:`LegacyKeyExchange` (transitional). + + .. deprecated:: + Prefer :class:`~unstructured_client.auth.AsyncClientCredentials` with + a ``uns_sk_...`` client secret. + """ + + def __init__( + self, + api_key: str, + *, + server_url: str, + refresh_buffer_seconds: int = 60, + request_timeout_seconds: float = 30.0, + max_retries: int = 3, + http_client=None, + ) -> None: + if not api_key: + raise ValueError("api_key must be a non-empty string") + super().__init__( + client_secret=api_key, + server_url=server_url, + refresh_buffer_seconds=refresh_buffer_seconds, + request_timeout_seconds=request_timeout_seconds, + max_retries=max_retries, + http_client=http_client, + ) + self._api_key = api_key + + def _build_request_body(self) -> Dict[str, Any]: + return {"grant_type": "api_key", "api_key": self._api_key} diff --git a/src/unstructured_client/sdk.py b/src/unstructured_client/sdk.py index 3671fa6e..44991b53 100644 --- a/src/unstructured_client/sdk.py +++ b/src/unstructured_client/sdk.py @@ -87,8 +87,15 @@ def __init__( security: Any = None if callable(api_key_auth): - # pylint: disable=unnecessary-lambda-assignment - security = lambda: shared.Security(api_key_auth=api_key_auth()) + # Preserve a reference to the user-supplied callable on the + # security factory so custom hooks (e.g. the auth-header hook) + # can detect ClientCredentials / LegacyKeyExchange instances + # without reaching into lambda closures. + def _security_factory() -> shared.Security: + return shared.Security(api_key_auth=api_key_auth()) + + setattr(_security_factory, "__wrapped_callable__", api_key_auth) + security = _security_factory else: security = shared.Security(api_key_auth=api_key_auth) From 8dc719d577588938e630001af7ad6f33e2bb5041 Mon Sep 17 00:00:00 2001 From: Mateusz Kuprowski Date: Fri, 24 Apr 2026 16:31:29 +0200 Subject: [PATCH 2/4] Test fixes and readme update --- README.md | 96 ++++ .../test_client_credentials_e2e.py | 115 +++++ .../auth/test_async_client_credentials.py | 243 +++++----- .../unit/auth/test_auth_header_hook.py | 263 ++++++----- .../unit/auth/test_client_credentials.py | 426 +++++++++--------- .../unit/auth/test_legacy_key_exchange.py | 108 +++-- 6 files changed, 728 insertions(+), 523 deletions(-) create mode 100644 _test_unstructured_client/integration/test_client_credentials_e2e.py diff --git a/README.md b/README.md index 17bfc6f4..438cd8df 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,102 @@ s = UnstructuredClient(debug_logger=logging.getLogger("unstructured_client")) +## Authentication with Client Secrets + +> Available from SDK version **X.Y.Z** (first release carrying the +> `unstructured_client.auth` module). + +If you are running against an Unstructured deployment that issues **client +secrets** (`uns_sk_...`) — e.g. Dedicated Instances or self-hosted +clusters with `DEPLOYMENT_MODE=dedicated` on account-service — the SDK +can transparently exchange that secret for a short-lived JWT, cache it, +refresh it before expiry, and send it on every request as +`Authorization: Bearer `. + +### Synchronous usage + +```python +from unstructured_client import UnstructuredClient +from unstructured_client.auth import ClientCredentials + +client = UnstructuredClient( + api_key_auth=ClientCredentials( + client_secret="uns_sk_...", + server_url="https://accounts.unstructuredapp.io", # account-service base URL + ), + server_url="https://platform.unstructuredapp.io", # platform-api / core-product +) + +# Every operation automatically carries Authorization: Bearer . +client.general.partition(...) +``` + +### Asynchronous usage + +```python +import asyncio +from unstructured_client import UnstructuredClient +from unstructured_client.auth import AsyncClientCredentials + +async def main() -> None: + auth = AsyncClientCredentials( + client_secret="uns_sk_...", + server_url="https://accounts.unstructuredapp.io", + ) + async with UnstructuredClient(api_key_auth=auth) as client: + await client.general.partition_async(...) + +asyncio.run(main()) +``` + +### Legacy API-key bridge + +For deployments still using legacy api-tracking keys, the same machinery +is available through `LegacyKeyExchange` / `AsyncLegacyKeyExchange`. It +hits the same `/auth/token-exchange` endpoint with +`grant_type=api_key` and is intentionally transitional — migrate to +`ClientCredentials` once client secrets are provisioned. + +```python +from unstructured_client.auth import LegacyKeyExchange + +client = UnstructuredClient( + api_key_auth=LegacyKeyExchange( + api_key="your-legacy-uns_ak-key", + server_url="https://accounts.unstructuredapp.io", + ), +) +``` + +### Behavior and tuning + +- **Caching:** JWTs are held in-memory and reused until + `refresh_buffer_seconds` (default **60s**) before absolute expiry. +- **Concurrency:** sync callers share a `threading.Lock`, async callers + share an `asyncio.Lock`. Ten concurrent requests on a cold cache drive + exactly one exchange. +- **Retries:** 5xx and network errors retry with exponential backoff + (default `max_retries=3`). `400` / `401` fail fast with + `TokenExchangeError` / `InvalidCredentialError`. +- **Outage fallback:** if account-service is unreachable *and* a cached + token is still within its absolute TTL, the cached token is returned + and a warning is logged on `unstructured-client.auth`. +- **Disabled exchange:** when the server responds with + `token_exchange_enabled=False`, the call raises + `TokenExchangeDisabledError` — that deployment expects the plain + `api_key_auth="..."` string form instead. + +### Backward compatibility + +Passing a plain string still works exactly as before: + +```python +client = UnstructuredClient(api_key_auth="your-key") +``` + +In that case the SDK sends `unstructured-api-key: your-key` without any +token exchange, identical to pre-`auth` SDK versions. + ### Maturity This SDK is in beta, and there may be breaking changes between versions without a major version update. Therefore, we recommend pinning usage diff --git a/_test_unstructured_client/integration/test_client_credentials_e2e.py b/_test_unstructured_client/integration/test_client_credentials_e2e.py new file mode 100644 index 00000000..d4b9ba63 --- /dev/null +++ b/_test_unstructured_client/integration/test_client_credentials_e2e.py @@ -0,0 +1,115 @@ +"""End-to-end integration test for :class:`ClientCredentials` and +:class:`LegacyKeyExchange`. + +This test is **opt-in**: it only runs when every required env var is set. +Point it at a deployment (e.g. the ``unsightly-koala`` dedicated-instance +test cluster) that has ``DEPLOYMENT_MODE=dedicated`` and a valid client +secret provisioned via account-service. + +Required env vars +----------------- + +- ``UNS_ACCOUNTS_URL`` base URL of account-service (e.g. + ``https://accounts.unsightly-koala.example``) +- ``UNS_CLIENT_SECRET`` ``uns_sk_...`` client secret +- ``UNS_PLATFORM_API_URL`` platform-api base URL to hit after exchange + +Optional: + +- ``UNS_LEGACY_API_KEY`` if set, the LegacyKeyExchange path is also + exercised against the same platform-api. + +What it verifies +---------------- + +1. The SDK can bootstrap a :class:`ClientCredentials` and successfully + exchange the secret for a JWT against real account-service. +2. A real downstream call (``jobs.list_jobs``) goes through with + ``Authorization: Bearer`` and returns 2xx. +3. Re-using the same client does not trigger a second exchange (cache + hit) because the first JWT is still within its TTL. +""" + +from __future__ import annotations + +import os + +import pytest + +from unstructured_client import UnstructuredClient +from unstructured_client.auth import ClientCredentials, LegacyKeyExchange +from unstructured_client.models import operations + +ACCOUNTS_URL = os.getenv("UNS_ACCOUNTS_URL") +CLIENT_SECRET = os.getenv("UNS_CLIENT_SECRET") +PLATFORM_API_URL = os.getenv("UNS_PLATFORM_API_URL") +LEGACY_API_KEY = os.getenv("UNS_LEGACY_API_KEY") + + +_REASON = ( + "Opt-in E2E: set UNS_ACCOUNTS_URL, UNS_CLIENT_SECRET, and " + "UNS_PLATFORM_API_URL to run against a real dedicated-instance " + "deployment (e.g. unsightly-koala)." +) + + +pytestmark = pytest.mark.skipif( + not (ACCOUNTS_URL and CLIENT_SECRET and PLATFORM_API_URL), + reason=_REASON, +) + + +def _list_jobs(session: UnstructuredClient) -> None: + """Lightweight read request that only needs an authenticated identity.""" + session.jobs.list_jobs(request=operations.ListJobsRequest()) + + +def test_client_credentials_exchange_and_list_jobs(): + cc = ClientCredentials( + client_secret=CLIENT_SECRET, # type: ignore[arg-type] + server_url=ACCOUNTS_URL, # type: ignore[arg-type] + ) + try: + session = UnstructuredClient( + api_key_auth=cc, + server_url=PLATFORM_API_URL, + timeout_ms=60_000, + ) + + _list_jobs(session) + + # Cached exchange: internal cache now holds a JWT; a second call + # should not trigger a new exchange unless we crossed the refresh + # buffer, which is unlikely across two sequential requests. + before_cache = cc._cache # type: ignore[attr-defined] + assert before_cache is not None, "expected cache to be populated after first call" + + _list_jobs(session) + + after_cache = cc._cache # type: ignore[attr-defined] + assert after_cache is before_cache, ( + "ClientCredentials re-exchanged within TTL; cache should be reused" + ) + finally: + cc.close() + + +@pytest.mark.skipif( + LEGACY_API_KEY is None, + reason="Set UNS_LEGACY_API_KEY to also exercise the LegacyKeyExchange path.", +) +def test_legacy_key_exchange_and_list_jobs(): + lke = LegacyKeyExchange( + api_key=LEGACY_API_KEY, # type: ignore[arg-type] + server_url=ACCOUNTS_URL, # type: ignore[arg-type] + ) + try: + session = UnstructuredClient( + api_key_auth=lke, + server_url=PLATFORM_API_URL, + timeout_ms=60_000, + ) + _list_jobs(session) + assert lke._cache is not None # type: ignore[attr-defined] + finally: + lke.close() diff --git a/_test_unstructured_client/unit/auth/test_async_client_credentials.py b/_test_unstructured_client/unit/auth/test_async_client_credentials.py index 7f533247..cf1568bd 100644 --- a/_test_unstructured_client/unit/auth/test_async_client_credentials.py +++ b/_test_unstructured_client/unit/auth/test_async_client_credentials.py @@ -45,123 +45,126 @@ def _now() -> float: return state -class DescribeAsyncClientCredentials: - @pytest.mark.asyncio - async def it_exchanges_then_caches(self, fake_clock): - transport = AsyncScriptedTransport( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - ) - - first = await acc.acquire() - second = await acc.acquire() - - assert first == second == "jwt-1" - assert len(transport.requests) == 1 - assert body_of(transport.requests[0]) == { - "grant_type": "client_credentials", - "client_secret": SECRET, - } - - @pytest.mark.asyncio - async def it_raises_invalid_credential_on_401(self, fake_clock): - transport = AsyncScriptedTransport( - [httpx.Response(401, json={"detail": "bad"})] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - max_retries=5, - ) - - with pytest.raises(InvalidCredentialError): - await acc.acquire() - - @pytest.mark.asyncio - async def it_retries_5xx_then_succeeds(self, fake_clock): - transport = AsyncScriptedTransport( - [ - httpx.Response(500), - httpx.Response(502), - exchange_response(access_token="jwt-1", expires_in=900), - ] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - max_retries=3, - ) - - assert await acc.acquire() == "jwt-1" - assert len(transport.requests) == 3 - - @pytest.mark.asyncio - async def it_serializes_concurrent_acquires(self, fake_clock): - """Ten concurrent ``acquire()`` calls must share one exchange.""" - transport = AsyncScriptedTransport( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - ) - - results: List[str] = await asyncio.gather(*(acc.acquire() for _ in range(10))) - - assert results == ["jwt-1"] * 10 - assert len(transport.requests) == 1 - - @pytest.mark.asyncio - async def it_raises_outage_error_without_cached_token(self, fake_clock): - transport = AsyncScriptedTransport([httpx.Response(500)] * 4) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - max_retries=3, - ) - - with pytest.raises(TokenExchangeError): - await acc.acquire() - - def it_sync_call_works_outside_running_loop(self, fake_clock): - """``__call__`` is the SDK entry point; must work without a loop.""" - transport = AsyncScriptedTransport( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - ) - - assert acc() == "jwt-1" - - @pytest.mark.asyncio - async def it_sync_call_works_inside_running_loop(self, fake_clock): - """Driving __call__ from a running loop offloads to a worker thread.""" - transport = AsyncScriptedTransport( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) - http_client = httpx.AsyncClient(transport=transport) - acc = AsyncClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - http_client=http_client, - ) - - token = await asyncio.to_thread(acc) - assert token == "jwt-1" +@pytest.mark.asyncio +async def test_exchanges_then_caches(fake_clock): + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + first = await acc.acquire() + second = await acc.acquire() + + assert first == second == "jwt-1" + assert len(transport.requests) == 1 + assert body_of(transport.requests[0]) == { + "grant_type": "client_credentials", + "client_secret": SECRET, + } + + +@pytest.mark.asyncio +async def test_raises_invalid_credential_on_401(fake_clock): + transport = AsyncScriptedTransport([httpx.Response(401, json={"detail": "bad"})]) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=5, + ) + + with pytest.raises(InvalidCredentialError): + await acc.acquire() + + +@pytest.mark.asyncio +async def test_retries_5xx_then_succeeds(fake_clock): + transport = AsyncScriptedTransport( + [ + httpx.Response(500), + httpx.Response(502), + exchange_response(access_token="jwt-1", expires_in=900), + ] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=3, + ) + + assert await acc.acquire() == "jwt-1" + assert len(transport.requests) == 3 + + +@pytest.mark.asyncio +async def test_serializes_concurrent_acquires(fake_clock): + """Ten concurrent ``acquire()`` calls must share one exchange.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + results: List[str] = await asyncio.gather(*(acc.acquire() for _ in range(10))) + + assert results == ["jwt-1"] * 10 + assert len(transport.requests) == 1 + + +@pytest.mark.asyncio +async def test_raises_outage_error_without_cached_token(fake_clock): + transport = AsyncScriptedTransport([httpx.Response(500)] * 4) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + max_retries=3, + ) + + with pytest.raises(TokenExchangeError): + await acc.acquire() + + +def test_sync_call_works_outside_running_loop(fake_clock): + """``__call__`` is the SDK entry point; must work without a loop.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + assert acc() == "jwt-1" + + +@pytest.mark.asyncio +async def test_sync_call_works_inside_running_loop(fake_clock): + """Driving __call__ from a running loop offloads to a worker thread.""" + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + token = await asyncio.to_thread(acc) + assert token == "jwt-1" diff --git a/_test_unstructured_client/unit/auth/test_auth_header_hook.py b/_test_unstructured_client/unit/auth/test_auth_header_hook.py index 1477cd5a..d55de91f 100644 --- a/_test_unstructured_client/unit/auth/test_auth_header_hook.py +++ b/_test_unstructured_client/unit/auth/test_auth_header_hook.py @@ -7,10 +7,7 @@ from __future__ import annotations -from typing import Optional - import httpx -import pytest from unstructured_client import UnstructuredClient from unstructured_client._hooks.custom.auth_header_hook import ( @@ -28,7 +25,9 @@ def _make_request(headers: dict) -> httpx.Request: - return httpx.Request("GET", "https://api.example.test/api/v1/jobs/", headers=headers) + return httpx.Request( + "GET", "https://api.example.test/api/v1/jobs/", headers=headers + ) def _make_hook_ctx(security_source) -> BeforeRequestContext: @@ -42,171 +41,163 @@ def _make_hook_ctx(security_source) -> BeforeRequestContext: return BeforeRequestContext(inner) -class DescribeAuthHeaderHookDirect: - def it_rewrites_header_when_source_is_client_credentials(self): - transport = ScriptedTransport([exchange_response()]) - cc = ClientCredentials( - client_secret=SECRET, - server_url=ACCOUNTS_URL, - http_client=httpx.Client(transport=transport), - ) +def test_rewrites_header_when_source_is_client_credentials(): + transport = ScriptedTransport([exchange_response()]) + cc = ClientCredentials( + client_secret=SECRET, + server_url=ACCOUNTS_URL, + http_client=httpx.Client(transport=transport), + ) - # Simulate what sdk.py builds: a factory with __wrapped_callable__ - def factory(): - return None + def factory(): + return None - setattr(factory, "__wrapped_callable__", cc) + setattr(factory, "__wrapped_callable__", cc) - hook = AuthHeaderBeforeRequestHook() - request = _make_request({"unstructured-api-key": "jwt-value"}) + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "jwt-value"}) - result = hook.before_request(_make_hook_ctx(factory), request) + result = hook.before_request(_make_hook_ctx(factory), request) - assert isinstance(result, httpx.Request) - assert result.headers.get("Authorization") == "Bearer jwt-value" - assert "unstructured-api-key" not in result.headers + assert isinstance(result, httpx.Request) + assert result.headers.get("Authorization") == "Bearer jwt-value" + assert "unstructured-api-key" not in result.headers - def it_rewrites_header_when_source_is_legacy_key_exchange(self): - transport = ScriptedTransport([exchange_response()]) - lke = LegacyKeyExchange( - api_key="legacy", - server_url=ACCOUNTS_URL, - http_client=httpx.Client(transport=transport), - ) - def factory(): - return None +def test_rewrites_header_when_source_is_legacy_key_exchange(): + transport = ScriptedTransport([exchange_response()]) + lke = LegacyKeyExchange( + api_key="legacy", + server_url=ACCOUNTS_URL, + http_client=httpx.Client(transport=transport), + ) - setattr(factory, "__wrapped_callable__", lke) + def factory(): + return None - hook = AuthHeaderBeforeRequestHook() - request = _make_request({"unstructured-api-key": "jwt-value"}) + setattr(factory, "__wrapped_callable__", lke) - result = hook.before_request(_make_hook_ctx(factory), request) + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "jwt-value"}) - assert isinstance(result, httpx.Request) - assert result.headers.get("Authorization") == "Bearer jwt-value" - assert "unstructured-api-key" not in result.headers + result = hook.before_request(_make_hook_ctx(factory), request) - def it_is_noop_for_plain_string_security_source(self): - # When api_key_auth is a string, sdk.py passes a `shared.Security` - # instance as `security_source`, not a callable. The hook must - # leave the request untouched. - from unstructured_client.models import shared + assert isinstance(result, httpx.Request) + assert result.headers.get("Authorization") == "Bearer jwt-value" + assert "unstructured-api-key" not in result.headers - hook = AuthHeaderBeforeRequestHook() - request = _make_request({"unstructured-api-key": FAKE_KEY}) - result = hook.before_request( - _make_hook_ctx(shared.Security(api_key_auth=FAKE_KEY)), - request, - ) +def test_is_noop_for_plain_string_security_source(): + from unstructured_client.models import shared - assert isinstance(result, httpx.Request) - assert result.headers.get("unstructured-api-key") == FAKE_KEY - assert "Authorization" not in result.headers + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": FAKE_KEY}) - def it_is_noop_for_arbitrary_user_callable(self): - def user_callable() -> str: - return "whatever" + result = hook.before_request( + _make_hook_ctx(shared.Security(api_key_auth=FAKE_KEY)), + request, + ) - def factory(): - return None + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == FAKE_KEY + assert "Authorization" not in result.headers - setattr(factory, "__wrapped_callable__", user_callable) - hook = AuthHeaderBeforeRequestHook() - request = _make_request({"unstructured-api-key": "whatever"}) +def test_is_noop_for_arbitrary_user_callable(): + def user_callable() -> str: + return "whatever" - result = hook.before_request(_make_hook_ctx(factory), request) + def factory(): + return None - assert isinstance(result, httpx.Request) - assert result.headers.get("unstructured-api-key") == "whatever" - assert "Authorization" not in result.headers + setattr(factory, "__wrapped_callable__", user_callable) - def it_is_noop_when_security_source_is_none(self): - hook = AuthHeaderBeforeRequestHook() - request = _make_request({"unstructured-api-key": FAKE_KEY}) + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": "whatever"}) - result = hook.before_request(_make_hook_ctx(None), request) + result = hook.before_request(_make_hook_ctx(factory), request) - assert isinstance(result, httpx.Request) - assert result.headers.get("unstructured-api-key") == FAKE_KEY - assert "Authorization" not in result.headers + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == "whatever" + assert "Authorization" not in result.headers -class DescribeAuthHeaderHookIntegration: - """End-to-end: instantiate :class:`UnstructuredClient` with a - :class:`ClientCredentials` and assert the outgoing request to the - downstream API carries ``Authorization: Bearer `` and no - ``unstructured-api-key``. - """ +def test_is_noop_when_security_source_is_none(): + hook = AuthHeaderBeforeRequestHook() + request = _make_request({"unstructured-api-key": FAKE_KEY}) - def it_sends_bearer_header_for_client_credentials(self): - exchange_transport = ScriptedTransport( - [exchange_response(access_token="jwt-abc", expires_in=900)] - ) - exchange_http_client = httpx.Client(transport=exchange_transport) - cc = ClientCredentials( - client_secret=SECRET, - server_url=ACCOUNTS_URL, - http_client=exchange_http_client, - ) + result = hook.before_request(_make_hook_ctx(None), request) - captured: dict = {} + assert isinstance(result, httpx.Request) + assert result.headers.get("unstructured-api-key") == FAKE_KEY + assert "Authorization" not in result.headers - def _mock(request: httpx.Request) -> httpx.Response: - captured["headers"] = dict(request.headers) - return httpx.Response(200, json={}) - downstream_transport = httpx.MockTransport(_mock) - downstream_client = httpx.Client(transport=downstream_transport) +def test_integration_sends_bearer_header_for_client_credentials(): + """End-to-end: UnstructuredClient + ClientCredentials -> ``Authorization: Bearer``.""" + exchange_transport = ScriptedTransport( + [exchange_response(access_token="jwt-abc", expires_in=900)] + ) + exchange_http_client = httpx.Client(transport=exchange_transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=ACCOUNTS_URL, + http_client=exchange_http_client, + ) - session = UnstructuredClient( - api_key_auth=cc, - client=downstream_client, - server_url=SERVER_URL, - ) + captured: dict = {} - try: - # Any operation triggers a request; cancel_job is lightweight. - from unstructured_client.models import operations - - session.jobs.cancel_job( - request=operations.CancelJobRequest(job_id="test-job-id"), - ) - except Exception: # noqa: BLE001 - # The mocked 200 with empty JSON won't unmarshal correctly, but - # by then the request already fired and the header was captured. - pass - - headers = captured.get("headers", {}) - assert headers.get("authorization") == "Bearer jwt-abc" - assert "unstructured-api-key" not in {k.lower() for k in headers} - - def it_leaves_legacy_path_unchanged_for_plain_string(self): - captured: dict = {} - - def _mock(request: httpx.Request) -> httpx.Response: - captured["headers"] = dict(request.headers) - return httpx.Response(200, json={}) - - client = httpx.Client(transport=httpx.MockTransport(_mock)) - session = UnstructuredClient( - api_key_auth=FAKE_KEY, - client=client, - server_url=SERVER_URL, + def _mock(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={}) + + downstream_client = httpx.Client(transport=httpx.MockTransport(_mock)) + + session = UnstructuredClient( + api_key_auth=cc, + client=downstream_client, + server_url=SERVER_URL, + ) + + try: + from unstructured_client.models import operations + + session.jobs.cancel_job( + request=operations.CancelJobRequest(job_id="test-job-id"), ) + except Exception: # noqa: BLE001 + # Mocked 200 with empty JSON won't deserialize correctly, but by then + # the outgoing headers were already captured. + pass + + headers = captured.get("headers", {}) + assert headers.get("authorization") == "Bearer jwt-abc" + assert "unstructured-api-key" not in {k.lower() for k in headers} - try: - from unstructured_client.models import operations - session.jobs.cancel_job( - request=operations.CancelJobRequest(job_id="test-job-id"), - ) - except Exception: # noqa: BLE001 - pass +def test_integration_leaves_legacy_path_unchanged_for_plain_string(): + captured: dict = {} + + def _mock(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={}) + + client = httpx.Client(transport=httpx.MockTransport(_mock)) + session = UnstructuredClient( + api_key_auth=FAKE_KEY, + client=client, + server_url=SERVER_URL, + ) + + try: + from unstructured_client.models import operations + + session.jobs.cancel_job( + request=operations.CancelJobRequest(job_id="test-job-id"), + ) + except Exception: # noqa: BLE001 + pass - headers = captured.get("headers", {}) - assert headers.get("unstructured-api-key") == FAKE_KEY - assert "authorization" not in {k.lower() for k in headers} + headers = captured.get("headers", {}) + assert headers.get("unstructured-api-key") == FAKE_KEY + assert "authorization" not in {k.lower() for k in headers} diff --git a/_test_unstructured_client/unit/auth/test_client_credentials.py b/_test_unstructured_client/unit/auth/test_client_credentials.py index 7e606527..c7d983eb 100644 --- a/_test_unstructured_client/unit/auth/test_client_credentials.py +++ b/_test_unstructured_client/unit/auth/test_client_credentials.py @@ -1,14 +1,14 @@ """Unit tests for :class:`unstructured_client.auth.ClientCredentials`. -Uses :class:`httpx.MockTransport` to script the ``/auth/token-exchange`` -endpoint; no real network IO. +Uses a scripted :class:`httpx.MockTransport` to stand in for the +``/auth/token-exchange`` endpoint; no real network IO. """ from __future__ import annotations import logging import threading -from typing import List +from typing import List, Tuple import httpx import pytest @@ -35,8 +35,7 @@ def _make_client_credentials( *, refresh_buffer_seconds: int = 60, max_retries: int = 3, -) -> tuple[ClientCredentials, ScriptedTransport]: - """Build a :class:`ClientCredentials` wired to a scripted transport.""" +) -> Tuple[ClientCredentials, ScriptedTransport]: transport = ScriptedTransport(steps) http_client = httpx.Client(transport=transport) cc = ClientCredentials( @@ -73,249 +72,252 @@ def _now() -> float: return state -class DescribeClientCredentialsFirstExchange: - def it_posts_client_credentials_body(self, fake_clock): - cc, transport = _make_client_credentials( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) +def test_posts_client_credentials_body(fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) - token = cc() - - assert token == "jwt-1" - assert len(transport.requests) == 1 - req = transport.requests[0] - assert req.method == "POST" - assert req.url.path == "/auth/token-exchange" - assert req.headers["content-type"] == "application/json" - assert body_of(req) == { - "grant_type": "client_credentials", - "client_secret": SECRET, - } - - def it_strips_trailing_slash_from_server_url(self, fake_clock): - transport = ScriptedTransport([exchange_response()]) - http_client = httpx.Client(transport=transport) - cc = ClientCredentials( - client_secret=SECRET, - server_url=f"{SERVER_URL}/", - http_client=http_client, - ) + token = cc() - cc() + assert token == "jwt-1" + assert len(transport.requests) == 1 + req = transport.requests[0] + assert req.method == "POST" + assert req.url.path == "/auth/token-exchange" + assert req.headers["content-type"] == "application/json" + assert body_of(req) == { + "grant_type": "client_credentials", + "client_secret": SECRET, + } + + +def test_strips_trailing_slash_from_server_url(fake_clock): + transport = ScriptedTransport([exchange_response()]) + http_client = httpx.Client(transport=transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=f"{SERVER_URL}/", + http_client=http_client, + ) - assert str(transport.requests[0].url).endswith("/auth/token-exchange") - assert "//auth/token-exchange" not in str(transport.requests[0].url) + cc() + assert str(transport.requests[0].url).endswith("/auth/token-exchange") + assert "//auth/token-exchange" not in str(transport.requests[0].url) -class DescribeClientCredentialsCaching: - def it_returns_cached_jwt_within_ttl(self, fake_clock): - cc, transport = _make_client_credentials( - [exchange_response(access_token="jwt-1", expires_in=900)] - ) - first = cc() - second = cc() - third = cc() +def test_returns_cached_jwt_within_ttl(fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) - assert first == second == third == "jwt-1" - assert len(transport.requests) == 1 + first = cc() + second = cc() + third = cc() - def it_refreshes_when_within_buffer_of_expiry(self, fake_clock): - cc, transport = _make_client_credentials( - [ - exchange_response(access_token="jwt-1", expires_in=900), - exchange_response(access_token="jwt-2", expires_in=900), - ], - refresh_buffer_seconds=60, - ) + assert first == second == third == "jwt-1" + assert len(transport.requests) == 1 - assert cc() == "jwt-1" - fake_clock["now"] += 900 - 59 # within the 60s refresh buffer - assert cc() == "jwt-2" - assert len(transport.requests) == 2 - def it_does_not_refresh_outside_buffer(self, fake_clock): - cc, transport = _make_client_credentials( - [exchange_response(access_token="jwt-1", expires_in=900)], - refresh_buffer_seconds=60, - ) +def test_refreshes_when_within_buffer_of_expiry(fake_clock): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + exchange_response(access_token="jwt-2", expires_in=900), + ], + refresh_buffer_seconds=60, + ) + + assert cc() == "jwt-1" + fake_clock["now"] += 900 - 59 # within the 60s refresh buffer + assert cc() == "jwt-2" + assert len(transport.requests) == 2 + +def test_does_not_refresh_outside_buffer(fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token="jwt-1", expires_in=900)], + refresh_buffer_seconds=60, + ) + + cc() + fake_clock["now"] += 900 - 120 # still 120s from expiry + cc() + + assert len(transport.requests) == 1 + + +def test_raises_invalid_credential_on_401_without_retry(fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(401, json={"detail": "invalid"})], + max_retries=5, + ) + + with pytest.raises(InvalidCredentialError): cc() - fake_clock["now"] += 900 - 120 # still 120s from expiry + assert len(transport.requests) == 1 + + +def test_raises_on_400_without_retry(fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(400, json={"detail": "bad"})], + max_retries=5, + ) + + with pytest.raises(TokenExchangeError): cc() + assert len(transport.requests) == 1 - assert len(transport.requests) == 1 +def test_raises_disabled_when_server_opts_out(fake_clock): + cc, transport = _make_client_credentials( + [exchange_response(access_token=None, expires_in=0, token_exchange_enabled=False)] + ) -class DescribeClientCredentialsErrors: - def it_raises_invalid_credential_on_401_without_retry(self, fake_clock): - cc, transport = _make_client_credentials( - [httpx.Response(401, json={"detail": "invalid"})], - max_retries=5, - ) + with pytest.raises(TokenExchangeDisabledError): + cc() + assert len(transport.requests) == 1 - with pytest.raises(InvalidCredentialError): - cc() - assert len(transport.requests) == 1 - def it_raises_on_400_without_retry(self, fake_clock): - cc, transport = _make_client_credentials( - [httpx.Response(400, json={"detail": "bad"})], - max_retries=5, - ) +def test_retries_5xx_then_succeeds(fake_clock): + cc, transport = _make_client_credentials( + [ + httpx.Response(503, json={}), + httpx.Response(500, json={}), + exchange_response(access_token="jwt-1", expires_in=900), + ], + max_retries=3, + ) - with pytest.raises(TokenExchangeError): - cc() - assert len(transport.requests) == 1 + assert cc() == "jwt-1" + assert len(transport.requests) == 3 - def it_raises_disabled_when_server_opts_out(self, fake_clock): - cc, transport = _make_client_credentials( - [exchange_response(access_token=None, expires_in=0, token_exchange_enabled=False)] - ) - with pytest.raises(TokenExchangeDisabledError): - cc() - assert len(transport.requests) == 1 +def test_retries_network_errors_then_succeeds(fake_clock): + cc, transport = _make_client_credentials( + [ + httpx.ConnectError("refused"), + httpx.ReadTimeout("slow"), + exchange_response(access_token="jwt-1", expires_in=900), + ], + max_retries=3, + ) + assert cc() == "jwt-1" + assert len(transport.requests) == 3 -class DescribeClientCredentialsRetry: - def it_retries_5xx_then_succeeds(self, fake_clock): - cc, transport = _make_client_credentials( - [ - httpx.Response(503, json={}), - httpx.Response(500, json={}), - exchange_response(access_token="jwt-1", expires_in=900), - ], - max_retries=3, - ) - assert cc() == "jwt-1" - assert len(transport.requests) == 3 - - def it_retries_network_errors_then_succeeds(self, fake_clock): - cc, transport = _make_client_credentials( - [ - httpx.ConnectError("refused"), - httpx.ReadTimeout("slow"), - exchange_response(access_token="jwt-1", expires_in=900), - ], - max_retries=3, - ) +def test_raises_when_retries_exhausted_without_cached_token(fake_clock): + cc, transport = _make_client_credentials( + [httpx.Response(500, json={})] * 4, + max_retries=3, + ) - assert cc() == "jwt-1" - assert len(transport.requests) == 3 + with pytest.raises(TokenExchangeError): + cc() + assert len(transport.requests) == 4 + + +def test_serves_cached_jwt_during_outage_when_still_within_ttl(fake_clock, caplog): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + httpx.Response(500, json={}), + httpx.Response(502, json={}), + httpx.Response(503, json={}), + httpx.Response(504, json={}), + ], + max_retries=3, + refresh_buffer_seconds=60, + ) - def it_raises_when_retries_exhausted_without_cached_token(self, fake_clock): - cc, transport = _make_client_credentials( - [httpx.Response(500, json={})] * 4, - max_retries=3, - ) + assert cc() == "jwt-1" + fake_clock["now"] += 900 - 30 # past refresh buffer but before absolute expiry - with pytest.raises(TokenExchangeError): - cc() - assert len(transport.requests) == 4 - - def it_serves_cached_jwt_during_outage_when_still_within_ttl( - self, fake_clock, caplog - ): - cc, transport = _make_client_credentials( - [ - exchange_response(access_token="jwt-1", expires_in=900), - httpx.Response(500, json={}), - httpx.Response(502, json={}), - httpx.Response(503, json={}), - httpx.Response(504, json={}), - ], - max_retries=3, - refresh_buffer_seconds=60, - ) + caplog.set_level(logging.WARNING, logger="unstructured-client.auth") + assert cc() == "jwt-1" + assert any( + "serving cached JWT" in record.getMessage() for record in caplog.records + ) - assert cc() == "jwt-1" - fake_clock["now"] += 900 - 30 # past refresh buffer but before absolute expiry - caplog.set_level(logging.WARNING, logger="unstructured-client.auth") - assert cc() == "jwt-1" - assert any( - "serving cached JWT" in record.getMessage() for record in caplog.records - ) +def test_raises_when_cached_token_has_fully_expired(fake_clock): + cc, transport = _make_client_credentials( + [ + exchange_response(access_token="jwt-1", expires_in=900), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + httpx.Response(500, json={}), + ], + max_retries=3, + refresh_buffer_seconds=60, + ) - def it_raises_when_cached_token_has_fully_expired(self, fake_clock): - cc, transport = _make_client_credentials( - [ - exchange_response(access_token="jwt-1", expires_in=900), - httpx.Response(500, json={}), - httpx.Response(500, json={}), - httpx.Response(500, json={}), - httpx.Response(500, json={}), - ], - max_retries=3, - refresh_buffer_seconds=60, - ) + assert cc() == "jwt-1" + fake_clock["now"] += 1000 # past absolute expiry - assert cc() == "jwt-1" - fake_clock["now"] += 1000 # past absolute expiry + with pytest.raises(TokenExchangeError): + cc() - with pytest.raises(TokenExchangeError): - cc() +def test_collapses_concurrent_calls_into_one_exchange(): + """Ten threads calling ``cc()`` concurrently must drive a single exchange.""" + barrier = threading.Barrier(10) + transport = ScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.Client(transport=transport) + cc = ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + results: List[str] = [] + + def _worker(): + barrier.wait() + results.append(cc()) -class DescribeClientCredentialsConcurrency: - def it_collapses_concurrent_calls_into_one_exchange(self): - """Ten threads calling `cc()` at once must produce exactly one - exchange HTTP request.""" - barrier = threading.Barrier(10) - transport = ScriptedTransport( - [exchange_response(access_token="jwt-1", expires_in=900)] + threads = [threading.Thread(target=_worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert results == ["jwt-1"] * 10 + assert len(transport.requests) == 1 + + +@pytest.mark.parametrize("bad_secret", ["", None]) +def test_rejects_empty_secret(bad_secret): + with pytest.raises(ValueError): + ClientCredentials( + client_secret=bad_secret, # type: ignore[arg-type] + server_url=SERVER_URL, ) - http_client = httpx.Client(transport=transport) - cc = ClientCredentials( + + +def test_rejects_empty_server_url(): + with pytest.raises(ValueError): + ClientCredentials(client_secret=SECRET, server_url="") + + +@pytest.mark.parametrize("bad_buffer", [-1, -100]) +def test_rejects_negative_refresh_buffer(bad_buffer): + with pytest.raises(ValueError): + ClientCredentials( client_secret=SECRET, server_url=SERVER_URL, - http_client=http_client, + refresh_buffer_seconds=bad_buffer, ) - results: List[str] = [] - - def _worker(): - barrier.wait() - results.append(cc()) - - threads = [threading.Thread(target=_worker) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert results == ["jwt-1"] * 10 - assert len(transport.requests) == 1 - - -class DescribeClientCredentialsConstructionValidation: - @pytest.mark.parametrize("bad_secret", ["", None]) - def it_rejects_empty_secret(self, bad_secret): - with pytest.raises(ValueError): - ClientCredentials( - client_secret=bad_secret, # type: ignore[arg-type] - server_url=SERVER_URL, - ) - - def it_rejects_empty_server_url(self): - with pytest.raises(ValueError): - ClientCredentials(client_secret=SECRET, server_url="") - - @pytest.mark.parametrize("bad_buffer", [-1, -100]) - def it_rejects_negative_refresh_buffer(self, bad_buffer): - with pytest.raises(ValueError): - ClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - refresh_buffer_seconds=bad_buffer, - ) - - def it_rejects_negative_max_retries(self): - with pytest.raises(ValueError): - ClientCredentials( - client_secret=SECRET, - server_url=SERVER_URL, - max_retries=-1, - ) + +def test_rejects_negative_max_retries(): + with pytest.raises(ValueError): + ClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + max_retries=-1, + ) diff --git a/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py index bf52591d..88bd9f54 100644 --- a/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py +++ b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py @@ -28,6 +28,10 @@ LEGACY_KEY = "uns_ak_legacy_example" +async def _noop_async_sleep(*_args, **_kwargs): + return None + + @pytest.fixture(autouse=True) def _no_sleep(monkeypatch): monkeypatch.setattr( @@ -40,60 +44,54 @@ def _no_sleep(monkeypatch): ) -async def _noop_async_sleep(*_args, **_kwargs): - return None +def test_sends_grant_type_api_key_and_api_key_field(): + transport = ScriptedTransport([exchange_response(access_token="jwt-1")]) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + assert lke() == "jwt-1" + + req = transport.requests[0] + assert body_of(req) == {"grant_type": "api_key", "api_key": LEGACY_KEY} + + +def test_propagates_401_as_invalid_credential(): + transport = ScriptedTransport([httpx.Response(401, json={"detail": "bad"})]) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + with pytest.raises(InvalidCredentialError): + lke() + + +@pytest.mark.asyncio +async def test_async_sends_grant_type_api_key_and_api_key_field(): + transport = AsyncScriptedTransport([exchange_response(access_token="jwt-1")]) + http_client = httpx.AsyncClient(transport=transport) + alke = AsyncLegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + ) + + token = await alke.acquire() + + assert token == "jwt-1" + assert body_of(transport.requests[0]) == { + "grant_type": "api_key", + "api_key": LEGACY_KEY, + } -class DescribeLegacyKeyExchangeBody: - def it_sends_grant_type_api_key_and_api_key_field(self): - transport = ScriptedTransport([exchange_response(access_token="jwt-1")]) - http_client = httpx.Client(transport=transport) - lke = LegacyKeyExchange( - api_key=LEGACY_KEY, - server_url=SERVER_URL, - http_client=http_client, - ) - - assert lke() == "jwt-1" - - req = transport.requests[0] - assert body_of(req) == {"grant_type": "api_key", "api_key": LEGACY_KEY} - - def it_propagates_401_as_invalid_credential(self): - transport = ScriptedTransport([httpx.Response(401, json={"detail": "bad"})]) - http_client = httpx.Client(transport=transport) - lke = LegacyKeyExchange( - api_key=LEGACY_KEY, - server_url=SERVER_URL, - http_client=http_client, - ) - - with pytest.raises(InvalidCredentialError): - lke() - - -class DescribeAsyncLegacyKeyExchangeBody: - @pytest.mark.asyncio - async def it_sends_grant_type_api_key_and_api_key_field(self): - transport = AsyncScriptedTransport([exchange_response(access_token="jwt-1")]) - http_client = httpx.AsyncClient(transport=transport) - alke = AsyncLegacyKeyExchange( - api_key=LEGACY_KEY, - server_url=SERVER_URL, - http_client=http_client, - ) - - token = await alke.acquire() - - assert token == "jwt-1" - assert body_of(transport.requests[0]) == { - "grant_type": "api_key", - "api_key": LEGACY_KEY, - } - - -class DescribeLegacyKeyExchangeConstruction: - @pytest.mark.parametrize("bad", ["", None]) - def it_rejects_empty_api_key(self, bad): - with pytest.raises(ValueError): - LegacyKeyExchange(api_key=bad, server_url=SERVER_URL) # type: ignore[arg-type] +@pytest.mark.parametrize("bad", ["", None]) +def test_rejects_empty_api_key(bad): + with pytest.raises(ValueError): + LegacyKeyExchange(api_key=bad, server_url=SERVER_URL) # type: ignore[arg-type] From 908d41887ad2cfdad15256a6a447758616f94745 Mon Sep 17 00:00:00 2001 From: Mateusz Kuprowski Date: Fri, 24 Apr 2026 17:03:16 +0200 Subject: [PATCH 3/4] Lint fixes --- src/unstructured_client/_hooks/custom/auth_header_hook.py | 6 ++---- src/unstructured_client/auth/_base.py | 1 - src/unstructured_client/auth/client_credentials.py | 7 +++---- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/unstructured_client/_hooks/custom/auth_header_hook.py b/src/unstructured_client/_hooks/custom/auth_header_hook.py index 7ee02618..64f849f5 100644 --- a/src/unstructured_client/_hooks/custom/auth_header_hook.py +++ b/src/unstructured_client/_hooks/custom/auth_header_hook.py @@ -17,6 +17,7 @@ import httpx from unstructured_client._hooks.types import BeforeRequestContext, BeforeRequestHook +from unstructured_client.auth._base import _ExchangeCallableBase class AuthHeaderBeforeRequestHook(BeforeRequestHook): @@ -43,11 +44,8 @@ def _is_exchange_callable(security_source: object) -> bool: token-exchange callables. The SDK wraps a user-supplied callable into an internal factory and - attaches ``__wrapped_callable__`` to it (see ``sdk.py``). We import - the base class lazily to avoid any cycle at module load. + attaches ``__wrapped_callable__`` to it (see ``sdk.py``). """ - from unstructured_client.auth._base import _ExchangeCallableBase - if security_source is None: return False diff --git a/src/unstructured_client/auth/_base.py b/src/unstructured_client/auth/_base.py index e03db337..c5c4d25f 100644 --- a/src/unstructured_client/auth/_base.py +++ b/src/unstructured_client/auth/_base.py @@ -16,7 +16,6 @@ from __future__ import annotations -import asyncio import logging import random import threading diff --git a/src/unstructured_client/auth/client_credentials.py b/src/unstructured_client/auth/client_credentials.py index 5223b6b9..98988395 100644 --- a/src/unstructured_client/auth/client_credentials.py +++ b/src/unstructured_client/auth/client_credentials.py @@ -19,13 +19,14 @@ from __future__ import annotations import asyncio +import concurrent.futures import time from typing import Any, Dict, Optional import httpx from ._base import _ExchangeCallableBase -from ._exceptions import InvalidCredentialError, TokenExchangeError +from ._exceptions import TokenExchangeError class ClientCredentials(_ExchangeCallableBase): @@ -254,14 +255,12 @@ def __call__(self) -> str: temporary loop via :func:`asyncio.run`. """ try: - loop = asyncio.get_running_loop() + asyncio.get_running_loop() except RuntimeError: return asyncio.run(self.acquire()) # Inside a running loop - offload to a worker thread that drives its # own event loop so we don't block the caller's loop on httpx IO. - import concurrent.futures - def _run_in_new_loop() -> str: return asyncio.run(self.acquire()) From 4fb09ac4eccc147cb52061fcf8e5c58b635c52de Mon Sep 17 00:00:00 2001 From: Mateusz Kuprowski Date: Mon, 27 Apr 2026 16:39:00 +0200 Subject: [PATCH 4/4] Code review fixes --- README.md | 10 +- .../test_client_credentials_e2e.py | 12 +- .../auth/test_async_client_credentials.py | 92 ++++++++ .../unit/auth/test_auth_header_hook.py | 34 +++ .../unit/auth/test_legacy_key_exchange.py | 50 +++++ .../_hooks/custom/auth_header_hook.py | 17 +- src/unstructured_client/_version.py | 4 +- src/unstructured_client/auth/__init__.py | 2 - src/unstructured_client/auth/_base.py | 10 +- .../auth/client_credentials.py | 201 +++++++++++++++--- .../auth/legacy_api_key.py | 17 +- 11 files changed, 396 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 438cd8df..673c3104 100644 --- a/README.md +++ b/README.md @@ -504,7 +504,7 @@ s = UnstructuredClient(debug_logger=logging.getLogger("unstructured_client")) ## Authentication with Client Secrets -> Available from SDK version **X.Y.Z** (first release carrying the +> Available from SDK version **0.44.0** (first release carrying the > `unstructured_client.auth` module). If you are running against an Unstructured deployment that issues **client @@ -586,6 +586,14 @@ client = UnstructuredClient( `token_exchange_enabled=False`, the call raises `TokenExchangeDisabledError` — that deployment expects the plain `api_key_auth="..."` string form instead. +- **Custom HTTP client:** pass `http_client=httpx.Client(...)` (or + `httpx.AsyncClient(...)` for the async variant) to share a connection + pool, route through a corporate proxy, pin a custom CA bundle for mTLS, + or otherwise control how the SDK reaches account-service. When the + argument is omitted, the auth callable lazily creates and owns a private + `httpx.Client`; that client is closed automatically when the auth + instance is garbage-collected, or you can call `close()` / + `aclose()` explicitly. ### Backward compatibility diff --git a/_test_unstructured_client/integration/test_client_credentials_e2e.py b/_test_unstructured_client/integration/test_client_credentials_e2e.py index d4b9ba63..2575c481 100644 --- a/_test_unstructured_client/integration/test_client_credentials_e2e.py +++ b/_test_unstructured_client/integration/test_client_credentials_e2e.py @@ -2,15 +2,15 @@ :class:`LegacyKeyExchange`. This test is **opt-in**: it only runs when every required env var is set. -Point it at a deployment (e.g. the ``unsightly-koala`` dedicated-instance -test cluster) that has ``DEPLOYMENT_MODE=dedicated`` and a valid client -secret provisioned via account-service. +Point it at any deployment that has ``DEPLOYMENT_MODE=dedicated`` (or any +other configuration that accepts ``/auth/token-exchange``) and a valid +client secret provisioned via account-service. Required env vars ----------------- - ``UNS_ACCOUNTS_URL`` base URL of account-service (e.g. - ``https://accounts.unsightly-koala.example``) + ``https://accounts..example``) - ``UNS_CLIENT_SECRET`` ``uns_sk_...`` client secret - ``UNS_PLATFORM_API_URL`` platform-api base URL to hit after exchange @@ -48,8 +48,8 @@ _REASON = ( "Opt-in E2E: set UNS_ACCOUNTS_URL, UNS_CLIENT_SECRET, and " - "UNS_PLATFORM_API_URL to run against a real dedicated-instance " - "deployment (e.g. unsightly-koala)." + "UNS_PLATFORM_API_URL to run against a real deployment that supports " + "/auth/token-exchange." ) diff --git a/_test_unstructured_client/unit/auth/test_async_client_credentials.py b/_test_unstructured_client/unit/auth/test_async_client_credentials.py index cf1568bd..9447fef5 100644 --- a/_test_unstructured_client/unit/auth/test_async_client_credentials.py +++ b/_test_unstructured_client/unit/auth/test_async_client_credentials.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import threading from typing import List import httpx @@ -168,3 +169,94 @@ async def test_sync_call_works_inside_running_loop(fake_clock): token = await asyncio.to_thread(acc) assert token == "jwt-1" + + +def test_sync_call_re_exchanges_after_cache_lapse_across_loops(fake_clock): + """Regression: ``__call__`` must not crash on the second exchange. + + Each invocation outside a running loop spins a fresh event loop via + ``asyncio.run``. The internal ``asyncio.Lock`` is created lazily inside + ``acquire`` so it binds to whatever loop is current; if it were created + in ``__init__`` it would stay bound to the first loop and ``async with`` + would raise ``RuntimeError`` on the second exchange. + """ + transport = AsyncScriptedTransport( + [ + exchange_response(access_token="jwt-1", expires_in=900), + exchange_response(access_token="jwt-2", expires_in=900), + ] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + refresh_buffer_seconds=60, + ) + + assert acc() == "jwt-1" + fake_clock["now"] += 900 - 30 # past refresh buffer + assert acc() == "jwt-2" + assert len(transport.requests) == 2 + + +@pytest.mark.asyncio +async def test_acquire_re_uses_correct_async_lock_after_loop_change(fake_clock): + """The lazy per-loop lock must refresh when the running loop changes.""" + transport = AsyncScriptedTransport( + [ + exchange_response(access_token="jwt-1", expires_in=900), + exchange_response(access_token="jwt-2", expires_in=900), + ] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + refresh_buffer_seconds=60, + ) + + first = await acc.acquire() + assert first == "jwt-1" + first_loop = acc._async_lock_loop # type: ignore[attr-defined] + + fake_clock["now"] += 900 - 30 # past refresh buffer + + # Run the next exchange on a *different* event loop driven from a worker + # thread; the lazy-init code path must re-bind the lock to that loop. + second = await asyncio.to_thread(lambda: asyncio.run(acc.acquire())) + assert second == "jwt-2" + + second_loop = acc._async_lock_loop # type: ignore[attr-defined] + assert second_loop is not None + assert second_loop is not first_loop + + +def test_sync_call_coalesces_concurrent_threads_into_one_exchange(fake_clock): + """Two OS threads racing into ``__call__`` must share one exchange.""" + barrier = threading.Barrier(2) + transport = AsyncScriptedTransport( + [exchange_response(access_token="jwt-1", expires_in=900)] + ) + http_client = httpx.AsyncClient(transport=transport) + acc = AsyncClientCredentials( + client_secret=SECRET, + server_url=SERVER_URL, + http_client=http_client, + ) + + results = [] + + def _worker(): + barrier.wait() + results.append(acc()) + + threads = [threading.Thread(target=_worker) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert results == ["jwt-1", "jwt-1"] + assert len(transport.requests) == 1 diff --git a/_test_unstructured_client/unit/auth/test_auth_header_hook.py b/_test_unstructured_client/unit/auth/test_auth_header_hook.py index d55de91f..ef1340e7 100644 --- a/_test_unstructured_client/unit/auth/test_auth_header_hook.py +++ b/_test_unstructured_client/unit/auth/test_auth_header_hook.py @@ -175,6 +175,40 @@ def _mock(request: httpx.Request) -> httpx.Response: assert "unstructured-api-key" not in {k.lower() for k in headers} +def test_security_factory_exposes_wrapped_callable_for_hook_detection(): + """Speakeasy-regen guard: the security factory built by ``UnstructuredClient`` + must expose ``__wrapped_callable__`` so :class:`AuthHeaderBeforeRequestHook` + can recognize our token-exchange callables. + + If a future Speakeasy regeneration strips the hand-edited block in + ``sdk.py``, this test will fail loudly instead of letting the auth-header + rewrite silently break for every user. + """ + transport = ScriptedTransport([exchange_response()]) + cc = ClientCredentials( + client_secret=SECRET, + server_url=ACCOUNTS_URL, + http_client=httpx.Client(transport=transport), + ) + + session = UnstructuredClient( + api_key_auth=cc, + client=httpx.Client(transport=httpx.MockTransport(lambda r: httpx.Response(200))), + server_url=SERVER_URL, + ) + + security_factory = session.sdk_configuration.security + assert callable(security_factory), ( + "When api_key_auth is callable, sdk.py must wrap it in a factory." + ) + wrapped = getattr(security_factory, "__wrapped_callable__", None) + assert wrapped is cc, ( + "sdk.py must attach __wrapped_callable__ pointing at the original " + "user callable; without it AuthHeaderBeforeRequestHook cannot detect " + "ClientCredentials / LegacyKeyExchange instances." + ) + + def test_integration_leaves_legacy_path_unchanged_for_plain_string(): captured: dict = {} diff --git a/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py index 88bd9f54..75eb1039 100644 --- a/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py +++ b/_test_unstructured_client/unit/auth/test_legacy_key_exchange.py @@ -95,3 +95,53 @@ async def test_async_sends_grant_type_api_key_and_api_key_field(): def test_rejects_empty_api_key(bad): with pytest.raises(ValueError): LegacyKeyExchange(api_key=bad, server_url=SERVER_URL) # type: ignore[arg-type] + + +def test_caches_jwt_within_ttl_for_legacy_path(monkeypatch): + """Caching applies to the legacy api-key path just like client-secrets.""" + state = {"now": 1_000_000.0} + monkeypatch.setattr("unstructured_client.auth._base.time.monotonic", lambda: state["now"]) + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.monotonic", lambda: state["now"] + ) + + transport = ScriptedTransport([exchange_response(access_token="jwt-1", expires_in=900)]) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + refresh_buffer_seconds=60, + ) + + assert lke() == "jwt-1" + assert lke() == "jwt-1" + assert lke() == "jwt-1" + assert len(transport.requests) == 1 + + +def test_retries_5xx_on_legacy_path_then_succeeds(monkeypatch): + """5xx exponential-backoff retries also apply to the legacy api-key path.""" + state = {"now": 1_000_000.0} + monkeypatch.setattr("unstructured_client.auth._base.time.monotonic", lambda: state["now"]) + monkeypatch.setattr( + "unstructured_client.auth.client_credentials.time.monotonic", lambda: state["now"] + ) + + transport = ScriptedTransport( + [ + httpx.Response(500), + httpx.Response(503), + exchange_response(access_token="jwt-1", expires_in=900), + ] + ) + http_client = httpx.Client(transport=transport) + lke = LegacyKeyExchange( + api_key=LEGACY_KEY, + server_url=SERVER_URL, + http_client=http_client, + max_retries=3, + ) + + assert lke() == "jwt-1" + assert len(transport.requests) == 3 diff --git a/src/unstructured_client/_hooks/custom/auth_header_hook.py b/src/unstructured_client/_hooks/custom/auth_header_hook.py index 64f849f5..fae412a5 100644 --- a/src/unstructured_client/_hooks/custom/auth_header_hook.py +++ b/src/unstructured_client/_hooks/custom/auth_header_hook.py @@ -1,11 +1,12 @@ """Before-request hook that promotes exchanged JWTs to ``Authorization: Bearer``. -Speakeasy's generated ``Security`` model places ``api_key_auth`` in the +The generated ``Security`` model places ``api_key_auth`` in the ``unstructured-api-key`` header. When the user supplies a token-exchange -callable (``ClientCredentials`` or ``LegacyKeyExchange``) the value is a JWT -and must be sent as ``Authorization: Bearer `` so the service-side -``utic-jwt-auth`` validator picks it up (see ``core-product`` auth_context -and ``platform-api`` public_api/dependencies). +callable from :mod:`unstructured_client.auth` (such as +:class:`~unstructured_client.auth.ClientCredentials` or +:class:`~unstructured_client.auth.LegacyKeyExchange`) the value is a JWT +and must be sent as ``Authorization: Bearer `` so the server-side +JWT validator accepts it. Plain-string ``api_key_auth`` is untouched. """ @@ -43,8 +44,10 @@ def _is_exchange_callable(security_source: object) -> bool: """Return True when ``security_source`` was built from one of our token-exchange callables. - The SDK wraps a user-supplied callable into an internal factory and - attaches ``__wrapped_callable__`` to it (see ``sdk.py``). + ``UnstructuredClient.__init__`` wraps a user-supplied callable into an + internal security factory and attaches ``__wrapped_callable__`` to it + so this hook can detect token-exchange instances without reaching + into the lambda's closure. """ if security_source is None: return False diff --git a/src/unstructured_client/_version.py b/src/unstructured_client/_version.py index e3a41815..6b0c1dcb 100644 --- a/src/unstructured_client/_version.py +++ b/src/unstructured_client/_version.py @@ -3,10 +3,10 @@ import importlib.metadata __title__: str = "unstructured-client" -__version__: str = "0.43.3" +__version__: str = "0.44.0" __openapi_doc_version__: str = "1.2.31" __gen_version__: str = "2.680.0" -__user_agent__: str = "speakeasy-sdk/python 0.43.3 2.680.0 1.2.31 unstructured-client" +__user_agent__: str = "speakeasy-sdk/python 0.44.0 2.680.0 1.2.31 unstructured-client" try: if __package__ is not None: diff --git a/src/unstructured_client/auth/__init__.py b/src/unstructured_client/auth/__init__.py index 0cffd986..48bd7eec 100644 --- a/src/unstructured_client/auth/__init__.py +++ b/src/unstructured_client/auth/__init__.py @@ -19,7 +19,6 @@ as the ``unstructured-api-key`` header. """ -from ._base import _ExchangeCallableBase from ._exceptions import ( InvalidCredentialError, TokenExchangeDisabledError, @@ -36,5 +35,4 @@ "LegacyKeyExchange", "TokenExchangeDisabledError", "TokenExchangeError", - "_ExchangeCallableBase", ] diff --git a/src/unstructured_client/auth/_base.py b/src/unstructured_client/auth/_base.py index c5c4d25f..5c152a21 100644 --- a/src/unstructured_client/auth/_base.py +++ b/src/unstructured_client/auth/_base.py @@ -165,12 +165,16 @@ def _raise_for_status(self, response: httpx.Response) -> None: """Map HTTP status to auth-specific exceptions before retry decisions.""" if response.status_code == 401: raise InvalidCredentialError( - "Account-service rejected the credential (401). Check that the " - "client secret / API key is correct and not revoked.", + f"Account-service rejected the credential (401) at " + f"{self._exchange_url}. Check that the client secret / API " + f"key is correct and not revoked.", ) if response.status_code == 400: raise TokenExchangeError( - f"Account-service rejected the token-exchange request (400): " + f"Account-service rejected the token-exchange request (400) " + f"at {self._exchange_url}. This commonly indicates the " + f"server_url is pointing somewhere other than account-service " + f"(for example, at platform-api). Response body: " f"{response.text[:500]}", ) diff --git a/src/unstructured_client/auth/client_credentials.py b/src/unstructured_client/auth/client_credentials.py index 98988395..d675f198 100644 --- a/src/unstructured_client/auth/client_credentials.py +++ b/src/unstructured_client/auth/client_credentials.py @@ -20,7 +20,9 @@ import asyncio import concurrent.futures +import threading import time +import weakref from typing import Any, Dict, Optional import httpx @@ -29,6 +31,39 @@ from ._exceptions import TokenExchangeError +def _close_httpx_client(client: Optional[httpx.Client]) -> None: + """Best-effort sync close used by :func:`weakref.finalize`.""" + if client is None: + return + try: + client.close() + except Exception: # noqa: BLE001 - finalize must never raise + pass + + +def _close_async_httpx_client(client: Optional[httpx.AsyncClient]) -> None: + """Best-effort async close from :func:`weakref.finalize`. + + ``AsyncClient.aclose`` is a coroutine, so we run it on a fresh event loop + on a worker thread to avoid touching whatever loop (if any) the user is + currently running. + """ + if client is None: + return + + def _run() -> None: + try: + asyncio.run(client.aclose()) + except Exception: # noqa: BLE001 + pass + + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + pool.submit(_run).result(timeout=5) + except Exception: # noqa: BLE001 - finalize must never raise + pass + + class ClientCredentials(_ExchangeCallableBase): """Synchronous ``client_credentials`` grant callable. @@ -73,6 +108,28 @@ def __init__( self._client_secret = client_secret self._http_client = http_client self._owns_http_client = http_client is None + # Separate init lock so :meth:`_get_http_client` can atomically create + # the lazy client even though :meth:`_exchange` already holds the + # outer ``self._lock`` (a non-reentrant ``threading.Lock``). + self._http_client_init_lock = threading.Lock() + self._finalizer: Optional[weakref.finalize] = None + if self._owns_http_client: + # Register a finalizer that closes a lazily-created private + # ``httpx.Client`` if the user never calls :meth:`close`. The + # closure binds an attribute lookup (``self._http_client``) at + # finalize-time via a small accessor so the finalizer doesn't + # itself keep ``self`` alive. + owner_ref = weakref.ref(self) + + def _finalize() -> None: + owner = owner_ref() + if owner is None: + return + # pylint: disable=protected-access + _close_httpx_client(owner._http_client) + owner._http_client = None + + self._finalizer = weakref.finalize(self, _finalize) def _build_request_body(self) -> Dict[str, Any]: return { @@ -81,9 +138,18 @@ def _build_request_body(self) -> Dict[str, Any]: } def _get_http_client(self) -> httpx.Client: - if self._http_client is None: - self._http_client = httpx.Client(timeout=self._request_timeout_seconds) - return self._http_client + """Return the lazily-initialized private ``httpx.Client``. + + Atomic across threads: a dedicated init lock guarantees that only a + single private client is ever created even if two callers race here + before any cache exists. + """ + if self._http_client is not None: + return self._http_client + with self._http_client_init_lock: + if self._http_client is None: + self._http_client = httpx.Client(timeout=self._request_timeout_seconds) + return self._http_client def __call__(self) -> str: """Return a valid JWT, performing an exchange only when necessary.""" @@ -145,16 +211,23 @@ def close(self) -> None: if self._owns_http_client and self._http_client is not None: self._http_client.close() self._http_client = None + if self._finalizer is not None: + # We've already cleaned up; detach the finalizer so it doesn't + # double-close. + self._finalizer.detach() + self._finalizer = None class AsyncClientCredentials(_ExchangeCallableBase): """Asynchronous twin of :class:`ClientCredentials`. - The synchronous wrapper (:meth:`__call__`) runs the async exchange via - :func:`asyncio.run` when invoked from a non-async context, so it can still - be plugged into the SDK's sync-only ``api_key_auth`` callable hook. When - already inside a running loop, it uses that loop's executor to avoid - deadlocking. + Async callers should ``await acquire()`` for a non-blocking exchange. + + The synchronous :meth:`__call__` exists so ``AsyncClientCredentials`` can + still be plugged into the SDK's sync-only ``api_key_auth`` callable hook. + Note that calling :meth:`__call__` from inside a running event loop blocks + that loop while the exchange runs on a worker thread - prefer + :meth:`acquire` in async-native code. """ def __init__( @@ -178,7 +251,29 @@ def __init__( self._client_secret = client_secret self._http_client = http_client self._owns_http_client = http_client is None - self._async_lock = asyncio.Lock() + self._http_client_init_lock = threading.Lock() + # The async lock is created lazily inside a coroutine so it binds to + # the *currently running* event loop. We refresh it whenever the + # running loop changes (e.g. a fresh ``asyncio.run()`` invocation + # after the cache lapses). A dedicated init lock guards the lazy + # field assignment so it stays decoupled from ``self._lock`` (the + # sync-entry coalescing lock) and can't deadlock with it. + self._async_lock: Optional[asyncio.Lock] = None + self._async_lock_loop: Optional[asyncio.AbstractEventLoop] = None + self._async_lock_init_lock = threading.Lock() + self._finalizer: Optional[weakref.finalize] = None + if self._owns_http_client: + owner_ref = weakref.ref(self) + + def _finalize() -> None: + owner = owner_ref() + if owner is None: + return + # pylint: disable=protected-access + _close_async_httpx_client(owner._http_client) + owner._http_client = None + + self._finalizer = weakref.finalize(self, _finalize) def _build_request_body(self) -> Dict[str, Any]: return { @@ -187,9 +282,33 @@ def _build_request_body(self) -> Dict[str, Any]: } def _get_http_client(self) -> httpx.AsyncClient: - if self._http_client is None: - self._http_client = httpx.AsyncClient(timeout=self._request_timeout_seconds) - return self._http_client + """Return the lazily-initialized private ``httpx.AsyncClient``. + + Atomic across threads: a dedicated init lock guarantees only a single + private client is ever created. + """ + if self._http_client is not None: + return self._http_client + with self._http_client_init_lock: + if self._http_client is None: + self._http_client = httpx.AsyncClient( + timeout=self._request_timeout_seconds + ) + return self._http_client + + def _get_async_lock(self) -> asyncio.Lock: + """Return an :class:`asyncio.Lock` bound to the *currently running* loop. + + A dedicated threading lock guards the lazy field assignment so + multiple OS threads concurrently driving their own event loops + can't race to replace the lock. + """ + loop = asyncio.get_running_loop() + with self._async_lock_init_lock: + if self._async_lock is None or self._async_lock_loop is not loop: + self._async_lock = asyncio.Lock() + self._async_lock_loop = loop + return self._async_lock async def acquire(self) -> str: """Async variant of ``__call__``. Returns a valid JWT.""" @@ -198,7 +317,7 @@ async def acquire(self) -> str: if cached is not None: return cached - async with self._async_lock: + async with self._get_async_lock(): now = time.monotonic() cached = self._cached_token_if_fresh(now) if cached is not None: @@ -249,27 +368,59 @@ async def _exchange(self) -> str: def __call__(self) -> str: """Sync entry point so the SDK's ``api_key_auth`` callable hook works. - When invoked from inside a running event loop (the usual case for - async SDK methods), the exchange runs in the loop's default executor - so we don't reenter :func:`asyncio.run`. Otherwise we spin up a - temporary loop via :func:`asyncio.run`. + Async-native code should ``await acquire()`` instead - it does not + block the caller's event loop and does not pay the cost of spinning + up a new event loop on a worker thread. + + Concurrency notes: + + * Outside a running loop, ``asyncio.run(self.acquire())`` drives a + fresh event loop. The threading lock coalesces concurrent OS + threads onto a single exchange so we don't fire N HTTP calls. + * Inside a running loop, we offload to a worker thread that runs a + private event loop. ``future.result()`` then blocks the caller's + loop until the exchange completes. This is unavoidable while we + have to return a sync value to Speakeasy's security factory; if + you need a non-blocking path, await :meth:`acquire` directly. """ + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached + try: asyncio.get_running_loop() + inside_running_loop = True except RuntimeError: - return asyncio.run(self.acquire()) + inside_running_loop = False - # Inside a running loop - offload to a worker thread that drives its - # own event loop so we don't block the caller's loop on httpx IO. - def _run_in_new_loop() -> str: - return asyncio.run(self.acquire()) + # Coalesce concurrent OS threads onto a single in-flight exchange. + # ``self._lock`` is non-reentrant; ``acquire()`` and helpers that run + # under it use dedicated init locks (``_http_client_init_lock`` and + # ``_async_lock_init_lock``) to avoid re-entering this lock. + with self._lock: + now = time.monotonic() + cached = self._cached_token_if_fresh(now) + if cached is not None: + return cached - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(_run_in_new_loop) - return future.result() + if not inside_running_loop: + return asyncio.run(self.acquire()) + + # Inside a running loop. ``future.result()`` blocks the caller's + # loop while the worker thread runs the exchange (see docstring). + def _run_in_new_loop() -> str: + return asyncio.run(self.acquire()) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(_run_in_new_loop) + return future.result() async def aclose(self) -> None: """Close the private HTTP client, if one was created internally.""" if self._owns_http_client and self._http_client is not None: await self._http_client.aclose() self._http_client = None + if self._finalizer is not None: + self._finalizer.detach() + self._finalizer = None diff --git a/src/unstructured_client/auth/legacy_api_key.py b/src/unstructured_client/auth/legacy_api_key.py index a15b7235..098cbec9 100644 --- a/src/unstructured_client/auth/legacy_api_key.py +++ b/src/unstructured_client/auth/legacy_api_key.py @@ -8,7 +8,9 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Dict, Optional + +import httpx from .client_credentials import AsyncClientCredentials, ClientCredentials @@ -21,6 +23,9 @@ class LegacyKeyExchange(ClientCredentials): refresh, and retry behavior are identical to :class:`~unstructured_client.auth.ClientCredentials`. + The legacy key is stored once on the inherited ``_client_secret`` slot; + only the request body shape differs from the parent class. + .. deprecated:: Prefer :class:`~unstructured_client.auth.ClientCredentials` with a ``uns_sk_...`` client secret. ``LegacyKeyExchange`` exists only to @@ -36,7 +41,7 @@ def __init__( refresh_buffer_seconds: int = 60, request_timeout_seconds: float = 30.0, max_retries: int = 3, - http_client=None, + http_client: Optional[httpx.Client] = None, ) -> None: if not api_key: raise ValueError("api_key must be a non-empty string") @@ -48,10 +53,9 @@ def __init__( max_retries=max_retries, http_client=http_client, ) - self._api_key = api_key def _build_request_body(self) -> Dict[str, Any]: - return {"grant_type": "api_key", "api_key": self._api_key} + return {"grant_type": "api_key", "api_key": self._client_secret} class AsyncLegacyKeyExchange(AsyncClientCredentials): @@ -70,7 +74,7 @@ def __init__( refresh_buffer_seconds: int = 60, request_timeout_seconds: float = 30.0, max_retries: int = 3, - http_client=None, + http_client: Optional[httpx.AsyncClient] = None, ) -> None: if not api_key: raise ValueError("api_key must be a non-empty string") @@ -82,7 +86,6 @@ def __init__( max_retries=max_retries, http_client=http_client, ) - self._api_key = api_key def _build_request_body(self) -> Dict[str, Any]: - return {"grant_type": "api_key", "api_key": self._api_key} + return {"grant_type": "api_key", "api_key": self._client_secret}