summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_kwargs/dependencies.py
blob: 88ffb07b1e95289de4e21628f03a1a313f66b255 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from litestar.utils.compat import async_next

__all__ = ("Dependency", "create_dependency_batches", "map_dependencies_recursively", "resolve_dependency")


if TYPE_CHECKING:
    from litestar._kwargs.cleanup import DependencyCleanupGroup
    from litestar.connection import ASGIConnection
    from litestar.di import Provide


class Dependency:
    """Dependency graph of a given combination of ``Route`` + ``RouteHandler``"""

    __slots__ = ("key", "provide", "dependencies")

    def __init__(self, key: str, provide: Provide, dependencies: list[Dependency]) -> None:
        """Initialize a dependency.

        Args:
            key: The dependency key
            provide: Provider
            dependencies: List of child nodes
        """
        self.key = key
        self.provide = provide
        self.dependencies = dependencies

    def __eq__(self, other: Any) -> bool:
        # check if memory address is identical, otherwise compare attributes
        return other is self or (isinstance(other, self.__class__) and other.key == self.key)

    def __hash__(self) -> int:
        return hash(self.key)


async def resolve_dependency(
    dependency: Dependency,
    connection: ASGIConnection,
    kwargs: dict[str, Any],
    cleanup_group: DependencyCleanupGroup,
) -> None:
    """Resolve a given instance of :class:`Dependency <litestar._kwargs.Dependency>`.

    All required sub dependencies must already
    be resolved into the kwargs. The result of the dependency will be stored in the kwargs.

    Args:
        dependency: An instance of :class:`Dependency <litestar._kwargs.Dependency>`
        connection: An instance of :class:`Request <litestar.connection.Request>` or
            :class:`WebSocket <litestar.connection.WebSocket>`.
        kwargs: Any kwargs to pass to the dependency, the result will be stored here as well.
        cleanup_group: DependencyCleanupGroup to which generators returned by ``dependency`` will be added
    """
    signature_model = dependency.provide.signature_model
    dependency_kwargs = (
        signature_model.parse_values_from_connection_kwargs(connection=connection, **kwargs)
        if signature_model._fields
        else {}
    )
    value = await dependency.provide(**dependency_kwargs)

    if dependency.provide.has_sync_generator_dependency:
        cleanup_group.add(value)
        value = next(value)
    elif dependency.provide.has_async_generator_dependency:
        cleanup_group.add(value)
        value = await async_next(value)

    kwargs[dependency.key] = value


def create_dependency_batches(expected_dependencies: set[Dependency]) -> list[set[Dependency]]:
    """Calculate batches for all dependencies, recursively.

    Args:
        expected_dependencies: A set of all direct :class:`Dependencies <litestar._kwargs.Dependency>`.

    Returns:
        A list of batches.
    """
    dependencies_to: dict[Dependency, set[Dependency]] = {}
    for dependency in expected_dependencies:
        if dependency not in dependencies_to:
            map_dependencies_recursively(dependency, dependencies_to)

    batches = []
    while dependencies_to:
        current_batch = {
            dependency
            for dependency, remaining_sub_dependencies in dependencies_to.items()
            if not remaining_sub_dependencies
        }

        for dependency in current_batch:
            del dependencies_to[dependency]
            for others_dependencies in dependencies_to.values():
                others_dependencies.discard(dependency)

        batches.append(current_batch)

    return batches


def map_dependencies_recursively(dependency: Dependency, dependencies_to: dict[Dependency, set[Dependency]]) -> None:
    """Recursively map dependencies to their sub dependencies.

    Args:
        dependency: The current dependency to map.
        dependencies_to: A map of dependency to its sub dependencies.
    """
    dependencies_to[dependency] = set(dependency.dependencies)
    for sub in dependency.dependencies:
        if sub not in dependencies_to:
            map_dependencies_recursively(sub, dependencies_to)