Skip to content

Commit 7c52958

Browse files
committed
use aiohttp_server instead
1 parent 4afd7c8 commit 7c52958

File tree

2 files changed

+27
-48
lines changed

2 files changed

+27
-48
lines changed

ddapm_test_agent/vcr_proxy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import hashlib
23
import json
34
import logging
@@ -214,8 +215,11 @@ async def proxy_request(
214215
auth = AWS4Auth(aws_access_key, AWS_SECRET_ACCESS_KEY, AWS_REGION, AWS_SERVICES[provider])
215216
request_kwargs["auth"] = auth
216217

217-
with get_vcr(provider, vcr_cassettes_directory, vcr_ignore_headers).use_cassette(cassette_file_name):
218-
provider_response = requests.request(**request_kwargs)
218+
def _make_request():
219+
with get_vcr(provider, vcr_cassettes_directory, vcr_ignore_headers).use_cassette(cassette_file_name):
220+
return requests.request(**request_kwargs)
221+
222+
provider_response = await asyncio.to_thread(_make_request)
219223

220224
# Extract content type without charset
221225
content_type = provider_response.headers.get("content-type", "")

tests/test_vcr_proxy.py

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,31 @@
1-
from http.server import BaseHTTPRequestHandler
2-
from http.server import HTTPServer
31
import os
4-
import socket
5-
import threading
62
from typing import Any
73
from typing import AsyncGenerator
4+
from typing import Awaitable
5+
from typing import Callable
86
from typing import Dict
97
from typing import Generator
108
from typing import List
119
from typing import Optional
1210
from typing import cast
1311

1412
from aiohttp import FormData
13+
from aiohttp import web
1514
from aiohttp.multipart import MultipartWriter
1615
from aiohttp.test_utils import TestClient
16+
from aiohttp.test_utils import TestServer
1717
import pytest
1818
import yaml
1919

2020

21-
class DummyHandler(BaseHTTPRequestHandler):
22-
def do_POST(self):
23-
content_length = int(self.headers.get("Content-Length", 0))
24-
if content_length > 0:
25-
self.rfile.read(content_length)
21+
async def serve_handler(request: web.Request) -> web.Response:
22+
response_headers = {}
2623

27-
if self.path == "/serve":
28-
self.send_response(200)
29-
self.send_header("Content-type", "text/plain")
30-
self.send_header("Connection", "close") # Ensure connection is closed cleanly
31-
pass_through_value = self.headers.get("Pass-Through-Header-Value")
32-
if pass_through_value:
33-
self.send_header("Pass-Through-Header-Value", pass_through_value)
34-
self.end_headers()
35-
self.wfile.write(b"OK")
36-
else:
37-
self.send_response(404)
38-
self.send_header("Connection", "close")
39-
self.end_headers()
24+
pass_through_value = request.headers.get("Pass-Through-Header-Value")
25+
if pass_through_value:
26+
response_headers["Pass-Through-Header-Value"] = pass_through_value
27+
28+
return web.Response(status=200, text="OK", headers=response_headers)
4029

4130

4231
class CustomFormData(FormData):
@@ -51,34 +40,20 @@ def get_cassettes_for_provider(provider: str, vcr_cassettes_directory: str) -> L
5140
return [f for f in os.listdir(custom_dir) if os.path.isfile(os.path.join(custom_dir, f))]
5241

5342

54-
@pytest.fixture(scope="session")
55-
def dummy_server() -> Generator[HTTPServer, None, None]:
56-
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
57-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
58-
sock.bind(("", 0))
59-
port = sock.getsockname()[1]
60-
sock.close()
61-
62-
server = HTTPServer(("127.0.0.1", port), DummyHandler)
63-
server.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
64-
65-
server_thread = threading.Thread(target=server.serve_forever)
66-
server_thread.daemon = True
67-
server_thread.start()
68-
69-
yield server
43+
@pytest.fixture
44+
async def dummy_server(aiohttp_server: Callable[[web.Application], Awaitable[TestServer]]) -> TestServer:
45+
app = web.Application()
46+
app.router.add_post("/serve", serve_handler)
7047

71-
server.shutdown()
72-
server.server_close()
73-
server_thread.join(timeout=2)
48+
server = await aiohttp_server(app)
49+
return server
7450

7551

7652
@pytest.fixture
77-
def vcr_provider_map(dummy_server: HTTPServer) -> Generator[str, None, None]:
78-
host, port = dummy_server.server_address
79-
# Ensure host is a string, not bytes
80-
host_str = host.decode() if isinstance(host, bytes) else host
81-
provider_map = f"custom=http://{host_str}:{port}"
53+
def vcr_provider_map(dummy_server: TestServer) -> Generator[str, None, None]:
54+
host = dummy_server.host
55+
port = dummy_server.port
56+
provider_map = f"custom=http://{host}:{port}"
8257
yield provider_map
8358

8459

0 commit comments

Comments
 (0)