diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/transport.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/transport.py | 192 |
1 files changed, 192 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/transport.py b/venv/lib/python3.11/site-packages/litestar/testing/transport.py new file mode 100644 index 0000000..ffa76a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/transport.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from io import BytesIO +from types import GeneratorType +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, Union, cast +from urllib.parse import unquote + +from anyio import Event +from httpx import ByteStream, Response + +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.testing.websocket_test_session import WebSocketTestSession + +if TYPE_CHECKING: + from httpx import Request + + from litestar.testing.client import AsyncTestClient, TestClient + from litestar.types import ( + HTTPDisconnectEvent, + HTTPRequestEvent, + Message, + Receive, + ReceiveMessage, + Send, + WebSocketScope, + ) + + +T = TypeVar("T", bound=Union["AsyncTestClient", "TestClient"]) + + +class ConnectionUpgradeExceptionError(Exception): + def __init__(self, session: WebSocketTestSession) -> None: + self.session = session + + +class SendReceiveContext(TypedDict): + request_complete: bool + response_complete: Event + raw_kwargs: dict[str, Any] + response_started: bool + template: str | None + context: Any | None + + +class TestClientTransport(Generic[T]): + def __init__( + self, + client: T, + raise_server_exceptions: bool = True, + root_path: str = "", + ) -> None: + self.client = client + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + + @staticmethod + def create_receive(request: Request, context: SendReceiveContext) -> Receive: + async def receive() -> ReceiveMessage: + if context["request_complete"]: + if not context["response_complete"].is_set(): + await context["response_complete"].wait() + disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} + return disconnect_event + + body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) + request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} + if isinstance(body, GeneratorType): # pragma: no cover + try: + chunk = body.send(None) + request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + request_event["more_body"] = True + except StopIteration: + context["request_complete"] = True + else: + context["request_complete"] = True + request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") + return request_event + + return receive + + @staticmethod + def create_send(request: Request, context: SendReceiveContext) -> Send: + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 + context["raw_kwargs"]["status_code"] = message["status"] + context["raw_kwargs"]["headers"] = [ + (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) + ] + context["response_started"] = True + elif message["type"] == "http.response.body": + assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 + assert not context[ # noqa: S101 + "response_complete" + ].is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + context["raw_kwargs"]["stream"].write(body) + if not more_body: + context["raw_kwargs"]["stream"].seek(0) + context["response_complete"].set() + elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover + context["template"] = message["template"] # type: ignore[unreachable] + context["context"] = message["context"] + + return send + + def parse_request(self, request: Request) -> dict[str, Any]: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + default_port = 433 if scheme in {"https", "wss"} else 80 + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") + + headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] + + return { + "type": "websocket" if scheme in {"ws", "wss"} else "http", + "path": unquote(path), + "raw_path": raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ("testclient", 50000), + "server": (host, port), + } + + def handle_request(self, request: Request) -> Response: + scope = self.parse_request(request=request) + if scope["type"] == "websocket": + scope.update( + subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] + ) + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + raise ConnectionUpgradeExceptionError(session) + + scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) + + raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + + try: + with self.client.portal() as portal: + response_complete = portal.call(Event) + context: SendReceiveContext = { + "response_complete": response_complete, + "request_complete": False, + "raw_kwargs": raw_kwargs, + "response_started": False, + "template": None, + "context": None, + } + portal.call( + self.client.app, + scope, + self.create_receive(request=request, context=context), + self.create_send(request=request, context=context), + ) + except BaseException as exc: # noqa: BLE001 + if self.raise_server_exceptions: + raise exc + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + else: + if not context["response_started"]: # pragma: no cover + if self.raise_server_exceptions: + assert context["response_started"], "TestClient did not receive any response." # noqa: S101 + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + + stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) + response = Response(**raw_kwargs, stream=stream, request=request) + setattr(response, "template", context["template"]) + setattr(response, "context", context["context"]) + return response + + async def handle_async_request(self, request: Request) -> Response: + return self.handle_request(request=request) |