summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/testing/client/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/testing/client/base.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/testing/client/base.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/testing/client/base.py b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py
new file mode 100644
index 0000000..3c25be1
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/testing/client/base.py
@@ -0,0 +1,180 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from http.cookiejar import CookieJar
+from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, TypeVar, cast
+from warnings import warn
+
+from anyio.from_thread import BlockingPortal, start_blocking_portal
+from httpx import Cookies, Request, Response
+
+from litestar import Litestar
+from litestar.connection import ASGIConnection
+from litestar.datastructures import MutableScopeHeaders
+from litestar.enums import ScopeType
+from litestar.exceptions import (
+ ImproperlyConfiguredException,
+)
+from litestar.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from httpx._types import CookieTypes
+
+ from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend
+ from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send
+
+T = TypeVar("T", bound=ASGIApp)
+
+
+def fake_http_send_message(headers: MutableScopeHeaders) -> HTTPResponseStartEvent:
+ headers.setdefault("content-type", "application/text")
+ return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.headers)
+
+
+def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnection[Any, Any, Any, Any]:
+ scope: HTTPScope = {
+ "type": ScopeType.HTTP,
+ "path": "/",
+ "raw_path": b"/",
+ "root_path": "",
+ "scheme": "http",
+ "query_string": b"",
+ "client": ("testclient", 50000),
+ "server": ("testserver", 80),
+ "headers": [],
+ "method": "GET",
+ "http_version": "1.1",
+ "extensions": {"http.response.template": {}},
+ "app": app, # type: ignore[typeddict-item]
+ "state": {},
+ "path_params": {},
+ "route_handler": None, # type: ignore[typeddict-item]
+ "asgi": {"version": "3.0", "spec_version": "2.1"},
+ "auth": None,
+ "session": None,
+ "user": None,
+ }
+ ScopeState.from_scope(scope).cookies = cookies
+ return ASGIConnection[Any, Any, Any, Any](scope=scope)
+
+
+def _wrap_app_to_add_state(app: ASGIApp) -> ASGIApp:
+ """Wrap an ASGI app to add state to the scope.
+
+ Litestar depends on `state` being present in the ASGI connection scope. Scope state is optional in the ASGI spec,
+ however, the Litestar app always ensures it is present so that it can be depended on internally.
+
+ When the ASGI app that is passed to the test client is _not_ a Litestar app, we need to add
+ state to the scope, because httpx does not do this for us.
+
+ This assists us in testing Litestar components that rely on state being present in the scope, without having
+ to create a Litestar app for every test case.
+
+ Args:
+ app: The ASGI app to wrap.
+
+ Returns:
+ The wrapped ASGI app.
+ """
+
+ async def wrapped(scope: Scope, receive: Receive, send: Send) -> None:
+ scope["state"] = {}
+ await app(scope, receive, send)
+
+ return wrapped
+
+
+class BaseTestClient(Generic[T]):
+ __test__ = False
+ blocking_portal: BlockingPortal
+
+ __slots__ = (
+ "app",
+ "base_url",
+ "backend",
+ "backend_options",
+ "session_config",
+ "_session_backend",
+ "cookies",
+ )
+
+ def __init__(
+ self,
+ app: T,
+ base_url: str = "http://testserver.local",
+ backend: AnyIOBackend = "asyncio",
+ backend_options: Mapping[str, Any] | None = None,
+ session_config: BaseBackendConfig | None = None,
+ cookies: CookieTypes | None = None,
+ ) -> None:
+ if "." not in base_url:
+ warn(
+ f"The base_url {base_url!r} might cause issues. Try adding a domain name such as .local: "
+ f"'{base_url}.local'",
+ UserWarning,
+ stacklevel=1,
+ )
+
+ self._session_backend: BaseSessionBackend | None = None
+ if session_config:
+ self._session_backend = session_config._backend_class(config=session_config)
+
+ if not isinstance(app, Litestar):
+ app = _wrap_app_to_add_state(app) # type: ignore[assignment]
+
+ self.app = cast("T", app) # type: ignore[redundant-cast] # pyright needs this
+
+ self.base_url = base_url
+ self.backend = backend
+ self.backend_options = backend_options
+ self.cookies = cookies
+
+ @property
+ def session_backend(self) -> BaseSessionBackend[Any]:
+ if not self._session_backend:
+ raise ImproperlyConfiguredException(
+ "Session has not been initialized for this TestClient instance. You can"
+ "do so by passing a configuration object to TestClient: TestClient(app=app, session_config=...)"
+ )
+ return self._session_backend
+
+ @contextmanager
+ def portal(self) -> Generator[BlockingPortal, None, None]:
+ """Get a BlockingPortal.
+
+ Returns:
+ A contextmanager for a BlockingPortal.
+ """
+ if hasattr(self, "blocking_portal"):
+ yield self.blocking_portal
+ else:
+ with start_blocking_portal(
+ backend=self.backend, backend_options=dict(self.backend_options or {})
+ ) as portal:
+ yield portal
+
+ async def _set_session_data(self, data: dict[str, Any]) -> None:
+ mutable_headers = MutableScopeHeaders()
+ connection = fake_asgi_connection(
+ app=self.app,
+ cookies=dict(self.cookies), # type: ignore[arg-type]
+ )
+ session_id = self.session_backend.get_session_id(connection)
+ connection._connection_state.session_id = session_id # pyright: ignore [reportGeneralTypeIssues]
+ await self.session_backend.store_in_message(
+ scope_session=data, message=fake_http_send_message(mutable_headers), connection=connection
+ )
+ response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers)
+
+ cookies = Cookies(CookieJar())
+ cookies.extract_cookies(response)
+ self.cookies.update(cookies) # type: ignore[union-attr]
+
+ async def _get_session_data(self) -> dict[str, Any]:
+ return await self.session_backend.load_from_connection(
+ connection=fake_asgi_connection(
+ app=self.app,
+ cookies=dict(self.cookies), # type: ignore[arg-type]
+ ),
+ )