from __future__ import annotations from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response from litestar import HttpMethod from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler from litestar.testing.transport import TestClientTransport from litestar.types import AnyIOBackend, ASGIApp if TYPE_CHECKING: from httpx._client import UseClientDefault from httpx._types import ( AuthTypes, CookieTypes, HeaderTypes, QueryParamTypes, RequestContent, RequestData, RequestFiles, TimeoutTypes, URLTypes, ) from typing_extensions import Self from litestar.middleware.session.base import BaseBackendConfig T = TypeVar("T", bound=ASGIApp) class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc] lifespan_handler: LifeSpanHandler[Any] exit_stack: AsyncExitStack def __init__( self, app: T, base_url: str = "http://testserver.local", raise_server_exceptions: bool = True, root_path: str = "", backend: AnyIOBackend = "asyncio", backend_options: Mapping[str, Any] | None = None, session_config: BaseBackendConfig | None = None, timeout: float | None = None, cookies: CookieTypes | None = None, ) -> None: """An Async client implementation providing a context manager for testing applications asynchronously. Args: app: The instance of :class:`Litestar ` under test. base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of wrapping them in an HTTP response. root_path: Path prefix for requests. backend: The async backend to use, options are "asyncio" or "trio". backend_options: 'anyio' options. session_config: Configuration for Session Middleware class to create raw session cookies for request to the route handlers. timeout: Request timeout cookies: Cookies to set on the client. """ BaseTestClient.__init__( self, app=app, base_url=base_url, backend=backend, backend_options=backend_options, session_config=session_config, cookies=cookies, ) AsyncClient.__init__( self, base_url=base_url, headers={"user-agent": "testclient"}, follow_redirects=True, cookies=cookies, transport=TestClientTransport( # type: ignore [arg-type] client=self, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ), timeout=timeout, ) async def __aenter__(self) -> Self: async with AsyncExitStack() as stack: self.blocking_portal = portal = stack.enter_context(self.portal()) self.lifespan_handler = LifeSpanHandler(client=self) @stack.callback def reset_portal() -> None: delattr(self, "blocking_portal") @stack.callback def wait_shutdown() -> None: portal.call(self.lifespan_handler.wait_shutdown) self.exit_stack = stack.pop_all() return self async def __aexit__(self, *args: Any) -> None: await self.exit_stack.aclose() async def request( self, method: str, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a request. Args: method: An HTTP method. url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.request( self, url=self.base_url.join(url), method=method.value if isinstance(method, HttpMethod) else method, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def get( # type: ignore [override] self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a GET request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.get( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def options( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends an OPTIONS request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.options( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def head( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a HEAD request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.head( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def post( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a POST request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.post( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def put( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a PUT request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.put( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def patch( self, url: URLTypes, *, content: RequestContent | None = None, data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a PATCH request. Args: url: URL or path for the request. content: Request content. data: Form encoded data. files: Multipart files to send. json: JSON data to send. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.patch( self, url, content=content, data=data, files=files, json=json, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def delete( self, url: URLTypes, *, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, extensions: Mapping[str, Any] | None = None, ) -> Response: """Sends a DELETE request. Args: url: URL or path for the request. params: Query parameters. headers: Request headers. cookies: Request cookies. auth: Auth headers. follow_redirects: Whether to follow redirects. timeout: Request timeout. extensions: Dictionary of ASGI extensions. Returns: An HTTPX Response. """ return await AsyncClient.delete( self, url, params=params, headers=headers, cookies=cookies, auth=auth, follow_redirects=follow_redirects, timeout=timeout, extensions=None if extensions is None else dict(extensions), ) async def get_session_data(self) -> dict[str, Any]: """Get session data. Returns: A dictionary containing session data. Examples: .. code-block:: python from litestar import Litestar, post from litestar.middleware.session.memory_backend import MemoryBackendConfig session_config = MemoryBackendConfig() @post(path="/test") def set_session_data(request: Request) -> None: request.session["foo"] == "bar" app = Litestar( route_handlers=[set_session_data], middleware=[session_config.middleware] ) async with AsyncTestClient(app=app, session_config=session_config) as client: await client.post("/test") assert await client.get_session_data() == {"foo": "bar"} """ return await super()._get_session_data() async def set_session_data(self, data: dict[str, Any]) -> None: """Set session data. Args: data: Session data Returns: None Examples: .. code-block:: python from litestar import Litestar, get from litestar.middleware.session.memory_backend import MemoryBackendConfig session_config = MemoryBackendConfig() @get(path="/test") def get_session_data(request: Request) -> Dict[str, Any]: return request.session app = Litestar( route_handlers=[get_session_data], middleware=[session_config.middleware] ) async with AsyncTestClient(app=app, session_config=session_config) as client: await client.set_session_data({"foo": "bar"}) assert await client.get("/test").json() == {"foo": "bar"} """ return await super()._set_session_data(data)