diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/testing | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing')
20 files changed, 3003 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__init__.py b/venv/lib/python3.11/site-packages/litestar/testing/__init__.py new file mode 100644 index 0000000..55af446 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__init__.py @@ -0,0 +1,16 @@ +from litestar.testing.client.async_client import AsyncTestClient +from litestar.testing.client.base import BaseTestClient +from litestar.testing.client.sync_client import TestClient +from litestar.testing.helpers import create_async_test_client, create_test_client +from litestar.testing.request_factory import RequestFactory +from litestar.testing.websocket_test_session import WebSocketTestSession + +__all__ = ( + "AsyncTestClient", + "BaseTestClient", + "create_async_test_client", + "create_test_client", + "RequestFactory", + "TestClient", + "WebSocketTestSession", +) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..77c7908 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a85995f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/helpers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b3ea9ef --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/life_span_handler.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9a66826 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/request_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..78a1aa6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/transport.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6b00590 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/__pycache__/websocket_test_session.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py new file mode 100644 index 0000000..5d03a7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__init__.py @@ -0,0 +1,36 @@ +"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/testclient.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from .async_client import AsyncTestClient +from .base import BaseTestClient +from .sync_client import TestClient + +__all__ = ("TestClient", "AsyncTestClient", "BaseTestClient") diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..18ad148 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..1ccc805 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/async_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..87d5de7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..29f0576 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/__pycache__/sync_client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py new file mode 100644 index 0000000..cf66f12 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/async_client.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, TypeVar + +from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + + +T = TypeVar("T", bound=ASGIApp) + + +class AsyncTestClient(AsyncClient, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: AsyncExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """An Async client implementation providing a context manager for testing applications asynchronously. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: 'anyio' options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + AsyncClient.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore [arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + async def __aenter__(self) -> Self: + async with AsyncExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.exit_stack.aclose() + + async def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get( # type: ignore [override] + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return await AsyncClient.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + async def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.post("/test") + assert await client.get_session_data() == {"foo": "bar"} + + """ + return await super()._get_session_data() + + async def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + async with AsyncTestClient(app=app, session_config=session_config) as client: + await client.set_session_data({"foo": "bar"}) + assert await client.get("/test").json() == {"foo": "bar"} + + """ + return await super()._set_session_data(data) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/base.py b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py new file mode 100644 index 0000000..3c25be1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from contextlib import contextmanager +from http.cookiejar import CookieJar +from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, TypeVar, cast +from warnings import warn + +from anyio.from_thread import BlockingPortal, start_blocking_portal +from httpx import Cookies, Request, Response + +from litestar import Litestar +from litestar.connection import ASGIConnection +from litestar.datastructures import MutableScopeHeaders +from litestar.enums import ScopeType +from litestar.exceptions import ( + ImproperlyConfiguredException, +) +from litestar.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from httpx._types import CookieTypes + + from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend + from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send + +T = TypeVar("T", bound=ASGIApp) + + +def fake_http_send_message(headers: MutableScopeHeaders) -> HTTPResponseStartEvent: + headers.setdefault("content-type", "application/text") + return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.headers) + + +def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnection[Any, Any, Any, Any]: + scope: HTTPScope = { + "type": ScopeType.HTTP, + "path": "/", + "raw_path": b"/", + "root_path": "", + "scheme": "http", + "query_string": b"", + "client": ("testclient", 50000), + "server": ("testserver", 80), + "headers": [], + "method": "GET", + "http_version": "1.1", + "extensions": {"http.response.template": {}}, + "app": app, # type: ignore[typeddict-item] + "state": {}, + "path_params": {}, + "route_handler": None, # type: ignore[typeddict-item] + "asgi": {"version": "3.0", "spec_version": "2.1"}, + "auth": None, + "session": None, + "user": None, + } + ScopeState.from_scope(scope).cookies = cookies + return ASGIConnection[Any, Any, Any, Any](scope=scope) + + +def _wrap_app_to_add_state(app: ASGIApp) -> ASGIApp: + """Wrap an ASGI app to add state to the scope. + + Litestar depends on `state` being present in the ASGI connection scope. Scope state is optional in the ASGI spec, + however, the Litestar app always ensures it is present so that it can be depended on internally. + + When the ASGI app that is passed to the test client is _not_ a Litestar app, we need to add + state to the scope, because httpx does not do this for us. + + This assists us in testing Litestar components that rely on state being present in the scope, without having + to create a Litestar app for every test case. + + Args: + app: The ASGI app to wrap. + + Returns: + The wrapped ASGI app. + """ + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + scope["state"] = {} + await app(scope, receive, send) + + return wrapped + + +class BaseTestClient(Generic[T]): + __test__ = False + blocking_portal: BlockingPortal + + __slots__ = ( + "app", + "base_url", + "backend", + "backend_options", + "session_config", + "_session_backend", + "cookies", + ) + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + cookies: CookieTypes | None = None, + ) -> None: + if "." not in base_url: + warn( + f"The base_url {base_url!r} might cause issues. Try adding a domain name such as .local: " + f"'{base_url}.local'", + UserWarning, + stacklevel=1, + ) + + self._session_backend: BaseSessionBackend | None = None + if session_config: + self._session_backend = session_config._backend_class(config=session_config) + + if not isinstance(app, Litestar): + app = _wrap_app_to_add_state(app) # type: ignore[assignment] + + self.app = cast("T", app) # type: ignore[redundant-cast] # pyright needs this + + self.base_url = base_url + self.backend = backend + self.backend_options = backend_options + self.cookies = cookies + + @property + def session_backend(self) -> BaseSessionBackend[Any]: + if not self._session_backend: + raise ImproperlyConfiguredException( + "Session has not been initialized for this TestClient instance. You can" + "do so by passing a configuration object to TestClient: TestClient(app=app, session_config=...)" + ) + return self._session_backend + + @contextmanager + def portal(self) -> Generator[BlockingPortal, None, None]: + """Get a BlockingPortal. + + Returns: + A contextmanager for a BlockingPortal. + """ + if hasattr(self, "blocking_portal"): + yield self.blocking_portal + else: + with start_blocking_portal( + backend=self.backend, backend_options=dict(self.backend_options or {}) + ) as portal: + yield portal + + async def _set_session_data(self, data: dict[str, Any]) -> None: + mutable_headers = MutableScopeHeaders() + connection = fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ) + session_id = self.session_backend.get_session_id(connection) + connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues] + await self.session_backend.store_in_message( + scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection + ) + response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) + + cookies = Cookies(CookieJar()) + cookies.extract_cookies(response) + self.cookies.update(cookies) # type: ignore[union-attr] + + async def _get_session_data(self) -> dict[str, Any]: + return await self.session_backend.load_from_connection( + connection=fake_asgi_connection( + app=self.app, + cookies=dict(self.cookies), # type: ignore[arg-type] + ), + ) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py new file mode 100644 index 0000000..d907056 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/client/sync_client.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from contextlib import ExitStack +from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar +from urllib.parse import urljoin + +from httpx import USE_CLIENT_DEFAULT, Client, Response + +from litestar import HttpMethod +from litestar.testing.client.base import BaseTestClient +from litestar.testing.life_span_handler import LifeSpanHandler +from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport +from litestar.types import AnyIOBackend, ASGIApp + +if TYPE_CHECKING: + from httpx._client import UseClientDefault + from httpx._types import ( + AuthTypes, + CookieTypes, + HeaderTypes, + QueryParamTypes, + RequestContent, + RequestData, + RequestFiles, + TimeoutTypes, + URLTypes, + ) + from typing_extensions import Self + + from litestar.middleware.session.base import BaseBackendConfig + from litestar.testing.websocket_test_session import WebSocketTestSession + + +T = TypeVar("T", bound=ASGIApp) + + +class TestClient(Client, BaseTestClient, Generic[T]): # type: ignore[misc] + lifespan_handler: LifeSpanHandler[Any] + exit_stack: ExitStack + + def __init__( + self, + app: T, + base_url: str = "http://testserver.local", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: AnyIOBackend = "asyncio", + backend_options: Mapping[str, Any] | None = None, + session_config: BaseBackendConfig | None = None, + timeout: float | None = None, + cookies: CookieTypes | None = None, + ) -> None: + """A client implementation providing a context manager for testing applications. + + Args: + app: The instance of :class:`Litestar <litestar.app.Litestar>` under test. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for the underlying test client to raise server exceptions instead of + wrapping them in an HTTP response. + root_path: Path prefix for requests. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + timeout: Request timeout + cookies: Cookies to set on the client. + """ + BaseTestClient.__init__( + self, + app=app, + base_url=base_url, + backend=backend, + backend_options=backend_options, + session_config=session_config, + cookies=cookies, + ) + + Client.__init__( + self, + base_url=base_url, + headers={"user-agent": "testclient"}, + follow_redirects=True, + cookies=cookies, + transport=TestClientTransport( # type: ignore[arg-type] + client=self, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + ), + timeout=timeout, + ) + + def __enter__(self) -> Self: + with ExitStack() as stack: + self.blocking_portal = portal = stack.enter_context(self.portal()) + self.lifespan_handler = LifeSpanHandler(client=self) + + @stack.callback + def reset_portal() -> None: + delattr(self, "blocking_portal") + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.lifespan_handler.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: Any) -> None: + self.exit_stack.close() + + def request( + self, + method: str, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a request. + + Args: + method: An HTTP method. + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.request( + self, + url=self.base_url.join(url), + method=method.value if isinstance(method, HttpMethod) else method, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def get( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a GET request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.get( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def options( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends an OPTIONS request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.options( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def head( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a HEAD request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.head( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def post( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a POST request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.post( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def put( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PUT request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.put( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def patch( + self, + url: URLTypes, + *, + content: RequestContent | None = None, + data: RequestData | None = None, + files: RequestFiles | None = None, + json: Any | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a PATCH request. + + Args: + url: URL or path for the request. + content: Request content. + data: Form encoded data. + files: Multipart files to send. + json: JSON data to send. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.patch( + self, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def delete( + self, + url: URLTypes, + *, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> Response: + """Sends a DELETE request. + + Args: + url: URL or path for the request. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + An HTTPX Response. + """ + return Client.delete( + self, + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + + def websocket_connect( + self, + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, + follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> WebSocketTestSession: + """Sends a GET request to establish a websocket connection. + + Args: + url: Request URL. + subprotocols: Websocket subprotocols. + params: Query parameters. + headers: Request headers. + cookies: Request cookies. + auth: Auth headers. + follow_redirects: Whether to follow redirects. + timeout: Request timeout. + extensions: Dictionary of ASGI extensions. + + Returns: + A `WebSocketTestSession <litestar.testing.WebSocketTestSession>` instance. + """ + url = urljoin("ws://testserver", url) + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + try: + Client.request( + self, + "GET", + url, + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=None if extensions is None else dict(extensions), + ) + except ConnectionUpgradeExceptionError as exc: + return exc.session + + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + def set_session_data(self, data: dict[str, Any]) -> None: + """Set session data. + + Args: + data: Session data + + Returns: + None + + Examples: + .. code-block:: python + + from litestar import Litestar, get + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @get(path="/test") + def get_session_data(request: Request) -> Dict[str, Any]: + return request.session + + + app = Litestar( + route_handlers=[get_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.set_session_data({"foo": "bar"}) + assert client.get("/test").json() == {"foo": "bar"} + + """ + with self.portal() as portal: + portal.call(self._set_session_data, data) + + def get_session_data(self) -> dict[str, Any]: + """Get session data. + + Returns: + A dictionary containing session data. + + Examples: + .. code-block:: python + + from litestar import Litestar, post + from litestar.middleware.session.memory_backend import MemoryBackendConfig + + session_config = MemoryBackendConfig() + + + @post(path="/test") + def set_session_data(request: Request) -> None: + request.session["foo"] == "bar" + + + app = Litestar( + route_handlers=[set_session_data], middleware=[session_config.middleware] + ) + + with TestClient(app=app, session_config=session_config) as client: + client.post("/test") + assert client.get_session_data() == {"foo": "bar"} + + """ + with self.portal() as portal: + return portal.call(self._get_session_data) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/helpers.py b/venv/lib/python3.11/site-packages/litestar/testing/helpers.py new file mode 100644 index 0000000..5ac59af --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/helpers.py @@ -0,0 +1,561 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Sequence + +from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar +from litestar.controller import Controller +from litestar.events import SimpleEventEmitter +from litestar.testing.client import AsyncTestClient, TestClient +from litestar.types import Empty +from litestar.utils.predicates import is_class_and_subclass + +if TYPE_CHECKING: + from contextlib import AbstractAsyncContextManager + + from litestar import Request, Response, WebSocket + from litestar.config.allowed_hosts import AllowedHostsConfig + from litestar.config.app import ExperimentalFeatures + from litestar.config.compression import CompressionConfig + from litestar.config.cors import CORSConfig + from litestar.config.csrf import CSRFConfig + from litestar.config.response_cache import ResponseCacheConfig + from litestar.datastructures import CacheControlHeader, ETag, State + from litestar.dto import AbstractDTO + from litestar.events import BaseEventEmitterBackend, EventListener + from litestar.logging.config import BaseLoggingConfig + from litestar.middleware.session.base import BaseBackendConfig + from litestar.openapi.config import OpenAPIConfig + from litestar.openapi.spec import SecurityRequirement + from litestar.plugins import PluginProtocol + from litestar.static_files.config import StaticFilesConfig + from litestar.stores.base import Store + from litestar.stores.registry import StoreRegistry + from litestar.template.config import TemplateConfig + from litestar.types import ( + AfterExceptionHookHandler, + AfterRequestHookHandler, + AfterResponseHookHandler, + BeforeMessageSendHookHandler, + BeforeRequestHookHandler, + ControllerRouterHandler, + Dependencies, + EmptyType, + ExceptionHandlersMap, + Guard, + LifespanHook, + Middleware, + OnAppInitHandler, + ParametersMap, + ResponseCookies, + ResponseHeaders, + TypeEncodersMap, + ) + + +def create_test_client( + route_handlers: ControllerRouterHandler | Sequence[ControllerRouterHandler] | None = None, + *, + after_exception: Sequence[AfterExceptionHookHandler] | None = None, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + allowed_hosts: Sequence[str] | AllowedHostsConfig | None = None, + backend: Literal["asyncio", "trio"] = "asyncio", + backend_options: Mapping[str, Any] | None = None, + base_url: str = "http://testserver.local", + before_request: BeforeRequestHookHandler | None = None, + before_send: Sequence[BeforeMessageSendHookHandler] | None = None, + cache_control: CacheControlHeader | None = None, + compression_config: CompressionConfig | None = None, + cors_config: CORSConfig | None = None, + csrf_config: CSRFConfig | None = None, + debug: bool = True, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + event_emitter_backend: type[BaseEventEmitterBackend] = SimpleEventEmitter, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + listeners: Sequence[EventListener] | None = None, + logging_config: BaseLoggingConfig | EmptyType | None = Empty, + middleware: Sequence[Middleware] | None = None, + multipart_form_part_limit: int = 1000, + on_app_init: Sequence[OnAppInitHandler] | None = None, + on_shutdown: Sequence[LifespanHook] | None = None, + on_startup: Sequence[LifespanHook] | None = None, + openapi_config: OpenAPIConfig | None = DEFAULT_OPENAPI_CONFIG, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + plugins: Sequence[PluginProtocol] | None = None, + lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, + raise_server_exceptions: bool = True, + pdb_on_exception: bool | None = None, + request_class: type[Request] | None = None, + response_cache_config: ResponseCacheConfig | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + root_path: str = "", + security: Sequence[SecurityRequirement] | None = None, + session_config: BaseBackendConfig | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + state: State | None = None, + static_files_config: Sequence[StaticFilesConfig] | None = None, + stores: StoreRegistry | dict[str, Store] | None = None, + tags: Sequence[str] | None = None, + template_config: TemplateConfig | None = None, + timeout: float | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + experimental_features: list[ExperimentalFeatures] | None = None, +) -> TestClient[Litestar]: + """Create a Litestar app instance and initializes it. + + :class:`TestClient <litestar.testing.TestClient>` with it. + + Notes: + - This function should be called as a context manager to ensure async startup and shutdown are + handled correctly. + + Examples: + .. code-block:: python + + from litestar import get + from litestar.testing import create_test_client + + + @get("/some-path") + def my_handler() -> dict[str, str]: + return {"hello": "world"} + + + def test_my_handler() -> None: + with create_test_client(my_handler) as client: + response = client.get("/some-path") + assert response.json() == {"hello": "world"} + + Args: + route_handlers: A single handler or a sequence of route handlers, which can include instances of + :class:`Router <litestar.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or + any function decorated by the route handler decorators. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for underlying the test client to raise server exceptions instead of wrapping them + in an HTTP response. + root_path: Path prefix for requests. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + after_exception: A sequence of :class:`exception hook handlers <.types.AfterExceptionHookHandler>`. This + hook is called after an exception occurs. In difference to exception handlers, it is not meant to + return a response - only to process the exception (e.g. log it, send it to Sentry etc.). + after_request: A sync or async function executed after the route handler function returned and the response + object has been resolved. Receives the response object. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + allowed_hosts: A sequence of allowed hosts, or an + :class:`AllowedHostsConfig <.config.allowed_hosts.AllowedHostsConfig>` instance. Enables the builtin + allowed hosts middleware. + before_request: A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + before_send: A sequence of :class:`before send hook handlers <.types.BeforeMessageSendHookHandler>`. Called + when the ASGI send function is called. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <litestar.datastructures.CacheControlHeader>` to add to route handlers of + this app. Can be overridden by route handlers. + compression_config: Configures compression behaviour of the application, this enabled a builtin or user + defined Compression middleware. + cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. + debug: If ``True``, app errors rendered as HTML with a stack trace. + dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + Can be overridden by route handlers. + event_emitter_backend: A subclass of + :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + lifespan: A list of callables returning async context managers, wrapping the lifespan of the ASGI application + listeners: A sequence of :class:`EventListener <.events.listener.EventListener>`. + logging_config: A subclass of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>`. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. This limit + is intended to protect from DoS attacks. + on_app_init: A sequence of :class:`OnAppInitHandler <.types.OnAppInitHandler>` instances. Handlers receive + an instance of :class:`AppConfig <.config.app.AppConfig>` that will have been initially populated with + the parameters passed to :class:`Litestar <litestar.app.Litestar>`, and must return an instance of same. + If more than one handler is registered they are called in the order they are provided. + on_shutdown: A sequence of :class:`LifespanHook <.types.LifespanHook>` called during application + shutdown. + on_startup: A sequence of :class:`LifespanHook <litestar.types.LifespanHook>` called during + application startup. + openapi_config: Defaults to :attr:`DEFAULT_OPENAPI_CONFIG` + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <litestar.connection.request.Request>` or + :class:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application + paths. + pdb_on_exception: Drop into the PDB when an exception occurs. + plugins: Sequence of plugins. + request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default + response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + response_cache_config: Configures caching behavior of the application. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + callable decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + state: An optional :class:`State <.datastructures.State>` for application state. + static_files_config: A sequence of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` + stores: Central registry of :class:`Store <.stores.base.Store>` that will be available throughout the + application. If this is a dictionary to it will be passed to a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`. If it is a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`, this instance will be used directly. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` + timeout: Request timeout + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket + connections. + experimental_features: An iterable of experimental features to enable + + + Returns: + An instance of :class:`TestClient <.testing.TestClient>` with a created app instance. + """ + route_handlers = () if route_handlers is None else route_handlers + if is_class_and_subclass(route_handlers, Controller) or not isinstance(route_handlers, Sequence): + route_handlers = (route_handlers,) + + app = Litestar( + after_exception=after_exception, + after_request=after_request, + after_response=after_response, + allowed_hosts=allowed_hosts, + before_request=before_request, + before_send=before_send, + cache_control=cache_control, + compression_config=compression_config, + cors_config=cors_config, + csrf_config=csrf_config, + debug=debug, + dependencies=dependencies, + dto=dto, + etag=etag, + lifespan=lifespan, + event_emitter_backend=event_emitter_backend, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + listeners=listeners, + logging_config=logging_config, + middleware=middleware, + multipart_form_part_limit=multipart_form_part_limit, + on_app_init=on_app_init, + on_shutdown=on_shutdown, + on_startup=on_startup, + openapi_config=openapi_config, + opt=opt, + parameters=parameters, + pdb_on_exception=pdb_on_exception, + plugins=plugins, + request_class=request_class, + response_cache_config=response_cache_config, + response_class=response_class, + response_cookies=response_cookies, + response_headers=response_headers, + return_dto=return_dto, + route_handlers=route_handlers, + security=security, + signature_namespace=signature_namespace, + signature_types=signature_types, + state=state, + static_files_config=static_files_config, + stores=stores, + tags=tags, + template_config=template_config, + type_encoders=type_encoders, + websocket_class=websocket_class, + experimental_features=experimental_features, + ) + + return TestClient[Litestar]( + app=app, + backend=backend, + backend_options=backend_options, + base_url=base_url, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + session_config=session_config, + timeout=timeout, + ) + + +def create_async_test_client( + route_handlers: ControllerRouterHandler | Sequence[ControllerRouterHandler] | None = None, + *, + after_exception: Sequence[AfterExceptionHookHandler] | None = None, + after_request: AfterRequestHookHandler | None = None, + after_response: AfterResponseHookHandler | None = None, + allowed_hosts: Sequence[str] | AllowedHostsConfig | None = None, + backend: Literal["asyncio", "trio"] = "asyncio", + backend_options: Mapping[str, Any] | None = None, + base_url: str = "http://testserver.local", + before_request: BeforeRequestHookHandler | None = None, + before_send: Sequence[BeforeMessageSendHookHandler] | None = None, + cache_control: CacheControlHeader | None = None, + compression_config: CompressionConfig | None = None, + cors_config: CORSConfig | None = None, + csrf_config: CSRFConfig | None = None, + debug: bool = True, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + etag: ETag | None = None, + event_emitter_backend: type[BaseEventEmitterBackend] = SimpleEventEmitter, + exception_handlers: ExceptionHandlersMap | None = None, + guards: Sequence[Guard] | None = None, + include_in_schema: bool | EmptyType = Empty, + lifespan: list[Callable[[Litestar], AbstractAsyncContextManager] | AbstractAsyncContextManager] | None = None, + listeners: Sequence[EventListener] | None = None, + logging_config: BaseLoggingConfig | EmptyType | None = Empty, + middleware: Sequence[Middleware] | None = None, + multipart_form_part_limit: int = 1000, + on_app_init: Sequence[OnAppInitHandler] | None = None, + on_shutdown: Sequence[LifespanHook] | None = None, + on_startup: Sequence[LifespanHook] | None = None, + openapi_config: OpenAPIConfig | None = DEFAULT_OPENAPI_CONFIG, + opt: Mapping[str, Any] | None = None, + parameters: ParametersMap | None = None, + pdb_on_exception: bool | None = None, + plugins: Sequence[PluginProtocol] | None = None, + raise_server_exceptions: bool = True, + request_class: type[Request] | None = None, + response_cache_config: ResponseCacheConfig | None = None, + response_class: type[Response] | None = None, + response_cookies: ResponseCookies | None = None, + response_headers: ResponseHeaders | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + root_path: str = "", + security: Sequence[SecurityRequirement] | None = None, + session_config: BaseBackendConfig | None = None, + signature_namespace: Mapping[str, Any] | None = None, + signature_types: Sequence[Any] | None = None, + state: State | None = None, + static_files_config: Sequence[StaticFilesConfig] | None = None, + stores: StoreRegistry | dict[str, Store] | None = None, + tags: Sequence[str] | None = None, + template_config: TemplateConfig | None = None, + timeout: float | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + experimental_features: list[ExperimentalFeatures] | None = None, +) -> AsyncTestClient[Litestar]: + """Create a Litestar app instance and initializes it. + + :class:`AsyncTestClient <litestar.testing.AsyncTestClient>` with it. + + Notes: + - This function should be called as a context manager to ensure async startup and shutdown are + handled correctly. + + Examples: + .. code-block:: python + + from litestar import get + from litestar.testing import create_async_test_client + + + @get("/some-path") + def my_handler() -> dict[str, str]: + return {"hello": "world"} + + + async def test_my_handler() -> None: + async with create_async_test_client(my_handler) as client: + response = await client.get("/some-path") + assert response.json() == {"hello": "world"} + + Args: + route_handlers: A single handler or a sequence of route handlers, which can include instances of + :class:`Router <litestar.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or + any function decorated by the route handler decorators. + backend: The async backend to use, options are "asyncio" or "trio". + backend_options: ``anyio`` options. + base_url: URL scheme and domain for test request paths, e.g. ``http://testserver``. + raise_server_exceptions: Flag for underlying the test client to raise server exceptions instead of wrapping them + in an HTTP response. + root_path: Path prefix for requests. + session_config: Configuration for Session Middleware class to create raw session cookies for request to the + route handlers. + after_exception: A sequence of :class:`exception hook handlers <.types.AfterExceptionHookHandler>`. This + hook is called after an exception occurs. In difference to exception handlers, it is not meant to + return a response - only to process the exception (e.g. log it, send it to Sentry etc.). + after_request: A sync or async function executed after the route handler function returned and the response + object has been resolved. Receives the response object. + after_response: A sync or async function called after the response has been awaited. It receives the + :class:`Request <.connection.Request>` object and should not return any values. + allowed_hosts: A sequence of allowed hosts, or an + :class:`AllowedHostsConfig <.config.allowed_hosts.AllowedHostsConfig>` instance. Enables the builtin + allowed hosts middleware. + before_request: A sync or async function called immediately before calling the route handler. Receives the + :class:`Request <.connection.Request>` instance and any non-``None`` return value is used for the + response, bypassing the route handler. + before_send: A sequence of :class:`before send hook handlers <.types.BeforeMessageSendHookHandler>`. Called + when the ASGI send function is called. + cache_control: A ``cache-control`` header of type + :class:`CacheControlHeader <litestar.datastructures.CacheControlHeader>` to add to route handlers of + this app. Can be overridden by route handlers. + compression_config: Configures compression behaviour of the application, this enabled a builtin or user + defined Compression middleware. + cors_config: If set, configures :class:`CORSMiddleware <.middleware.cors.CORSMiddleware>`. + csrf_config: If set, configures :class:`CSRFMiddleware <.middleware.csrf.CSRFMiddleware>`. + debug: If ``True``, app errors rendered as HTML with a stack trace. + dependencies: A string keyed mapping of dependency :class:`Providers <.di.Provide>`. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app. + Can be overridden by route handlers. + event_emitter_backend: A subclass of + :class:`BaseEventEmitterBackend <.events.emitter.BaseEventEmitterBackend>`. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + include_in_schema: A boolean flag dictating whether the route handler should be documented in the OpenAPI schema. + lifespan: A list of callables returning async context managers, wrapping the lifespan of the ASGI application + listeners: A sequence of :class:`EventListener <.events.listener.EventListener>`. + logging_config: A subclass of :class:`BaseLoggingConfig <.logging.config.BaseLoggingConfig>`. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + multipart_form_part_limit: The maximal number of allowed parts in a multipart/formdata request. This limit + is intended to protect from DoS attacks. + on_app_init: A sequence of :class:`OnAppInitHandler <.types.OnAppInitHandler>` instances. Handlers receive + an instance of :class:`AppConfig <.config.app.AppConfig>` that will have been initially populated with + the parameters passed to :class:`Litestar <litestar.app.Litestar>`, and must return an instance of same. + If more than one handler is registered they are called in the order they are provided. + on_shutdown: A sequence of :class:`LifespanHook <.types.LifespanHook>` called during application + shutdown. + on_startup: A sequence of :class:`LifespanHook <litestar.types.LifespanHook>` called during + application startup. + openapi_config: Defaults to :attr:`DEFAULT_OPENAPI_CONFIG` + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <litestar.connection.request.Request>` or + :class:`ASGI Scope <.types.Scope>`. + parameters: A mapping of :class:`Parameter <.params.Parameter>` definitions available to all application + paths. + pdb_on_exception: Drop into the PDB when an exception occurs. + plugins: Sequence of plugins. + request_class: An optional subclass of :class:`Request <.connection.Request>` to use for http connections. + response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the app's default + response. + response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>`. + response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>` + response_cache_config: Configures caching behavior of the application. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + route_handlers: A sequence of route handlers, which can include instances of + :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any + callable decorated by the route handler decorators. + security: A sequence of dicts that will be added to the schema of all route handlers in the application. + See + :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>` for details. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling. + signature_types: A sequence of types for use in forward reference resolution during signature modeling. + These types will be added to the signature namespace using their ``__name__`` attribute. + state: An optional :class:`State <.datastructures.State>` for application state. + static_files_config: A sequence of :class:`StaticFilesConfig <.static_files.StaticFilesConfig>` + stores: Central registry of :class:`Store <.stores.base.Store>` that will be available throughout the + application. If this is a dictionary to it will be passed to a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`. If it is a + :class:`StoreRegistry <.stores.registry.StoreRegistry>`, this instance will be used directly. + tags: A sequence of string tags that will be appended to the schema of all route handlers under the + application. + template_config: An instance of :class:`TemplateConfig <.template.TemplateConfig>` + timeout: Request timeout + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + websocket_class: An optional subclass of :class:`WebSocket <.connection.WebSocket>` to use for websocket + connections. + experimental_features: An iterable of experimental features to enable + + Returns: + An instance of :class:`AsyncTestClient <litestar.testing.AsyncTestClient>` with a created app instance. + """ + route_handlers = () if route_handlers is None else route_handlers + if is_class_and_subclass(route_handlers, Controller) or not isinstance(route_handlers, Sequence): + route_handlers = (route_handlers,) + + app = Litestar( + after_exception=after_exception, + after_request=after_request, + after_response=after_response, + allowed_hosts=allowed_hosts, + before_request=before_request, + before_send=before_send, + cache_control=cache_control, + compression_config=compression_config, + cors_config=cors_config, + csrf_config=csrf_config, + debug=debug, + dependencies=dependencies, + dto=dto, + etag=etag, + event_emitter_backend=event_emitter_backend, + exception_handlers=exception_handlers, + guards=guards, + include_in_schema=include_in_schema, + lifespan=lifespan, + listeners=listeners, + logging_config=logging_config, + middleware=middleware, + multipart_form_part_limit=multipart_form_part_limit, + on_app_init=on_app_init, + on_shutdown=on_shutdown, + on_startup=on_startup, + openapi_config=openapi_config, + opt=opt, + parameters=parameters, + pdb_on_exception=pdb_on_exception, + plugins=plugins, + request_class=request_class, + response_cache_config=response_cache_config, + response_class=response_class, + response_cookies=response_cookies, + response_headers=response_headers, + return_dto=return_dto, + route_handlers=route_handlers, + security=security, + signature_namespace=signature_namespace, + signature_types=signature_types, + state=state, + static_files_config=static_files_config, + stores=stores, + tags=tags, + template_config=template_config, + type_encoders=type_encoders, + websocket_class=websocket_class, + experimental_features=experimental_features, + ) + + return AsyncTestClient[Litestar]( + app=app, + backend=backend, + backend_options=backend_options, + base_url=base_url, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + session_config=session_config, + timeout=timeout, + ) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py b/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py new file mode 100644 index 0000000..8ee7d22 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/life_span_handler.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from math import inf +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast + +from anyio import create_memory_object_stream +from anyio.streams.stapled import StapledObjectStream + +from litestar.testing.client.base import BaseTestClient + +if TYPE_CHECKING: + from litestar.types import ( + LifeSpanReceiveMessage, # noqa: F401 + LifeSpanSendMessage, + LifeSpanShutdownEvent, + LifeSpanStartupEvent, + ) + +T = TypeVar("T", bound=BaseTestClient) + + +class LifeSpanHandler(Generic[T]): + __slots__ = "stream_send", "stream_receive", "client", "task" + + def __init__(self, client: T) -> None: + self.client = client + self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type] + self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type] + + with self.client.portal() as portal: + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + async def receive(self) -> LifeSpanSendMessage: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return cast("LifeSpanSendMessage", message) + + async def wait_startup(self) -> None: + event: LifeSpanStartupEvent = {"type": "lifespan.startup"} + await self.stream_receive.send(event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.startup.complete' or " + f"'lifespan.startup.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.startup.failed": + await self.receive() + + async def wait_shutdown(self) -> None: + async with self.stream_send: + lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"} + await self.stream_receive.send(lifespan_shutdown_event) + + message = await self.receive() + if message["type"] not in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ): + raise RuntimeError( + "Received unexpected ASGI message type. Expected 'lifespan.shutdown.complete' or " + f"'lifespan.shutdown.failed'. Got {message['type']!r}", + ) + if message["type"] == "lifespan.shutdown.failed": + await self.receive() + + async def lifespan(self) -> None: + scope = {"type": "lifespan"} + try: + await self.client.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + await self.stream_send.send(None) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py b/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py new file mode 100644 index 0000000..ccb29c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/request_factory.py @@ -0,0 +1,565 @@ +from __future__ import annotations + +import json +from functools import partial +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import urlencode + +from httpx._content import encode_json as httpx_encode_json +from httpx._content import encode_multipart_data, encode_urlencoded_data + +from litestar import delete, patch, post, put +from litestar.app import Litestar +from litestar.connection import Request +from litestar.enums import HttpMethod, ParamType, RequestEncodingType, ScopeType +from litestar.handlers.http_handlers import get +from litestar.serialization import decode_json, default_serializer, encode_json +from litestar.types import DataContainerType, HTTPScope, RouteHandlerType +from litestar.types.asgi_types import ASGIVersion +from litestar.utils import get_serializer_from_scope +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from httpx._types import FileTypes + + from litestar.datastructures.cookie import Cookie + from litestar.handlers.http_handlers import HTTPRouteHandler + +_decorator_http_method_map: dict[HttpMethod, type[HTTPRouteHandler]] = { + HttpMethod.GET: get, + HttpMethod.POST: post, + HttpMethod.DELETE: delete, + HttpMethod.PATCH: patch, + HttpMethod.PUT: put, +} + + +def _create_default_route_handler( + http_method: HttpMethod, handler_kwargs: dict[str, Any] | None, app: Litestar +) -> HTTPRouteHandler: + handler_decorator = _decorator_http_method_map[http_method] + + def _default_route_handler() -> None: ... + + handler = handler_decorator("/", sync_to_thread=False, **(handler_kwargs or {}))(_default_route_handler) + handler.owner = app + return handler + + +def _create_default_app() -> Litestar: + return Litestar(route_handlers=[]) + + +class RequestFactory: + """Factory to create :class:`Request <litestar.connection.Request>` instances.""" + + __slots__ = ( + "app", + "server", + "port", + "root_path", + "scheme", + "handler_kwargs", + "serializer", + ) + + def __init__( + self, + app: Litestar | None = None, + server: str = "test.org", + port: int = 3000, + root_path: str = "", + scheme: str = "http", + handler_kwargs: dict[str, Any] | None = None, + ) -> None: + """Initialize ``RequestFactory`` + + Args: + app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["app"]``. + server: The server's domain. + port: The server's port. + root_path: Root path for the server. + scheme: Scheme for the server. + handler_kwargs: Kwargs to pass to the route handler created for the request + + Examples: + .. code-block:: python + + from litestar import Litestar + from litestar.enums import RequestEncodingType + from litestar.testing import RequestFactory + + from tests import PersonFactory + + my_app = Litestar(route_handlers=[]) + my_server = "litestar.org" + + # Create a GET request + query_params = {"id": 1} + get_user_request = RequestFactory(app=my_app, server=my_server).get( + "/person", query_params=query_params + ) + + # Create a POST request + new_person = PersonFactory.build() + create_user_request = RequestFactory(app=my_app, server=my_server).post( + "/person", data=person + ) + + # Create a request with a special header + headers = {"header1": "value1"} + request_with_header = RequestFactory(app=my_app, server=my_server).get( + "/person", query_params=query_params, headers=headers + ) + + # Create a request with a media type + request_with_media_type = RequestFactory(app=my_app, server=my_server).post( + "/person", data=person, request_media_type=RequestEncodingType.MULTI_PART + ) + + """ + + self.app = app if app is not None else _create_default_app() + self.server = server + self.port = port + self.root_path = root_path + self.scheme = scheme + self.handler_kwargs = handler_kwargs + self.serializer = partial(default_serializer, type_encoders=self.app.type_encoders) + + def _create_scope( + self, + path: str, + http_method: HttpMethod, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> HTTPScope: + """Create the scope for the :class:`Request <litestar.connection.Request>`. + + Args: + path: The request's path. + http_method: The request's HTTP method. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A dictionary that can be passed as a scope to the :class:`Request <litestar.connection.Request>` ctor. + """ + if session is None: + session = {} + + if state is None: + state = {} + + if path_params is None: + path_params = {} + + return HTTPScope( + type=ScopeType.HTTP, + method=http_method.value, + scheme=self.scheme, + server=(self.server, self.port), + root_path=self.root_path.rstrip("/"), + path=path, + headers=[], + app=self.app, + session=session, + user=user, + auth=auth, + query_string=urlencode(query_params, doseq=True).encode() if query_params else b"", + path_params=path_params, + client=(self.server, self.port), + state=state, + asgi=ASGIVersion(spec_version="3.0", version="3.0"), + http_version=http_version or "1.1", + raw_path=path.encode("ascii"), + route_handler=route_handler + or _create_default_route_handler(http_method, self.handler_kwargs, app=self.app), + extensions={}, + ) + + @classmethod + def _create_cookie_header(cls, headers: dict[str, str], cookies: list[Cookie] | str | None = None) -> None: + """Create the cookie header and add it to the ``headers`` dictionary. + + Args: + headers: A dictionary of headers, the cookie header will be added to it. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + """ + if not cookies: + return + + if isinstance(cookies, list): + cookie_header = "; ".join(cookie.to_header(header="") for cookie in cookies) + headers[ParamType.COOKIE] = cookie_header + elif isinstance(cookies, str): + headers[ParamType.COOKIE] = cookies + + def _build_headers( + self, + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + ) -> list[tuple[bytes, bytes]]: + """Build a list of encoded headers that can be passed to the request scope. + + Args: + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + + Returns: + A list of encoded headers that can be passed to the request scope. + """ + headers = headers or {} + self._create_cookie_header(headers, cookies) + return [ + ((key.lower()).encode("latin-1", errors="ignore"), value.encode("latin-1", errors="ignore")) + for key, value in headers.items() + ] + + def _create_request_with_data( + self, + http_method: HttpMethod, + path: str, + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + files: dict[str, FileTypes] | list[tuple[str, FileTypes]] | None = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a :class:`Request <litestar.connection.Request>` instance that has body (data) + + Args: + http_method: The request's HTTP method. + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]` + auth: A value for `request.scope["auth"]` + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + files: A dictionary of files to be sent with the request. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=http_method, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + headers = headers or {} + body = b"" + if data: + data = json.loads(encode_json(data, serializer=get_serializer_from_scope(scope))) + + if request_media_type == RequestEncodingType.JSON: + encoding_headers, stream = httpx_encode_json(data) + elif request_media_type == RequestEncodingType.MULTI_PART: + encoding_headers, stream = encode_multipart_data( # type: ignore[assignment] + cast("dict[str, Any]", data), files=files or [], boundary=None + ) + else: + encoding_headers, stream = encode_urlencoded_data(decode_json(value=encode_json(data))) + headers.update(encoding_headers) + for chunk in stream: + body += chunk + ScopeState.from_scope(scope).body = body + self._create_cookie_header(headers, cookies) + scope["headers"] = self._build_headers(headers) + return Request(scope=scope) + + def get( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a GET :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=HttpMethod.GET, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + scope["headers"] = self._build_headers(headers, cookies) + return Request(scope=scope) + + def post( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a POST :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.POST, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def put( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a PUT :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.PUT, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def patch( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + request_media_type: RequestEncodingType = RequestEncodingType.JSON, + data: dict[str, Any] | DataContainerType | None = None, # pyright: ignore + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a PATCH :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + request_media_type: The 'Content-Type' header of the request. + data: A value for the request's body. Can be any supported serializable type. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + return self._create_request_with_data( + auth=auth, + cookies=cookies, + data=data, + headers=headers, + http_method=HttpMethod.PATCH, + path=path, + query_params=query_params, + request_media_type=request_media_type, + session=session, + user=user, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + + def delete( + self, + path: str = "/", + headers: dict[str, str] | None = None, + cookies: list[Cookie] | str | None = None, + session: dict[str, Any] | None = None, + user: Any = None, + auth: Any = None, + query_params: dict[str, str | list[str]] | None = None, + state: dict[str, Any] | None = None, + path_params: dict[str, str] | None = None, + http_version: str | None = "1.1", + route_handler: RouteHandlerType | None = None, + ) -> Request[Any, Any, Any]: + """Create a POST :class:`Request <litestar.connection.Request>` instance. + + Args: + path: The request's path. + headers: A dictionary of headers. + cookies: A string representing the cookie header or a list of "Cookie" instances. + This value can include multiple cookies. + session: A dictionary of session data. + user: A value for `request.scope["user"]`. + auth: A value for `request.scope["auth"]`. + query_params: A dictionary of values from which the request's query will be generated. + state: Arbitrary request state. + path_params: A string keyed dictionary of path parameter values. + http_version: HTTP version. Defaults to "1.1". + route_handler: A route handler instance or method. If not provided a default handler is set. + + Returns: + A :class:`Request <litestar.connection.Request>` instance + """ + scope = self._create_scope( + path=path, + http_method=HttpMethod.DELETE, + session=session, + user=user, + auth=auth, + query_params=query_params, + state=state, + path_params=path_params, + http_version=http_version, + route_handler=route_handler, + ) + scope["headers"] = self._build_headers(headers, cookies) + return Request(scope=scope) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/transport.py b/venv/lib/python3.11/site-packages/litestar/testing/transport.py new file mode 100644 index 0000000..ffa76a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/transport.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from io import BytesIO +from types import GeneratorType +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, Union, cast +from urllib.parse import unquote + +from anyio import Event +from httpx import ByteStream, Response + +from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from litestar.testing.websocket_test_session import WebSocketTestSession + +if TYPE_CHECKING: + from httpx import Request + + from litestar.testing.client import AsyncTestClient, TestClient + from litestar.types import ( + HTTPDisconnectEvent, + HTTPRequestEvent, + Message, + Receive, + ReceiveMessage, + Send, + WebSocketScope, + ) + + +T = TypeVar("T", bound=Union["AsyncTestClient", "TestClient"]) + + +class ConnectionUpgradeExceptionError(Exception): + def __init__(self, session: WebSocketTestSession) -> None: + self.session = session + + +class SendReceiveContext(TypedDict): + request_complete: bool + response_complete: Event + raw_kwargs: dict[str, Any] + response_started: bool + template: str | None + context: Any | None + + +class TestClientTransport(Generic[T]): + def __init__( + self, + client: T, + raise_server_exceptions: bool = True, + root_path: str = "", + ) -> None: + self.client = client + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + + @staticmethod + def create_receive(request: Request, context: SendReceiveContext) -> Receive: + async def receive() -> ReceiveMessage: + if context["request_complete"]: + if not context["response_complete"].is_set(): + await context["response_complete"].wait() + disconnect_event: HTTPDisconnectEvent = {"type": "http.disconnect"} + return disconnect_event + + body = cast("Union[bytes, str, GeneratorType]", (request.read() or b"")) + request_event: HTTPRequestEvent = {"type": "http.request", "body": b"", "more_body": False} + if isinstance(body, GeneratorType): # pragma: no cover + try: + chunk = body.send(None) + request_event["body"] = chunk if isinstance(chunk, bytes) else chunk.encode("utf-8") + request_event["more_body"] = True + except StopIteration: + context["request_complete"] = True + else: + context["request_complete"] = True + request_event["body"] = body if isinstance(body, bytes) else body.encode("utf-8") + return request_event + + return receive + + @staticmethod + def create_send(request: Request, context: SendReceiveContext) -> Send: + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert not context["response_started"], 'Received multiple "http.response.start" messages.' # noqa: S101 + context["raw_kwargs"]["status_code"] = message["status"] + context["raw_kwargs"]["headers"] = [ + (k.decode("utf-8"), v.decode("utf-8")) for k, v in message.get("headers", []) + ] + context["response_started"] = True + elif message["type"] == "http.response.body": + assert context["response_started"], 'Received "http.response.body" without "http.response.start".' # noqa: S101 + assert not context[ # noqa: S101 + "response_complete" + ].is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + context["raw_kwargs"]["stream"].write(body) + if not more_body: + context["raw_kwargs"]["stream"].seek(0) + context["response_complete"].set() + elif message["type"] == "http.response.template": # type: ignore[comparison-overlap] # pragma: no cover + context["template"] = message["template"] # type: ignore[unreachable] + context["context"] = message["context"] + + return send + + def parse_request(self, request: Request) -> dict[str, Any]: + scheme = request.url.scheme + netloc = unquote(request.url.netloc.decode(encoding="ascii")) + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + default_port = 433 if scheme in {"https", "wss"} else 80 + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + host_header = request.headers.pop("host", host if port == default_port else f"{host}:{port}") + + headers = [(k.lower().encode(), v.encode()) for k, v in (("host", host_header), *request.headers.items())] + + return { + "type": "websocket" if scheme in {"ws", "wss"} else "http", + "path": unquote(path), + "raw_path": raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": ("testclient", 50000), + "server": (host, port), + } + + def handle_request(self, request: Request) -> Response: + scope = self.parse_request(request=request) + if scope["type"] == "websocket": + scope.update( + subprotocols=[value.strip() for value in request.headers.get("sec-websocket-protocol", "").split(",")] + ) + session = WebSocketTestSession(client=self.client, scope=cast("WebSocketScope", scope)) # type: ignore[arg-type] + raise ConnectionUpgradeExceptionError(session) + + scope.update(method=request.method, http_version="1.1", extensions={"http.response.template": {}}) + + raw_kwargs: dict[str, Any] = {"stream": BytesIO()} + + try: + with self.client.portal() as portal: + response_complete = portal.call(Event) + context: SendReceiveContext = { + "response_complete": response_complete, + "request_complete": False, + "raw_kwargs": raw_kwargs, + "response_started": False, + "template": None, + "context": None, + } + portal.call( + self.client.app, + scope, + self.create_receive(request=request, context=context), + self.create_send(request=request, context=context), + ) + except BaseException as exc: # noqa: BLE001 + if self.raise_server_exceptions: + raise exc + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + else: + if not context["response_started"]: # pragma: no cover + if self.raise_server_exceptions: + assert context["response_started"], "TestClient did not receive any response." # noqa: S101 + return Response( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, headers=[], stream=ByteStream(b""), request=request + ) + + stream = ByteStream(raw_kwargs.pop("stream", BytesIO()).read()) + response = Response(**raw_kwargs, stream=stream, request=request) + setattr(response, "template", context["template"]) + setattr(response, "context", context["context"]) + return response + + async def handle_async_request(self, request: Request) -> Response: + return self.handle_request(request=request) diff --git a/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py b/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py new file mode 100644 index 0000000..292e8a9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/testing/websocket_test_session.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from contextlib import ExitStack +from queue import Queue +from typing import TYPE_CHECKING, Any, Literal, cast + +from anyio import sleep + +from litestar.exceptions import WebSocketDisconnect +from litestar.serialization import decode_json, decode_msgpack, encode_json, encode_msgpack +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +if TYPE_CHECKING: + from litestar.testing.client.sync_client import TestClient + from litestar.types import ( + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveMessage, + WebSocketScope, + WebSocketSendMessage, + ) + + +__all__ = ("WebSocketTestSession",) + + +class WebSocketTestSession: + exit_stack: ExitStack + + def __init__( + self, + client: TestClient[Any], + scope: WebSocketScope, + ) -> None: + self.client = client + self.scope = scope + self.accepted_subprotocol: str | None = None + self.receive_queue: Queue[WebSocketReceiveMessage] = Queue() + self.send_queue: Queue[WebSocketSendMessage | BaseException] = Queue() + self.extra_headers: list[tuple[bytes, bytes]] | None = None + + def __enter__(self) -> WebSocketTestSession: + self.exit_stack = ExitStack() + + portal = self.exit_stack.enter_context(self.client.portal()) + + try: + portal.start_task_soon(self.do_asgi_call) + event: WebSocketConnectEvent = {"type": "websocket.connect"} + self.receive_queue.put(event) + + message = self.receive(timeout=self.client.timeout.read) + self.accepted_subprotocol = cast("str | None", message.get("subprotocol", None)) + self.extra_headers = cast("list[tuple[bytes, bytes]] | None", message.get("headers", None)) + return self + except Exception: + self.exit_stack.close() + raise + + def __exit__(self, *args: Any) -> None: + try: + self.close() + finally: + self.exit_stack.close() + while not self.send_queue.empty(): + message = self.send_queue.get() + if isinstance(message, BaseException): + raise message + + async def do_asgi_call(self) -> None: + """The sub-thread in which the websocket session runs.""" + + async def receive() -> WebSocketReceiveMessage: + while self.receive_queue.empty(): + await sleep(0) + return self.receive_queue.get() + + async def send(message: WebSocketSendMessage) -> None: + if message["type"] == "websocket.accept": + headers = message.get("headers", []) + if headers: + headers_list = list(self.scope["headers"]) + headers_list.extend(headers) + self.scope["headers"] = headers_list + subprotocols = cast("str | None", message.get("subprotocols")) + if subprotocols: # pragma: no cover + self.scope["subprotocols"].append(subprotocols) + self.send_queue.put(message) + + try: + await self.client.app(self.scope, receive, send) + except BaseException as exc: + self.send_queue.put(exc) + raise + + def send(self, data: str | bytes, mode: Literal["text", "binary"] = "text", encoding: str = "utf-8") -> None: + """Sends a "receive" event. This is the inverse of the ASGI send method. + + Args: + data: Either a string or a byte string. + mode: The key to use - ``text`` or ``bytes`` + encoding: The encoding to use when encoding or decoding data. + + Returns: + None. + """ + if mode == "text": + data = data.decode(encoding) if isinstance(data, bytes) else data + text_event: WebSocketReceiveMessage = {"type": "websocket.receive", "text": data} # type: ignore[assignment] + self.receive_queue.put(text_event) + else: + data = data if isinstance(data, bytes) else data.encode(encoding) + binary_event: WebSocketReceiveMessage = {"type": "websocket.receive", "bytes": data} # type: ignore[assignment] + self.receive_queue.put(binary_event) + + def send_text(self, data: str, encoding: str = "utf-8") -> None: + """Sends the data using the ``text`` key. + + Args: + data: Data to send. + encoding: Encoding to use. + + Returns: + None + """ + self.send(data=data, encoding=encoding) + + def send_bytes(self, data: bytes, encoding: str = "utf-8") -> None: + """Sends the data using the ``bytes`` key. + + Args: + data: Data to send. + encoding: Encoding to use. + + Returns: + None + """ + self.send(data=data, mode="binary", encoding=encoding) + + def send_json(self, data: Any, mode: Literal["text", "binary"] = "text") -> None: + """Sends the given data as JSON. + + Args: + data: The data to send. + mode: Either ``text`` or ``binary`` + + Returns: + None + """ + self.send(encode_json(data), mode=mode) + + def send_msgpack(self, data: Any) -> None: + """Sends the given data as MessagePack. + + Args: + data: The data to send. + + Returns: + None + """ + self.send(encode_msgpack(data), mode="binary") + + def close(self, code: int = WS_1000_NORMAL_CLOSURE) -> None: + """Sends an 'websocket.disconnect' event. + + Args: + code: status code for closing the connection. + + Returns: + None. + """ + event: WebSocketDisconnectEvent = {"type": "websocket.disconnect", "code": code} + self.receive_queue.put(event) + + def receive(self, block: bool = True, timeout: float | None = None) -> WebSocketSendMessage: + """This is the base receive method. + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Notes: + - you can use one of the other receive methods to extract the data from the message. + + Returns: + A websocket message. + """ + message = cast("WebSocketSendMessage", self.send_queue.get(block=block, timeout=timeout)) + + if isinstance(message, BaseException): + raise message + + if message["type"] == "websocket.close": + raise WebSocketDisconnect( + detail=cast("str", message.get("reason", "")), + code=message.get("code", WS_1000_NORMAL_CLOSURE), + ) + return message + + def receive_text(self, block: bool = True, timeout: float | None = None) -> str: + """Receive data in ``text`` mode and return a string + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + A string value. + """ + message = self.receive(block=block, timeout=timeout) + return cast("str", message.get("text", "")) + + def receive_bytes(self, block: bool = True, timeout: float | None = None) -> bytes: + """Receive data in ``binary`` mode and return bytes + + Args: + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + A string value. + """ + message = self.receive(block=block, timeout=timeout) + return cast("bytes", message.get("bytes", b"")) + + def receive_json( + self, mode: Literal["text", "binary"] = "text", block: bool = True, timeout: float | None = None + ) -> Any: + """Receive data in either ``text`` or ``binary`` mode and decode it as JSON. + + Args: + mode: Either ``text`` or ``binary`` + block: Block until a message is received + timeout: If ``block`` is ``True``, block at most ``timeout`` seconds + + Returns: + An arbitrary value + """ + message = self.receive(block=block, timeout=timeout) + + if mode == "text": + return decode_json(cast("str", message.get("text", ""))) + + return decode_json(cast("bytes", message.get("bytes", b""))) + + def receive_msgpack(self, block: bool = True, timeout: float | None = None) -> Any: + message = self.receive(block=block, timeout=timeout) + return decode_msgpack(cast("bytes", message.get("bytes", b""))) |