summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/contrib/prometheus/middleware.py
blob: 50bc7cb03f088097908f734e8dad32de2ba8dffe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from __future__ import annotations

import time
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast

from litestar.connection.request import Request
from litestar.enums import ScopeType
from litestar.exceptions import MissingDependencyException
from litestar.middleware.base import AbstractMiddleware

__all__ = ("PrometheusMiddleware",)

from litestar.status_codes import HTTP_500_INTERNAL_SERVER_ERROR

try:
    import prometheus_client  # noqa: F401
except ImportError as e:
    raise MissingDependencyException("prometheus_client", "prometheus-client", "prometheus") from e

from prometheus_client import Counter, Gauge, Histogram

if TYPE_CHECKING:
    from prometheus_client.metrics import MetricWrapperBase

    from litestar.contrib.prometheus import PrometheusConfig
    from litestar.types import ASGIApp, Message, Receive, Scope, Send


class PrometheusMiddleware(AbstractMiddleware):
    """Prometheus Middleware."""

    _metrics: ClassVar[dict[str, MetricWrapperBase]] = {}

    def __init__(self, app: ASGIApp, config: PrometheusConfig) -> None:
        """Middleware that adds Prometheus instrumentation to the application.

        Args:
            app: The ``next`` ASGI app to call.
            config: An instance of :class:`PrometheusConfig <.contrib.prometheus.PrometheusConfig>`
        """
        super().__init__(app=app, scopes=config.scopes, exclude=config.exclude, exclude_opt_key=config.exclude_opt_key)
        self._config = config
        self._kwargs: dict[str, Any] = {}

        if self._config.buckets is not None:
            self._kwargs["buckets"] = self._config.buckets

    def request_count(self, labels: dict[str, str | int | float]) -> Counter:
        metric_name = f"{self._config.prefix}_requests_total"

        if metric_name not in PrometheusMiddleware._metrics:
            PrometheusMiddleware._metrics[metric_name] = Counter(
                name=metric_name,
                documentation="Total requests",
                labelnames=[*labels.keys()],
            )

        return cast("Counter", PrometheusMiddleware._metrics[metric_name])

    def request_time(self, labels: dict[str, str | int | float]) -> Histogram:
        metric_name = f"{self._config.prefix}_request_duration_seconds"

        if metric_name not in PrometheusMiddleware._metrics:
            PrometheusMiddleware._metrics[metric_name] = Histogram(
                name=metric_name,
                documentation="Request duration, in seconds",
                labelnames=[*labels.keys()],
                **self._kwargs,
            )
        return cast("Histogram", PrometheusMiddleware._metrics[metric_name])

    def requests_in_progress(self, labels: dict[str, str | int | float]) -> Gauge:
        metric_name = f"{self._config.prefix}_requests_in_progress"

        if metric_name not in PrometheusMiddleware._metrics:
            PrometheusMiddleware._metrics[metric_name] = Gauge(
                name=metric_name,
                documentation="Total requests currently in progress",
                labelnames=[*labels.keys()],
                multiprocess_mode="livesum",
            )
        return cast("Gauge", PrometheusMiddleware._metrics[metric_name])

    def requests_error_count(self, labels: dict[str, str | int | float]) -> Counter:
        metric_name = f"{self._config.prefix}_requests_error_total"

        if metric_name not in PrometheusMiddleware._metrics:
            PrometheusMiddleware._metrics[metric_name] = Counter(
                name=metric_name,
                documentation="Total errors in requests",
                labelnames=[*labels.keys()],
            )
        return cast("Counter", PrometheusMiddleware._metrics[metric_name])

    def _get_extra_labels(self, request: Request[Any, Any, Any]) -> dict[str, str]:
        """Get extra labels provided by the config and if they are callable, parse them.

        Args:
        request: The request object.

        Returns:
        A dictionary of extra labels.
        """

        return {k: str(v(request) if callable(v) else v) for k, v in (self._config.labels or {}).items()}

    def _get_default_labels(self, request: Request[Any, Any, Any]) -> dict[str, str | int | float]:
        """Get default label values from the request.

        Args:
            request: The request object.

        Returns:
            A dictionary of default labels.
        """

        return {
            "method": request.method if request.scope["type"] == ScopeType.HTTP else request.scope["type"],
            "path": request.url.path,
            "status_code": 200,
            "app_name": self._config.app_name,
        }

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """ASGI callable.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.

        Returns:
            None
        """

        request = Request[Any, Any, Any](scope, receive)

        if self._config.excluded_http_methods and request.method in self._config.excluded_http_methods:
            await self.app(scope, receive, send)
            return

        labels = {**self._get_default_labels(request), **self._get_extra_labels(request)}

        request_span = {"start_time": time.perf_counter(), "end_time": 0, "duration": 0, "status_code": 200}

        wrapped_send = self._get_wrapped_send(send, request_span)

        self.requests_in_progress(labels).labels(*labels.values()).inc()

        try:
            await self.app(scope, receive, wrapped_send)
        finally:
            extra: dict[str, Any] = {}
            if self._config.exemplars:
                extra["exemplar"] = self._config.exemplars(request)

            self.requests_in_progress(labels).labels(*labels.values()).dec()

            labels["status_code"] = request_span["status_code"]
            label_values = [*labels.values()]

            if request_span["status_code"] >= HTTP_500_INTERNAL_SERVER_ERROR:
                self.requests_error_count(labels).labels(*label_values).inc(**extra)

            self.request_count(labels).labels(*label_values).inc(**extra)
            self.request_time(labels).labels(*label_values).observe(request_span["duration"], **extra)

    def _get_wrapped_send(self, send: Send, request_span: dict[str, float]) -> Callable:
        @wraps(send)
        async def wrapped_send(message: Message) -> None:
            if message["type"] == "http.response.start":
                request_span["status_code"] = message["status"]

            if message["type"] == "http.response.body":
                end = time.perf_counter()
                request_span["duration"] = end - request_span["start_time"]
                request_span["end_time"] = end
            await send(message)

        return wrapped_send