diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/polyfactory/collection_extender.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory/collection_extender.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/polyfactory/collection_extender.py | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py new file mode 100644 index 0000000..6377125 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import deque +from typing import Any + +from polyfactory.utils.predicates import is_safe_subclass + + +class CollectionExtender(ABC): + __types__: tuple[type, ...] + + @staticmethod + @abstractmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + raise NotImplementedError + + @classmethod + def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]: + return next( + ( + subclass + for subclass in cls.__subclasses__() + if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__) + ), + FallbackExtender, + ) + + @classmethod + def extend_type_args( + cls, + annotation_alias: Any, + type_args: tuple[Any, ...], + number_of_args: int, + ) -> tuple[Any, ...]: + return cls._subclass_for_type(annotation_alias)._extend_type_args(type_args, number_of_args) + + +class TupleExtender(CollectionExtender): + __types__ = (tuple,) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + if type_args[-1] is not ...: + return type_args + type_to_extend = type_args[-2] + return type_args[:-2] + (type_to_extend,) * number_of_args + + +class ListLikeExtender(CollectionExtender): + __types__ = (list, deque) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class SetExtender(CollectionExtender): + __types__ = (set, frozenset) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class DictExtender(CollectionExtender): + __types__ = (dict,) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + return type_args * number_of_args + + +class FallbackExtender(CollectionExtender): + __types__ = () + + @staticmethod + def _extend_type_args( + type_args: tuple[Any, ...], + number_of_args: int, # noqa: ARG004 + ) -> tuple[Any, ...]: # - investigate @guacs + return type_args |