Skip to content

Conversation

@keurcien
Copy link
Contributor

@keurcien keurcien commented Dec 12, 2025

Partial fix for #1318.

Motivation and Context

When storing OAuth tokens in a persistent storage, the OAuthClientProvider context would never refresh tokens, because it would consider tokens valid (by is_token_valid()) and proceed with the expired access_token until it meets a 401 response from the MCP server, forcing users to start a new OAuth flow.

The proposed change is small and only solve part of the problem. It's mainly inspired by the discussion from #1318 and the approach taken in the FastMCP library: https://github.com/jlowin/fastmcp/blob/main/src/fastmcp/client/auth/oauth.py

This fix works in the case where token.expires_in returned by the TokenStorage is well calculated (which is not the case if we just simply store the OAuthToken as is, cf. below), and when the token endpoint is the default MCP_SERVER_URL/token.

This PR could also be considered a draft, as it raises other questions about how client_metadata (in case where token endpoint is not obvious) and token.expires_at values should be stored.

How Has This Been Tested?

Here's a sample script to test the proposed change (tested with Notion MCP and Linear MCP). It uses a StoredToken class to hold an extra expires_at value, to return a correct token.expires_in value in get_tokens (cf. jlowin/fastmcp@f73b7b5)

Steps:

  • Run the script once to retrieve the OAuth token.
  • Rerun the script when the token expires: current implementation would trigger the 401 error. New implementation would silently refresh the token.
import asyncio
import json
from pathlib import Path
from urllib.parse import parse_qs, urlparse

import httpx
from datetime import datetime, timezone, timedelta
from pydantic import AnyUrl, BaseModel

from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken


SERVER_URL = "https://mcp.notion.com/mcp"


class StoredToken(BaseModel):
    token_payload: OAuthToken
    expires_at: datetime | None


class FileTokenStorage(TokenStorage):
    """File-backed token storage."""

    def __init__(self, tokens_file: Path = Path("tokens.json"), client_info_file: Path = Path("client_info.json")):
        self.tokens_file = tokens_file
        self.client_info_file = client_info_file

    async def get_tokens(self) -> OAuthToken | None:
        if not self.tokens_file.exists():
            return None

        try:
            stored_token = StoredToken.model_validate_json(self.tokens_file.read_text())
        except (OSError, json.JSONDecodeError):
            return None

        if stored_token.expires_at is not None:
            now = datetime.now(timezone.utc)

            if stored_token.token_payload.expires_in is not None:
                remaining = stored_token.expires_at - now
                stored_token.token_payload.expires_in = max(0, int(remaining.total_seconds()))
        
        return stored_token.token_payload

    async def set_tokens(self, tokens: OAuthToken) -> None:
        expires_at: datetime | None = None

        if tokens.expires_in is not None:
            expires_at = datetime.now(timezone.utc) + timedelta(seconds=tokens.expires_in)

        stored_token = StoredToken(token_payload=tokens, expires_at=expires_at)

        self.tokens_file.write_text(stored_token.model_dump_json())

    async def get_client_info(self) -> OAuthClientInformationFull | None:
        if not self.client_info_file.exists():
            return None

        try:
            return OAuthClientInformationFull.model_validate_json(self.client_info_file.read_text())
        except (OSError, json.JSONDecodeError):
            return None

    async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
        self.client_info_file.write_text(client_info.model_dump_json())


async def handle_redirect(auth_url: str) -> None:
    print(f"Visit: {auth_url}")


async def handle_callback() -> tuple[str, str | None]:
    callback_url = input("Paste callback URL: ")
    params = parse_qs(urlparse(callback_url).query)
    return params["code"][0], params.get("state", [None])[0]


async def main():
    """Run the OAuth client example."""
    oauth_auth = OAuthClientProvider(
        server_url=SERVER_URL,
        client_metadata=OAuthClientMetadata(
            client_name="Example MCP Client",
            redirect_uris=[AnyUrl("http://localhost:3000/callback")],
            grant_types=["authorization_code", "refresh_token"],
            response_types=["code"],
            scope="user",
        ),
        storage=FileTokenStorage(),
        redirect_handler=handle_redirect,
        callback_handler=handle_callback,
    )

    async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
        async with streamable_http_client(SERVER_URL, http_client=custom_client) as (read, write, _):
            async with ClientSession(read, write) as session:
                await session.initialize()

                tools = await session.list_tools()
                print(f"Available tools: {[tool.name for tool in tools.tools]}")

                resources = await session.list_resources()
                print(f"Available resources: {[r.uri for r in resources.resources]}")


def run():
    asyncio.run(main())


if __name__ == "__main__":
    run()

Breaking Changes

No breaking changes.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation update

Checklist

  • I have read the MCP Documentation
  • My code follows the repository's style guidelines
  • New and existing tests pass locally
  • I have added appropriate error handling
  • I have added or updated documentation as needed

Additional context

The added tests complement the existing test_token_validity_check as it assumes the token_expiry_time gets set at initialization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant