summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/middleware/allowed_hosts.py
blob: 0172176d9208afa1b9d47e2500f1359bd22e4368 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Pattern

from litestar.datastructures import URL, MutableScopeHeaders
from litestar.middleware.base import AbstractMiddleware
from litestar.response.base import ASGIResponse
from litestar.response.redirect import ASGIRedirectResponse
from litestar.status_codes import HTTP_400_BAD_REQUEST

__all__ = ("AllowedHostsMiddleware",)


if TYPE_CHECKING:
    from litestar.config.allowed_hosts import AllowedHostsConfig
    from litestar.types import ASGIApp, Receive, Scope, Send


class AllowedHostsMiddleware(AbstractMiddleware):
    """Middleware ensuring the host of a request originated in a trusted host."""

    def __init__(self, app: ASGIApp, config: AllowedHostsConfig) -> None:
        """Initialize ``AllowedHostsMiddleware``.

        Args:
            app: The ``next`` ASGI app to call.
            config: An instance of AllowedHostsConfig.
        """

        super().__init__(app=app, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key, scopes=config.scopes)

        self.allowed_hosts_regex: Pattern | None = None
        self.redirect_domains: Pattern | None = None

        if any(host == "*" for host in config.allowed_hosts):
            return

        allowed_hosts: set[str] = {
            rf".*\.{host.replace('*.', '')}$" if host.startswith("*.") else host for host in config.allowed_hosts
        }

        self.allowed_hosts_regex = re.compile("|".join(sorted(allowed_hosts)))  # pyright: ignore

        if config.www_redirect and (
            redirect_domains := {host.replace("www.", "") for host in config.allowed_hosts if host.startswith("www.")}
        ):
            self.redirect_domains = re.compile("|".join(sorted(redirect_domains)))  # pyright: ignore

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """ASGI callable.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.

        Returns:
            None
        """
        if self.allowed_hosts_regex is None:
            await self.app(scope, receive, send)
            return

        headers = MutableScopeHeaders(scope=scope)
        if host := headers.get("host", headers.get("x-forwarded-host", "")).split(":")[0]:
            if self.allowed_hosts_regex.fullmatch(host):
                await self.app(scope, receive, send)
                return

            if self.redirect_domains is not None and self.redirect_domains.fullmatch(host):
                url = URL.from_scope(scope)
                redirect_url = url.with_replacements(netloc=f"www.{url.netloc}")
                redirect_response = ASGIRedirectResponse(path=str(redirect_url))
                await redirect_response(scope, receive, send)
                return

        response = ASGIResponse(body=b'{"message":"invalid host header"}', status_code=HTTP_400_BAD_REQUEST)
        await response(scope, receive, send)