summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py
blob: 6fc39714806a0bfab5237a41bb6c1cbcb3362b4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence
from typing import AbstractSet, Any, Generic, Set, TypeVar, cast

from typing_extensions import ParamSpec

from polyfactory.exceptions import ParameterException


class CoverageContainerBase(ABC):
    """Base class for coverage container implementations.

    A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and
    track a "done" state to indicate that all coverage examples have been generated.
    """

    @abstractmethod
    def next_value(self) -> Any:
        """Provide the next value"""
        ...

    @abstractmethod
    def is_done(self) -> bool:
        """Indicate if this container has provided every coverage example it has"""
        ...


T = TypeVar("T")


class CoverageContainer(CoverageContainerBase, Generic[T]):
    """A coverage container that wraps a collection of values.

    When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate
    examples to be returned (wraps around).

    If there are any coverage containers within the given collection, the values from those containers are essentially merged
    into the parent container.
    """

    def __init__(self, instances: Iterable[T]) -> None:
        self._pos = 0
        self._instances = list(instances)
        if not self._instances:
            msg = "CoverageContainer must have at least one instance"
            raise ValueError(msg)

    def next_value(self) -> T:
        value = self._instances[self._pos % len(self._instances)]
        if isinstance(value, CoverageContainerBase):
            result = value.next_value()
            if value.is_done():
                # Only move onto the next instance if the sub-container is done
                self._pos += 1
            return cast(T, result)

        self._pos += 1
        return value

    def is_done(self) -> bool:
        return self._pos >= len(self._instances)

    def __repr__(self) -> str:
        return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})"


P = ParamSpec("P")


class CoverageContainerCallable(CoverageContainerBase, Generic[T]):
    """A coverage container that wraps a callable.

    When calling ``next_value()`` the wrapped callable is called to provide a value.
    """

    def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
        self._func = func
        self._args = args
        self._kwargs = kwargs

    def next_value(self) -> T:
        try:
            return self._func(*self._args, **self._kwargs)
        except Exception as e:  # noqa: BLE001
            msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type."
            raise ParameterException(msg) from e

    def is_done(self) -> bool:
        return True


def _resolve_next(unresolved: Any) -> tuple[Any, bool]:  # noqa: C901
    if isinstance(unresolved, CoverageContainerBase):
        result, done = _resolve_next(unresolved.next_value())
        return result, unresolved.is_done() and done

    if isinstance(unresolved, Mapping):
        result = {}
        done_status = True
        for key, value in unresolved.items():
            val_resolved, val_done = _resolve_next(value)
            key_resolved, key_done = _resolve_next(key)
            result[key_resolved] = val_resolved
            done_status = done_status and val_done and key_done
        return result, done_status

    if isinstance(unresolved, (tuple, MutableSequence)):
        result = []
        done_status = True
        for value in unresolved:
            resolved, done = _resolve_next(value)
            result.append(resolved)
            done_status = done_status and done
        if isinstance(unresolved, tuple):
            result = tuple(result)
        return result, done_status

    if isinstance(unresolved, Set):
        result = type(unresolved)()
        done_status = True
        for value in unresolved:
            resolved, done = _resolve_next(value)
            result.add(resolved)
            done_status = done_status and done
        return result, done_status

    if issubclass(type(unresolved), AbstractSet):
        result = type(unresolved)()
        done_status = True
        resolved_values = []
        for value in unresolved:
            resolved, done = _resolve_next(value)
            resolved_values.append(resolved)
            done_status = done_status and done
        return result.union(resolved_values), done_status

    return unresolved, True


def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]:
    done = False
    while not done:
        resolved, done = _resolve_next(kwargs)
        yield resolved