Skip to content

Commit e617b1f

Browse files
committed
fix: preserve oauth authorization endpoint query
1 parent 19fe9fa commit e617b1f

2 files changed

Lines changed: 48 additions & 2 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import AsyncGenerator, Awaitable, Callable
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import quote, urljoin, urlparse
1616

1717
import anyio
1818
import httpx
@@ -353,7 +353,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
353353
if "offline_access" in self.context.client_metadata.scope.split():
354354
auth_params["prompt"] = "consent"
355355

356-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
356+
authorization_url = str(httpx.URL(auth_endpoint).copy_merge_params(auth_params))
357357
await self.context.redirect_handler(authorization_url)
358358

359359
# Wait for callback

tests/client/test_auth.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,52 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
11671167
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
11681168
assert oauth_provider.context.token_expiry_time is not None
11691169

1170+
@pytest.mark.anyio
1171+
async def test_authorization_endpoint_preserves_existing_query_params(
1172+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1173+
):
1174+
captured_auth_url: str | None = None
1175+
captured_state: str | None = None
1176+
1177+
async def redirect_handler(url: str) -> None:
1178+
nonlocal captured_auth_url, captured_state
1179+
captured_auth_url = url
1180+
captured_state = parse_qs(urlparse(url).query)["state"][0]
1181+
1182+
async def callback_handler() -> tuple[str, str | None]:
1183+
return "test_auth_code", captured_state
1184+
1185+
provider = OAuthClientProvider(
1186+
server_url="https://api.example.com/v1/mcp",
1187+
client_metadata=client_metadata,
1188+
storage=mock_storage,
1189+
redirect_handler=redirect_handler,
1190+
callback_handler=callback_handler,
1191+
)
1192+
provider.context.oauth_metadata = OAuthMetadata(
1193+
issuer=AnyHttpUrl("https://auth.example.com"),
1194+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize?prompt=select_account"),
1195+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
1196+
)
1197+
provider.context.client_info = OAuthClientInformationFull(
1198+
client_id="test_client",
1199+
client_secret="test_secret",
1200+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
1201+
)
1202+
1203+
auth_code, code_verifier = await provider._perform_authorization_code_grant()
1204+
1205+
assert auth_code == "test_auth_code"
1206+
assert code_verifier
1207+
assert captured_auth_url is not None
1208+
parsed = urlparse(captured_auth_url)
1209+
params = parse_qs(parsed.query)
1210+
assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == "https://auth.example.com/authorize"
1211+
assert params["prompt"] == ["select_account"]
1212+
assert params["response_type"] == ["code"]
1213+
assert params["client_id"] == ["test_client"]
1214+
assert params["redirect_uri"] == ["http://localhost:3030/callback"]
1215+
11701216
@pytest.mark.anyio
11711217
async def test_auth_flow_no_unnecessary_retry_after_oauth(
11721218
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken

0 commit comments

Comments
 (0)