1- from http .server import BaseHTTPRequestHandler
2- from http .server import HTTPServer
31import os
4- import socket
5- import threading
62from typing import Any
73from typing import AsyncGenerator
4+ from typing import Awaitable
5+ from typing import Callable
86from typing import Dict
97from typing import Generator
108from typing import List
119from typing import Optional
1210from typing import cast
1311
1412from aiohttp import FormData
13+ from aiohttp import web
1514from aiohttp .multipart import MultipartWriter
1615from aiohttp .test_utils import TestClient
16+ from aiohttp .test_utils import TestServer
1717import pytest
1818import yaml
1919
2020
21- class DummyHandler (BaseHTTPRequestHandler ):
22- def do_POST (self ):
23- content_length = int (self .headers .get ("Content-Length" , 0 ))
24- if content_length > 0 :
25- self .rfile .read (content_length )
21+ async def serve_handler (request : web .Request ) -> web .Response :
22+ response_headers = {}
2623
27- if self .path == "/serve" :
28- self .send_response (200 )
29- self .send_header ("Content-type" , "text/plain" )
30- self .send_header ("Connection" , "close" ) # Ensure connection is closed cleanly
31- pass_through_value = self .headers .get ("Pass-Through-Header-Value" )
32- if pass_through_value :
33- self .send_header ("Pass-Through-Header-Value" , pass_through_value )
34- self .end_headers ()
35- self .wfile .write (b"OK" )
36- else :
37- self .send_response (404 )
38- self .send_header ("Connection" , "close" )
39- self .end_headers ()
24+ pass_through_value = request .headers .get ("Pass-Through-Header-Value" )
25+ if pass_through_value :
26+ response_headers ["Pass-Through-Header-Value" ] = pass_through_value
27+
28+ return web .Response (status = 200 , text = "OK" , headers = response_headers )
4029
4130
4231class CustomFormData (FormData ):
@@ -51,34 +40,20 @@ def get_cassettes_for_provider(provider: str, vcr_cassettes_directory: str) -> L
5140 return [f for f in os .listdir (custom_dir ) if os .path .isfile (os .path .join (custom_dir , f ))]
5241
5342
54- @pytest .fixture (scope = "session" )
55- def dummy_server () -> Generator [HTTPServer , None , None ]:
56- sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
57- sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
58- sock .bind (("" , 0 ))
59- port = sock .getsockname ()[1 ]
60- sock .close ()
61-
62- server = HTTPServer (("127.0.0.1" , port ), DummyHandler )
63- server .socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
64-
65- server_thread = threading .Thread (target = server .serve_forever )
66- server_thread .daemon = True
67- server_thread .start ()
68-
69- yield server
43+ @pytest .fixture
44+ async def dummy_server (aiohttp_server : Callable [[web .Application ], Awaitable [TestServer ]]) -> TestServer :
45+ app = web .Application ()
46+ app .router .add_post ("/serve" , serve_handler )
7047
71- server .shutdown ()
72- server .server_close ()
73- server_thread .join (timeout = 2 )
48+ server = await aiohttp_server (app )
49+ return server
7450
7551
7652@pytest .fixture
77- def vcr_provider_map (dummy_server : HTTPServer ) -> Generator [str , None , None ]:
78- host , port = dummy_server .server_address
79- # Ensure host is a string, not bytes
80- host_str = host .decode () if isinstance (host , bytes ) else host
81- provider_map = f"custom=http://{ host_str } :{ port } "
53+ def vcr_provider_map (dummy_server : TestServer ) -> Generator [str , None , None ]:
54+ host = dummy_server .host
55+ port = dummy_server .port
56+ provider_map = f"custom=http://{ host } :{ port } "
8257 yield provider_map
8358
8459
0 commit comments