diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/transport.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/transport.py | 192 |
1 files changed, 0 insertions, 192 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/transport.py b/venv/lib/python3.11/site-packages/litestar/testing/transport.py deleted file mode 100644 index ffa76a4..0000000 --- a/venv/lib/python3.11/site-packages/litestar/testing/transport.py +++ /dev/null @@ -1,192 +0,0 @@ -from __future__ import annotations - -from io import BytesIO -from types import GeneratorType -from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, Union, cast -from urllib.parse import unquote - -from anyio import Event -from httpx import ByteStream, Response - -from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR -from litestar.testing.websocket_test_session import WebSocketTestSession - -if TYPE_CHECKING: - from httpx import Request - - from litestar.testing.client import AsyncTestClient, TestClient - from litestar.types import ( - HTTPDisconnectEvent, - HTTPRequestEvent, - Message, - Receive, - ReceiveMessage, - Send, - WebSocketScope, - ) - - -T = TypeVar("T", bound=Union["AsyncTestClient", "TestClient"]) - - -class ConnectionUpgradeExceptionError(Exception): - def __init__(self, session: WebSocketTestSession) -> None: - self.session = session - - -class SendReceiveContext(TypedDict): - request_complete: bool - response_complete: Event - raw_kwargs: dict[str, Any] - response_started: bool - template: str | None - context: Any | None - - -class TestClientTransport(Generic[T]): - def __init__( - self, - client: T, - raise_server_exceptions: bool = True, - root_path: str = "", - ) -> None: - self.client = client - self.raise_server_exceptions = raise_server_exceptions - self.root_path = root_path - - @staticmethod - def create_receive(request: Request, context: SendReceiveContext) -> Receive: - async def receive() -> ReceiveMessage: - if context["request_complete"]: - if not context["response_complete"].is_set(): - await context["response_complete"].wait() - disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} - return disconnect_event - - body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) - request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} - if isinstance(body, GeneratorType): # pragma: no cover - try: - chunk = body.send(None) - request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") - request_event["more_body"] = True - except StopIteration: - context["request_complete"] = True - else: - context["request_complete"] = True - request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") - return request_event - - return receive - - @staticmethod - def create_send(request: Request, context: SendReceiveContext) -> Send: - async def send(message: Message) -> None: - if message["type"] == "http.response.start": - assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 - context["raw_kwargs"]["status_code"] = message["status"] - context["raw_kwargs"]["headers"] = [ - (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) - ] - context["response_started"] = True - elif message["type"] == "http.response.body": - assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 - assert not context[ # noqa: S101 - "response_complete" - ].is_set(), 'Received "http.response.body" after response completed.' - body = message.get("body", b"") - more_body = message.get("more_body", False) - if request.method != "HEAD": - context["raw_kwargs"]["stream"].write(body) - if not more_body: - context["raw_kwargs"]["stream"].seek(0) - context["response_complete"].set() - elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover - context["template"] = message["template"] # type: ignore[unreachable] - context["context"] = message["context"] - - return send - - def parse_request(self, request: Request) -> dict[str, Any]: - scheme = request.url.scheme - netloc = unquote(request.url.netloc.decode(encoding="ascii")) - path = request.url.path - raw_path = request.url.raw_path - query = request.url.query.decode(encoding="ascii") - default_port = 433 if scheme in {"https", "wss"} else 80 - - if ":" in netloc: - host, port_string = netloc.split(":", 1) - port = int(port_string) - else: - host = netloc - port = default_port - - host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") - - headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] - - return { - "type": "websocket" if scheme in {"ws", "wss"} else "http", - "path": unquote(path), - "raw_path": raw_path, - "root_path": self.root_path, - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ("testclient", 50000), - "server": (host, port), - } - - def handle_request(self, request: Request) -> Response: - scope = self.parse_request(request=request) - if scope["type"] == "websocket": - scope.update( - subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] - ) - session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] - raise ConnectionUpgradeExceptionError(session) - - scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) - - raw_kwargs: dict[str, Any] = {"stream": BytesIO()} - - try: - with self.client.portal() as portal: - response_complete = portal.call(Event) - context: SendReceiveContext = { - "response_complete": response_complete, - "request_complete": False, - "raw_kwargs": raw_kwargs, - "response_started": False, - "template": None, - "context": None, - } - portal.call( - self.client.app, - scope, - self.create_receive(request=request, context=context), - self.create_send(request=request, context=context), - ) - except BaseException as exc: # noqa: BLE001 - if self.raise_server_exceptions: - raise exc - return Response( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request - ) - else: - if not context["response_started"]: # pragma: no cover - if self.raise_server_exceptions: - assert context["response_started"], "TestClient did not receive any response." # noqa: S101 - return Response( - status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request - ) - - stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) - response = Response(**raw_kwargs, stream=stream, request=request) - setattr(response, "template", context["template"]) - setattr(response, "context", context["context"]) - return response - - async def handle_async_request(self, request: Request) -> Response: - return self.handle_request(request=request) |