summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/polyfactory/collection_extender.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory/collection_extender.py')
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/collection_extender.py89
1 files changed, 89 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py
new file mode 100644
index 0000000..6377125
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+import random
+from abc import ABC, abstractmethod
+from collections import deque
+from typing import Any
+
+from polyfactory.utils.predicates import is_safe_subclass
+
+
+class CollectionExtender(ABC):
+ __types__: tuple[type, ...]
+
+ @staticmethod
+ @abstractmethod
+ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
+ raise NotImplementedError
+
+ @classmethod
+ def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]:
+ return next(
+ (
+ subclass
+ for subclass in cls.__subclasses__()
+ if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__)
+ ),
+ FallbackExtender,
+ )
+
+ @classmethod
+ def extend_type_args(
+ cls,
+ annotation_alias: Any,
+ type_args: tuple[Any, ...],
+ number_of_args: int,
+ ) -> tuple[Any, ...]:
+ return cls._subclass_for_type(annotation_alias)._extend_type_args(type_args, number_of_args)
+
+
+class TupleExtender(CollectionExtender):
+ __types__ = (tuple,)
+
+ @staticmethod
+ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
+ if not type_args:
+ return type_args
+ if type_args[-1] is not ...:
+ return type_args
+ type_to_extend = type_args[-2]
+ return type_args[:-2] + (type_to_extend,) * number_of_args
+
+
+class ListLikeExtender(CollectionExtender):
+ __types__ = (list, deque)
+
+ @staticmethod
+ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
+ if not type_args:
+ return type_args
+ return tuple(random.choice(type_args) for _ in range(number_of_args))
+
+
+class SetExtender(CollectionExtender):
+ __types__ = (set, frozenset)
+
+ @staticmethod
+ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
+ if not type_args:
+ return type_args
+ return tuple(random.choice(type_args) for _ in range(number_of_args))
+
+
+class DictExtender(CollectionExtender):
+ __types__ = (dict,)
+
+ @staticmethod
+ def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]:
+ return type_args * number_of_args
+
+
+class FallbackExtender(CollectionExtender):
+ __types__ = ()
+
+ @staticmethod
+ def _extend_type_args(
+ type_args: tuple[Any, ...],
+ number_of_args: int, # noqa: ARG004
+ ) -> tuple[Any, ...]: # - investigate @guacs
+ return type_args