from __future__ import annotations from contextlib import ExitStack from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar from urllib.parse import urljoin from httpx import USE_CLIENT_DEFAULT, Client, Response from litestar import HttpMethod from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport from litestar.types import AnyIOBackend, ASGIApp if TYPE_CHECKING: from httpx._client import UseClientDefault from httpx._types import ( AuthTypes, CookieTypes, HeaderTypes, QueryParamTypes, RequestContent, RequestData, RequestFiles, TimeoutTypes, URLTypes, ) from typing_extensions import Self from litestar.middleware.session.base import BaseBackendConfig from litestar.testing.websocket_test_session import WebSocketTestSession T = TypeVar("T", bound=ASGIApp) class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc] lifespan_handler: LifeSpanHandler[Any] exit_stack: ExitStack def __init__( self, app: T, base_url: str = "http://testserver.local", raise_server_exceptions: bool = True, root_path: str = "", backend: AnyIOBackend = "asyncio", backend_options: Mapping[str, Any] | None = None, session_config: BaseBackendConfig | None = None, timeout: float | None = None, cookies: CookieTypes | None = None, ) -> None: """A client implementation providing a context manager for testing applications. Args: app: The instance of :class:`Litestar ` under test. base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of wrapping them in an HTTP response. root_path: Path prefix for requests. backend: The async backend to use, options are "asyncio" or "trio". backend_options: ``anyio`` options. session_config: Configuration for Session Middleware class to create raw session cookies for request to the route handlers. timeout: Request timeout cookies: Cookies to set on the client. """ BaseTestClient.__init__( self, app=app, base_url=base_url, backend=backend, backend_options=backend_options, session_config=session_config, cookies=cookies, ) Client.__init__( self, base_url=base_url, headers={"user-agent": "testclient"}, follow_redirects=True, cookies=cookies, transport=TestClientTransport( # type: ignore[arg-type] client=self, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ), timeout=timeout, ) def __enter__(self) -> Self: with ExitStack() as stack: self.blocking_portal = portal = stack.enter_context(self.portal()) self.lifespan_handler = LifeSpanHandler(client=self) @stack.callback def reset_portal() -> None: delattr(self, "blocking_portal") @stack.callback def wait_shutdown() -> None: portal.call(self.lifespan_handler.wait_shutdown) self.exit_stack = stack.pop_all() return self def __exit__(self, *args: Any) -> None: self.exit_stack.close() def request( self, method: str, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a request. Args: method: An HTTP method. url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.request( self, url=self.base_url.join(url), method=method.value if isinstance(method, HttpMethod) else method, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def get( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a GET request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.get( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def options( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends an OPTIONS request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.options( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def head( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a HEAD request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.head( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def post( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a POST request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.post( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def put( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a PUT request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.put( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def patch( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a PATCH request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.patch( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def delete( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a DELETE request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return Client.delete( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) def websocket_connect( self, url: str, subprotocols: Sequence[str] | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> WebSocketTestSession: """Sends a GET request to establish a websocket connection. Args: url: Request URL. subprotocols: Websocket subprotocols. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: A `WebSocketTestSession ` instance. """ url = urljoin("ws://testserver", url) default_headers: dict[str, str] = {} default_headers.setdefault("connection", "upgrade") default_headers.setdefault("sec-websocket-key", "testserver==") default_headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) try: Client.request( self, "GET", url, headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] params=params, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) except ConnectionUpgradeExceptionError as exc: return exc.session raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover def set_session_data(self, data: dict[str, Any]) -> None: """Set session data. Args: data: Session data Returns: None Examples: .. code-block:: python from litestar import Litestar, get from litestar.middleware.session.memory_backend import MemoryBackendConfig session_config = MemoryBackendConfig() @get(path="/test") def get_session_data(request: Request) -> Dict[str, Any]: return request.session app = Litestar( route_handlers=[get_session_data], middleware=[session_config.middleware] ) with TestClient(app=app, session_config=session_config) as client: client.set_session_data({"foo": "bar"}) assert client.get("/test").json() == {"foo": "bar"} """ with self.portal() as portal: portal.call(self._set_session_data, data) def get_session_data(self) -> dict[str, Any]: """Get session data. Returns: A dictionary containing session data. Examples: .. code-block:: python from litestar import Litestar, post from litestar.middleware.session.memory_backend import MemoryBackendConfig session_config = MemoryBackendConfig() @post(path="/test") def set_session_data(request: Request) -> None: request.session["foo"] == "bar" app = Litestar( route_handlers=[set_session_data], middleware=[session_config.middleware] ) with TestClient(app=app, session_config=session_config) as client: client.post("/test") assert client.get_session_data() == {"foo": "bar"} """ with self.portal() as portal: return portal.call(self._get_session_data)