summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py
blob: 6377125651cefec70588e6b8303ab6714de1e544 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import annotations

import random
from abc import ABC, abstractmethod
from collections import deque
from typing import Any

from polyfactory.utils.predicates import is_safe_subclass


class CollectionExtender(ABC):
    __types__: tuple[type, ...]

    @staticmethod
    @abstractmethod
    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
        raise NotImplementedError

    @classmethod
    def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]:
        return next(
            (
                subclass
                for subclass in cls.__subclasses__()
                if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__)
            ),
            FallbackExtender,
        )

    @classmethod
    def extend_type_args(
        cls,
        annotation_alias: Any,
        type_args: tuple[Any, ...],
        number_of_args: int,
    ) -> tuple[Any, ...]:
        return cls._subclass_for_type(annotation_alias)._extend_type_args(type_args, number_of_args)


class TupleExtender(CollectionExtender):
    __types__ = (tuple,)

    @staticmethod
    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
        if not type_args:
            return type_args
        if type_args[-1] is not ...:
            return type_args
        type_to_extend = type_args[-2]
        return type_args[:-2] + (type_to_extend,) * number_of_args


class ListLikeExtender(CollectionExtender):
    __types__ = (list, deque)

    @staticmethod
    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
        if not type_args:
            return type_args
        return tuple(random.choice(type_args) for _ in range(number_of_args))


class SetExtender(CollectionExtender):
    __types__ = (set, frozenset)

    @staticmethod
    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
        if not type_args:
            return type_args
        return tuple(random.choice(type_args) for _ in range(number_of_args))


class DictExtender(CollectionExtender):
    __types__ = (dict,)

    @staticmethod
    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
        return type_args * number_of_args


class FallbackExtender(CollectionExtender):
    __types__ = ()

    @staticmethod
    def _extend_type_args(
        type_args: tuple[Any, ...],
        number_of_args: int,  # noqa: ARG004
    ) -> tuple[Any, ...]:  # - investigate @guacs
        return type_args