diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8c1d9d6a..7afc1870 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,29 +16,60 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" PYTHONIOENCODING: "utf8" - PYTHONDEVMODE: "1" - HATCH_VERBOSE: "1" jobs: - run: + changes: + name: Check for changed files + runs-on: ubuntu-latest + outputs: + source: ${{ steps.filter.outputs.source }} + tests: ${{ steps.filter.outputs.tests }} + steps: + - uses: actions/checkout@v2 + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + source: + - 'src/**' + tests: + - 'tests/**' + + run-tests: + needs: changes + if: ${{ needs.changes.outputs.source == 'true' || needs.changes.outputs.tests == 'true' }} name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} runs-on: ${{ matrix.os }} strategy: - fail-fast: false + fail-fast: true matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install Hatch - run: pip install hatch + uses: pypa/hatch@install - name: Run tests run: hatch test + + tests-pass: + runs-on: ubuntu-latest + name: All tests passed + if: always() + + needs: + - run-tests + + steps: + - name: Check whether all tests passed + uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} + allowed-skips: ${{ toJSON(needs) }} diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml new file mode 100644 index 00000000..1fdc7b63 --- /dev/null +++ b/.github/workflows/triage.yml @@ -0,0 +1,67 @@ +name: Triage +on: + pull_request: + types: + - "opened" + - "reopened" + - "synchronize" + - "labeled" + - "unlabeled" + +jobs: + changelog_check: + runs-on: ubuntu-latest + name: Check for changelog updates + steps: + - name: "Check if the source directory was changed" + uses: dorny/paths-filter@v3 + id: changes + with: + filters: | + src: + - 'src/**' + + - name: "Check for changelog updates" + if: steps.changes.outputs.src == 'true' + uses: brettcannon/check-for-changed-files@v1 + with: + file-pattern: | + CHANGELOG.md + skip-label: "skip changelog" + failure-message: "Missing a CHANGELOG.md update; please add one or apply the ${skip-label} label to the pull request" + + tests_check: + runs-on: ubuntu-latest + name: Check for updated tests + steps: + - name: "Check if the source directory was changed" + uses: dorny/paths-filter@v3 + id: changes + with: + filters: | + src: + - 'src/**' + + - name: "Check for test updates" + if: steps.changes.outputs.src == 'true' + uses: brettcannon/check-for-changed-files@v1 + with: + file-pattern: | + tests/* + skip-label: "skip tests" + failure-message: "Missing unit tests; please add some or apply the ${skip-label} label to the pull request" + + all_green: + runs-on: ubuntu-latest + name: PR has no missing information + if: always() + + needs: + - changelog_check + - tests_check + + steps: + - name: Check whether jobs passed + uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} diff --git a/pyproject.toml b/pyproject.toml index bf1c30f3..ba895565 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = ["multidict~=6.5", "loguru~=0.7", "aiofiles~=24.1", "typing_extensions>=4"] +dependencies = ["loguru~=0.7", "aiofiles~=24.1", "typing_extensions>=4"] dynamic = ["version", "license"] [project.optional-dependencies] diff --git a/src/view/cache.py b/src/view/cache.py index f3a3ac2d..d530ad21 100644 --- a/src/view/cache.py +++ b/src/view/cache.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from multidict import CIMultiDict + from view.core.headers import HTTPHeaders from view.core.response import ( Response, @@ -47,7 +47,7 @@ async def __call__( @dataclass(slots=True, frozen=True) class _CachedResponse: body: bytes - headers: CIMultiDict[str] + headers: HTTPHeaders status: int last_reset: float diff --git a/src/view/core/app.py b/src/view/core/app.py index efe0653a..dcbd5ecd 100644 --- a/src/view/core/app.py +++ b/src/view/core/app.py @@ -33,7 +33,7 @@ from view.run.asgi import ASGIProtocol from view.run.wsgi import WSGIProtocol -__all__ = "BaseApp", "as_app", "App" +__all__ = "App", "BaseApp", "as_app" T = TypeVar("T") P = ParamSpec("P") @@ -143,7 +143,7 @@ def run( settings.run_app_on_any_server() except KeyboardInterrupt: logger.info("CTRL^C received, shutting down") - except Exception: # noqa + except Exception: # noqa: BLE001 logger.exception("Error in server lifecycle") finally: logger.info("Server finished") diff --git a/src/view/core/headers.py b/src/view/core/headers.py index a144dbd6..35e3cdc4 100644 --- a/src/view/core/headers.py +++ b/src/view/core/headers.py @@ -3,88 +3,157 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, TypeAlias -from multidict import CIMultiDict +from typing_extensions import Self +from view.core.multi_map import MultiMap from view.exceptions import InvalidTypeError if TYPE_CHECKING: from view.run.asgi import ASGIHeaders + from view.run.wsgi import WSGIHeaders __all__ = ( - "RequestHeaders", + "HTTPHeaders", "HeadersLike", - "as_multidict", - "asgi_as_multidict", - "multidict_as_asgi", - "wsgi_as_multidict", + "as_real_headers", + "asgi_to_headers", + "headers_to_asgi", + "wsgi_to_headers", ) -RequestHeaders: TypeAlias = CIMultiDict[str] + +class LowerStr(str): + """ + A string that always acts in lowercase. This is useful for case-insensitive + comparisons. + """ + + __slots__ = () + + def __new__(cls, data: object) -> Self: + return super().__new__(cls, cls._to_lower(data)) + + @staticmethod + def _to_lower(data: object) -> object: + if isinstance(data, str): + data = data.lower() + + return data + + def __contains__(self, key: str, /) -> bool: + return super().__contains__(key.lower()) + + def __eq__(self, string: object) -> bool: + return super().__eq__(self._to_lower(string)) + + def __ne__(self, value: object, /) -> bool: + return super().__ne__(self._to_lower(value)) + + def __hash__(self) -> int: + return hash(str(self)) + + +class HTTPHeaders(MultiMap[str, str]): + """ + Case-insensitive multi-map of HTTP headers. + """ + + def __getitem__(self, key: str, /) -> str: + return super().__getitem__(LowerStr(key)) + + def __contains__(self, key: object, /) -> bool: + return super().__contains__(LowerStr(key)) + + def __repr__(self) -> str: + return f"HTTPHeaders({self.as_sequence()})" + + def get_exactly_one(self, key: str) -> str: + return super().get_exactly_one(LowerStr(key)) + + def with_new_value(self, key: str, value: str) -> HTTPHeaders: + new_sequence = [*list(self.as_sequence()), (LowerStr(key), value)] + return type(self)(new_sequence) + + HeadersLike: TypeAlias = ( - RequestHeaders | Mapping[str, str] | Mapping[bytes, bytes] + HTTPHeaders | Mapping[str, str] | Mapping[bytes, bytes] ) -def as_multidict(headers: HeadersLike | None, /) -> RequestHeaders: +def as_real_headers(headers: HeadersLike | None, /) -> HTTPHeaders: """ Convenience function for casting a "header-like object" (or `None`) - to a `CIMultiDict`. + to a `MultiMap`. """ if headers is None: - return CIMultiDict[str]() + return HTTPHeaders() - if isinstance(headers, CIMultiDict): + if isinstance(headers, HTTPHeaders): return headers if __debug__ and not isinstance(headers, Mapping): raise InvalidTypeError(Mapping, headers) assert isinstance(headers, dict) - multidict = CIMultiDict[str]() + all_values: list[tuple[LowerStr, str]] = [] + for key, value in headers.items(): if isinstance(key, bytes): - key = key.decode("utf-8") # noqa + key = key.decode("utf-8") # noqa: PLW2901 if isinstance(value, bytes): - value = value.decode("utf-8") # noqa + value = value.decode("utf-8") # noqa: PLW2901 - multidict[key] = value + all_values.append((LowerStr(key), value)) - return multidict + return HTTPHeaders(all_values) -def wsgi_as_multidict(environ: Mapping[str, Any]) -> RequestHeaders: +def wsgi_to_headers(environ: Mapping[str, Any]) -> HTTPHeaders: """ - Convert WSGI headers (from the `environ`) to a case-insensitive multidict. + Convert WSGI headers (from the `environ`) to a case-insensitive multi-map. """ - headers = CIMultiDict[str]() + values: list[tuple[LowerStr, str]] = [] for key, value in environ.items(): if not key.startswith("HTTP_"): continue assert isinstance(value, str) - key = key.removeprefix("HTTP_").replace("_", "-").lower() # noqa - headers[key] = value + key = key.removeprefix("HTTP_").replace("_", "-").lower() # noqa: PLW2901 + values.append((LowerStr(key), value)) + + return HTTPHeaders(values) + + +def headers_to_wsgi(headers: HTTPHeaders) -> WSGIHeaders: + """ + Convert a case-insensitive multi-map to a WSGI header iterable. + """ + + wsgi_headers: WSGIHeaders = [] + for key, value in headers.items(): + wsgi_headers.append((str(key), value)) - return headers + return wsgi_headers -def asgi_as_multidict(headers: ASGIHeaders, /) -> RequestHeaders: +def asgi_to_headers(headers: ASGIHeaders, /) -> HTTPHeaders: """ - Convert ASGI headers to a case-insensitive multidict. + Convert ASGI headers to a case-insensitive multi-map. """ - multidict = CIMultiDict[str]() + values: list[tuple[LowerStr, str]] = [] for key, value in headers: - multidict[key.decode("utf-8")] = value.decode("utf-8") + lower_str = LowerStr(key.decode("utf-8")) + values.append((lower_str, value.decode("utf-8"))) - return multidict + return MultiMap(values) -def multidict_as_asgi(headers: RequestHeaders, /) -> ASGIHeaders: +def headers_to_asgi(headers: HTTPHeaders, /) -> ASGIHeaders: """ - Convert a case-insensitive multidict to an ASGI header iterable. + Convert a case-insensitive multi-map to an ASGI header iterable. """ asgi_headers: ASGIHeaders = [] diff --git a/src/view/core/multi_map.py b/src/view/core/multi_map.py new file mode 100644 index 00000000..6fbee710 --- /dev/null +++ b/src/view/core/multi_map.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from collections.abc import ( + ItemsView, + Iterable, + Iterator, + KeysView, + Mapping, + Sequence, + ValuesView, +) +from typing import Any, TypeVar + +from view.exceptions import ViewError + +__all__ = "HasMultipleValuesError", "MultiMap" + +KeyT = TypeVar("KeyT") +ValueT = TypeVar("ValueT") +T = TypeVar("T") + + +class HasMultipleValuesError(ViewError): + """ + Multiple values were found when they were explicitly disallowed. + """ + + def __init__(self, key: Any) -> None: + super().__init__(f"{key!r} has multiple values") + + +class MultiMap(Mapping[KeyT, ValueT]): + """ + Mapping of individual keys to one or many values. + """ + + __slots__ = ("_values",) + + def __init__(self, items: Iterable[tuple[KeyT, ValueT]] = ()) -> None: + self._values: dict[KeyT, list[ValueT]] = {} + + for key, value in items: + values = self._values.setdefault(key, []) + values.append(value) + + def __getitem__(self, key: KeyT, /) -> ValueT: + """ + Get the first value if it exists, or else raise a `KeyError`. + """ + + return self._values[key][0] + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self) -> Iterator[KeyT]: + return iter(self._values) + + def __contains__(self, key: object, /) -> bool: + return key in self._values + + def __eq__(self, other: object, /) -> bool: + if isinstance(other, MultiMap): + return other._values == self._values + + if isinstance(other, dict): + return self._as_flat() == other + + return NotImplemented + + def __ne__(self, other: object, /) -> bool: + if isinstance(other, MultiMap): + return other._values != self._values + + return NotImplemented + + def __repr__(self) -> str: + return f"MultiMap({self.as_sequence()})" + + def __hash__(self) -> int: + return hash(self._values) + + def _as_flat(self) -> dict[KeyT, ValueT]: + """ + Turn this into a "flat" representation of the mapping in which all + keys have exactly one value. + """ + return {key: value[0] for key, value in self._values.items()} + + def keys(self) -> KeysView[KeyT]: + """ + Return a view of all the keys in this map. + """ + return self._values.keys() + + def values(self) -> ValuesView[ValueT]: + """ + Return a view of the first value for each key in the mapping. + """ + return self._as_flat().values() + + def many_values(self) -> ValuesView[Sequence[ValueT]]: + """ + Return a view of all values in the mapping. + """ + return self._values.values() + + def items(self) -> ItemsView[KeyT, ValueT]: + """ + Return a view of all items in the mapping, using the first value + for each key. + """ + return self._as_flat().items() + + def many_items(self) -> ItemsView[KeyT, Sequence[ValueT]]: + """ + Return a view of all items in the mapping. + """ + return self._values.items() + + def get_many(self, key: KeyT) -> Sequence[ValueT]: + """ + Get one or many values for a given key. + """ + return self._values[key] + + def get_exactly_one(self, key: KeyT) -> ValueT: + """ + Get precisely one value for a key. If more than one value is present, + then this raises a `HasMultipleValuesError`. + """ + value = self._values[key] + if len(value) != 1: + raise HasMultipleValuesError(key) + + return value[0] + + def as_sequence(self) -> Sequence[tuple[KeyT, ValueT]]: + """ + Return all the keys and values in a sequence of (key, value) tuples. + """ + result: list[tuple[KeyT, ValueT]] = [] + for key, values in self._values.items(): + for value in values: + result.append((key, value)) # noqa: PERF401 + + return result + + def with_new_value( + self, key: KeyT, value: ValueT + ) -> MultiMap[KeyT, ValueT]: + """ + Create a copy of this map with a new key and value included. + """ + new_sequence = [*list(self.as_sequence()), (key, value)] + return type(self)(new_sequence) diff --git a/src/view/core/request.py b/src/view/core/request.py index dfd35f08..e9b3dc2e 100644 --- a/src/view/core/request.py +++ b/src/view/core/request.py @@ -6,16 +6,15 @@ from enum import auto from typing import TYPE_CHECKING, Any -from multidict import MultiDict - from view.core.body import BodyMixin +from view.core.multi_map import MultiMap from view.core.router import normalize_route if TYPE_CHECKING: from collections.abc import Mapping from view.core.app import BaseApp - from view.core.headers import RequestHeaders + from view.core.headers import HTTPHeaders __all__ = "Method", "Request" @@ -119,13 +118,13 @@ class Request(BodyMixin): The HTTP method of the request. See `Method`. """ - headers: RequestHeaders + headers: HTTPHeaders """ A "multi-dictionary" containing the request headers. This is `dict`-like, but if a header has multiple values, it is represented by a list. """ - query_parameters: MultiDict[str] + query_parameters: MultiMap[str, str] """ The query string parameters of the HTTP request. """ @@ -141,18 +140,12 @@ def __post_init__(self) -> None: self.path = normalize_route(self.path) -def extract_query_parameters(query_string: str | bytes) -> MultiDict[str]: +def extract_query_parameters(query_string: str | bytes) -> MultiMap[str, str]: """ - Extract a query string from a URL and return it as a multidict. + Extract a query string from a URL and return it as a multi-map. """ if isinstance(query_string, bytes): query_string = query_string.decode("utf-8") assert isinstance(query_string, str), query_string - parsed = urllib.parse.parse_qsl(query_string) - result = MultiDict() - - for key, value in parsed: - result[key] = value - - return result + return MultiMap(urllib.parse.parse_qsl(query_string)) diff --git a/src/view/core/response.py b/src/view/core/response.py index 4059d66e..b2e3e4f4 100644 --- a/src/view/core/response.py +++ b/src/view/core/response.py @@ -11,13 +11,17 @@ import aiofiles from loguru import logger -from multidict import CIMultiDict from view.core.body import BodyMixin -from view.core.headers import HeadersLike, RequestHeaders, as_multidict +from view.core.headers import ( + HeadersLike, + HTTPHeaders, + LowerStr, + as_real_headers, +) from view.exceptions import InvalidTypeError, ViewError -__all__ = "Response", "ViewResult", "ResponseLike" +__all__ = "Response", "ResponseLike", "ViewResult" @dataclass(slots=True) @@ -27,7 +31,7 @@ class Response(BodyMixin): """ status_code: int - headers: CIMultiDict[str] + headers: HTTPHeaders def __post_init__(self) -> None: if __debug__: @@ -39,7 +43,7 @@ def __post_init__(self) -> None: f"{self.status_code!r} is not a valid HTTP status code" ) - async def as_tuple(self) -> tuple[bytes, int, RequestHeaders]: + async def as_tuple(self) -> tuple[bytes, int, HTTPHeaders]: """ Process the response as a tuple. This is mainly useful for assertions in testing. @@ -103,12 +107,14 @@ async def stream(): length = len(data) yield data - multidict = as_multidict(headers) - if "content-type" not in multidict: + multi_map = as_real_headers(headers) + if "content-type" not in multi_map: content_type = content_type or _guess_file_type(path) - multidict["content-type"] = content_type + multi_map = multi_map.with_new_value( + LowerStr("content-type"), content_type + ) - return cls(stream, status_code, multidict, path) + return cls(stream, status_code, multi_map, path) def _as_bytes(data: str | bytes) -> bytes: @@ -148,7 +154,7 @@ def from_content( async def stream() -> AsyncGenerator[bytes]: yield _as_bytes(content) - return cls(stream, status_code, as_multidict(headers), content) + return cls(stream, status_code, as_real_headers(headers), content) @dataclass(slots=True) @@ -173,7 +179,7 @@ async def stream() -> AsyncGenerator[bytes]: return cls( content=content, parsed_data=data, - headers=as_multidict(headers), + headers=as_real_headers(headers), status_code=status_code, receive_data=stream, ) @@ -211,10 +217,10 @@ def _wrap_response_tuple(response: _ResponseTuple) -> Response: # Ruff wants me to use a constant here, but I think this is clear enough # for lengths. - if len(response) > 2: # noqa + if len(response) > 2: # noqa: PLR2004 headers = response[2] - if __debug__ and len(response) > 3: # noqa + if __debug__ and len(response) > 3: # noqa: PLR2004 raise InvalidResponseError( f"Got excess data in response tuple {response[3:]!r}" ) @@ -244,7 +250,7 @@ async def stream() -> AsyncGenerator[bytes]: async for data in response: yield _as_bytes(data) - return Response(stream, status_code=200, headers=CIMultiDict()) + return Response(stream, status_code=200, headers=HTTPHeaders()) if isinstance(response, Generator): @@ -252,7 +258,7 @@ async def stream() -> AsyncGenerator[bytes]: for data in response: yield _as_bytes(data) - return Response(stream, status_code=200, headers=CIMultiDict()) + return Response(stream, status_code=200, headers=HTTPHeaders()) raise TypeError(f"Invalid response: {response!r}") diff --git a/src/view/dom/core.py b/src/view/dom/core.py index c85e7ee8..3b97eaf5 100644 --- a/src/view/dom/core.py +++ b/src/view/dom/core.py @@ -15,7 +15,7 @@ from queue import LifoQueue from typing import TYPE_CHECKING, ClassVar, ParamSpec, TypeAlias -from view.core.headers import as_multidict +from view.core.headers import as_real_headers from view.core.response import Response from view.exceptions import InvalidTypeError from view.javascript import SupportsJavaScript @@ -216,7 +216,7 @@ async def stream() -> AsyncIterator[bytes]: return Response( stream, status_code or 200, - as_multidict({"content-type": "text/html"}), + as_real_headers({"content-type": "text/html"}), ) return wrapper diff --git a/src/view/run/asgi.py b/src/view/run/asgi.py index 530adf72..734c6fd3 100644 --- a/src/view/run/asgi.py +++ b/src/view/run/asgi.py @@ -5,7 +5,7 @@ from typing_extensions import NotRequired -from view.core.headers import asgi_as_multidict, multidict_as_asgi +from view.core.headers import asgi_to_headers, headers_to_asgi from view.core.request import Method, Request, extract_query_parameters if TYPE_CHECKING: @@ -78,7 +78,7 @@ async def asgi( ) -> None: assert scope["type"] == "http" method = Method(scope["method"]) - headers = asgi_as_multidict(scope["headers"]) + headers = asgi_to_headers(scope["headers"]) async def receive_data() -> AsyncIterator[bytes]: more_body = True @@ -98,7 +98,7 @@ async def receive_data() -> AsyncIterator[bytes]: { "type": "http.response.start", "status": response.status_code, - "headers": multidict_as_asgi(response.headers), + "headers": headers_to_asgi(response.headers), } ) async for data in response.stream_body(): diff --git a/src/view/run/servers.py b/src/view/run/servers.py index 9dc4d97c..bba566d1 100644 --- a/src/view/run/servers.py +++ b/src/view/run/servers.py @@ -157,6 +157,6 @@ def run_app_on_any_server(self) -> None: ) from error # I'm not sure what Ruff is complaining about here - for start_server in servers.values(): # noqa: RET503 + for start_server in servers.values(): with suppress(ImportError): - return start_server() # noqa: RET503 + return start_server() diff --git a/src/view/run/wsgi.py b/src/view/run/wsgi.py index 7e3e8dcf..c3e9a1c1 100644 --- a/src/view/run/wsgi.py +++ b/src/view/run/wsgi.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable from typing import IO, TYPE_CHECKING, Any, TypeAlias -from view.core.headers import wsgi_as_multidict +from view.core.headers import headers_to_wsgi, wsgi_to_headers from view.core.request import Method, Request, extract_query_parameters from view.core.status_codes import STATUS_STRINGS @@ -13,7 +13,7 @@ __all__ = ("wsgi_for_app",) -WSGIHeaders: TypeAlias = list[tuple[str, str]] +WSGIHeaders: TypeAlias = Iterable[tuple[str, str]] # We can't use a TypedDict for the environment because it has arbitrary keys # for the headers. WSGIEnvironment: TypeAlias = dict[str, Any] @@ -52,15 +52,12 @@ async def stream(): path = environ["PATH_INFO"] assert isinstance(path, str) - headers = wsgi_as_multidict(environ) + headers = wsgi_to_headers(environ) parameters = extract_query_parameters(environ["QUERY_STRING"]) request = Request(stream, app, path, method, headers, parameters) response = loop.run_until_complete(app.process_request(request)) - wsgi_headers: WSGIHeaders = [] - for key, value in response.headers.items(): - # Multidict has a weird string subclass as the key for some reason - wsgi_headers.append((str(key), value)) + wsgi_headers: WSGIHeaders = headers_to_wsgi(response.headers) # WSGI is such a weird spec status_str = ( diff --git a/src/view/testing.py b/src/view/testing.py index 5d8e47fb..4d986e19 100644 --- a/src/view/testing.py +++ b/src/view/testing.py @@ -2,16 +2,15 @@ from typing import TYPE_CHECKING -from view.core.headers import HeadersLike, as_multidict +from view.core.headers import HeadersLike, as_real_headers from view.core.request import Method, Request, extract_query_parameters from view.core.status_codes import STATUS_STRINGS if TYPE_CHECKING: from collections.abc import AsyncGenerator, Awaitable - from multidict import CIMultiDict - from view.core.app import BaseApp + from view.core.headers import HTTPHeaders from view.core.response import Response __all__ = ("AppTestClient",) @@ -19,7 +18,7 @@ async def into_tuple( response_coro: Awaitable[Response], / -) -> tuple[bytes, int, CIMultiDict]: +) -> tuple[bytes, int, HTTPHeaders]: """ Convenience function for transferring a test client call into a tuple through a single ``await``. @@ -76,7 +75,7 @@ async def stream() -> AsyncGenerator[bytes]: app=self.app, path=path, method=method, - headers=as_multidict(headers), + headers=as_real_headers(headers), query_parameters=extract_query_parameters(query_string), ) return await self.app.process_request(request_data) diff --git a/tests/test_misc.py b/tests/test_misc.py index bf1ca7d2..b7093f5f 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1,6 +1,7 @@ import pytest from view.core.app import App, as_app from view.exceptions import InvalidTypeError +from view.core.multi_map import HasMultipleValuesError, MultiMap def test_as_app_invalid(): @@ -16,3 +17,118 @@ def test_invalid_type_route(): with pytest.raises(InvalidTypeError): app.get("/")(object()) # type: ignore + + +def test_empty_multi_map(): + multi_map = MultiMap() + assert multi_map == {} + + with pytest.raises(KeyError): + multi_map["a"] + + with pytest.raises(KeyError): + multi_map[object()] + + with pytest.raises(KeyError): + multi_map[None] + + assert len(multi_map) == 0 + assert multi_map.as_sequence() == [] + + called = False + for _ in multi_map.keys(): + called = True + + assert called is False + + for _ in multi_map.values(): + called = True + + assert called is False + + for _ in multi_map.items(): + called = True + + assert called is False + + for _ in multi_map: + called = True + + assert called is False + + +def test_multi_map_no_duplicates(): + data = [('a', 1), ('b', 2), ('c', 3)] + multi_map = MultiMap(data) + + assert multi_map == {"a": 1, "b": 2, "c": 3} + assert len(multi_map) == 3 + assert multi_map.as_sequence() == data + + for key, value in data: + assert key in multi_map + assert multi_map[key] == value + assert multi_map.get_many(key) == [value] + assert multi_map.get(key) == value + assert multi_map.get_exactly_one(key) == value + assert key in multi_map.keys() + assert value in multi_map.values() + + called = 0 + for key in multi_map: + called += 1 + assert key in ("a", "b", "c") + + assert called == 3 + + + +def test_multi_map_with_duplicates(): + data = [('a', 1), ('a', 2), ('a', 3), ('b', 4)] + multi_map = MultiMap(data) + assert len(multi_map) == 2 + assert multi_map.as_sequence() == data + + assert multi_map == {"a": 1, "b": 4} + assert multi_map["a"] == 1 + assert multi_map.get_many("a") == [1, 2, 3] + + assert "a" in multi_map + assert "b" in multi_map + assert list(multi_map.keys()) == ['a', 'b'] + assert list(multi_map.values()) == [1, 4] + assert list(multi_map.items()) == [('a', 1), ('b', 4)] + assert list(multi_map.many_values()) == [[1, 2, 3], [4]] + assert list(multi_map.many_items()) == [('a', [1, 2, 3]), ('b', [4])] + + with pytest.raises(HasMultipleValuesError): + multi_map.get_exactly_one('a') + + assert multi_map.get_exactly_one("b") == 4 + + called = 0 + for key in multi_map: + called += 1 + assert key in ("a", "b") + + assert called == 2 + + +def test_multi_map_with_new_value(): + data = [('a', 1), ('b', 2), ('b', 3)] + multi_map = MultiMap(data) + assert len(multi_map) == 2 + + new_map = multi_map.with_new_value('b', 4) + assert len(new_map) == 2 + assert "b" in new_map + assert multi_map != new_map + assert new_map.get_many("b") == [2, 3, 4] + + new_map = new_map.with_new_value("c", 4) + assert len(new_map) == 3 + assert "c" in new_map + assert new_map != multi_map + assert new_map["c"] == 4 + assert new_map.get_exactly_one("c") == 4 + assert new_map.get_many("b") == [2, 3, 4] diff --git a/tests/test_requests.py b/tests/test_requests.py index 8991806f..7ebc4565 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -2,14 +2,14 @@ from collections.abc import AsyncIterator import pytest -from multidict import MultiDict from view.core.app import App, as_app from view.core.body import InvalidJSONError -from view.core.headers import as_multidict +from view.core.headers import as_real_headers from view.core.request import Method, Request from view.core.response import ResponseLike from view.core.router import DuplicateRouteError from view.core.status_codes import BadRequest +from view.core.multi_map import MultiMap from view.testing import AppTestClient, bad, into_tuple, ok @@ -61,8 +61,8 @@ async def stream_none() -> AsyncIterator[bytes]: app=app, path="/", method=Method.POST, - headers=as_multidict({"test": "42"}), - query_parameters=MultiDict(), + headers=as_real_headers({"test": "42"}), + query_parameters=MultiMap(), ) response = await app.process_request(manual_request) assert (await response.body()) == b"1" @@ -95,9 +95,9 @@ async def app(request: Request) -> ResponseLike: assert request.headers["foo"] == "42" return "1" elif request.path == "/many": - assert request.headers["Bar"] == "42" - assert request.headers["bar"] == "42" - assert request.headers["baR"] == "42" + assert request.headers["Bar"] == "24" + assert request.headers["bar"] == "24" + assert request.headers["baR"] == "24" assert request.headers["test"] == "123" return "2" else: @@ -318,8 +318,8 @@ async def test_request_query_parameters(): async def main(): request = app.current_request() assert request.query_parameters["foo"] == "bar" - # FIXME: Why doesn't multidict work? - # assert request.query_parameters["test"] == ["1", "2", "3"] + assert request.query_parameters["test"] == "1" + assert request.query_parameters.get_many("test") == ["1", "2", "3"] assert "noexist" not in request.query_parameters return "ok" diff --git a/tests/test_responses.py b/tests/test_responses.py index a49e0fcd..120b0ba4 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -4,7 +4,7 @@ import pytest from view.core.app import App, as_app -from view.core.headers import as_multidict +from view.core.headers import as_real_headers from view.core.request import Request from view.core.response import FileResponse, JSONResponse, Response, ResponseLike from view.core.status_codes import ( @@ -49,7 +49,7 @@ async def stream(): return Response( receive_data=stream, status_code=Success.CREATED, - headers=as_multidict({"hello": "world"}), + headers=as_real_headers({"hello": "world"}), ) client = AppTestClient(app) diff --git a/tests/test_servers.py b/tests/test_servers.py index 2eaa24f9..a7d9d055 100644 --- a/tests/test_servers.py +++ b/tests/test_servers.py @@ -32,10 +32,12 @@ async def index(): app.run(server_hint={server_name!r}) """ process = subprocess.Popen([sys.executable, "-c", code]) - time.sleep(2) - response = requests.get("http://localhost:5000") - assert response.text == "ok" - process.kill() + try: + time.sleep(2) + response = requests.get("http://localhost:5000") + assert response.text == "ok" + finally: + process.kill() @pytest.mark.parametrize("server_name", ServerSettings.AVAILABLE_SERVERS)