summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py593
1 files changed, 593 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py
new file mode 100644
index 0000000..d907056
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py
@@ -0,0 +1,593 @@
+from __future__ import annotations
+
+from contextlib import ExitStack
+from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar
+from urllib.parse import urljoin
+
+from httpx import USE_CLIENT_DEFAULT, Client, Response
+
+from litestar import HttpMethod
+from litestar.testing.client.base import BaseTestClient
+from litestar.testing.life_span_handler import LifeSpanHandler
+from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport
+from litestar.types import AnyIOBackend, ASGIApp
+
+if TYPE_CHECKING:
+ from httpx._client import UseClientDefault
+ from httpx._types import (
+ AuthTypes,
+ CookieTypes,
+ HeaderTypes,
+ QueryParamTypes,
+ RequestContent,
+ RequestData,
+ RequestFiles,
+ TimeoutTypes,
+ URLTypes,
+ )
+ from typing_extensions import Self
+
+ from litestar.middleware.session.base import BaseBackendConfig
+ from litestar.testing.websocket_test_session import WebSocketTestSession
+
+
+T = TypeVar("T", bound=ASGIApp)
+
+
+class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc]
+ lifespan_handler: LifeSpanHandler[Any]
+ exit_stack: ExitStack
+
+ def __init__(
+ self,
+ app: T,
+ base_url: str = "http://testserver.local",
+ raise_server_exceptions: bool = True,
+ root_path: str = "",
+ backend: AnyIOBackend = "asyncio",
+ backend_options: Mapping[str, Any] | None = None,
+ session_config: BaseBackendConfig | None = None,
+ timeout: float | None = None,
+ cookies: CookieTypes | None = None,
+ ) -> None:
+ """A client implementation providing a context manager for testing applications.
+
+ Args:
+ app: The instance of :class:`Litestar <litestar.app.Litestar>` under test.
+ base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``.
+ raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of
+ wrapping them in an HTTP response.
+ root_path: Path prefix for requests.
+ backend: The async backend to use, options are "asyncio" or "trio".
+ backend_options: ``anyio`` options.
+ session_config: Configuration for Session Middleware class to create raw session cookies for request to the
+ route handlers.
+ timeout: Request timeout
+ cookies: Cookies to set on the client.
+ """
+ BaseTestClient.__init__(
+ self,
+ app=app,
+ base_url=base_url,
+ backend=backend,
+ backend_options=backend_options,
+ session_config=session_config,
+ cookies=cookies,
+ )
+
+ Client.__init__(
+ self,
+ base_url=base_url,
+ headers={"user-agent": "testclient"},
+ follow_redirects=True,
+ cookies=cookies,
+ transport=TestClientTransport( # type: ignore[arg-type]
+ client=self,
+ raise_server_exceptions=raise_server_exceptions,
+ root_path=root_path,
+ ),
+ timeout=timeout,
+ )
+
+ def __enter__(self) -> Self:
+ with ExitStack() as stack:
+ self.blocking_portal = portal = stack.enter_context(self.portal())
+ self.lifespan_handler = LifeSpanHandler(client=self)
+
+ @stack.callback
+ def reset_portal() -> None:
+ delattr(self, "blocking_portal")
+
+ @stack.callback
+ def wait_shutdown() -> None:
+ portal.call(self.lifespan_handler.wait_shutdown)
+
+ self.exit_stack = stack.pop_all()
+
+ return self
+
+ def __exit__(self, *args: Any) -> None:
+ self.exit_stack.close()
+
+ def request(
+ self,
+ method: str,
+ url: URLTypes,
+ *,
+ content: RequestContent | None = None,
+ data: RequestData | None = None,
+ files: RequestFiles | None = None,
+ json: Any | None = None,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a request.
+
+ Args:
+ method: An HTTP method.
+ url: URL or path for the request.
+ content: Request content.
+ data: Form encoded data.
+ files: Multipart files to send.
+ json: JSON data to send.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.request(
+ self,
+ url=self.base_url.join(url),
+ method=method.value if isinstance(method, HttpMethod) else method,
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def get(
+ self,
+ url: URLTypes,
+ *,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a GET request.
+
+ Args:
+ url: URL or path for the request.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.get(
+ self,
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def options(
+ self,
+ url: URLTypes,
+ *,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends an OPTIONS request.
+
+ Args:
+ url: URL or path for the request.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.options(
+ self,
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def head(
+ self,
+ url: URLTypes,
+ *,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a HEAD request.
+
+ Args:
+ url: URL or path for the request.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.head(
+ self,
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def post(
+ self,
+ url: URLTypes,
+ *,
+ content: RequestContent | None = None,
+ data: RequestData | None = None,
+ files: RequestFiles | None = None,
+ json: Any | None = None,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a POST request.
+
+ Args:
+ url: URL or path for the request.
+ content: Request content.
+ data: Form encoded data.
+ files: Multipart files to send.
+ json: JSON data to send.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.post(
+ self,
+ url,
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def put(
+ self,
+ url: URLTypes,
+ *,
+ content: RequestContent | None = None,
+ data: RequestData | None = None,
+ files: RequestFiles | None = None,
+ json: Any | None = None,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a PUT request.
+
+ Args:
+ url: URL or path for the request.
+ content: Request content.
+ data: Form encoded data.
+ files: Multipart files to send.
+ json: JSON data to send.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.put(
+ self,
+ url,
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def patch(
+ self,
+ url: URLTypes,
+ *,
+ content: RequestContent | None = None,
+ data: RequestData | None = None,
+ files: RequestFiles | None = None,
+ json: Any | None = None,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a PATCH request.
+
+ Args:
+ url: URL or path for the request.
+ content: Request content.
+ data: Form encoded data.
+ files: Multipart files to send.
+ json: JSON data to send.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.patch(
+ self,
+ url,
+ content=content,
+ data=data,
+ files=files,
+ json=json,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def delete(
+ self,
+ url: URLTypes,
+ *,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> Response:
+ """Sends a DELETE request.
+
+ Args:
+ url: URL or path for the request.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ An HTTPX Response.
+ """
+ return Client.delete(
+ self,
+ url,
+ params=params,
+ headers=headers,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+
+ def websocket_connect(
+ self,
+ url: str,
+ subprotocols: Sequence[str] | None = None,
+ params: QueryParamTypes | None = None,
+ headers: HeaderTypes | None = None,
+ cookies: CookieTypes | None = None,
+ auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT,
+ timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
+ extensions: Mapping[str, Any] | None = None,
+ ) -> WebSocketTestSession:
+ """Sends a GET request to establish a websocket connection.
+
+ Args:
+ url: Request URL.
+ subprotocols: Websocket subprotocols.
+ params: Query parameters.
+ headers: Request headers.
+ cookies: Request cookies.
+ auth: Auth headers.
+ follow_redirects: Whether to follow redirects.
+ timeout: Request timeout.
+ extensions: Dictionary of ASGI extensions.
+
+ Returns:
+ A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance.
+ """
+ url = urljoin("ws://testserver", url)
+ default_headers: dict[str, str] = {}
+ default_headers.setdefault("connection", "upgrade")
+ default_headers.setdefault("sec-websocket-key", "testserver==")
+ default_headers.setdefault("sec-websocket-version", "13")
+ if subprotocols is not None:
+ default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
+ try:
+ Client.request(
+ self,
+ "GET",
+ url,
+ headers={**dict(headers or {}), **default_headers}, # type: ignore[misc]
+ params=params,
+ cookies=cookies,
+ auth=auth,
+ follow_redirects=follow_redirects,
+ timeout=timeout,
+ extensions=None if extensions is None else dict(extensions),
+ )
+ except ConnectionUpgradeExceptionError as exc:
+ return exc.session
+
+ raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
+
+ def set_session_data(self, data: dict[str, Any]) -> None:
+ """Set session data.
+
+ Args:
+ data: Session data
+
+ Returns:
+ None
+
+ Examples:
+ .. code-block:: python
+
+ from litestar import Litestar, get
+ from litestar.middleware.session.memory_backend import MemoryBackendConfig
+
+ session_config = MemoryBackendConfig()
+
+
+ @get(path="/test")
+ def get_session_data(request: Request) -> Dict[str, Any]:
+ return request.session
+
+
+ app = Litestar(
+ route_handlers=[get_session_data], middleware=[session_config.middleware]
+ )
+
+ with TestClient(app=app, session_config=session_config) as client:
+ client.set_session_data({"foo": "bar"})
+ assert client.get("/test").json() == {"foo": "bar"}
+
+ """
+ with self.portal() as portal:
+ portal.call(self._set_session_data, data)
+
+ def get_session_data(self) -> dict[str, Any]:
+ """Get session data.
+
+ Returns:
+ A dictionary containing session data.
+
+ Examples:
+ .. code-block:: python
+
+ from litestar import Litestar, post
+ from litestar.middleware.session.memory_backend import MemoryBackendConfig
+
+ session_config = MemoryBackendConfig()
+
+
+ @post(path="/test")
+ def set_session_data(request: Request) -> None:
+ request.session["foo"] == "bar"
+
+
+ app = Litestar(
+ route_handlers=[set_session_data], middleware=[session_config.middleware]
+ )
+
+ with TestClient(app=app, session_config=session_config) as client:
+ client.post("/test")
+ assert client.get_session_data() == {"foo": "bar"}
+
+ """
+ with self.portal() as portal:
+ return portal.call(self._get_session_data)