summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/testing/transport.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/transport.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/testing/transport.py192
1 files changed, 192 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/transport.py b/venv/lib/python3.11/site-packages/litestar/testing/transport.py
new file mode 100644
index 0000000..ffa76a4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/testing/transport.py
@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+from io import BytesIO
+from types import GeneratorType
+from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, Union, cast
+from urllib.parse import unquote
+
+from anyio import Event
+from httpx import ByteStream, Response
+
+from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR
+from litestar.testing.websocket_test_session import WebSocketTestSession
+
+if TYPE_CHECKING:
+ from httpx import Request
+
+ from litestar.testing.client import AsyncTestClient, TestClient
+ from litestar.types import (
+ HTTPDisconnectEvent,
+ HTTPRequestEvent,
+ Message,
+ Receive,
+ ReceiveMessage,
+ Send,
+ WebSocketScope,
+ )
+
+
+T = TypeVar("T", bound=Union["AsyncTestClient", "TestClient"])
+
+
+class ConnectionUpgradeExceptionError(Exception):
+ def __init__(self, session: WebSocketTestSession) -> None:
+ self.session = session
+
+
+class SendReceiveContext(TypedDict):
+ request_complete: bool
+ response_complete: Event
+ raw_kwargs: dict[str, Any]
+ response_started: bool
+ template: str | None
+ context: Any | None
+
+
+class TestClientTransport(Generic[T]):
+ def __init__(
+ self,
+ client: T,
+ raise_server_exceptions: bool = True,
+ root_path: str = "",
+ ) -> None:
+ self.client = client
+ self.raise_server_exceptions = raise_server_exceptions
+ self.root_path = root_path
+
+ @staticmethod
+ def create_receive(request: Request, context: SendReceiveContext) -> Receive:
+ async def receive() -> ReceiveMessage:
+ if context["request_complete"]:
+ if not context["response_complete"].is_set():
+ await context["response_complete"].wait()
+ disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"}
+ return disconnect_event
+
+ body = cast("Union[bytes, str, GeneratorType]", (request.read() or b""))
+ request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False}
+ if isinstance(body, GeneratorType): # pragma: no cover
+ try:
+ chunk = body.send(None)
+ request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8")
+ request_event["more_body"] = True
+ except StopIteration:
+ context["request_complete"] = True
+ else:
+ context["request_complete"] = True
+ request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8")
+ return request_event
+
+ return receive
+
+ @staticmethod
+ def create_send(request: Request, context: SendReceiveContext) -> Send:
+ async def send(message: Message) -> None:
+ if message["type"] == "http.response.start":
+ assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101
+ context["raw_kwargs"]["status_code"] = message["status"]
+ context["raw_kwargs"]["headers"] = [
+ (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", [])
+ ]
+ context["response_started"] = True
+ elif message["type"] == "http.response.body":
+ assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101
+ assert not context[ # noqa: S101
+ "response_complete"
+ ].is_set(), 'Received "http.response.body" after response completed.'
+ body = message.get("body", b"")
+ more_body = message.get("more_body", False)
+ if request.method != "HEAD":
+ context["raw_kwargs"]["stream"].write(body)
+ if not more_body:
+ context["raw_kwargs"]["stream"].seek(0)
+ context["response_complete"].set()
+ elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover
+ context["template"] = message["template"] # type: ignore[unreachable]
+ context["context"] = message["context"]
+
+ return send
+
+ def parse_request(self, request: Request) -> dict[str, Any]:
+ scheme = request.url.scheme
+ netloc = unquote(request.url.netloc.decode(encoding="ascii"))
+ path = request.url.path
+ raw_path = request.url.raw_path
+ query = request.url.query.decode(encoding="ascii")
+ default_port = 433 if scheme in {"https", "wss"} else 80
+
+ if ":" in netloc:
+ host, port_string = netloc.split(":", 1)
+ port = int(port_string)
+ else:
+ host = netloc
+ port = default_port
+
+ host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}")
+
+ headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())]
+
+ return {
+ "type": "websocket" if scheme in {"ws", "wss"} else "http",
+ "path": unquote(path),
+ "raw_path": raw_path,
+ "root_path": self.root_path,
+ "scheme": scheme,
+ "query_string": query.encode(),
+ "headers": headers,
+ "client": ("testclient", 50000),
+ "server": (host, port),
+ }
+
+ def handle_request(self, request: Request) -> Response:
+ scope = self.parse_request(request=request)
+ if scope["type"] == "websocket":
+ scope.update(
+ subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")]
+ )
+ session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type]
+ raise ConnectionUpgradeExceptionError(session)
+
+ scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}})
+
+ raw_kwargs: dict[str, Any] = {"stream": BytesIO()}
+
+ try:
+ with self.client.portal() as portal:
+ response_complete = portal.call(Event)
+ context: SendReceiveContext = {
+ "response_complete": response_complete,
+ "request_complete": False,
+ "raw_kwargs": raw_kwargs,
+ "response_started": False,
+ "template": None,
+ "context": None,
+ }
+ portal.call(
+ self.client.app,
+ scope,
+ self.create_receive(request=request, context=context),
+ self.create_send(request=request, context=context),
+ )
+ except BaseException as exc: # noqa: BLE001
+ if self.raise_server_exceptions:
+ raise exc
+ return Response(
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
+ )
+ else:
+ if not context["response_started"]: # pragma: no cover
+ if self.raise_server_exceptions:
+ assert context["response_started"], "TestClient did not receive any response." # noqa: S101
+ return Response(
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request
+ )
+
+ stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read())
+ response = Response(**raw_kwargs, stream=stream, request=request)
+ setattr(response, "template", context["template"])
+ setattr(response, "context", context["context"])
+ return response
+
+ async def handle_async_request(self, request: Request) -> Response:
+ return self.handle_request(request=request)