Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 66 additions & 25 deletions packages/client/src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import type {
OAuthClientMetadata,
OAuthMetadata,
OAuthProtectedResourceMetadata,
OAuthTokens
OAuthTokens,
UserAgentProvider
} from '@modelcontextprotocol/core';
import {
checkResourceAllowed,
createUserAgentProvider,
InvalidClientError,
InvalidClientMetadataError,
InvalidGrantError,
Expand Down Expand Up @@ -361,18 +363,23 @@ export async function auth(
scope?: string;
resourceMetadataUrl?: URL;
fetchFn?: FetchLike;
userAgentProvider?: UserAgentProvider;
}
): Promise<AuthResult> {
const optionsWithDefaults = {
...options,
userAgentProvider: options.userAgentProvider ?? createUserAgentProvider()
};
try {
return await authInternal(provider, options);
return await authInternal(provider, optionsWithDefaults);
} catch (error) {
// Handle recoverable error types by invalidating credentials and retrying
if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) {
await provider.invalidateCredentials?.('all');
return await authInternal(provider, options);
return await authInternal(provider, optionsWithDefaults);
} else if (error instanceof InvalidGrantError) {
await provider.invalidateCredentials?.('tokens');
return await authInternal(provider, options);
return await authInternal(provider, optionsWithDefaults);
}

// Throw otherwise
Expand All @@ -387,20 +394,22 @@ async function authInternal(
authorizationCode,
scope,
resourceMetadataUrl,
fetchFn
fetchFn,
userAgentProvider
}: {
serverUrl: string | URL;
authorizationCode?: string;
scope?: string;
resourceMetadataUrl?: URL;
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<AuthResult> {
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
let authorizationServerUrl: string | URL | undefined;

try {
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn);
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, userAgentProvider, { resourceMetadataUrl }, fetchFn);
if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) {
authorizationServerUrl = resourceMetadata.authorization_servers[0];
}
Expand All @@ -418,7 +427,7 @@ async function authInternal(

const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata);

const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, {
const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, userAgentProvider, {
fetchFn
});

Expand Down Expand Up @@ -455,6 +464,7 @@ async function authInternal(
const fullInformation = await registerClient(authorizationServerUrl, {
metadata,
clientMetadata: provider.clientMetadata,
userAgentProvider,
fetchFn
});

Expand All @@ -472,7 +482,8 @@ async function authInternal(
metadata,
resource,
authorizationCode,
fetchFn
fetchFn,
userAgentProvider
});

await provider.saveTokens(tokens);
Expand All @@ -491,7 +502,8 @@ async function authInternal(
refreshToken: tokens.refresh_token,
resource,
addClientAuthentication: provider.addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
});

await provider.saveTokens(newTokens);
Expand Down Expand Up @@ -661,10 +673,11 @@ export function extractResourceMetadataUrl(res: Response): URL | undefined {
*/
export async function discoverOAuthProtectedResourceMetadata(
serverUrl: string | URL,
userAgentProvider: UserAgentProvider,
opts?: { protocolVersion?: string; resourceMetadataUrl?: string | URL },
fetchFn: FetchLike = fetch
): Promise<OAuthProtectedResourceMetadata> {
const response = await discoverMetadataWithFallback(serverUrl, 'oauth-protected-resource', fetchFn, {
const response = await discoverMetadataWithFallback(serverUrl, 'oauth-protected-resource', userAgentProvider, fetchFn, {
protocolVersion: opts?.protocolVersion,
metadataUrl: opts?.resourceMetadataUrl
});
Expand Down Expand Up @@ -720,9 +733,15 @@ function buildWellKnownPath(
/**
* Tries to discover OAuth metadata at a specific URL
*/
async function tryMetadataDiscovery(url: URL, protocolVersion: string, fetchFn: FetchLike = fetch): Promise<Response | undefined> {
async function tryMetadataDiscovery(
url: URL,
protocolVersion: string,
userAgentProvider: UserAgentProvider,
fetchFn: FetchLike = fetch
): Promise<Response | undefined> {
const headers = {
'MCP-Protocol-Version': protocolVersion
'MCP-Protocol-Version': protocolVersion,
'User-Agent': await userAgentProvider()
};
return await fetchWithCorsRetry(url, headers, fetchFn);
}
Expand All @@ -740,6 +759,7 @@ function shouldAttemptFallback(response: Response | undefined, pathname: string)
async function discoverMetadataWithFallback(
serverUrl: string | URL,
wellKnownType: 'oauth-authorization-server' | 'oauth-protected-resource',
userAgentProvider: UserAgentProvider,
fetchFn: FetchLike,
opts?: { protocolVersion?: string; metadataUrl?: string | URL; metadataServerUrl?: string | URL }
): Promise<Response | undefined> {
Expand All @@ -756,12 +776,12 @@ async function discoverMetadataWithFallback(
url.search = issuer.search;
}

let response = await tryMetadataDiscovery(url, protocolVersion, fetchFn);
let response = await tryMetadataDiscovery(url, protocolVersion, userAgentProvider, fetchFn);

// If path-aware discovery fails with 404 and we're not already at root, try fallback to root discovery
if (!opts?.metadataUrl && shouldAttemptFallback(response, issuer.pathname)) {
const rootUrl = new URL(`/.well-known/${wellKnownType}`, issuer);
response = await tryMetadataDiscovery(rootUrl, protocolVersion, fetchFn);
response = await tryMetadataDiscovery(rootUrl, protocolVersion, userAgentProvider, fetchFn);
}

return response;
Expand All @@ -777,6 +797,7 @@ async function discoverMetadataWithFallback(
*/
export async function discoverOAuthMetadata(
issuer: string | URL,
userAgentProvider: UserAgentProvider,
{
authorizationServerUrl,
protocolVersion
Expand All @@ -797,7 +818,7 @@ export async function discoverOAuthMetadata(
}
protocolVersion ??= LATEST_PROTOCOL_VERSION;

const response = await discoverMetadataWithFallback(authorizationServerUrl, 'oauth-authorization-server', fetchFn, {
const response = await discoverMetadataWithFallback(authorizationServerUrl, 'oauth-authorization-server', userAgentProvider, fetchFn, {
protocolVersion,
metadataServerUrl: authorizationServerUrl
});
Expand Down Expand Up @@ -889,6 +910,7 @@ export function buildDiscoveryUrls(authorizationServerUrl: string | URL): { url:
*/
export async function discoverAuthorizationServerMetadata(
authorizationServerUrl: string | URL,
userAgentProvider: UserAgentProvider,
{
fetchFn = fetch,
protocolVersion = LATEST_PROTOCOL_VERSION
Expand All @@ -899,7 +921,8 @@ export async function discoverAuthorizationServerMetadata(
): Promise<AuthorizationServerMetadata | undefined> {
const headers = {
'MCP-Protocol-Version': protocolVersion,
Accept: 'application/json'
Accept: 'application/json',
'User-Agent': await userAgentProvider()
};

// Get the list of URLs to try
Expand Down Expand Up @@ -1047,14 +1070,16 @@ async function executeTokenRequest(
clientInformation,
addClientAuthentication,
resource,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
tokenRequestParams: URLSearchParams;
clientInformation?: OAuthClientInformationMixed;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
resource?: URL;
fetchFn?: FetchLike;
userAgentProvider?: UserAgentProvider;
}
): Promise<OAuthTokens> {
const tokenUrl = metadata?.token_endpoint ? new URL(metadata.token_endpoint) : new URL('/token', authorizationServerUrl);
Expand All @@ -1064,6 +1089,10 @@ async function executeTokenRequest(
Accept: 'application/json'
});

if (userAgentProvider) {
headers.set('User-Agent', await userAgentProvider());
}

if (resource) {
tokenRequestParams.set('resource', resource.href);
}
Expand Down Expand Up @@ -1111,7 +1140,8 @@ export async function exchangeAuthorization(
redirectUri,
resource,
addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformationMixed;
Expand All @@ -1121,6 +1151,7 @@ export async function exchangeAuthorization(
resource?: URL;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
fetchFn?: FetchLike;
userAgentProvider?: UserAgentProvider;
}
): Promise<OAuthTokens> {
const tokenRequestParams = prepareAuthorizationCodeRequest(authorizationCode, codeVerifier, redirectUri);
Expand All @@ -1131,7 +1162,8 @@ export async function exchangeAuthorization(
clientInformation,
addClientAuthentication,
resource,
fetchFn
fetchFn,
userAgentProvider: userAgentProvider ?? createUserAgentProvider()
});
}

Expand All @@ -1155,14 +1187,16 @@ export async function refreshAuthorization(
refreshToken,
resource,
addClientAuthentication,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientInformation: OAuthClientInformationMixed;
refreshToken: string;
resource?: URL;
addClientAuthentication?: OAuthClientProvider['addClientAuthentication'];
fetchFn?: FetchLike;
userAgentProvider?: UserAgentProvider;
}
): Promise<OAuthTokens> {
const tokenRequestParams = new URLSearchParams({
Expand All @@ -1176,7 +1210,8 @@ export async function refreshAuthorization(
clientInformation,
addClientAuthentication,
resource,
fetchFn
fetchFn,
userAgentProvider: userAgentProvider ?? createUserAgentProvider()
});

// Preserve original refresh token if server didn't return a new one
Expand Down Expand Up @@ -1216,13 +1251,15 @@ export async function fetchToken(
metadata,
resource,
authorizationCode,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
resource?: URL;
/** Authorization code for the default authorization_code grant flow */
authorizationCode?: string;
fetchFn?: FetchLike;
userAgentProvider?: UserAgentProvider;
} = {}
): Promise<OAuthTokens> {
const scope = provider.clientMetadata.scope;
Expand Down Expand Up @@ -1253,7 +1290,8 @@ export async function fetchToken(
clientInformation: clientInformation ?? undefined,
addClientAuthentication: provider.addClientAuthentication,
resource,
fetchFn
fetchFn,
userAgentProvider
});
}

Expand All @@ -1265,11 +1303,13 @@ export async function registerClient(
{
metadata,
clientMetadata,
fetchFn
fetchFn,
userAgentProvider
}: {
metadata?: AuthorizationServerMetadata;
clientMetadata: OAuthClientMetadata;
fetchFn?: FetchLike;
userAgentProvider: UserAgentProvider;
}
): Promise<OAuthClientInformationFull> {
let registrationUrl: URL;
Expand All @@ -1287,7 +1327,8 @@ export async function registerClient(
const response = await (fetchFn ?? fetch)(registrationUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
'Content-Type': 'application/json',
'User-Agent': await userAgentProvider()
},
body: JSON.stringify(clientMetadata)
});
Expand Down
10 changes: 7 additions & 3 deletions packages/client/src/client/middleware.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { FetchLike } from '@modelcontextprotocol/core';
import type { FetchLike, UserAgentProvider } from '@modelcontextprotocol/core';
import { createUserAgentProvider } from '@modelcontextprotocol/core';

import type { OAuthClientProvider } from './auth.js';
import { auth, extractWWWAuthenticateParams, UnauthorizedError } from './auth.js';
Expand Down Expand Up @@ -33,11 +34,13 @@ export type Middleware = (next: FetchLike) => FetchLike;
*
* @param provider - OAuth client provider for authentication
* @param baseUrl - Base URL for OAuth server discovery (defaults to request URL domain)
* @param userAgentProvider - User agent provider for the connection.
* @returns A fetch middleware function
*/
export const withOAuth =
(provider: OAuthClientProvider, baseUrl?: string | URL): Middleware =>
(provider: OAuthClientProvider, baseUrl?: string | URL, userAgentProvider?: UserAgentProvider): Middleware =>
next => {
const uaProvider = userAgentProvider ?? createUserAgentProvider();
return async (input, init) => {
const makeRequest = async (): Promise<Response> => {
const headers = new Headers(init?.headers);
Expand Down Expand Up @@ -65,7 +68,8 @@ export const withOAuth =
serverUrl,
resourceMetadataUrl,
scope,
fetchFn: next
fetchFn: next,
userAgentProvider: uaProvider
});

if (result === 'REDIRECT') {
Expand Down
Loading
Loading