1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence
from typing import AbstractSet, Any, Generic, Set, TypeVar, cast
from typing_extensions import ParamSpec
from polyfactory.exceptions import ParameterException
class CoverageContainerBase(ABC):
"""Base class for coverage container implementations.
A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and
track a "done" state to indicate that all coverage examples have been generated.
"""
@abstractmethod
def next_value(self) -> Any:
"""Provide the next value"""
...
@abstractmethod
def is_done(self) -> bool:
"""Indicate if this container has provided every coverage example it has"""
...
T = TypeVar("T")
class CoverageContainer(CoverageContainerBase, Generic[T]):
"""A coverage container that wraps a collection of values.
When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate
examples to be returned (wraps around).
If there are any coverage containers within the given collection, the values from those containers are essentially merged
into the parent container.
"""
def __init__(self, instances: Iterable[T]) -> None:
self._pos = 0
self._instances = list(instances)
if not self._instances:
msg = "CoverageContainer must have at least one instance"
raise ValueError(msg)
def next_value(self) -> T:
value = self._instances[self._pos % len(self._instances)]
if isinstance(value, CoverageContainerBase):
result = value.next_value()
if value.is_done():
# Only move onto the next instance if the sub-container is done
self._pos += 1
return cast(T, result)
self._pos += 1
return value
def is_done(self) -> bool:
return self._pos >= len(self._instances)
def __repr__(self) -> str:
return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})"
P = ParamSpec("P")
class CoverageContainerCallable(CoverageContainerBase, Generic[T]):
"""A coverage container that wraps a callable.
When calling ``next_value()`` the wrapped callable is called to provide a value.
"""
def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
self._func = func
self._args = args
self._kwargs = kwargs
def next_value(self) -> T:
try:
return self._func(*self._args, **self._kwargs)
except Exception as e: # noqa: BLE001
msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type."
raise ParameterException(msg) from e
def is_done(self) -> bool:
return True
def _resolve_next(unresolved: Any) -> tuple[Any, bool]: # noqa: C901
if isinstance(unresolved, CoverageContainerBase):
result, done = _resolve_next(unresolved.next_value())
return result, unresolved.is_done() and done
if isinstance(unresolved, Mapping):
result = {}
done_status = True
for key, value in unresolved.items():
val_resolved, val_done = _resolve_next(value)
key_resolved, key_done = _resolve_next(key)
result[key_resolved] = val_resolved
done_status = done_status and val_done and key_done
return result, done_status
if isinstance(unresolved, (tuple, MutableSequence)):
result = []
done_status = True
for value in unresolved:
resolved, done = _resolve_next(value)
result.append(resolved)
done_status = done_status and done
if isinstance(unresolved, tuple):
result = tuple(result)
return result, done_status
if isinstance(unresolved, Set):
result = type(unresolved)()
done_status = True
for value in unresolved:
resolved, done = _resolve_next(value)
result.add(resolved)
done_status = done_status and done
return result, done_status
if issubclass(type(unresolved), AbstractSet):
result = type(unresolved)()
done_status = True
resolved_values = []
for value in unresolved:
resolved, done = _resolve_next(value)
resolved_values.append(resolved)
done_status = done_status and done
return result.union(resolved_values), done_status
return unresolved, True
def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]:
done = False
while not done:
resolved, done = _resolve_next(kwargs)
yield resolved
|