diff options
| author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
|---|---|---|
| committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
| commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
| tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/security/jwt/token.py | |
| parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) | |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/security/jwt/token.py')
| -rw-r--r-- | venv/lib/python3.11/site-packages/litestar/security/jwt/token.py | 119 | 
1 files changed, 119 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py b/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py new file mode 100644 index 0000000..279111a --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/security/jwt/token.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import dataclasses +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from jose import JWSError, JWTError, jwt + +from litestar.exceptions import ImproperlyConfiguredException, NotAuthorizedException + +if TYPE_CHECKING: +    from typing_extensions import Self + + +__all__ = ("Token",) + + +def _normalize_datetime(value: datetime) -> datetime: +    """Convert the given value into UTC and strip microseconds. + +    Args: +        value: A datetime instance + +    Returns: +        A datetime instance +    """ +    if value.tzinfo is not None: +        value.astimezone(timezone.utc) + +    return value.replace(microsecond=0) + + +@dataclass +class Token: +    """JWT Token DTO.""" + +    exp: datetime +    """Expiration - datetime for token expiration.""" +    sub: str +    """Subject - usually a unique identifier of the user or equivalent entity.""" +    iat: datetime = field(default_factory=lambda: _normalize_datetime(datetime.now(timezone.utc))) +    """Issued at - should always be current now.""" +    iss: str | None = field(default=None) +    """Issuer - optional unique identifier for the issuer.""" +    aud: str | None = field(default=None) +    """Audience - intended audience.""" +    jti: str | None = field(default=None) +    """JWT ID - a unique identifier of the JWT between different issuers.""" +    extras: dict[str, Any] = field(default_factory=dict) +    """Extra fields that were found on the JWT token.""" + +    def __post_init__(self) -> None: +        if len(self.sub) < 1: +            raise ImproperlyConfiguredException("sub must be a string with a length greater than 0") + +        if isinstance(self.exp, datetime) and ( +            (exp := _normalize_datetime(self.exp)).timestamp() +            >= _normalize_datetime(datetime.now(timezone.utc)).timestamp() +        ): +            self.exp = exp +        else: +            raise ImproperlyConfiguredException("exp value must be a datetime in the future") + +        if isinstance(self.iat, datetime) and ( +            (iat := _normalize_datetime(self.iat)).timestamp() +            <= _normalize_datetime(datetime.now(timezone.utc)).timestamp() +        ): +            self.iat = iat +        else: +            raise ImproperlyConfiguredException("iat must be a current or past time") + +    @classmethod +    def decode(cls, encoded_token: str, secret: str | dict[str, str], algorithm: str) -> Self: +        """Decode a passed in token string and returns a Token instance. + +        Args: +            encoded_token: A base64 string containing an encoded JWT. +            secret: The secret with which the JWT is encoded. It may optionally be an individual JWK or JWS set dict +            algorithm: The algorithm used to encode the JWT. + +        Returns: +            A decoded Token instance. + +        Raises: +            NotAuthorizedException: If the token is invalid. +        """ +        try: +            payload = jwt.decode(token=encoded_token, key=secret, algorithms=[algorithm], options={"verify_aud": False}) +            exp = datetime.fromtimestamp(payload.pop("exp"), tz=timezone.utc) +            iat = datetime.fromtimestamp(payload.pop("iat"), tz=timezone.utc) +            field_names = {f.name for f in dataclasses.fields(Token)} +            extra_fields = payload.keys() - field_names +            extras = payload.pop("extras", {}) +            for key in extra_fields: +                extras[key] = payload.pop(key) +            return cls(exp=exp, iat=iat, **payload, extras=extras) +        except (KeyError, JWTError, ImproperlyConfiguredException) as e: +            raise NotAuthorizedException("Invalid token") from e + +    def encode(self, secret: str, algorithm: str) -> str: +        """Encode the token instance into a string. + +        Args: +            secret: The secret with which the JWT is encoded. +            algorithm: The algorithm used to encode the JWT. + +        Returns: +            An encoded token string. + +        Raises: +            ImproperlyConfiguredException: If encoding fails. +        """ +        try: +            return jwt.encode( +                claims={k: v for k, v in asdict(self).items() if v is not None}, key=secret, algorithm=algorithm +            ) +        except (JWTError, JWSError) as e: +            raise ImproperlyConfiguredException("Failed to encode token") from e | 
