summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/datastructures/headers.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/datastructures/headers.py534
1 files changed, 534 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py
new file mode 100644
index 0000000..f3e9bd7
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/datastructures/headers.py
@@ -0,0 +1,534 @@
+import re
+from abc import ABC, abstractmethod
+from contextlib import suppress
+from copy import copy
+from dataclasses import dataclass
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Pattern,
+ Tuple,
+ Union,
+ cast,
+)
+
+from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping
+from typing_extensions import get_type_hints
+
+from litestar._multipart import parse_content_header
+from litestar.datastructures.multi_dicts import MultiMixin
+from litestar.dto.base_dto import AbstractDTO
+from litestar.exceptions import ImproperlyConfiguredException, ValidationException
+from litestar.types.empty import Empty
+from litestar.typing import FieldDefinition
+from litestar.utils.dataclass import simple_asdict
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from litestar.types.asgi_types import (
+ HeaderScope,
+ Message,
+ RawHeaders,
+ RawHeadersList,
+ Scope,
+ )
+
+__all__ = ("Accept", "CacheControlHeader", "ETag", "Header", "Headers", "MutableScopeHeaders")
+
+ETAG_RE = re.compile(r'([Ww]/)?"(.+)"')
+PRINTABLE_ASCII_RE: Pattern[str] = re.compile(r"^[ -~]+$")
+
+
+def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList":
+ return [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers]
+
+
+class Headers(CIMultiDictProxy[str], MultiMixin[str]):
+ """An immutable, case-insensitive multi dict for HTTP headers."""
+
+ def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeaders", MultiMapping]] = None) -> None:
+ """Initialize ``Headers``.
+
+ Args:
+ headers: Initial value.
+ """
+ if not isinstance(headers, MultiMapping):
+ headers_: Union[Mapping[str, str], List[Tuple[str, str]]] = {}
+ if headers:
+ if isinstance(headers, Mapping):
+ headers_ = headers # pyright: ignore
+ else:
+ headers_ = [(key.decode("latin-1"), value.decode("latin-1")) for key, value in headers]
+
+ super().__init__(CIMultiDict(headers_))
+ else:
+ super().__init__(headers)
+ self._header_list: Optional[RawHeadersList] = None
+
+ @classmethod
+ def from_scope(cls, scope: "Scope") -> "Headers":
+ """Create headers from a send-message.
+
+ Args:
+ scope: The ASGI connection scope.
+
+ Returns:
+ Headers
+
+ Raises:
+ ValueError: If the message does not have a ``headers`` key
+ """
+ connection_state = ScopeState.from_scope(scope)
+ if (headers := connection_state.headers) is Empty:
+ headers = connection_state.headers = cls(scope["headers"])
+ return headers
+
+ def to_header_list(self) -> "RawHeadersList":
+ """Raw header value.
+
+ Returns:
+ A list of tuples contain the header and header-value as bytes
+ """
+ # Since ``Headers`` are immutable, this can be cached
+ if not self._header_list:
+ self._header_list = _encode_headers((key, value) for key in set(self) for value in self.getall(key))
+ return self._header_list
+
+
+class MutableScopeHeaders(MutableMapping):
+ """A case-insensitive, multidict-like structure that can be used to mutate headers within a
+ :class:`Scope <.types.Scope>`
+ """
+
+ def __init__(self, scope: Optional["HeaderScope"] = None) -> None:
+ """Initialize ``MutableScopeHeaders`` from a ``HeaderScope``.
+
+ Args:
+ scope: The ASGI connection scope.
+ """
+ self.headers: RawHeadersList
+ if scope is not None:
+ if not isinstance(scope["headers"], list):
+ scope["headers"] = list(scope["headers"])
+
+ self.headers = cast("RawHeadersList", scope["headers"])
+ else:
+ self.headers = []
+
+ @classmethod
+ def from_message(cls, message: "Message") -> "MutableScopeHeaders":
+ """Construct a header from a message object.
+
+ Args:
+ message: :class:`Message <.types.Message>`.
+
+ Returns:
+ MutableScopeHeaders.
+
+ Raises:
+ ValueError: If the message does not have a ``headers`` key.
+ """
+ if "headers" not in message:
+ raise ValueError(f"Invalid message type: {message['type']!r}")
+
+ return cls(cast("HeaderScope", message))
+
+ def add(self, key: str, value: str) -> None:
+ """Add a header to the scope.
+
+ Notes:
+ - This method keeps duplicates.
+
+ Args:
+ key: Header key.
+ value: Header value.
+
+ Returns:
+ None.
+ """
+ self.headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
+
+ def getall(self, key: str, default: Optional[List[str]] = None) -> List[str]:
+ """Get all values of a header.
+
+ Args:
+ key: Header key.
+ default: Default value to return if ``name`` is not found.
+
+ Returns:
+ A list of strings.
+
+ Raises:
+ KeyError: if no header for ``name`` was found and ``default`` is not given.
+ """
+ name = key.lower()
+ values = [
+ header_value.decode("latin-1")
+ for header_name, header_value in self.headers
+ if header_name.decode("latin-1").lower() == name
+ ]
+ if not values:
+ if default:
+ return default
+ raise KeyError
+ return values
+
+ def extend_header_value(self, key: str, value: str) -> None:
+ """Extend a multivalued header.
+
+ Notes:
+ - A multivalues header is a header that can take a comma separated list.
+ - If the header previously did not exist, it will be added.
+
+ Args:
+ key: Header key.
+ value: Header value to add,
+
+ Returns:
+ None
+ """
+ existing = self.get(key)
+ if existing is not None:
+ value = ",".join([*existing.split(","), value])
+ self[key] = value
+
+ def __getitem__(self, key: str) -> str:
+ """Get the first header matching ``name``"""
+ name = key.lower()
+ for header in self.headers:
+ if header[0].decode("latin-1").lower() == name:
+ return header[1].decode("latin-1")
+ raise KeyError
+
+ def _find_indices(self, key: str) -> List[int]:
+ name = key.lower()
+ return [i for i, (name_, _) in enumerate(self.headers) if name_.decode("latin-1").lower() == name]
+
+ def __setitem__(self, key: str, value: str) -> None:
+ """Set a header in the scope, overwriting duplicates."""
+ name_encoded = key.lower().encode("latin-1")
+ value_encoded = value.encode("latin-1")
+ if indices := self._find_indices(key):
+ for i in indices[1:]:
+ del self.headers[i]
+ self.headers[indices[0]] = (name_encoded, value_encoded)
+ else:
+ self.headers.append((name_encoded, value_encoded))
+
+ def __delitem__(self, key: str) -> None:
+ """Delete all headers matching ``name``"""
+ indices = self._find_indices(key)
+ for i in indices[::-1]:
+ del self.headers[i]
+
+ def __len__(self) -> int:
+ """Return the length of the internally stored headers, including duplicates."""
+ return len(self.headers)
+
+ def __iter__(self) -> Iterator[str]:
+ """Create an iterator of header names including duplicates."""
+ return iter(h[0].decode("latin-1") for h in self.headers)
+
+
+@dataclass
+class Header(ABC):
+ """An abstract type for HTTP headers."""
+
+ HEADER_NAME: ClassVar[str] = ""
+
+ documentation_only: bool = False
+ """Defines the header instance as for OpenAPI documentation purpose only."""
+
+ @abstractmethod
+ def _get_header_value(self) -> str:
+ """Get the header value as string."""
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def from_header(cls, header_value: str) -> "Header":
+ """Construct a header from its string representation."""
+
+ def to_header(self, include_header_name: bool = False) -> str:
+ """Get the header as string.
+
+ Args:
+ include_header_name: should include the header name in the return value. If set to false
+ the return value will only include the header value. if set to true the return value
+ will be: ``<header name>: <header value>``. Defaults to false.
+ """
+
+ if not self.HEADER_NAME:
+ raise ImproperlyConfiguredException("Missing header name")
+
+ return (f"{self.HEADER_NAME}: " if include_header_name else "") + self._get_header_value()
+
+
+@dataclass
+class CacheControlHeader(Header):
+ """A ``cache-control`` header."""
+
+ HEADER_NAME: ClassVar[str] = "cache-control"
+
+ max_age: Optional[int] = None
+ """Accessor for the ``max-age`` directive."""
+ s_maxage: Optional[int] = None
+ """Accessor for the ``s-maxage`` directive."""
+ no_cache: Optional[bool] = None
+ """Accessor for the ``no-cache`` directive."""
+ no_store: Optional[bool] = None
+ """Accessor for the ``no-store`` directive."""
+ private: Optional[bool] = None
+ """Accessor for the ``private`` directive."""
+ public: Optional[bool] = None
+ """Accessor for the ``public`` directive."""
+ no_transform: Optional[bool] = None
+ """Accessor for the ``no-transform`` directive."""
+ must_revalidate: Optional[bool] = None
+ """Accessor for the ``must-revalidate`` directive."""
+ proxy_revalidate: Optional[bool] = None
+ """Accessor for the ``proxy-revalidate`` directive."""
+ must_understand: Optional[bool] = None
+ """Accessor for the ``must-understand`` directive."""
+ immutable: Optional[bool] = None
+ """Accessor for the ``immutable`` directive."""
+ stale_while_revalidate: Optional[int] = None
+ """Accessor for the ``stale-while-revalidate`` directive."""
+
+ _field_definitions: ClassVar[Optional[Dict[str, FieldDefinition]]] = None
+
+ def _get_header_value(self) -> str:
+ """Get the header value as string."""
+
+ cc_items = [
+ key.replace("_", "-") if isinstance(value, bool) else f"{key.replace('_', '-')}={value}"
+ for key, value in simple_asdict(self, exclude_none=True, exclude={"documentation_only"}).items()
+ ]
+ return ", ".join(cc_items)
+
+ @classmethod
+ def from_header(cls, header_value: str) -> "CacheControlHeader":
+ """Create a ``CacheControlHeader`` instance from the header value.
+
+ Args:
+ header_value: the header value as string
+
+ Returns:
+ An instance of ``CacheControlHeader``
+ """
+
+ cc_items = [v.strip() for v in header_value.split(",")]
+ kwargs: Dict[str, Any] = {}
+ field_definitions = cls._get_field_definitions()
+ for cc_item in cc_items:
+ key_value = cc_item.split("=")
+ key_value[0] = key_value[0].replace("-", "_")
+ if len(key_value) == 1:
+ kwargs[key_value[0]] = True
+ elif len(key_value) == 2:
+ key, value = key_value
+ if key not in field_definitions:
+ raise ImproperlyConfiguredException("Invalid cache-control header")
+ kwargs[key] = cls._convert_to_type(value, field_definition=field_definitions[key])
+ else:
+ raise ImproperlyConfiguredException("Invalid cache-control header value")
+
+ try:
+ return CacheControlHeader(**kwargs)
+ except TypeError as exc:
+ raise ImproperlyConfiguredException from exc
+
+ @classmethod
+ def prevent_storing(cls) -> "CacheControlHeader":
+ """Create a ``cache-control`` header with the ``no-store`` directive which indicates that any caches of any kind
+ (private or shared) should not store this response.
+ """
+
+ return cls(no_store=True)
+
+ @classmethod
+ def _get_field_definitions(cls) -> Dict[str, FieldDefinition]:
+ """Get the type annotations for the ``CacheControlHeader`` class properties.
+
+ This is needed due to the conversion from pydantic models to dataclasses. Dataclasses do not support
+ automatic conversion of types like pydantic models do.
+
+ Returns:
+ A dictionary of type annotations
+
+ """
+
+ if cls._field_definitions is None:
+ cls._field_definitions = {}
+ for key, value in get_type_hints(cls, include_extras=True).items():
+ definition = FieldDefinition.from_kwarg(annotation=value, name=key)
+ # resolve_model_type so that field_definition.raw has the real raw type e.g. <class 'bool'>
+ cls._field_definitions[key] = AbstractDTO.resolve_model_type(definition)
+ return cls._field_definitions
+
+ @classmethod
+ def _convert_to_type(cls, value: str, field_definition: FieldDefinition) -> Any:
+ """Convert the value to the expected type.
+
+ Args:
+ value: the value of the cache-control directive
+ field_definition: the field definition for the value to convert
+
+ Returns:
+ The value converted to the expected type
+ """
+ # bool values shouldn't be initiated since they should have been caught earlier in the from_header method and
+ # set with a value of True
+ expected_type = field_definition.raw
+ if expected_type is bool:
+ raise ImproperlyConfiguredException("Invalid cache-control header value")
+ return expected_type(value)
+
+
+@dataclass
+class ETag(Header):
+ """An ``etag`` header."""
+
+ HEADER_NAME: ClassVar[str] = "etag"
+
+ weak: bool = False
+ value: Optional[str] = None # only ASCII characters
+
+ def _get_header_value(self) -> str:
+ value = f'"{self.value}"'
+ return f"W/{value}" if self.weak else value
+
+ @classmethod
+ def from_header(cls, header_value: str) -> "ETag":
+ """Construct an ``etag`` header from its string representation.
+
+ Note that this will unquote etag-values
+ """
+ match = ETAG_RE.match(header_value)
+ if not match:
+ raise ImproperlyConfiguredException
+ weak, value = match.group(1, 2)
+ try:
+ return cls(weak=bool(weak), value=value)
+ except ValueError as exc:
+ raise ImproperlyConfiguredException from exc
+
+ def __post_init__(self) -> None:
+ if self.documentation_only is False and self.value is None:
+ raise ValidationException("value must be set if documentation_only is false")
+ if self.value and not PRINTABLE_ASCII_RE.fullmatch(self.value):
+ raise ValidationException("value must only contain ASCII printable characters")
+
+
+class MediaTypeHeader:
+ """A helper class for ``Accept`` header parsing."""
+
+ __slots__ = ("maintype", "subtype", "params", "_params_str")
+
+ def __init__(self, type_str: str) -> None:
+ # preserve the original parameters, because the order might be
+ # changed in the dict
+ self._params_str = "".join(type_str.partition(";")[1:])
+
+ full_type, self.params = parse_content_header(type_str)
+ self.maintype, _, self.subtype = full_type.partition("/")
+
+ def __str__(self) -> str:
+ return f"{self.maintype}/{self.subtype}{self._params_str}"
+
+ @property
+ def priority(self) -> Tuple[int, int]:
+ # Use fixed point values with two decimals to avoid problems
+ # when comparing float values
+ quality = 100
+ if "q" in self.params:
+ with suppress(ValueError):
+ quality = int(100 * float(self.params["q"]))
+
+ if self.maintype == "*":
+ specificity = 0
+ elif self.subtype == "*":
+ specificity = 1
+ elif not self.params or ("q" in self.params and len(self.params) == 1):
+ # no params or 'q' is the only one which we ignore
+ specificity = 2
+ else:
+ specificity = 3
+
+ return quality, specificity
+
+ def match(self, other: "MediaTypeHeader") -> bool:
+ return next(
+ (False for key, value in self.params.items() if key != "q" and value != other.params.get(key)),
+ False
+ if self.subtype != "*" and other.subtype != "*" and self.subtype != other.subtype
+ else self.maintype == "*" or other.maintype == "*" or self.maintype == other.maintype,
+ )
+
+
+class Accept:
+ """An ``Accept`` header."""
+
+ __slots__ = ("_accepted_types",)
+
+ def __init__(self, accept_value: str) -> None:
+ self._accepted_types = [MediaTypeHeader(t) for t in accept_value.split(",")]
+ self._accepted_types.sort(key=lambda t: t.priority, reverse=True)
+
+ def __len__(self) -> int:
+ return len(self._accepted_types)
+
+ def __getitem__(self, key: int) -> str:
+ return str(self._accepted_types[key])
+
+ def __iter__(self) -> Iterator[str]:
+ return map(str, self._accepted_types)
+
+ def best_match(self, provided_types: List[str], default: Optional[str] = None) -> Optional[str]:
+ """Find the best matching media type for the request.
+
+ Args:
+ provided_types: A list of media types that can be provided as a response. These types
+ can contain a wildcard ``*`` character in the main- or subtype part.
+ default: The media type that is returned if none of the provided types match.
+
+ Returns:
+ The best matching media type. If the matching provided type contains wildcard characters,
+ they are replaced with the corresponding part of the accepted type. Otherwise the
+ provided type is returned as-is.
+ """
+ types = [MediaTypeHeader(t) for t in provided_types]
+
+ for accepted in self._accepted_types:
+ for provided in types:
+ if provided.match(accepted):
+ # Return the accepted type with wildcards replaced
+ # by concrete parts from the provided type
+ result = copy(provided)
+ if result.subtype == "*":
+ result.subtype = accepted.subtype
+ if result.maintype == "*":
+ result.maintype = accepted.maintype
+ return str(result)
+ return default
+
+ def accepts(self, media_type: str) -> bool:
+ """Check if the request accepts the specified media type.
+
+ If multiple media types can be provided, it is better to use :func:`best_match`.
+
+ Args:
+ media_type: The media type to check for.
+
+ Returns:
+ True if the request accepts ``media_type``.
+ """
+ return self.best_match([media_type]) == media_type