diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/datastructures/headers.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/datastructures/headers.py | 534 |
1 files changed, 534 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py new file mode 100644 index 0000000..f3e9bd7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py @@ -0,0 +1,534 @@ +import re +from abc import ABC, abstractmethod +from contextlib import suppress +from copy import copy +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Pattern, + Tuple, + Union, + cast, +) + +from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping +from typing_extensions import get_type_hints + +from litestar._multipart import parse_content_header +from litestar.datastructures.multi_dicts import MultiMixin +from litestar.dto.base_dto import AbstractDTO +from litestar.exceptions import ImproperlyConfiguredException, ValidationException +from litestar.types.empty import Empty +from litestar.typing import FieldDefinition +from litestar.utils.dataclass import simple_asdict +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from litestar.types.asgi_types import ( + HeaderScope, + Message, + RawHeaders, + RawHeadersList, + Scope, + ) + +__all__ = ("Accept", "CacheControlHeader", "ETag", "Header", "Headers", "MutableScopeHeaders") + +ETAG_RE = re.compile(r'([Ww]/)?"(.+)"') +PRINTABLE_ASCII_RE: Pattern[str] = re.compile(r"^[ -~]+$") + + +def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList": + return [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers] + + +class Headers(CIMultiDictProxy[str], MultiMixin[str]): + """An immutable, case-insensitive multi dict for HTTP headers.""" + + def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None: + """Initialize ``Headers``. + + Args: + headers: Initial value. + """ + if not isinstance(headers, MultiMapping): + headers_: Union[Mapping[str, str], List[Tuple[str, str]]] = {} + if headers: + if isinstance(headers, Mapping): + headers_ = headers # pyright: ignore + else: + headers_ = [(key.decode("latin-1"), value.decode("latin-1")) for key, value in headers] + + super().__init__(CIMultiDict(headers_)) + else: + super().__init__(headers) + self._header_list: Optional[RawHeadersList] = None + + @classmethod + def from_scope(cls, scope: "Scope") -> "Headers": + """Create headers from a send-message. + + Args: + scope: The ASGI connection scope. + + Returns: + Headers + + Raises: + ValueError: If the message does not have a ``headers`` key + """ + connection_state = ScopeState.from_scope(scope) + if (headers := connection_state.headers) is Empty: + headers = connection_state.headers = cls(scope["headers"]) + return headers + + def to_header_list(self) -> "RawHeadersList": + """Raw header value. + + Returns: + A list of tuples contain the header and header-value as bytes + """ + # Since ``Headers`` are immutable, this can be cached + if not self._header_list: + self._header_list = _encode_headers((key, value) for key in set(self) for value in self.getall(key)) + return self._header_list + + +class MutableScopeHeaders(MutableMapping): + """A case-insensitive, multidict-like structure that can be used to mutate headers within a + :class:`Scope <.types.Scope>` + """ + + def __init__(self, scope: Optional["HeaderScope"] = None) -> None: + """Initialize ``MutableScopeHeaders`` from a ``HeaderScope``. + + Args: + scope: The ASGI connection scope. + """ + self.headers: RawHeadersList + if scope is not None: + if not isinstance(scope["headers"], list): + scope["headers"] = list(scope["headers"]) + + self.headers = cast("RawHeadersList", scope["headers"]) + else: + self.headers = [] + + @classmethod + def from_message(cls, message: "Message") -> "MutableScopeHeaders": + """Construct a header from a message object. + + Args: + message: :class:`Message <.types.Message>`. + + Returns: + MutableScopeHeaders. + + Raises: + ValueError: If the message does not have a ``headers`` key. + """ + if "headers" not in message: + raise ValueError(f"Invalid message type: {message['type']!r}") + + return cls(cast("HeaderScope", message)) + + def add(self, key: str, value: str) -> None: + """Add a header to the scope. + + Notes: + - This method keeps duplicates. + + Args: + key: Header key. + value: Header value. + + Returns: + None. + """ + self.headers.append((key.lower().encode("latin-1"), value.encode("latin-1"))) + + def getall(self, key: str, default: Optional[List[str]] = None) -> List[str]: + """Get all values of a header. + + Args: + key: Header key. + default: Default value to return if ``name`` is not found. + + Returns: + A list of strings. + + Raises: + KeyError: if no header for ``name`` was found and ``default`` is not given. + """ + name = key.lower() + values = [ + header_value.decode("latin-1") + for header_name, header_value in self.headers + if header_name.decode("latin-1").lower() == name + ] + if not values: + if default: + return default + raise KeyError + return values + + def extend_header_value(self, key: str, value: str) -> None: + """Extend a multivalued header. + + Notes: + - A multivalues header is a header that can take a comma separated list. + - If the header previously did not exist, it will be added. + + Args: + key: Header key. + value: Header value to add, + + Returns: + None + """ + existing = self.get(key) + if existing is not None: + value = ",".join([*existing.split(","), value]) + self[key] = value + + def __getitem__(self, key: str) -> str: + """Get the first header matching ``name``""" + name = key.lower() + for header in self.headers: + if header[0].decode("latin-1").lower() == name: + return header[1].decode("latin-1") + raise KeyError + + def _find_indices(self, key: str) -> List[int]: + name = key.lower() + return [i for i, (name_, _) in enumerate(self.headers) if name_.decode("latin-1").lower() == name] + + def __setitem__(self, key: str, value: str) -> None: + """Set a header in the scope, overwriting duplicates.""" + name_encoded = key.lower().encode("latin-1") + value_encoded = value.encode("latin-1") + if indices := self._find_indices(key): + for i in indices[1:]: + del self.headers[i] + self.headers[indices[0]] = (name_encoded, value_encoded) + else: + self.headers.append((name_encoded, value_encoded)) + + def __delitem__(self, key: str) -> None: + """Delete all headers matching ``name``""" + indices = self._find_indices(key) + for i in indices[::-1]: + del self.headers[i] + + def __len__(self) -> int: + """Return the length of the internally stored headers, including duplicates.""" + return len(self.headers) + + def __iter__(self) -> Iterator[str]: + """Create an iterator of header names including duplicates.""" + return iter(h[0].decode("latin-1") for h in self.headers) + + +@dataclass +class Header(ABC): + """An abstract type for HTTP headers.""" + + HEADER_NAME: ClassVar[str] = "" + + documentation_only: bool = False + """Defines the header instance as for OpenAPI documentation purpose only.""" + + @abstractmethod + def _get_header_value(self) -> str: + """Get the header value as string.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_header(cls, header_value: str) -> "Header": + """Construct a header from its string representation.""" + + def to_header(self, include_header_name: bool = False) -> str: + """Get the header as string. + + Args: + include_header_name: should include the header name in the return value. If set to false + the return value will only include the header value. if set to true the return value + will be: ``<header name>: <header value>``. Defaults to false. + """ + + if not self.HEADER_NAME: + raise ImproperlyConfiguredException("Missing header name") + + return (f"{self.HEADER_NAME}: " if include_header_name else "") + self._get_header_value() + + +@dataclass +class CacheControlHeader(Header): + """A ``cache-control`` header.""" + + HEADER_NAME: ClassVar[str] = "cache-control" + + max_age: Optional[int] = None + """Accessor for the ``max-age`` directive.""" + s_maxage: Optional[int] = None + """Accessor for the ``s-maxage`` directive.""" + no_cache: Optional[bool] = None + """Accessor for the ``no-cache`` directive.""" + no_store: Optional[bool] = None + """Accessor for the ``no-store`` directive.""" + private: Optional[bool] = None + """Accessor for the ``private`` directive.""" + public: Optional[bool] = None + """Accessor for the ``public`` directive.""" + no_transform: Optional[bool] = None + """Accessor for the ``no-transform`` directive.""" + must_revalidate: Optional[bool] = None + """Accessor for the ``must-revalidate`` directive.""" + proxy_revalidate: Optional[bool] = None + """Accessor for the ``proxy-revalidate`` directive.""" + must_understand: Optional[bool] = None + """Accessor for the ``must-understand`` directive.""" + immutable: Optional[bool] = None + """Accessor for the ``immutable`` directive.""" + stale_while_revalidate: Optional[int] = None + """Accessor for the ``stale-while-revalidate`` directive.""" + + _field_definitions: ClassVar[Optional[Dict[str, FieldDefinition]]] = None + + def _get_header_value(self) -> str: + """Get the header value as string.""" + + cc_items = [ + key.replace("_", "-") if isinstance(value, bool) else f"{key.replace('_', '-')}={value}" + for key, value in simple_asdict(self, exclude_none=True, exclude={"documentation_only"}).items() + ] + return ", ".join(cc_items) + + @classmethod + def from_header(cls, header_value: str) -> "CacheControlHeader": + """Create a ``CacheControlHeader`` instance from the header value. + + Args: + header_value: the header value as string + + Returns: + An instance of ``CacheControlHeader`` + """ + + cc_items = [v.strip() for v in header_value.split(",")] + kwargs: Dict[str, Any] = {} + field_definitions = cls._get_field_definitions() + for cc_item in cc_items: + key_value = cc_item.split("=") + key_value[0] = key_value[0].replace("-", "_") + if len(key_value) == 1: + kwargs[key_value[0]] = True + elif len(key_value) == 2: + key, value = key_value + if key not in field_definitions: + raise ImproperlyConfiguredException("Invalid cache-control header") + kwargs[key] = cls._convert_to_type(value, field_definition=field_definitions[key]) + else: + raise ImproperlyConfiguredException("Invalid cache-control header value") + + try: + return CacheControlHeader(**kwargs) + except TypeError as exc: + raise ImproperlyConfiguredException from exc + + @classmethod + def prevent_storing(cls) -> "CacheControlHeader": + """Create a ``cache-control`` header with the ``no-store`` directive which indicates that any caches of any kind + (private or shared) should not store this response. + """ + + return cls(no_store=True) + + @classmethod + def _get_field_definitions(cls) -> Dict[str, FieldDefinition]: + """Get the type annotations for the ``CacheControlHeader`` class properties. + + This is needed due to the conversion from pydantic models to dataclasses. Dataclasses do not support + automatic conversion of types like pydantic models do. + + Returns: + A dictionary of type annotations + + """ + + if cls._field_definitions is None: + cls._field_definitions = {} + for key, value in get_type_hints(cls, include_extras=True).items(): + definition = FieldDefinition.from_kwarg(annotation=value, name=key) + # resolve_model_type so that field_definition.raw has the real raw type e.g. <class 'bool'> + cls._field_definitions[key] = AbstractDTO.resolve_model_type(definition) + return cls._field_definitions + + @classmethod + def _convert_to_type(cls, value: str, field_definition: FieldDefinition) -> Any: + """Convert the value to the expected type. + + Args: + value: the value of the cache-control directive + field_definition: the field definition for the value to convert + + Returns: + The value converted to the expected type + """ + # bool values shouldn't be initiated since they should have been caught earlier in the from_header method and + # set with a value of True + expected_type = field_definition.raw + if expected_type is bool: + raise ImproperlyConfiguredException("Invalid cache-control header value") + return expected_type(value) + + +@dataclass +class ETag(Header): + """An ``etag`` header.""" + + HEADER_NAME: ClassVar[str] = "etag" + + weak: bool = False + value: Optional[str] = None # only ASCII characters + + def _get_header_value(self) -> str: + value = f'"{self.value}"' + return f"W/{value}" if self.weak else value + + @classmethod + def from_header(cls, header_value: str) -> "ETag": + """Construct an ``etag`` header from its string representation. + + Note that this will unquote etag-values + """ + match = ETAG_RE.match(header_value) + if not match: + raise ImproperlyConfiguredException + weak, value = match.group(1, 2) + try: + return cls(weak=bool(weak), value=value) + except ValueError as exc: + raise ImproperlyConfiguredException from exc + + def __post_init__(self) -> None: + if self.documentation_only is False and self.value is None: + raise ValidationException("value must be set if documentation_only is false") + if self.value and not PRINTABLE_ASCII_RE.fullmatch(self.value): + raise ValidationException("value must only contain ASCII printable characters") + + +class MediaTypeHeader: + """A helper class for ``Accept`` header parsing.""" + + __slots__ = ("maintype", "subtype", "params", "_params_str") + + def __init__(self, type_str: str) -> None: + # preserve the original parameters, because the order might be + # changed in the dict + self._params_str = "".join(type_str.partition(";")[1:]) + + full_type, self.params = parse_content_header(type_str) + self.maintype, _, self.subtype = full_type.partition("/") + + def __str__(self) -> str: + return f"{self.maintype}/{self.subtype}{self._params_str}" + + @property + def priority(self) -> Tuple[int, int]: + # Use fixed point values with two decimals to avoid problems + # when comparing float values + quality = 100 + if "q" in self.params: + with suppress(ValueError): + quality = int(100 * float(self.params["q"])) + + if self.maintype == "*": + specificity = 0 + elif self.subtype == "*": + specificity = 1 + elif not self.params or ("q" in self.params and len(self.params) == 1): + # no params or 'q' is the only one which we ignore + specificity = 2 + else: + specificity = 3 + + return quality, specificity + + def match(self, other: "MediaTypeHeader") -> bool: + return next( + (False for key, value in self.params.items() if key != "q" and value != other.params.get(key)), + False + if self.subtype != "*" and other.subtype != "*" and self.subtype != other.subtype + else self.maintype == "*" or other.maintype == "*" or self.maintype == other.maintype, + ) + + +class Accept: + """An ``Accept`` header.""" + + __slots__ = ("_accepted_types",) + + def __init__(self, accept_value: str) -> None: + self._accepted_types = [MediaTypeHeader(t) for t in accept_value.split(",")] + self._accepted_types.sort(key=lambda t: t.priority, reverse=True) + + def __len__(self) -> int: + return len(self._accepted_types) + + def __getitem__(self, key: int) -> str: + return str(self._accepted_types[key]) + + def __iter__(self) -> Iterator[str]: + return map(str, self._accepted_types) + + def best_match(self, provided_types: List[str], default: Optional[str] = None) -> Optional[str]: + """Find the best matching media type for the request. + + Args: + provided_types: A list of media types that can be provided as a response. These types + can contain a wildcard ``*`` character in the main- or subtype part. + default: The media type that is returned if none of the provided types match. + + Returns: + The best matching media type. If the matching provided type contains wildcard characters, + they are replaced with the corresponding part of the accepted type. Otherwise the + provided type is returned as-is. + """ + types = [MediaTypeHeader(t) for t in provided_types] + + for accepted in self._accepted_types: + for provided in types: + if provided.match(accepted): + # Return the accepted type with wildcards replaced + # by concrete parts from the provided type + result = copy(provided) + if result.subtype == "*": + result.subtype = accepted.subtype + if result.maintype == "*": + result.maintype = accepted.maintype + return str(result) + return default + + def accepts(self, media_type: str) -> bool: + """Check if the request accepts the specified media type. + + If multiple media types can be provided, it is better to use :func:`best_match`. + + Args: + media_type: The media type to check for. + + Returns: + True if the request accepts ``media_type``. + """ + return self.best_match([media_type]) == media_type |