summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/data_extractors.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/data_extractors.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/data_extractors.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/data_extractors.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/data_extractors.py b/venv/lib/python3.11/site-packages/litestar/data_extractors.py
new file mode 100644
index 0000000..61993b4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/data_extractors.py
@@ -0,0 +1,443 @@
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast
+
+from litestar._parsers import parse_cookie_string
+from litestar.connection.request import Request
+from litestar.datastructures.upload_file import UploadFile
+from litestar.enums import HttpMethod, RequestEncodingType
+
+__all__ = (
+ "ConnectionDataExtractor",
+ "ExtractedRequestData",
+ "ExtractedResponseData",
+ "ResponseDataExtractor",
+ "RequestExtractorField",
+ "ResponseExtractorField",
+)
+
+
+if TYPE_CHECKING:
+ from litestar.connection import ASGIConnection
+ from litestar.types import Method
+ from litestar.types.asgi_types import HTTPResponseBodyEvent, HTTPResponseStartEvent
+
+
+def _obfuscate(values: dict[str, Any], fields_to_obfuscate: set[str]) -> dict[str, Any]:
+ """Obfuscate values in a dictionary, replacing values with `******`
+
+ Args:
+ values: A dictionary of strings
+ fields_to_obfuscate: keys to obfuscate
+
+ Returns:
+ A dictionary with obfuscated strings
+ """
+ return {key: "*****" if key.lower() in fields_to_obfuscate else value for key, value in values.items()}
+
+
+RequestExtractorField = Literal[
+ "path", "method", "content_type", "headers", "cookies", "query", "path_params", "body", "scheme", "client"
+]
+
+ResponseExtractorField = Literal["status_code", "headers", "body", "cookies"]
+
+
+class ExtractedRequestData(TypedDict, total=False):
+ """Dictionary representing extracted request data."""
+
+ body: Coroutine[Any, Any, Any]
+ client: tuple[str, int]
+ content_type: tuple[str, dict[str, str]]
+ cookies: dict[str, str]
+ headers: dict[str, str]
+ method: Method
+ path: str
+ path_params: dict[str, Any]
+ query: bytes | dict[str, Any]
+ scheme: str
+
+
+class ConnectionDataExtractor:
+ """Utility class to extract data from an :class:`ASGIConnection <litestar.connection.ASGIConnection>`,
+ :class:`Request <litestar.connection.Request>` or :class:`WebSocket <litestar.connection.WebSocket>` instance.
+ """
+
+ __slots__ = (
+ "connection_extractors",
+ "request_extractors",
+ "parse_body",
+ "parse_query",
+ "obfuscate_headers",
+ "obfuscate_cookies",
+ "skip_parse_malformed_body",
+ )
+
+ def __init__(
+ self,
+ extract_body: bool = True,
+ extract_client: bool = True,
+ extract_content_type: bool = True,
+ extract_cookies: bool = True,
+ extract_headers: bool = True,
+ extract_method: bool = True,
+ extract_path: bool = True,
+ extract_path_params: bool = True,
+ extract_query: bool = True,
+ extract_scheme: bool = True,
+ obfuscate_cookies: set[str] | None = None,
+ obfuscate_headers: set[str] | None = None,
+ parse_body: bool = False,
+ parse_query: bool = False,
+ skip_parse_malformed_body: bool = False,
+ ) -> None:
+ """Initialize ``ConnectionDataExtractor``
+
+ Args:
+ extract_body: Whether to extract body, (for requests only).
+ extract_client: Whether to extract the client (host, port) mapping.
+ extract_content_type: Whether to extract the content type and any options.
+ extract_cookies: Whether to extract cookies.
+ extract_headers: Whether to extract headers.
+ extract_method: Whether to extract the HTTP method, (for requests only).
+ extract_path: Whether to extract the path.
+ extract_path_params: Whether to extract path parameters.
+ extract_query: Whether to extract query parameters.
+ extract_scheme: Whether to extract the http scheme.
+ obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'.
+ obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'.
+ parse_body: Whether to parse the body value or return the raw byte string, (for requests only).
+ parse_query: Whether to parse query parameters or return the raw byte string.
+ skip_parse_malformed_body: Whether to skip parsing the body if it is malformed
+ """
+ self.parse_body = parse_body
+ self.parse_query = parse_query
+ self.skip_parse_malformed_body = skip_parse_malformed_body
+ self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())}
+ self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())}
+ self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {}
+ self.request_extractors: dict[RequestExtractorField, Callable[[Request[Any, Any, Any]], Any]] = {}
+ if extract_scheme:
+ self.connection_extractors["scheme"] = self.extract_scheme
+ if extract_client:
+ self.connection_extractors["client"] = self.extract_client
+ if extract_path:
+ self.connection_extractors["path"] = self.extract_path
+ if extract_headers:
+ self.connection_extractors["headers"] = self.extract_headers
+ if extract_cookies:
+ self.connection_extractors["cookies"] = self.extract_cookies
+ if extract_query:
+ self.connection_extractors["query"] = self.extract_query
+ if extract_path_params:
+ self.connection_extractors["path_params"] = self.extract_path_params
+ if extract_method:
+ self.request_extractors["method"] = self.extract_method
+ if extract_content_type:
+ self.request_extractors["content_type"] = self.extract_content_type
+ if extract_body:
+ self.request_extractors["body"] = self.extract_body
+
+ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedRequestData:
+ """Extract data from the connection, returning a dictionary of values.
+
+ Notes:
+ - The value for ``body`` - if present - is an unresolved Coroutine and as such should be awaited by the receiver.
+
+ Args:
+ connection: An ASGI connection or its subclasses.
+
+ Returns:
+ A string keyed dictionary of extracted values.
+ """
+ extractors = (
+ {**self.connection_extractors, **self.request_extractors} # type: ignore[misc]
+ if isinstance(connection, Request)
+ else self.connection_extractors
+ )
+ return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()})
+
+ async def extract(
+ self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str]
+ ) -> ExtractedRequestData:
+ extractors = (
+ {**self.connection_extractors, **self.request_extractors} # type: ignore[misc]
+ if isinstance(connection, Request)
+ else self.connection_extractors
+ )
+ data = {}
+ for key, extractor in extractors.items():
+ if key not in fields:
+ continue
+ if inspect.iscoroutinefunction(extractor):
+ value = await extractor(connection)
+ else:
+ value = extractor(connection)
+ data[key] = value
+ return cast("ExtractedRequestData", data)
+
+ @staticmethod
+ def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str:
+ """Extract the scheme from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ The connection's scope["scheme"] value
+ """
+ return connection.scope["scheme"]
+
+ @staticmethod
+ def extract_client(connection: ASGIConnection[Any, Any, Any, Any]) -> tuple[str, int]:
+ """Extract the client from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ The connection's scope["client"] value or a default value.
+ """
+ return connection.scope.get("client") or ("", 0)
+
+ @staticmethod
+ def extract_path(connection: ASGIConnection[Any, Any, Any, Any]) -> str:
+ """Extract the path from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ The connection's scope["path"] value
+ """
+ return connection.scope["path"]
+
+ def extract_headers(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]:
+ """Extract headers from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ A dictionary with the connection's headers.
+ """
+ headers = {k.decode("latin-1"): v.decode("latin-1") for k, v in connection.scope["headers"]}
+ return _obfuscate(headers, self.obfuscate_headers) if self.obfuscate_headers else headers
+
+ def extract_cookies(self, connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, str]:
+ """Extract cookies from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ A dictionary with the connection's cookies.
+ """
+ return _obfuscate(connection.cookies, self.obfuscate_cookies) if self.obfuscate_cookies else connection.cookies
+
+ def extract_query(self, connection: ASGIConnection[Any, Any, Any, Any]) -> Any:
+ """Extract query from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ Either a dictionary with the connection's parsed query string or the raw query byte-string.
+ """
+ return connection.query_params.dict() if self.parse_query else connection.scope.get("query_string", b"")
+
+ @staticmethod
+ def extract_path_params(connection: ASGIConnection[Any, Any, Any, Any]) -> dict[str, Any]:
+ """Extract the path parameters from an ``ASGIConnection``
+
+ Args:
+ connection: An :class:`ASGIConnection <litestar.connection.ASGIConnection>` instance.
+
+ Returns:
+ A dictionary with the connection's path parameters.
+ """
+ return connection.path_params
+
+ @staticmethod
+ def extract_method(request: Request[Any, Any, Any]) -> Method:
+ """Extract the method from an ``ASGIConnection``
+
+ Args:
+ request: A :class:`Request <litestar.connection.Request>` instance.
+
+ Returns:
+ The request's scope["method"] value.
+ """
+ return request.scope["method"]
+
+ @staticmethod
+ def extract_content_type(request: Request[Any, Any, Any]) -> tuple[str, dict[str, str]]:
+ """Extract the content-type from an ``ASGIConnection``
+
+ Args:
+ request: A :class:`Request <litestar.connection.Request>` instance.
+
+ Returns:
+ A tuple containing the request's parsed 'Content-Type' header.
+ """
+ return request.content_type
+
+ async def extract_body(self, request: Request[Any, Any, Any]) -> Any:
+ """Extract the body from an ``ASGIConnection``
+
+ Args:
+ request: A :class:`Request <litestar.connection.Request>` instance.
+
+ Returns:
+ Either the parsed request body or the raw byte-string.
+ """
+ if request.method == HttpMethod.GET:
+ return None
+ if not self.parse_body:
+ return await request.body()
+ try:
+ request_encoding_type = request.content_type[0]
+ if request_encoding_type == RequestEncodingType.JSON:
+ return await request.json()
+ form_data = await request.form()
+ if request_encoding_type == RequestEncodingType.URL_ENCODED:
+ return dict(form_data)
+ return {
+ key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()
+ }
+ except Exception as exc:
+ if self.skip_parse_malformed_body:
+ return await request.body()
+ raise exc
+
+
+class ExtractedResponseData(TypedDict, total=False):
+ """Dictionary representing extracted response data."""
+
+ body: bytes
+ status_code: int
+ headers: dict[str, str]
+ cookies: dict[str, str]
+
+
+class ResponseDataExtractor:
+ """Utility class to extract data from a ``Message``"""
+
+ __slots__ = ("extractors", "parse_headers", "obfuscate_headers", "obfuscate_cookies")
+
+ def __init__(
+ self,
+ extract_body: bool = True,
+ extract_cookies: bool = True,
+ extract_headers: bool = True,
+ extract_status_code: bool = True,
+ obfuscate_cookies: set[str] | None = None,
+ obfuscate_headers: set[str] | None = None,
+ ) -> None:
+ """Initialize ``ResponseDataExtractor`` with options.
+
+ Args:
+ extract_body: Whether to extract the body.
+ extract_cookies: Whether to extract the cookies.
+ extract_headers: Whether to extract the headers.
+ extract_status_code: Whether to extract the status code.
+ obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'.
+ obfuscate_headers: headers keys to obfuscate. Obfuscated values are replaced with '*****'.
+ """
+ self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())}
+ self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())}
+ self.extractors: dict[
+ ResponseExtractorField, Callable[[tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]], Any]
+ ] = {}
+ if extract_body:
+ self.extractors["body"] = self.extract_response_body
+ if extract_status_code:
+ self.extractors["status_code"] = self.extract_status_code
+ if extract_headers:
+ self.extractors["headers"] = self.extract_headers
+ if extract_cookies:
+ self.extractors["cookies"] = self.extract_cookies
+
+ def __call__(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> ExtractedResponseData:
+ """Extract data from the response, returning a dictionary of values.
+
+ Args:
+ messages: A tuple containing
+ :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>`
+ and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`.
+
+ Returns:
+ A string keyed dictionary of extracted values.
+ """
+ return cast("ExtractedResponseData", {key: extractor(messages) for key, extractor in self.extractors.items()})
+
+ @staticmethod
+ def extract_response_body(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> bytes:
+ """Extract the response body from a ``Message``
+
+ Args:
+ messages: A tuple containing
+ :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>`
+ and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`.
+
+ Returns:
+ The Response's body as a byte-string.
+ """
+ return messages[1]["body"]
+
+ @staticmethod
+ def extract_status_code(messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> int:
+ """Extract a status code from a ``Message``
+
+ Args:
+ messages: A tuple containing
+ :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>`
+ and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`.
+
+ Returns:
+ The Response's status-code.
+ """
+ return messages[0]["status"]
+
+ def extract_headers(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]:
+ """Extract headers from a ``Message``
+
+ Args:
+ messages: A tuple containing
+ :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>`
+ and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`.
+
+ Returns:
+ The Response's headers dict.
+ """
+ headers = {
+ key.decode("latin-1"): value.decode("latin-1")
+ for key, value in filter(lambda x: x[0].lower() != b"set-cookie", messages[0]["headers"])
+ }
+ return (
+ _obfuscate(
+ headers,
+ self.obfuscate_headers,
+ )
+ if self.obfuscate_headers
+ else headers
+ )
+
+ def extract_cookies(self, messages: tuple[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> dict[str, str]:
+ """Extract cookies from a ``Message``
+
+ Args:
+ messages: A tuple containing
+ :class:`HTTPResponseStartEvent <litestar.types.asgi_types.HTTPResponseStartEvent>`
+ and :class:`HTTPResponseBodyEvent <litestar.types.asgi_types.HTTPResponseBodyEvent>`.
+
+ Returns:
+ The Response's cookies dict.
+ """
+ if cookie_string := ";".join(
+ [x[1].decode("latin-1") for x in filter(lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"])]
+ ):
+ parsed_cookies = parse_cookie_string(cookie_string)
+ return _obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies
+ return {}