|
5 | 5 | import threading |
6 | 6 | from typing import Any |
7 | 7 | from typing import AsyncGenerator |
| 8 | +from typing import Dict |
8 | 9 | from typing import Generator |
9 | 10 | from typing import List |
10 | 11 | from typing import Optional |
| 12 | +from typing import cast |
11 | 13 |
|
12 | 14 | from aiohttp import FormData |
13 | 15 | from aiohttp.multipart import MultipartWriter |
14 | 16 | from aiohttp.test_utils import TestClient |
15 | 17 | import pytest |
| 18 | +import yaml |
16 | 19 |
|
17 | 20 |
|
18 | 21 | class DummyHandler(BaseHTTPRequestHandler): |
@@ -79,13 +82,24 @@ def vcr_provider_map(dummy_server: HTTPServer) -> Generator[str, None, None]: |
79 | 82 | yield provider_map |
80 | 83 |
|
81 | 84 |
|
| 85 | +@pytest.fixture |
| 86 | +def vcr_ignore_headers() -> Generator[str, None, None]: |
| 87 | + yield "foo-bar,user-super-secret-api-key" |
| 88 | + |
| 89 | + |
82 | 90 | @pytest.fixture |
83 | 91 | async def vcr_test_name(agent: TestClient[Any, Any]) -> AsyncGenerator[None, None]: |
84 | 92 | await agent.post("/vcr/test/start", json={"test_name": "test_name_prefix"}) |
85 | 93 | yield |
86 | 94 | await agent.post("/vcr/test/stop") |
87 | 95 |
|
88 | 96 |
|
| 97 | +def get_recorded_request_from_yaml(file_path: str) -> Dict[str, Any]: |
| 98 | + with open(file_path, "r") as file: |
| 99 | + content = yaml.load(file, Loader=yaml.UnsafeLoader) |
| 100 | + return cast(Dict[str, Any], content["interactions"][0]) |
| 101 | + |
| 102 | + |
89 | 103 | async def test_vcr_proxy_make_cassette(agent: TestClient[Any, Any], vcr_cassettes_directory: str) -> None: |
90 | 104 | resp = await agent.post("/vcr/custom/serve", json={"foo": "bar"}) |
91 | 105 |
|
@@ -177,3 +191,32 @@ async def test_vcr_proxy_with_multipart_form_data(agent: TestClient[Any, Any], v |
177 | 191 |
|
178 | 192 | cassette_files = get_cassettes_for_provider("custom", vcr_cassettes_directory) |
179 | 193 | assert len(cassette_files) == 1 |
| 194 | + |
| 195 | + |
| 196 | +async def test_vcr_proxy_does_not_record_ignored_headers( |
| 197 | + agent: TestClient[Any, Any], vcr_cassettes_directory: str |
| 198 | +) -> None: |
| 199 | + resp = await agent.post( |
| 200 | + "/vcr/custom/serve", |
| 201 | + json={"foo": "bar"}, |
| 202 | + headers={ |
| 203 | + "User-Super-Secret-Api-Key": "secret", |
| 204 | + "Foo-Bar": "foo", |
| 205 | + "Authorization": "test", |
| 206 | + "Please-Record-Header": "test", |
| 207 | + }, |
| 208 | + ) |
| 209 | + |
| 210 | + assert resp.status == 200 |
| 211 | + assert await resp.text() == "OK" |
| 212 | + |
| 213 | + cassette_files = get_cassettes_for_provider("custom", vcr_cassettes_directory) |
| 214 | + assert len(cassette_files) == 1 |
| 215 | + |
| 216 | + cassette_file = cassette_files[0] |
| 217 | + recorded_request = get_recorded_request_from_yaml(os.path.join(vcr_cassettes_directory, "custom", cassette_file)) |
| 218 | + |
| 219 | + assert recorded_request["request"]["headers"]["Please-Record-Header"] == ["test"] |
| 220 | + assert "User-Super-Secret-Api-Key" not in recorded_request["request"]["headers"] |
| 221 | + assert "Foo-Bar" not in recorded_request["request"]["headers"] |
| 222 | + assert "Authorization" not in recorded_request["request"]["headers"] |
0 commit comments