@@ -1167,6 +1167,52 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
11671167 assert oauth_provider .context .current_tokens .access_token == "new_access_token"
11681168 assert oauth_provider .context .token_expiry_time is not None
11691169
1170+ @pytest .mark .anyio
1171+ async def test_authorization_endpoint_preserves_existing_query_params (
1172+ self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
1173+ ):
1174+ captured_auth_url : str | None = None
1175+ captured_state : str | None = None
1176+
1177+ async def redirect_handler (url : str ) -> None :
1178+ nonlocal captured_auth_url , captured_state
1179+ captured_auth_url = url
1180+ captured_state = parse_qs (urlparse (url ).query )["state" ][0 ]
1181+
1182+ async def callback_handler () -> tuple [str , str | None ]:
1183+ return "test_auth_code" , captured_state
1184+
1185+ provider = OAuthClientProvider (
1186+ server_url = "https://api.example.com/v1/mcp" ,
1187+ client_metadata = client_metadata ,
1188+ storage = mock_storage ,
1189+ redirect_handler = redirect_handler ,
1190+ callback_handler = callback_handler ,
1191+ )
1192+ provider .context .oauth_metadata = OAuthMetadata (
1193+ issuer = AnyHttpUrl ("https://auth.example.com" ),
1194+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize?prompt=select_account" ),
1195+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
1196+ )
1197+ provider .context .client_info = OAuthClientInformationFull (
1198+ client_id = "test_client" ,
1199+ client_secret = "test_secret" ,
1200+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
1201+ )
1202+
1203+ auth_code , code_verifier = await provider ._perform_authorization_code_grant ()
1204+
1205+ assert auth_code == "test_auth_code"
1206+ assert code_verifier
1207+ assert captured_auth_url is not None
1208+ parsed = urlparse (captured_auth_url )
1209+ params = parse_qs (parsed .query )
1210+ assert f"{ parsed .scheme } ://{ parsed .netloc } { parsed .path } " == "https://auth.example.com/authorize"
1211+ assert params ["prompt" ] == ["select_account" ]
1212+ assert params ["response_type" ] == ["code" ]
1213+ assert params ["client_id" ] == ["test_client" ]
1214+ assert params ["redirect_uri" ] == ["http://localhost:3030/callback" ]
1215+
11701216 @pytest .mark .anyio
11711217 async def test_auth_flow_no_unnecessary_retry_after_oauth (
11721218 self , oauth_provider : OAuthClientProvider , mock_storage : MockTokenStorage , valid_tokens : OAuthToken
0 commit comments