From 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:10:44 -0400 Subject: venv --- .../litestar/testing/client/sync_client.py | 593 +++++++++++++++++++++ 1 file changed, 593 insertions(+) create mode 100644 venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py (limited to 'venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py') diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py new file mode 100644 index 0000000..d907056 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar +from urllib.parse import urljoin + +from httpx import USE_CLIENT_DEFAULT, Client, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + from litestar.testing.websocket_test_session import WebSocketTestSession + + +T = TypeVar("T", bound=ASGIApp) + + +class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: ExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """A client implementation providing a context manager for testing applications. + + Args: + app: The instance of :class:`Litestar ` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + + Client.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore[arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + def __enter__(self) -> Self: + with ExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: Any) -> None: + self.exit_stack.close() + + def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def get( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def websocket_connect( + self, + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> WebSocketTestSession: + """Sends a GET request to establish a websocket connection. + + Args: + url: Request URL. + subprotocols: Websocket subprotocols. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + A `WebSocketTestSession ` instance. + """ + url = urljoin("ws://testserver", url) + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + try: + Client.request( + self, + "GET", + url, + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + except ConnectionUpgradeExceptionError as exc: + return exc.session + + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} + + """ + with self.portal() as portal: + portal.call(self._set_session_data, data) + + def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.post("/test") + assert client.get_session_data() == {"foo": "bar"} + + """ + with self.portal() as portal: + return portal.call(self._get_session_data) -- cgit v1.2.3