Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/oxia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
ComparisonType,
)
from oxia.defs import (
Authentication,
NotificationType,
Notification,
SequenceUpdates,
)

__all__ = [
'ex',
'auth',
'Authentication',
'ComparisonType',
'Client',
'Version',
Expand Down
54 changes: 54 additions & 0 deletions src/oxia/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 The Oxia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Authentication implementations for the Oxia client.

Usage::

from oxia.auth import TokenAuthentication

client = oxia.Client('oxia.example.com:6648', tls=True,
authentication=TokenAuthentication('my-token'))
"""

from typing import Callable, Union

from oxia.defs import Authentication


class TokenAuthentication(Authentication):
"""Bearer-token authentication.

Emits an ``Authorization: Bearer <token>`` metadata header on every
RPC. The token can be static or supplied by a callable for
refresh-on-demand semantics.
"""

_AUTHORIZATION_KEY = "authorization"
_BEARER_PREFIX = "Bearer "

def __init__(self, token: Union[str, Callable[[], str]]):
"""
@param token: Either a static token string, or a zero-argument
callable that returns the current token. The callable is
invoked on every RPC.
"""
if callable(token):
self._token_supplier = token
else:
self._token_supplier = lambda: token

def generate_credentials(self) -> dict[str, str]:
return {self._AUTHORIZATION_KEY:
self._BEARER_PREFIX + self._token_supplier()}
9 changes: 8 additions & 1 deletion src/oxia/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self, service_address: str,
session_timeout_ms: int = 30_000,
client_identifier: str = None,
request_timeout_ms: int = 30_000,
authentication: oxia.defs.Authentication = None,
):
"""Create a new Oxia client.

Expand All @@ -167,9 +168,15 @@ def __init__(self, service_address: str,
sequence updates, shard assignments) are not bounded.
Default is 30 000 ms. A ``grpc.RpcError`` with
``StatusCode.DEADLINE_EXCEEDED`` is raised on timeout.
@param authentication: Optional L{oxia.defs.Authentication}
implementation. If provided, its credentials are attached
as gRPC metadata on every outgoing RPC. See
L{oxia.auth.TokenAuthentication} for the bearer-token
implementation.
"""
self._closed = False
self._connections = ConnectionPool(request_timeout_ms=request_timeout_ms)
self._connections = ConnectionPool(request_timeout_ms=request_timeout_ms,
authentication=authentication)
self._service_discovery = ServiceDiscovery(service_address, self._connections, namespace)
self._session_manager = SessionManager(self._service_discovery, session_timeout_ms, client_identifier)

Expand Down
21 changes: 20 additions & 1 deletion src/oxia/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import ABC, abstractmethod
from enum import Enum
from typing import Iterator

Expand Down Expand Up @@ -62,3 +62,22 @@ class SequenceUpdates(Iterator[str], ABC):
def close(self):
"""Close the subscription and release resources."""
pass


class Authentication(ABC):
"""Pluggable authentication for Oxia client RPCs.

Implementations return a dict of metadata entries that will be
attached to every outgoing RPC (e.g. ``{"authorization": "Bearer <token>"}``).
``generate_credentials`` is called on every RPC, so dynamic
token-refresh schemes are supported by returning a fresh value each call.
"""

@abstractmethod
def generate_credentials(self) -> dict[str, str]:
"""Return the gRPC metadata to attach to each outgoing RPC.

Keys are interpreted as gRPC metadata header names (will be
lowercased by the transport). Values are arbitrary strings.
"""
...
11 changes: 9 additions & 2 deletions src/oxia/internal/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,26 @@
from typing import Optional

import grpc
from oxia.internal.interceptors import RequestTimeoutInterceptor
from oxia.defs import Authentication
from oxia.internal.interceptors import (
AuthenticationInterceptor,
RequestTimeoutInterceptor,
)
from oxia.internal.proto.io.streamnative.oxia.proto import OxiaClientStub


class ConnectionPool:

def __init__(self, request_timeout_ms: Optional[int] = None):
def __init__(self, request_timeout_ms: Optional[int] = None,
authentication: Optional[Authentication] = None):
self._lock = threading.Lock()
self.connections = {}
self._interceptors = []
if request_timeout_ms is not None:
self._interceptors.append(
RequestTimeoutInterceptor(request_timeout_ms / 1000.0))
if authentication is not None:
self._interceptors.append(AuthenticationInterceptor(authentication))

def get(self, address) -> OxiaClientStub:
with self._lock:
Expand Down
28 changes: 28 additions & 0 deletions src/oxia/internal/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import grpc

from oxia.defs import Authentication

# Server streaming RPCs that are long-lived and should NOT carry a
# request timeout (they stay open for the lifetime of the client
# subscription).
Expand Down Expand Up @@ -50,3 +52,29 @@ def intercept_unary_unary(self, continuation, client_call_details, request):

def intercept_unary_stream(self, continuation, client_call_details, request):
return continuation(self._with_timeout(client_call_details), request)


class AuthenticationInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
):
"""Attach the credentials produced by an L{oxia.defs.Authentication}
implementation to every outgoing RPC as gRPC metadata headers."""

def __init__(self, authentication: Authentication):
self._authentication = authentication

def _with_auth(self, client_call_details):
credentials = self._authentication.generate_credentials()
if not credentials:
return client_call_details
metadata = list(client_call_details.metadata or [])
for k, v in credentials.items():
metadata.append((k.lower(), v))
return client_call_details._replace(metadata=metadata)

def intercept_unary_unary(self, continuation, client_call_details, request):
return continuation(self._with_auth(client_call_details), request)

def intercept_unary_stream(self, continuation, client_call_details, request):
return continuation(self._with_auth(client_call_details), request)
143 changes: 143 additions & 0 deletions tests/auth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2025 The Oxia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the authentication subsystem."""

from collections import namedtuple

from oxia.auth import TokenAuthentication
from oxia.defs import Authentication
from oxia.internal.interceptors import AuthenticationInterceptor


CallDetails = namedtuple("CallDetails",
["method", "timeout", "metadata", "credentials",
"wait_for_ready", "compression"])


def _details(method="/io.streamnative.oxia.proto.OxiaClient/Write",
metadata=None):
return CallDetails(method, None, metadata, None, None, None)


class _Recorder:
def __init__(self):
self.last_details = None

def __call__(self, details, request):
self.last_details = details
return None


# ---------------------------------------------------------------------------
# TokenAuthentication
# ---------------------------------------------------------------------------

def test_token_authentication_static_token():
auth = TokenAuthentication("my-secret")
creds = auth.generate_credentials()
assert creds == {"authorization": "Bearer my-secret"}


def test_token_authentication_dynamic_supplier():
counter = [0]

def supplier():
counter[0] += 1
return f"token-{counter[0]}"

auth = TokenAuthentication(supplier)
assert auth.generate_credentials() == {"authorization": "Bearer token-1"}
assert auth.generate_credentials() == {"authorization": "Bearer token-2"}


def test_token_authentication_is_an_authentication():
assert isinstance(TokenAuthentication("x"), Authentication)


# ---------------------------------------------------------------------------
# AuthenticationInterceptor
# ---------------------------------------------------------------------------

def test_interceptor_adds_metadata_on_unary_unary():
auth = TokenAuthentication("tok")
interceptor = AuthenticationInterceptor(auth)
rec = _Recorder()
interceptor.intercept_unary_unary(rec, _details(), object())
assert ("authorization", "Bearer tok") in rec.last_details.metadata


def test_interceptor_adds_metadata_on_unary_stream():
auth = TokenAuthentication("tok")
interceptor = AuthenticationInterceptor(auth)
rec = _Recorder()
interceptor.intercept_unary_stream(rec, _details(), object())
assert ("authorization", "Bearer tok") in rec.last_details.metadata


def test_interceptor_preserves_existing_metadata():
auth = TokenAuthentication("tok")
interceptor = AuthenticationInterceptor(auth)
rec = _Recorder()
existing = [("x-trace-id", "abc123")]
interceptor.intercept_unary_unary(rec, _details(metadata=existing), object())
assert ("x-trace-id", "abc123") in rec.last_details.metadata
assert ("authorization", "Bearer tok") in rec.last_details.metadata


def test_interceptor_lowercases_header_keys():
"""gRPC metadata keys must be lowercase."""
class UppercaseAuth(Authentication):
def generate_credentials(self):
return {"Authorization": "Bearer x", "X-Custom-Header": "v"}

rec = _Recorder()
AuthenticationInterceptor(UppercaseAuth()).intercept_unary_unary(
rec, _details(), object())
keys = [k for k, _v in rec.last_details.metadata]
assert all(k == k.lower() for k in keys), \
f"metadata keys must be lowercase, got {keys}"


def test_interceptor_skips_empty_credentials():
class NoAuth(Authentication):
def generate_credentials(self):
return {}

rec = _Recorder()
original = _details()
AuthenticationInterceptor(NoAuth()).intercept_unary_unary(rec, original, object())
# No change to details
assert rec.last_details is original


# ---------------------------------------------------------------------------
# ConnectionPool wiring
# ---------------------------------------------------------------------------

def test_connection_pool_installs_auth_interceptor():
from oxia.internal.connection_pool import ConnectionPool

auth = TokenAuthentication("tok")
pool = ConnectionPool(authentication=auth)
assert any(isinstance(i, AuthenticationInterceptor)
for i in pool._interceptors)


def test_connection_pool_no_auth_interceptor_by_default():
from oxia.internal.connection_pool import ConnectionPool

pool = ConnectionPool()
assert not any(isinstance(i, AuthenticationInterceptor)
for i in pool._interceptors)
Loading