diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/datastructures/headers.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/datastructures/headers.py | 534 |
1 files changed, 0 insertions, 534 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py deleted file mode 100644 index f3e9bd7..0000000 --- a/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py +++ /dev/null @@ -1,534 +0,0 @@ -import re -from abc import ABC, abstractmethod -from contextlib import suppress -from copy import copy -from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Pattern, - Tuple, - Union, - cast, -) - -from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping -from typing_extensions import get_type_hints - -from litestar._multipart import parse_content_header -from litestar.datastructures.multi_dicts import MultiMixin -from litestar.dto.base_dto import AbstractDTO -from litestar.exceptions import ImproperlyConfiguredException, ValidationException -from litestar.types.empty import Empty -from litestar.typing import FieldDefinition -from litestar.utils.dataclass import simple_asdict -from litestar.utils.scope.state import ScopeState - -if TYPE_CHECKING: - from litestar.types.asgi_types import ( - HeaderScope, - Message, - RawHeaders, - RawHeadersList, - Scope, - ) - -__all__ = ("Accept", "CacheControlHeader", "ETag", "Header", "Headers", "MutableScopeHeaders") - -ETAG_RE = re.compile(r'([Ww]/)?"(.+)"') -PRINTABLE_ASCII_RE: Pattern[str] = re.compile(r"^[ -~]+$") - - -def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList": - return [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers] - - -class Headers(CIMultiDictProxy[str], MultiMixin[str]): - """An immutable, case-insensitive multi dict for HTTP headers.""" - - def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None: - """Initialize ``Headers``. - - Args: - headers: Initial value. - """ - if not isinstance(headers, MultiMapping): - headers_: Union[Mapping[str, str], List[Tuple[str, str]]] = {} - if headers: - if isinstance(headers, Mapping): - headers_ = headers # pyright: ignore - else: - headers_ = [(key.decode("latin-1"), value.decode("latin-1")) for key, value in headers] - - super().__init__(CIMultiDict(headers_)) - else: - super().__init__(headers) - self._header_list: Optional[RawHeadersList] = None - - @classmethod - def from_scope(cls, scope: "Scope") -> "Headers": - """Create headers from a send-message. - - Args: - scope: The ASGI connection scope. - - Returns: - Headers - - Raises: - ValueError: If the message does not have a ``headers`` key - """ - connection_state = ScopeState.from_scope(scope) - if (headers := connection_state.headers) is Empty: - headers = connection_state.headers = cls(scope["headers"]) - return headers - - def to_header_list(self) -> "RawHeadersList": - """Raw header value. - - Returns: - A list of tuples contain the header and header-value as bytes - """ - # Since ``Headers`` are immutable, this can be cached - if not self._header_list: - self._header_list = _encode_headers((key, value) for key in set(self) for value in self.getall(key)) - return self._header_list - - -class MutableScopeHeaders(MutableMapping): - """A case-insensitive, multidict-like structure that can be used to mutate headers within a - :class:`Scope <.types.Scope>` - """ - - def __init__(self, scope: Optional["HeaderScope"] = None) -> None: - """Initialize ``MutableScopeHeaders`` from a ``HeaderScope``. - - Args: - scope: The ASGI connection scope. - """ - self.headers: RawHeadersList - if scope is not None: - if not isinstance(scope["headers"], list): - scope["headers"] = list(scope["headers"]) - - self.headers = cast("RawHeadersList", scope["headers"]) - else: - self.headers = [] - - @classmethod - def from_message(cls, message: "Message") -> "MutableScopeHeaders": - """Construct a header from a message object. - - Args: - message: :class:`Message <.types.Message>`. - - Returns: - MutableScopeHeaders. - - Raises: - ValueError: If the message does not have a ``headers`` key. - """ - if "headers" not in message: - raise ValueError(f"Invalid message type: {message['type']!r}") - - return cls(cast("HeaderScope", message)) - - def add(self, key: str, value: str) -> None: - """Add a header to the scope. - - Notes: - - This method keeps duplicates. - - Args: - key: Header key. - value: Header value. - - Returns: - None. - """ - self.headers.append((key.lower().encode("latin-1"), value.encode("latin-1"))) - - def getall(self, key: str, default: Optional[List[str]] = None) -> List[str]: - """Get all values of a header. - - Args: - key: Header key. - default: Default value to return if ``name`` is not found. - - Returns: - A list of strings. - - Raises: - KeyError: if no header for ``name`` was found and ``default`` is not given. - """ - name = key.lower() - values = [ - header_value.decode("latin-1") - for header_name, header_value in self.headers - if header_name.decode("latin-1").lower() == name - ] - if not values: - if default: - return default - raise KeyError - return values - - def extend_header_value(self, key: str, value: str) -> None: - """Extend a multivalued header. - - Notes: - - A multivalues header is a header that can take a comma separated list. - - If the header previously did not exist, it will be added. - - Args: - key: Header key. - value: Header value to add, - - Returns: - None - """ - existing = self.get(key) - if existing is not None: - value = ",".join([*existing.split(","), value]) - self[key] = value - - def __getitem__(self, key: str) -> str: - """Get the first header matching ``name``""" - name = key.lower() - for header in self.headers: - if header[0].decode("latin-1").lower() == name: - return header[1].decode("latin-1") - raise KeyError - - def _find_indices(self, key: str) -> List[int]: - name = key.lower() - return [i for i, (name_, _) in enumerate(self.headers) if name_.decode("latin-1").lower() == name] - - def __setitem__(self, key: str, value: str) -> None: - """Set a header in the scope, overwriting duplicates.""" - name_encoded = key.lower().encode("latin-1") - value_encoded = value.encode("latin-1") - if indices := self._find_indices(key): - for i in indices[1:]: - del self.headers[i] - self.headers[indices[0]] = (name_encoded, value_encoded) - else: - self.headers.append((name_encoded, value_encoded)) - - def __delitem__(self, key: str) -> None: - """Delete all headers matching ``name``""" - indices = self._find_indices(key) - for i in indices[::-1]: - del self.headers[i] - - def __len__(self) -> int: - """Return the length of the internally stored headers, including duplicates.""" - return len(self.headers) - - def __iter__(self) -> Iterator[str]: - """Create an iterator of header names including duplicates.""" - return iter(h[0].decode("latin-1") for h in self.headers) - - -@dataclass -class Header(ABC): - """An abstract type for HTTP headers.""" - - HEADER_NAME: ClassVar[str] = "" - - documentation_only: bool = False - """Defines the header instance as for OpenAPI documentation purpose only.""" - - @abstractmethod - def _get_header_value(self) -> str: - """Get the header value as string.""" - raise NotImplementedError - - @classmethod - @abstractmethod - def from_header(cls, header_value: str) -> "Header": - """Construct a header from its string representation.""" - - def to_header(self, include_header_name: bool = False) -> str: - """Get the header as string. - - Args: - include_header_name: should include the header name in the return value. If set to false - the return value will only include the header value. if set to true the return value - will be: ``<header name>: <header value>``. Defaults to false. - """ - - if not self.HEADER_NAME: - raise ImproperlyConfiguredException("Missing header name") - - return (f"{self.HEADER_NAME}: " if include_header_name else "") + self._get_header_value() - - -@dataclass -class CacheControlHeader(Header): - """A ``cache-control`` header.""" - - HEADER_NAME: ClassVar[str] = "cache-control" - - max_age: Optional[int] = None - """Accessor for the ``max-age`` directive.""" - s_maxage: Optional[int] = None - """Accessor for the ``s-maxage`` directive.""" - no_cache: Optional[bool] = None - """Accessor for the ``no-cache`` directive.""" - no_store: Optional[bool] = None - """Accessor for the ``no-store`` directive.""" - private: Optional[bool] = None - """Accessor for the ``private`` directive.""" - public: Optional[bool] = None - """Accessor for the ``public`` directive.""" - no_transform: Optional[bool] = None - """Accessor for the ``no-transform`` directive.""" - must_revalidate: Optional[bool] = None - """Accessor for the ``must-revalidate`` directive.""" - proxy_revalidate: Optional[bool] = None - """Accessor for the ``proxy-revalidate`` directive.""" - must_understand: Optional[bool] = None - """Accessor for the ``must-understand`` directive.""" - immutable: Optional[bool] = None - """Accessor for the ``immutable`` directive.""" - stale_while_revalidate: Optional[int] = None - """Accessor for the ``stale-while-revalidate`` directive.""" - - _field_definitions: ClassVar[Optional[Dict[str, FieldDefinition]]] = None - - def _get_header_value(self) -> str: - """Get the header value as string.""" - - cc_items = [ - key.replace("_", "-") if isinstance(value, bool) else f"{key.replace('_', '-')}={value}" - for key, value in simple_asdict(self, exclude_none=True, exclude={"documentation_only"}).items() - ] - return ", ".join(cc_items) - - @classmethod - def from_header(cls, header_value: str) -> "CacheControlHeader": - """Create a ``CacheControlHeader`` instance from the header value. - - Args: - header_value: the header value as string - - Returns: - An instance of ``CacheControlHeader`` - """ - - cc_items = [v.strip() for v in header_value.split(",")] - kwargs: Dict[str, Any] = {} - field_definitions = cls._get_field_definitions() - for cc_item in cc_items: - key_value = cc_item.split("=") - key_value[0] = key_value[0].replace("-", "_") - if len(key_value) == 1: - kwargs[key_value[0]] = True - elif len(key_value) == 2: - key, value = key_value - if key not in field_definitions: - raise ImproperlyConfiguredException("Invalid cache-control header") - kwargs[key] = cls._convert_to_type(value, field_definition=field_definitions[key]) - else: - raise ImproperlyConfiguredException("Invalid cache-control header value") - - try: - return CacheControlHeader(**kwargs) - except TypeError as exc: - raise ImproperlyConfiguredException from exc - - @classmethod - def prevent_storing(cls) -> "CacheControlHeader": - """Create a ``cache-control`` header with the ``no-store`` directive which indicates that any caches of any kind - (private or shared) should not store this response. - """ - - return cls(no_store=True) - - @classmethod - def _get_field_definitions(cls) -> Dict[str, FieldDefinition]: - """Get the type annotations for the ``CacheControlHeader`` class properties. - - This is needed due to the conversion from pydantic models to dataclasses. Dataclasses do not support - automatic conversion of types like pydantic models do. - - Returns: - A dictionary of type annotations - - """ - - if cls._field_definitions is None: - cls._field_definitions = {} - for key, value in get_type_hints(cls, include_extras=True).items(): - definition = FieldDefinition.from_kwarg(annotation=value, name=key) - # resolve_model_type so that field_definition.raw has the real raw type e.g. <class 'bool'> - cls._field_definitions[key] = AbstractDTO.resolve_model_type(definition) - return cls._field_definitions - - @classmethod - def _convert_to_type(cls, value: str, field_definition: FieldDefinition) -> Any: - """Convert the value to the expected type. - - Args: - value: the value of the cache-control directive - field_definition: the field definition for the value to convert - - Returns: - The value converted to the expected type - """ - # bool values shouldn't be initiated since they should have been caught earlier in the from_header method and - # set with a value of True - expected_type = field_definition.raw - if expected_type is bool: - raise ImproperlyConfiguredException("Invalid cache-control header value") - return expected_type(value) - - -@dataclass -class ETag(Header): - """An ``etag`` header.""" - - HEADER_NAME: ClassVar[str] = "etag" - - weak: bool = False - value: Optional[str] = None # only ASCII characters - - def _get_header_value(self) -> str: - value = f'"{self.value}"' - return f"W/{value}" if self.weak else value - - @classmethod - def from_header(cls, header_value: str) -> "ETag": - """Construct an ``etag`` header from its string representation. - - Note that this will unquote etag-values - """ - match = ETAG_RE.match(header_value) - if not match: - raise ImproperlyConfiguredException - weak, value = match.group(1, 2) - try: - return cls(weak=bool(weak), value=value) - except ValueError as exc: - raise ImproperlyConfiguredException from exc - - def __post_init__(self) -> None: - if self.documentation_only is False and self.value is None: - raise ValidationException("value must be set if documentation_only is false") - if self.value and not PRINTABLE_ASCII_RE.fullmatch(self.value): - raise ValidationException("value must only contain ASCII printable characters") - - -class MediaTypeHeader: - """A helper class for ``Accept`` header parsing.""" - - __slots__ = ("maintype", "subtype", "params", "_params_str") - - def __init__(self, type_str: str) -> None: - # preserve the original parameters, because the order might be - # changed in the dict - self._params_str = "".join(type_str.partition(";")[1:]) - - full_type, self.params = parse_content_header(type_str) - self.maintype, _, self.subtype = full_type.partition("/") - - def __str__(self) -> str: - return f"{self.maintype}/{self.subtype}{self._params_str}" - - @property - def priority(self) -> Tuple[int, int]: - # Use fixed point values with two decimals to avoid problems - # when comparing float values - quality = 100 - if "q" in self.params: - with suppress(ValueError): - quality = int(100 * float(self.params["q"])) - - if self.maintype == "*": - specificity = 0 - elif self.subtype == "*": - specificity = 1 - elif not self.params or ("q" in self.params and len(self.params) == 1): - # no params or 'q' is the only one which we ignore - specificity = 2 - else: - specificity = 3 - - return quality, specificity - - def match(self, other: "MediaTypeHeader") -> bool: - return next( - (False for key, value in self.params.items() if key != "q" and value != other.params.get(key)), - False - if self.subtype != "*" and other.subtype != "*" and self.subtype != other.subtype - else self.maintype == "*" or other.maintype == "*" or self.maintype == other.maintype, - ) - - -class Accept: - """An ``Accept`` header.""" - - __slots__ = ("_accepted_types",) - - def __init__(self, accept_value: str) -> None: - self._accepted_types = [MediaTypeHeader(t) for t in accept_value.split(",")] - self._accepted_types.sort(key=lambda t: t.priority, reverse=True) - - def __len__(self) -> int: - return len(self._accepted_types) - - def __getitem__(self, key: int) -> str: - return str(self._accepted_types[key]) - - def __iter__(self) -> Iterator[str]: - return map(str, self._accepted_types) - - def best_match(self, provided_types: List[str], default: Optional[str] = None) -> Optional[str]: - """Find the best matching media type for the request. - - Args: - provided_types: A list of media types that can be provided as a response. These types - can contain a wildcard ``*`` character in the main- or subtype part. - default: The media type that is returned if none of the provided types match. - - Returns: - The best matching media type. If the matching provided type contains wildcard characters, - they are replaced with the corresponding part of the accepted type. Otherwise the - provided type is returned as-is. - """ - types = [MediaTypeHeader(t) for t in provided_types] - - for accepted in self._accepted_types: - for provided in types: - if provided.match(accepted): - # Return the accepted type with wildcards replaced - # by concrete parts from the provided type - result = copy(provided) - if result.subtype == "*": - result.subtype = accepted.subtype - if result.maintype == "*": - result.maintype = accepted.maintype - return str(result) - return default - - def accepts(self, media_type: str) -> bool: - """Check if the request accepts the specified media type. - - If multiple media types can be provided, it is better to use :func:`best_match`. - - Args: - media_type: The media type to check for. - - Returns: - True if the request accepts ``media_type``. - """ - return self.best_match([media_type]) == media_type |