diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/client')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py | 36 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc | bin | 0 -> 2091 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc | bin | 0 -> 17343 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc | bin | 0 -> 9397 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc | bin | 0 -> 19167 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py | 534 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/base.py | 180 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py | 593 |
8 files changed, 1343 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py new file mode 100644 index 0000000..5d03a7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py @@ -0,0 +1,36 @@ +"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/testclient.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from .async_client import AsyncTestClient +from .base import BaseTestClient +from .sync_client import TestClient + +__all__ = ("TestClient", "AsyncTestClient", "BaseTestClient") diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..18ad148 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1ccc805 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..87d5de7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..29f0576 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py new file mode 100644 index 0000000..cf66f12 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar + +from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + + +T = TypeVar("T", bound=ASGIApp) + + +class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: AsyncExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """An Async client implementation providing a context manager for testing applications asynchronously. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: 'anyio' options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + AsyncClient.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore [arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + async def __aenter__(self) -> Self: + async with AsyncExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.exit_stack.aclose() + + async def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get( # type: ignore [override] + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.post("/test") + assert await client.get_session_data() == {"foo": "bar"} + + """ + return await super()._get_session_data() + + async def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.set_session_data({"foo": "bar"}) + assert await client.get("/test").json() == {"foo": "bar"} + + """ + return await super()._set_session_data(data) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/base.py b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py new file mode 100644 index 0000000..3c25be1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from contextlib import contextmanager +from http.cookiejar import CookieJar +from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, TypeVar, cast +from warnings import warn + +from anyio.from_thread import BlockingPortal, start_blocking_portal +from httpx import Cookies, Request, Response + +from litestar import Litestar +from litestar.connection import ASGIConnection +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import ( + ImproperlyConfiguredException, +) +from litestar.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from httpx._types import CookieTypes + + from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend + from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send + +T = TypeVar("T", bound=ASGIApp) + + +def fake_http_send_message(headers: MutableScopeHeaders) -> HTTPResponseStartEvent: + headers.setdefault("content-type", "application/text") + return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.headers) + + +def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnection[Any, Any, Any, Any]: + scope: HTTPScope = { + "type": ScopeType.HTTP, + "path": "/", + "raw_path": b"/", + "root_path": "", + "scheme": "http", + "query_string": b"", + "client": ("testclient", 50000), + "server": ("testserver", 80), + "headers": [], + "method": "GET", + "http_version": "1.1", + "extensions": {"http.response.template": {}}, + "app": app, # type: ignore[typeddict-item] + "state": {}, + "path_params": {}, + "route_handler": None, # type: ignore[typeddict-item] + "asgi": {"version": "3.0", "spec_version": "2.1"}, + "auth": None, + "session": None, + "user": None, + } + ScopeState.from_scope(scope).cookies = cookies + return ASGIConnection[Any, Any, Any, Any](scope=scope) + + +def _wrap_app_to_add_state(app: ASGIApp) -> ASGIApp: + """Wrap an ASGI app to add state to the scope. + + Litestar depends on `state` being present in the ASGI connection scope. Scope state is optional in the ASGI spec, + however, the Litestar app always ensures it is present so that it can be depended on internally. + + When the ASGI app that is passed to the test client is _not_ a Litestar app, we need to add + state to the scope, because httpx does not do this for us. + + This assists us in testing Litestar components that rely on state being present in the scope, without having + to create a Litestar app for every test case. + + Args: + app: The ASGI app to wrap. + + Returns: + The wrapped ASGI app. + """ + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + scope["state"] = {} + await app(scope, receive, send) + + return wrapped + + +class BaseTestClient(Generic[T]): + __test__ = False + blocking_portal: BlockingPortal + + __slots__ = ( + "app", + "base_url", + "backend", + "backend_options", + "session_config", + "_session_backend", + "cookies", + ) + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + cookies: CookieTypes | None = None, + ) -> None: + if "." not in base_url: + warn( + f"The base_url {base_url!r} might cause issues. Try adding a domain name such as .local: " + f"'{base_url}.local'", + UserWarning, + stacklevel=1, + ) + + self._session_backend: BaseSessionBackend | None = None + if session_config: + self._session_backend = session_config._backend_class(config=session_config) + + if not isinstance(app, Litestar): + app = _wrap_app_to_add_state(app) # type: ignore[assignment] + + self.app = cast("T", app) # type: ignore[redundant-cast] # pyright needs this + + self.base_url = base_url + self.backend = backend + self.backend_options = backend_options + self.cookies = cookies + + @property + def session_backend(self) -> BaseSessionBackend[Any]: + if not self._session_backend: + raise ImproperlyConfiguredException( + "Session has not been initialized for this TestClient instance. You can" + "do so by passing a configuration object to TestClient: TestClient(app=app, session_config=...)" + ) + return self._session_backend + + @contextmanager + def portal(self) -> Generator[BlockingPortal, None, None]: + """Get a BlockingPortal. + + Returns: + A contextmanager for a BlockingPortal. + """ + if hasattr(self, "blocking_portal"): + yield self.blocking_portal + else: + with start_blocking_portal( + backend=self.backend, backend_options=dict(self.backend_options or {}) + ) as portal: + yield portal + + async def _set_session_data(self, data: dict[str, Any]) -> None: + mutable_headers = MutableScopeHeaders() + connection = fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ) + session_id = self.session_backend.get_session_id(connection) + connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues] + await self.session_backend.store_in_message( + scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection + ) + response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) + + cookies = Cookies(CookieJar()) + cookies.extract_cookies(response) + self.cookies.update(cookies) # type: ignore[union-attr] + + async def _get_session_data(self) -> dict[str, Any]: + return await self.session_backend.load_from_connection( + connection=fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ), + ) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py new file mode 100644 index 0000000..d907056 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar +from urllib.parse import urljoin + +from httpx import USE_CLIENT_DEFAULT, Client, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + from litestar.testing.websocket_test_session import WebSocketTestSession + + +T = TypeVar("T", bound=ASGIApp) + + +class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: ExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """A client implementation providing a context manager for testing applications. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + + Client.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore[arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + def __enter__(self) -> Self: + with ExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: Any) -> None: + self.exit_stack.close() + + def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def get( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def websocket_connect( + self, + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> WebSocketTestSession: + """Sends a GET request to establish a websocket connection. + + Args: + url: Request URL. + subprotocols: Websocket subprotocols. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance. + """ + url = urljoin("ws://testserver", url) + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + try: + Client.request( + self, + "GET", + url, + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + except ConnectionUpgradeExceptionError as exc: + return exc.session + + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} + + """ + with self.portal() as portal: + portal.call(self._set_session_data, data) + + def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.post("/test") + assert client.get_session_data() == {"foo": "bar"} + + """ + with self.portal() as portal: + return portal.call(self._get_session_data) |