diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory')
73 files changed, 4816 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/polyfactory/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/__init__.py new file mode 100644 index 0000000..0a2269e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__init__.py @@ -0,0 +1,16 @@ +from .exceptions import ConfigurationException +from .factories import BaseFactory +from .fields import Fixture, Ignore, PostGenerated, Require, Use +from .persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + +__all__ = ( + "AsyncPersistenceProtocol", + "BaseFactory", + "ConfigurationException", + "Fixture", + "Ignore", + "PostGenerated", + "Require", + "SyncPersistenceProtocol", + "Use", +) diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..415b3cc --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5421b30 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c170008 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e0c5378 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2d049c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8575103 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0c7fc3a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6d40b2e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..63cd9ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py new file mode 100644 index 0000000..6377125 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import deque +from typing import Any + +from polyfactory.utils.predicates import is_safe_subclass + + +class CollectionExtender(ABC): + __types__: tuple[type, ...] + + @staticmethod + @abstractmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + raise NotImplementedError + + @classmethod + def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]: + return next( + ( + subclass + for subclass in cls.__subclasses__() + if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__) + ), + FallbackExtender, + ) + + @classmethod + def extend_type_args( + cls, + annotation_alias: Any, + type_args: tuple[Any, ...], + number_of_args: int, + ) -> tuple[Any, ...]: + return cls._subclass_for_type(annotation_alias)._extend_type_args(type_args, number_of_args) + + +class TupleExtender(CollectionExtender): + __types__ = (tuple,) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + if type_args[-1] is not ...: + return type_args + type_to_extend = type_args[-2] + return type_args[:-2] + (type_to_extend,) * number_of_args + + +class ListLikeExtender(CollectionExtender): + __types__ = (list, deque) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class SetExtender(CollectionExtender): + __types__ = (set, frozenset) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + if not type_args: + return type_args + return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class DictExtender(CollectionExtender): + __types__ = (dict,) + + @staticmethod + def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: + return type_args * number_of_args + + +class FallbackExtender(CollectionExtender): + __types__ = () + + @staticmethod + def _extend_type_args( + type_args: tuple[Any, ...], + number_of_args: int, # noqa: ARG004 + ) -> tuple[Any, ...]: # - investigate @guacs + return type_args diff --git a/venv/lib/python3.11/site-packages/polyfactory/constants.py b/venv/lib/python3.11/site-packages/polyfactory/constants.py new file mode 100644 index 0000000..c0e6d50 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/constants.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import sys +from collections import abc, defaultdict, deque +from random import Random +from typing import ( + DefaultDict, + Deque, + Dict, + FrozenSet, + Iterable, + List, + Mapping, + Sequence, + Set, + Tuple, + Union, +) + +try: + from types import UnionType +except ImportError: + UnionType = Union # type: ignore[misc,assignment] + +PY_38 = sys.version_info.major == 3 and sys.version_info.minor == 8 # noqa: PLR2004 + +# Mapping of type annotations into concrete types. This is used to normalize python <= 3.9 annotations. +INSTANTIABLE_TYPE_MAPPING = { + DefaultDict: defaultdict, + Deque: deque, + Dict: dict, + FrozenSet: frozenset, + Iterable: list, + List: list, + Mapping: dict, + Sequence: list, + Set: set, + Tuple: tuple, + abc.Iterable: list, + abc.Mapping: dict, + abc.Sequence: list, + abc.Set: set, + UnionType: Union, +} + + +if not PY_38: + TYPE_MAPPING = INSTANTIABLE_TYPE_MAPPING +else: + # For 3.8, we have to keep the types from typing since dict[str] syntax is not supported in 3.8. + TYPE_MAPPING = { + DefaultDict: DefaultDict, + Deque: Deque, + Dict: Dict, + FrozenSet: FrozenSet, + Iterable: List, + List: List, + Mapping: Dict, + Sequence: List, + Set: Set, + Tuple: Tuple, + abc.Iterable: List, + abc.Mapping: Dict, + abc.Sequence: List, + abc.Set: Set, + } + + +DEFAULT_RANDOM = Random() +RANDOMIZE_COLLECTION_LENGTH = False +MIN_COLLECTION_LENGTH = 0 +MAX_COLLECTION_LENGTH = 5 diff --git a/venv/lib/python3.11/site-packages/polyfactory/decorators.py b/venv/lib/python3.11/site-packages/polyfactory/decorators.py new file mode 100644 index 0000000..88c1021 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/decorators.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import contextlib +import inspect +from typing import Any, Callable + +from polyfactory import PostGenerated + + +class post_generated: # noqa: N801 + """Descriptor class for wrapping a classmethod into a ``PostGenerated`` field.""" + + __slots__ = ("method", "cache") + + def __init__(self, method: Callable | classmethod) -> None: + if not isinstance(method, classmethod): + msg = "post_generated decorator can only be used on classmethods" + raise TypeError(msg) + self.method = method + self.cache: dict[type, PostGenerated] = {} + + def __get__(self, obj: Any, objtype: type) -> PostGenerated: + with contextlib.suppress(KeyError): + return self.cache[objtype] + fn = self.method.__func__ # pyright: ignore[reportFunctionMemberAccess] + fn_args = inspect.getfullargspec(fn).args[1:] + + def new_fn(name: str, values: dict[str, Any]) -> Any: # noqa: ARG001 - investigate @guacs + return fn(objtype, **{arg: values[arg] for arg in fn_args if arg in values}) + + return self.cache.setdefault(objtype, PostGenerated(new_fn)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/exceptions.py b/venv/lib/python3.11/site-packages/polyfactory/exceptions.py new file mode 100644 index 0000000..53f1271 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/exceptions.py @@ -0,0 +1,18 @@ +class FactoryException(Exception): + """Base Factory error class""" + + +class ConfigurationException(FactoryException): + """Configuration Error class - used for misconfiguration""" + + +class ParameterException(FactoryException): + """Parameter exception - used when wrong parameters are used""" + + +class MissingBuildKwargException(FactoryException): + """Missing Build Kwarg exception - used when a required build kwarg is not provided""" + + +class MissingDependencyException(FactoryException, ImportError): + """Missing dependency exception - used when a dependency is not installed""" diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py new file mode 100644 index 0000000..c8a9b92 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py @@ -0,0 +1,5 @@ +from polyfactory.factories.base import BaseFactory +from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.typed_dict_factory import TypedDictFactory + +__all__ = ("BaseFactory", "TypedDictFactory", "DataclassFactory") diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0ebad18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b5ca945 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2ee4cdb --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0fc27ab --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6e59c83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5e528d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6f45693 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ca70ca8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..31d93d5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a1ec583 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py new file mode 100644 index 0000000..00ffa03 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +if TYPE_CHECKING: + from typing import Any, TypeGuard + + +try: + import attrs + from attr._make import Factory + from attrs import AttrsInstance +except ImportError as ex: + msg = "attrs is not installed" + raise MissingDependencyException(msg) from ex + + +T = TypeVar("T", bound=AttrsInstance) + + +class AttrsFactory(Generic[T], BaseFactory[T]): + """Base factory for attrs classes.""" + + __model__: type[T] + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + return isclass(value) and hasattr(value, "__attrs_attrs__") + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + field_metas: list[FieldMeta] = [] + none_type = type(None) + + cls.resolve_types(cls.__model__) + fields = attrs.fields(cls.__model__) + + for field in fields: + if not field.init: + continue + + annotation = none_type if field.type is None else field.type + + default = field.default + if isinstance(default, Factory): + # The default value is not currently being used when generating + # the field values. When that is implemented, this would need + # to be handled differently since the `default.factory` could + # take a `self` argument. + default_value = default.factory + elif default is None: + default_value = Null + else: + default_value = default + + field_metas.append( + FieldMeta.from_type( + annotation=annotation, + name=field.alias, + default=default_value, + random=cls.__random__, + ), + ) + + return field_metas + + @classmethod + def resolve_types(cls, model: type[T], **kwargs: Any) -> None: + """Resolve any strings and forward annotations in type annotations. + + :param model: The model to resolve the type annotations for. + :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`. + """ + + attrs.resolve_types(model, **kwargs) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py new file mode 100644 index 0000000..60fe7a7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py @@ -0,0 +1,1127 @@ +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections import Counter, abc, deque +from contextlib import suppress +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import EnumMeta +from functools import partial +from importlib import import_module +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, + ip_address, + ip_interface, + ip_network, +) +from os.path import realpath +from pathlib import Path +from random import Random +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Collection, + Generic, + Iterable, + Mapping, + Sequence, + Type, + TypedDict, + TypeVar, + cast, +) +from uuid import UUID + +from faker import Faker +from typing_extensions import get_args, get_origin, get_original_bases + +from polyfactory.constants import ( + DEFAULT_RANDOM, + MAX_COLLECTION_LENGTH, + MIN_COLLECTION_LENGTH, + RANDOMIZE_COLLECTION_LENGTH, +) +from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException +from polyfactory.field_meta import Null +from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.utils.helpers import ( + flatten_annotation, + get_collection_type, + unwrap_annotation, + unwrap_args, + unwrap_optional, +) +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage +from polyfactory.value_generators.constrained_collections import ( + handle_constrained_collection, + handle_constrained_mapping, +) +from polyfactory.value_generators.constrained_dates import handle_constrained_date +from polyfactory.value_generators.constrained_numbers import ( + handle_constrained_decimal, + handle_constrained_float, + handle_constrained_int, +) +from polyfactory.value_generators.constrained_path import handle_constrained_path +from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes +from polyfactory.value_generators.constrained_url import handle_constrained_url +from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + from polyfactory.field_meta import Constraints, FieldMeta + from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + + +T = TypeVar("T") +F = TypeVar("F", bound="BaseFactory[Any]") + + +class BuildContext(TypedDict): + seen_models: set[type] + + +def _get_build_context(build_context: BuildContext | None) -> BuildContext: + if build_context is None: + return {"seen_models": set()} + + return copy.deepcopy(build_context) + + +class BaseFactory(ABC, Generic[T]): + """Base Factory class - this class holds the main logic of the library""" + + # configuration attributes + __model__: type[T] + """ + The model for the factory. + This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument. + """ + __check_model__: bool = False + """ + Flag dictating whether to check if fields defined on the factory exists on the model or not. + If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only. + """ + __allow_none_optionals__: ClassVar[bool] = True + """ + Flag dictating whether to allow 'None' for optional values. + If 'True', 'None' will be randomly generated as a value for optional model fields + """ + __sync_persistence__: type[SyncPersistenceProtocol[T]] | SyncPersistenceProtocol[T] | None = None + """A sync persistence handler. Can be a class or a class instance.""" + __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None + """An async persistence handler. Can be a class or a class instance.""" + __set_as_default_factory_for_type__ = False + """ + Flag dictating whether to set as the default factory for the given type. + If 'True' the factory will be used instead of dynamically generating a factory for the type. + """ + __is_base_factory__: bool = False + """ + Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types. + For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories. + """ + __base_factory_overrides__: dict[Any, type[BaseFactory[Any]]] | None = None + """ + A base factory to override with this factory. If this value is set, the given factory will replace the given base factory. + + Note: this value can only be set when '__is_base_factory__' is 'True'. + """ + __faker__: ClassVar["Faker"] = Faker() + """ + A faker instance to use. Can be a user provided value. + """ + __random__: ClassVar["Random"] = DEFAULT_RANDOM + """ + An instance of 'random.Random' to use. + """ + __random_seed__: ClassVar[int] + """ + An integer to seed the factory's Faker and Random instances with. + This attribute can be used to control random generation. + """ + __randomize_collection_length__: ClassVar[bool] = RANDOMIZE_COLLECTION_LENGTH + """ + Flag dictating whether to randomize collections lengths. + """ + __min_collection_length__: ClassVar[int] = MIN_COLLECTION_LENGTH + """ + An integer value that defines minimum length of a collection. + """ + __max_collection_length__: ClassVar[int] = MAX_COLLECTION_LENGTH + """ + An integer value that defines maximum length of a collection. + """ + __use_defaults__: ClassVar[bool] = False + """ + Flag indicating whether to use the default value on a specific field, if provided. + """ + + __config_keys__: tuple[str, ...] = ( + "__check_model__", + "__allow_none_optionals__", + "__set_as_default_factory_for_type__", + "__faker__", + "__random__", + "__randomize_collection_length__", + "__min_collection_length__", + "__max_collection_length__", + "__use_defaults__", + ) + """Keys to be considered as config values to pass on to dynamically created factories.""" + + # cached attributes + _fields_metadata: list[FieldMeta] + # BaseFactory only attributes + _factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]] + _base_factories: ClassVar[list[type[BaseFactory[Any]]]] + + # Non-public attributes + _extra_providers: dict[Any, Callable[[], Any]] | None = None + + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 + super().__init_subclass__(*args, **kwargs) + + if not hasattr(BaseFactory, "_base_factories"): + BaseFactory._base_factories = [] + + if not hasattr(BaseFactory, "_factory_type_mapping"): + BaseFactory._factory_type_mapping = {} + + if cls.__min_collection_length__ > cls.__max_collection_length__: + msg = "Minimum collection length shouldn't be greater than maximum collection length" + raise ConfigurationException( + msg, + ) + + if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: + model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type() + if not model: + msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" + raise ConfigurationException( + msg, + ) + cls.__model__ = model + if not cls.is_supported_type(model): + for factory in BaseFactory._base_factories: + if factory.is_supported_type(model): + msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}" + raise ConfigurationException( + msg, + ) + msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory." + raise ConfigurationException( + msg, + ) + if cls.__check_model__: + cls._check_declared_fields_exist_in_model() + else: + BaseFactory._base_factories.append(cls) + + random_seed = getattr(cls, "__random_seed__", None) + if random_seed is not None: + cls.seed_random(random_seed) + + if cls.__set_as_default_factory_for_type__ and hasattr(cls, "__model__"): + BaseFactory._factory_type_mapping[cls.__model__] = cls + + @classmethod + def _infer_model_type(cls: type[F]) -> type[T] | None: + """Return model type inferred from class declaration. + class Foo(ModelFactory[MyModel]): # <<< MyModel + ... + + If more than one base class and/or generic arguments specified return None. + + :returns: Inferred model type or None + """ + + factory_bases: Iterable[type[BaseFactory[T]]] = ( + b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory) + ) + generic_args: Sequence[type[T]] = [ + arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) + ] + if len(generic_args) != 1: + return None + + return generic_args[0] + + @classmethod + def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: + """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + + :raises: ConfigurationException + :returns: SyncPersistenceHandler + """ + if cls.__sync_persistence__: + return cls.__sync_persistence__() if callable(cls.__sync_persistence__) else cls.__sync_persistence__ + msg = "A '__sync_persistence__' handler must be defined in the factory to use this method" + raise ConfigurationException( + msg, + ) + + @classmethod + def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: + """Return a AsyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + + :raises: ConfigurationException + :returns: AsyncPersistenceHandler + """ + if cls.__async_persistence__: + return cls.__async_persistence__() if callable(cls.__async_persistence__) else cls.__async_persistence__ + msg = "An '__async_persistence__' handler must be defined in the factory to use this method" + raise ConfigurationException( + msg, + ) + + @classmethod + def _handle_factory_field( + cls, + field_value: Any, + build_context: BuildContext, + field_build_parameters: Any | None = None, + ) -> Any: + """Handle a value defined on the factory class itself. + + :param field_value: A value defined as an attribute on the factory class. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value correlating with the given field_meta value. + """ + if is_safe_subclass(field_value, BaseFactory): + if isinstance(field_build_parameters, Mapping): + return field_value.build(_build_context=build_context, **field_build_parameters) + + if isinstance(field_build_parameters, Sequence): + return [ + field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters + ] + + return field_value.build(_build_context=build_context) + + if isinstance(field_value, Use): + return field_value.to_value() + + if isinstance(field_value, Fixture): + return field_value.to_value() + + return field_value() if callable(field_value) else field_value + + @classmethod + def _handle_factory_field_coverage( + cls, + field_value: Any, + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: + """Handle a value defined on the factory class itself. + + :param field_value: A value defined as an attribute on the factory class. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value correlating with the given field_meta value. + """ + if is_safe_subclass(field_value, BaseFactory): + if isinstance(field_build_parameters, Mapping): + return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) + + if isinstance(field_build_parameters, Sequence): + return [ + CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) + for parameter in field_build_parameters + ] + + return CoverageContainer(field_value.coverage()) + + if isinstance(field_value, Use): + return field_value.to_value() + + if isinstance(field_value, Fixture): + return CoverageContainerCallable(field_value.to_value) + + return CoverageContainerCallable(field_value) if callable(field_value) else field_value + + @classmethod + def _get_config(cls) -> dict[str, Any]: + return { + **{key: getattr(cls, key) for key in cls.__config_keys__}, + "_extra_providers": cls.get_provider_map(), + } + + @classmethod + def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: + """Get a factory from registered factories or generate a factory dynamically. + + :param model: A model type. + :returns: A Factory sub-class. + + """ + if factory := BaseFactory._factory_type_mapping.get(model): + return factory + + config = cls._get_config() + + if cls.__base_factory_overrides__: + for model_ancestor in model.mro(): + if factory := cls.__base_factory_overrides__.get(model_ancestor): + return factory.create_factory(model, **config) + + for factory in reversed(BaseFactory._base_factories): + if factory.is_supported_type(model): + return factory.create_factory(model, **config) + + msg = f"unsupported model type {model.__name__}" + raise ParameterException(msg) # pragma: no cover + + # Public Methods + + @classmethod + def is_factory_type(cls, annotation: Any) -> bool: + """Determine whether a given field is annotated with a type that is supported by a base factory. + + :param annotation: A type annotation. + :returns: Boolean dictating whether the annotation is a factory type + """ + return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories) + + @classmethod + def is_batch_factory_type(cls, annotation: Any) -> bool: + """Determine whether a given field is annotated with a sequence of supported factory types. + + :param annotation: A type annotation. + :returns: Boolean dictating whether the annotation is a batch factory type + """ + origin = get_type_origin(annotation) or annotation + if is_safe_subclass(origin, Sequence) and (args := unwrap_args(annotation, random=cls.__random__)): + return len(args) == 1 and BaseFactory.is_factory_type(annotation=args[0]) + return False + + @classmethod + def extract_field_build_parameters(cls, field_meta: FieldMeta, build_args: dict[str, Any]) -> Any: + """Extract from the build kwargs any build parameters passed for a given field meta - if it is a factory type. + + :param field_meta: A field meta instance. + :param build_args: Any kwargs passed to the factory. + :returns: Any values + """ + if build_arg := build_args.get(field_meta.name): + annotation = unwrap_optional(field_meta.annotation) + if ( + BaseFactory.is_factory_type(annotation=annotation) + and isinstance(build_arg, Mapping) + and not BaseFactory.is_factory_type(annotation=type(build_arg)) + ): + return build_args.pop(field_meta.name) + + if ( + BaseFactory.is_batch_factory_type(annotation=annotation) + and isinstance(build_arg, Sequence) + and not any(BaseFactory.is_factory_type(annotation=type(value)) for value in build_arg) + ): + return build_args.pop(field_meta.name) + return None + + @classmethod + @abstractmethod + def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": # pragma: no cover + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + raise NotImplementedError + + @classmethod + def seed_random(cls, seed: int) -> None: + """Seed faker and random with the given integer. + + :param seed: An integer to set as seed. + :returns: 'None' + + """ + cls.__random__ = Random(seed) + cls.__faker__.seed_instance(seed) + + @classmethod + def is_ignored_type(cls, value: Any) -> bool: + """Check whether a given value is an ignored type. + + :param value: An arbitrary value. + + :notes: + - This method is meant to be overwritten by extension factories and other subclasses + + :returns: A boolean determining whether the value should be ignored. + + """ + return value is None + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + """Map types to callables. + + :notes: + - This method is distinct to allow overriding. + + + :returns: a dictionary mapping types to callables. + + """ + + def _create_generic_fn() -> Callable: + """Return a generic lambda""" + return lambda *args: None + + return { + Any: lambda: None, + # primitives + object: object, + float: cls.__faker__.pyfloat, + int: cls.__faker__.pyint, + bool: cls.__faker__.pybool, + str: cls.__faker__.pystr, + bytes: partial(create_random_bytes, cls.__random__), + # built-in objects + dict: cls.__faker__.pydict, + tuple: cls.__faker__.pytuple, + list: cls.__faker__.pylist, + set: cls.__faker__.pyset, + frozenset: lambda: frozenset(cls.__faker__.pylist()), + deque: lambda: deque(cls.__faker__.pylist()), + # standard library objects + Path: lambda: Path(realpath(__file__)), + Decimal: cls.__faker__.pydecimal, + UUID: lambda: UUID(cls.__faker__.uuid4()), + # datetime + datetime: cls.__faker__.date_time_between, + date: cls.__faker__.date_this_decade, + time: cls.__faker__.time_object, + timedelta: cls.__faker__.time_delta, + # ip addresses + IPv4Address: lambda: ip_address(cls.__faker__.ipv4()), + IPv4Interface: lambda: ip_interface(cls.__faker__.ipv4()), + IPv4Network: lambda: ip_network(cls.__faker__.ipv4(network=True)), + IPv6Address: lambda: ip_address(cls.__faker__.ipv6()), + IPv6Interface: lambda: ip_interface(cls.__faker__.ipv6()), + IPv6Network: lambda: ip_network(cls.__faker__.ipv6(network=True)), + # types + Callable: _create_generic_fn, + abc.Callable: _create_generic_fn, + Counter: lambda: Counter(cls.__faker__.pystr()), + **(cls._extra_providers or {}), + } + + @classmethod + def create_factory( + cls: type[F], + model: type[T] | None = None, + bases: tuple[type[BaseFactory[Any]], ...] | None = None, + **kwargs: Any, + ) -> type[F]: + """Generate a factory for the given type dynamically. + + :param model: A type to model. Defaults to current factory __model__ if any. + Otherwise, raise an error + :param bases: Base classes to use when generating the new class. + :param kwargs: Any kwargs. + + :returns: A 'ModelFactory' subclass. + + """ + if model is None: + try: + model = cls.__model__ + except AttributeError as ex: + msg = "A 'model' argument is required when creating a new factory from a base one" + raise TypeError(msg) from ex + return cast( + "Type[F]", + type( + f"{model.__name__}Factory", # pyright: ignore[reportOptionalMemberAccess] + (*(bases or ()), cls), + {"__model__": model, **kwargs}, + ), + ) + + @classmethod + def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: # noqa: C901, PLR0911, PLR0912 + try: + constraints = cast("Constraints", field_meta.constraints) + if is_safe_subclass(annotation, float): + return handle_constrained_float( + random=cls.__random__, + multiple_of=cast("Any", constraints.get("multiple_of")), + gt=cast("Any", constraints.get("gt")), + ge=cast("Any", constraints.get("ge")), + lt=cast("Any", constraints.get("lt")), + le=cast("Any", constraints.get("le")), + ) + + if is_safe_subclass(annotation, int): + return handle_constrained_int( + random=cls.__random__, + multiple_of=cast("Any", constraints.get("multiple_of")), + gt=cast("Any", constraints.get("gt")), + ge=cast("Any", constraints.get("ge")), + lt=cast("Any", constraints.get("lt")), + le=cast("Any", constraints.get("le")), + ) + + if is_safe_subclass(annotation, Decimal): + return handle_constrained_decimal( + random=cls.__random__, + decimal_places=cast("Any", constraints.get("decimal_places")), + max_digits=cast("Any", constraints.get("max_digits")), + multiple_of=cast("Any", constraints.get("multiple_of")), + gt=cast("Any", constraints.get("gt")), + ge=cast("Any", constraints.get("ge")), + lt=cast("Any", constraints.get("lt")), + le=cast("Any", constraints.get("le")), + ) + + if url_constraints := constraints.get("url"): + return handle_constrained_url(constraints=url_constraints) + + if is_safe_subclass(annotation, str) or is_safe_subclass(annotation, bytes): + return handle_constrained_string_or_bytes( + random=cls.__random__, + t_type=str if is_safe_subclass(annotation, str) else bytes, + lower_case=constraints.get("lower_case") or False, + upper_case=constraints.get("upper_case") or False, + min_length=constraints.get("min_length"), + max_length=constraints.get("max_length"), + pattern=constraints.get("pattern"), + ) + + try: + collection_type = get_collection_type(annotation) + except ValueError: + collection_type = None + if collection_type is not None: + if collection_type == dict: + return handle_constrained_mapping( + factory=cls, + field_meta=field_meta, + min_items=constraints.get("min_length"), + max_items=constraints.get("max_length"), + ) + return handle_constrained_collection( + collection_type=collection_type, # type: ignore[type-var] + factory=cls, + field_meta=field_meta.children[0] if field_meta.children else field_meta, + item_type=constraints.get("item_type"), + max_items=constraints.get("max_length"), + min_items=constraints.get("min_length"), + unique_items=constraints.get("unique_items", False), + ) + + if is_safe_subclass(annotation, date): + return handle_constrained_date( + faker=cls.__faker__, + ge=cast("Any", constraints.get("ge")), + gt=cast("Any", constraints.get("gt")), + le=cast("Any", constraints.get("le")), + lt=cast("Any", constraints.get("lt")), + tz=cast("Any", constraints.get("tz")), + ) + + if is_safe_subclass(annotation, UUID) and (uuid_version := constraints.get("uuid_version")): + return handle_constrained_uuid( + uuid_version=uuid_version, + faker=cls.__faker__, + ) + + if is_safe_subclass(annotation, Path) and (path_constraint := constraints.get("path_type")): + return handle_constrained_path(constraint=path_constraint, faker=cls.__faker__) + except TypeError as e: + raise ParameterException from e + + msg = f"received constraints for unsupported type {annotation}" + raise ParameterException(msg) + + @classmethod + def get_field_value( # noqa: C901, PLR0911, PLR0912 + cls, + field_meta: FieldMeta, + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: + """Return a field value on the subclass if existing, otherwise returns a mock value. + + :param field_meta: FieldMeta instance. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. + + :returns: An arbitrary value. + + """ + build_context = _get_build_context(build_context) + if cls.is_ignored_type(field_meta.annotation): + return None + + if field_build_parameters is None and cls.should_set_none_value(field_meta=field_meta): + return None + + unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=cls.__random__) + + if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): + return cls.__random__.choice(literal_args) + + if isinstance(unwrapped_annotation, EnumMeta): + return cls.__random__.choice(list(unwrapped_annotation)) + + if field_meta.constraints: + return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) + + if is_union(field_meta.annotation) and field_meta.children: + seen_models = build_context["seen_models"] + children = [child for child in field_meta.children if child.annotation not in seen_models] + + # `None` is removed from the children when creating FieldMeta so when `children` + # is empty, it must mean that the field meta is an optional type. + if children: + return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context) + + if BaseFactory.is_factory_type(annotation=unwrapped_annotation): + if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: + return None if is_optional(field_meta.annotation) else Null + + return cls._get_or_create_factory(model=unwrapped_annotation).build( + _build_context=build_context, + **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), + ) + + if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): + factory = cls._get_or_create_factory(model=field_meta.type_args[0]) + if isinstance(field_build_parameters, Sequence): + return [ + factory.build(_build_context=build_context, **field_parameters) + for field_parameters in field_build_parameters + ] + + if field_meta.type_args[0] in build_context["seen_models"]: + return [] + + if not cls.__randomize_collection_length__: + return [factory.build(_build_context=build_context)] + + batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) + return factory.batch(size=batch_size, _build_context=build_context) + + if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): + if cls.__randomize_collection_length__: + collection_type = get_collection_type(unwrapped_annotation) + if collection_type != dict: + return handle_constrained_collection( + collection_type=collection_type, # type: ignore[type-var] + factory=cls, + item_type=Any, + field_meta=field_meta.children[0] if field_meta.children else field_meta, + min_items=cls.__min_collection_length__, + max_items=cls.__max_collection_length__, + ) + return handle_constrained_mapping( + factory=cls, + field_meta=field_meta, + min_items=cls.__min_collection_length__, + max_items=cls.__max_collection_length__, + ) + + return handle_collection_type(field_meta, origin, cls) + + if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): + return create_random_string(cls.__random__, min_length=1, max_length=10) + + if provider := cls.get_provider_map().get(unwrapped_annotation): + return provider() + + if callable(unwrapped_annotation): + # if value is a callable we can try to naively call it. + # this will work for callables that do not require any parameters passed + with suppress(Exception): + return unwrapped_annotation() + + msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException( + msg, + ) + + @classmethod + def get_field_value_coverage( # noqa: C901 + cls, + field_meta: FieldMeta, + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Iterable[Any]: + """Return a field value on the subclass if existing, otherwise returns a mock value. + + :param field_meta: FieldMeta instance. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + :param build_context: BuildContext data for current build. + + :returns: An iterable of values. + + """ + if cls.is_ignored_type(field_meta.annotation): + return [None] + + for unwrapped_annotation in flatten_annotation(field_meta.annotation): + if unwrapped_annotation in (None, NoneType): + yield None + + elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): + yield CoverageContainer(literal_args) + + elif isinstance(unwrapped_annotation, EnumMeta): + yield CoverageContainer(list(unwrapped_annotation)) + + elif field_meta.constraints: + yield CoverageContainerCallable( + cls.get_constrained_field_value, + annotation=unwrapped_annotation, + field_meta=field_meta, + ) + + elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): + yield CoverageContainer( + cls._get_or_create_factory(model=unwrapped_annotation).coverage( + _build_context=build_context, + **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), + ), + ) + + elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + yield handle_collection_type_coverage(field_meta, origin, cls) + + elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): + yield create_random_string(cls.__random__, min_length=1, max_length=10) + + elif provider := cls.get_provider_map().get(unwrapped_annotation): + yield CoverageContainerCallable(provider) + + elif callable(unwrapped_annotation): + # if value is a callable we can try to naively call it. + # this will work for callables that do not require any parameters passed + yield CoverageContainerCallable(unwrapped_annotation) + else: + msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException( + msg, + ) + + @classmethod + def should_set_none_value(cls, field_meta: FieldMeta) -> bool: + """Determine whether a given model field_meta should be set to None. + + :param field_meta: Field metadata. + + :notes: + - This method is distinct to allow overriding. + + :returns: A boolean determining whether 'None' should be set for the given field_meta. + + """ + return ( + cls.__allow_none_optionals__ + and is_optional(field_meta.annotation) + and create_random_boolean(random=cls.__random__) + ) + + @classmethod + def should_use_default_value(cls, field_meta: FieldMeta) -> bool: + """Determine whether to use the default value for the given field. + + :param field_meta: FieldMeta instance. + + :notes: + - This method is distinct to allow overriding. + + :returns: A boolean determining whether the default value should be used for the given field_meta. + + """ + return cls.__use_defaults__ and field_meta.default is not Null + + @classmethod + def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: + """Determine whether to set a value for a given field_name. + + :param field_meta: FieldMeta instance. + :param kwargs: Any kwargs passed to the factory. + + :notes: + - This method is distinct to allow overriding. + + :returns: A boolean determining whether a value should be set for the given field_meta. + + """ + return not field_meta.name.startswith("_") and field_meta.name not in kwargs + + @classmethod + @abstractmethod + def get_model_fields(cls) -> list[FieldMeta]: # pragma: no cover + """Retrieve a list of fields from the factory's model. + + + :returns: A list of field MetaData instances. + + """ + raise NotImplementedError + + @classmethod + def get_factory_fields(cls) -> list[tuple[str, Any]]: + """Retrieve a list of fields from the factory. + + Trying to be smart about what should be considered a field on the model, + ignoring dunder methods and some parent class attributes. + + :returns: A list of tuples made of field name and field definition + """ + factory_fields = cls.__dict__.items() + return [ + (field_name, field_value) + for field_name, field_value in factory_fields + if not (field_name.startswith("__") or field_name == "_abc_impl") + ] + + @classmethod + def _check_declared_fields_exist_in_model(cls) -> None: + model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} + factory_fields = cls.get_factory_fields() + + for field_name, field_value in factory_fields: + if field_name in model_fields_names: + continue + + error_message = ( + f"{field_name} is declared on the factory {cls.__name__}" + f" but it is not part of the model {cls.__model__.__name__}" + ) + if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): + raise ConfigurationException(error_message) + + @classmethod + def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: + """Process the given kwargs and generate values for the factory's model. + + :param kwargs: Any build kwargs. + + :returns: A dictionary of build results. + + """ + _build_context = _get_build_context(kwargs.pop("_build_context", None)) + _build_context["seen_models"].add(cls.__model__) + + result: dict[str, Any] = {**kwargs} + generate_post: dict[str, PostGenerated] = {} + + for field_meta in cls.get_model_fields(): + field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): + if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): + field_value = getattr(cls, field_meta.name) + if isinstance(field_value, Ignore): + continue + + if isinstance(field_value, Require) and field_meta.name not in kwargs: + msg = f"Require kwarg {field_meta.name} is missing" + raise MissingBuildKwargException(msg) + + if isinstance(field_value, PostGenerated): + generate_post[field_meta.name] = field_value + continue + + result[field_meta.name] = cls._handle_factory_field( + field_value=field_value, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ) + continue + + field_result = cls.get_field_value( + field_meta, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ) + if field_result is Null: + continue + + result[field_meta.name] = field_result + + for field_name, post_generator in generate_post.items(): + result[field_name] = post_generator.to_value(field_name, result) + + return result + + @classmethod + def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: + """Process the given kwargs and generate values for the factory's model. + + :param kwargs: Any build kwargs. + :param build_context: BuildContext data for current build. + + :returns: A dictionary of build results. + + """ + _build_context = _get_build_context(kwargs.pop("_build_context", None)) + _build_context["seen_models"].add(cls.__model__) + + result: dict[str, Any] = {**kwargs} + generate_post: dict[str, PostGenerated] = {} + + for field_meta in cls.get_model_fields(): + field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + + if cls.should_set_field_value(field_meta, **kwargs): + if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): + field_value = getattr(cls, field_meta.name) + if isinstance(field_value, Ignore): + continue + + if isinstance(field_value, Require) and field_meta.name not in kwargs: + msg = f"Require kwarg {field_meta.name} is missing" + raise MissingBuildKwargException(msg) + + if isinstance(field_value, PostGenerated): + generate_post[field_meta.name] = field_value + continue + + result[field_meta.name] = cls._handle_factory_field_coverage( + field_value=field_value, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ) + continue + + result[field_meta.name] = CoverageContainer( + cls.get_field_value_coverage( + field_meta, + field_build_parameters=field_build_parameters, + build_context=_build_context, + ), + ) + + for resolved in resolve_kwargs_coverage(result): + for field_name, post_generator in generate_post.items(): + resolved[field_name] = post_generator.to_value(field_name, resolved) + yield resolved + + @classmethod + def build(cls, **kwargs: Any) -> T: + """Build an instance of the factory's __model__ + + :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. + + :returns: An instance of type T. + + """ + return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + + @classmethod + def batch(cls, size: int, **kwargs: Any) -> list[T]: + """Build a batch of size n of the factory's Meta.model. + + :param size: Size of the batch. + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A list of instances of type T. + + """ + return [cls.build(**kwargs) for _ in range(size)] + + @classmethod + def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: + """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. + + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A iterator of instances of type T. + + """ + for data in cls.process_kwargs_coverage(**kwargs): + instance = cls.__model__(**data) + yield cast("T", instance) + + @classmethod + def create_sync(cls, **kwargs: Any) -> T: + """Build and persists synchronously a single model instance. + + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: An instance of type T. + + """ + return cls._get_sync_persistence().save(data=cls.build(**kwargs)) + + @classmethod + def create_batch_sync(cls, size: int, **kwargs: Any) -> list[T]: + """Build and persists synchronously a batch of n size model instances. + + :param size: Size of the batch. + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A list of instances of type T. + + """ + return cls._get_sync_persistence().save_many(data=cls.batch(size, **kwargs)) + + @classmethod + async def create_async(cls, **kwargs: Any) -> T: + """Build and persists asynchronously a single model instance. + + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: An instance of type T. + """ + return await cls._get_async_persistence().save(data=cls.build(**kwargs)) + + @classmethod + async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: + """Build and persists asynchronously a batch of n size model instances. + + + :param size: Size of the batch. + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A list of instances of type T. + """ + return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) + + +def _register_builtin_factories() -> None: + """This function is used to register the base factories, if present. + + :returns: None + """ + import polyfactory.factories.dataclass_factory + import polyfactory.factories.typed_dict_factory # noqa: F401 + + for module in [ + "polyfactory.factories.pydantic_factory", + "polyfactory.factories.beanie_odm_factory", + "polyfactory.factories.odmantic_odm_factory", + "polyfactory.factories.msgspec_factory", + # `AttrsFactory` is not being registered by default since not all versions of `attrs` are supported. + # Issue: https://github.com/litestar-org/polyfactory/issues/356 + # "polyfactory.factories.attrs_factory", + ]: + try: + import_module(module) + except ImportError: # noqa: PERF203 + continue + + +_register_builtin_factories() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py new file mode 100644 index 0000000..ddd3169 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from typing_extensions import get_args + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.persistence import AsyncPersistenceProtocol +from polyfactory.utils.predicates import is_safe_subclass + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + from polyfactory.factories.base import BuildContext + from polyfactory.field_meta import FieldMeta + +try: + from beanie import Document +except ImportError as e: + msg = "beanie is not installed" + raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Document) + + +class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]): + """Persistence Handler using beanie logic""" + + async def save(self, data: T) -> T: + """Persist a single instance in mongoDB.""" + return await data.insert() # type: ignore[no-any-return] + + async def save_many(self, data: list[T]) -> list[T]: + """Persist multiple instances in mongoDB. + + .. note:: we cannot use the ``.insert_many`` method from Beanie here because it doesn't + return the created instances + """ + return [await doc.insert() for doc in data] # pyright: ignore[reportGeneralTypeIssues] + + +class BeanieDocumentFactory(Generic[T], ModelFactory[T]): + """Base factory for Beanie Documents""" + + __async_persistence__ = BeaniePersistenceHandler + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + return is_safe_subclass(value, Document) + + @classmethod + def get_field_value( + cls, + field_meta: "FieldMeta", + field_build_parameters: Any | None = None, + build_context: BuildContext | None = None, + ) -> Any: + """Return a field value on the subclass if existing, otherwise returns a mock value. + + :param field_meta: FieldMeta instance. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value. + + """ + if hasattr(field_meta.annotation, "__name__"): + if "Indexed " in field_meta.annotation.__name__: + base_type = field_meta.annotation.__bases__[0] + field_meta.annotation = base_type + + if "Link" in field_meta.annotation.__name__: + link_class = get_args(field_meta.annotation)[0] + field_meta.annotation = link_class + field_meta.annotation = link_class + + return super().get_field_value( + field_meta=field_meta, + field_build_parameters=field_build_parameters, + build_context=build_context, + ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py new file mode 100644 index 0000000..01cfbe7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields, is_dataclass +from typing import Any, Generic + +from typing_extensions import TypeGuard, get_type_hints + +from polyfactory.factories.base import BaseFactory, T +from polyfactory.field_meta import FieldMeta, Null + + +class DataclassFactory(Generic[T], BaseFactory[T]): + """Dataclass base factory""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + return bool(is_dataclass(value)) + + @classmethod + def get_model_fields(cls) -> list["FieldMeta"]: + """Retrieve a list of fields from the factory's model. + + + :returns: A list of field MetaData instances. + + """ + fields_meta: list["FieldMeta"] = [] + + model_type_hints = get_type_hints(cls.__model__, include_extras=True) + + for field in fields(cls.__model__): # type: ignore[arg-type] + if not field.init: + continue + + if field.default_factory and field.default_factory is not MISSING: + default_value = field.default_factory() + elif field.default is not MISSING: + default_value = field.default + else: + default_value = Null + + fields_meta.append( + FieldMeta.from_type( + annotation=model_type_hints[field.name], + name=field.name, + default=default_value, + random=cls.__random__, + ), + ) + + return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py new file mode 100644 index 0000000..1b579ae --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +from typing_extensions import get_type_hints + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null +from polyfactory.value_generators.constrained_numbers import handle_constrained_int +from polyfactory.value_generators.primitives import create_random_bytes + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + +try: + import msgspec + from msgspec.structs import fields +except ImportError as e: + msg = "msgspec is not installed" + raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=msgspec.Struct) + + +class MsgspecFactory(Generic[T], BaseFactory[T]): + """Base factory for msgspec Structs.""" + + __is_base_factory__ = True + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + def get_msgpack_ext() -> msgspec.msgpack.Ext: + code = handle_constrained_int(cls.__random__, ge=-128, le=127) + data = create_random_bytes(cls.__random__) + return msgspec.msgpack.Ext(code, data) + + msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext} + + provider_map = super().get_provider_map() + provider_map.update(msgspec_provider_map) + + return provider_map + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + return isclass(value) and hasattr(value, "__struct_fields__") + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + fields_meta: list[FieldMeta] = [] + + type_hints = get_type_hints(cls.__model__, include_extras=True) + for field in fields(cls.__model__): + annotation = type_hints[field.name] + if field.default is not msgspec.NODEFAULT: + default_value = field.default + elif field.default_factory is not msgspec.NODEFAULT: + default_value = field.default_factory() + else: + default_value = Null + + fields_meta.append( + FieldMeta.from_type( + annotation=annotation, + name=field.name, + default=default_value, + random=cls.__random__, + ), + ) + return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py new file mode 100644 index 0000000..1b3367a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import decimal +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.utils.predicates import is_safe_subclass +from polyfactory.value_generators.primitives import create_random_bytes + +try: + from bson.decimal128 import Decimal128, create_decimal128_context + from odmantic import EmbeddedModel, Model + from odmantic import bson as odbson + +except ImportError as e: + msg = "odmantic is not installed" + raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Union[Model, EmbeddedModel]) + +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + +class OdmanticModelFactory(Generic[T], ModelFactory[T]): + """Base factory for odmantic models""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + return is_safe_subclass(value, (Model, EmbeddedModel)) + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + provider_map = super().get_provider_map() + provider_map.update( + { + odbson.Int64: lambda: odbson.Int64.validate(cls.__faker__.pyint()), + odbson.Decimal128: lambda: _to_decimal128(cls.__faker__.pydecimal()), + odbson.Binary: lambda: odbson.Binary.validate(create_random_bytes(cls.__random__)), + odbson._datetime: lambda: odbson._datetime.validate(cls.__faker__.date_time_between()), + # bson.Regex and bson._Pattern not supported as there is no way to generate + # a random regular expression with Faker + # bson.Regex: + # bson._Pattern: + }, + ) + return provider_map + + +def _to_decimal128(value: decimal.Decimal) -> Decimal128: + with decimal.localcontext(create_decimal128_context()) as ctx: + return Decimal128(ctx.create_decimal(value)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py new file mode 100644 index 0000000..a6028b1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +from contextlib import suppress +from datetime import timezone +from functools import partial +from os.path import realpath +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast +from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5 + +from typing_extensions import Literal, get_args, get_origin + +from polyfactory.collection_extender import CollectionExtender +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import Constraints, FieldMeta, Null +from polyfactory.utils.deprecation import check_for_deprecated_parameters +from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional +from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.primitives import create_random_bytes + +try: + import pydantic + from pydantic import VERSION, Json + from pydantic.fields import FieldInfo +except ImportError as e: + msg = "pydantic is not installed" + raise MissingDependencyException(msg) from e + +try: + # pydantic v1 + from pydantic import ( # noqa: I001 + UUID1, + UUID3, + UUID4, + UUID5, + AmqpDsn, + AnyHttpUrl, + AnyUrl, + DirectoryPath, + FilePath, + HttpUrl, + KafkaDsn, + PostgresDsn, + RedisDsn, + ) + from pydantic import BaseModel as BaseModelV1 + from pydantic.color import Color + from pydantic.fields import ( # type: ignore[attr-defined] + DeferredType, # pyright: ignore[reportGeneralTypeIssues] + ModelField, # pyright: ignore[reportGeneralTypeIssues] + Undefined, # pyright: ignore[reportGeneralTypeIssues] + ) + + # Keep this import last to prevent warnings from pydantic if pydantic v2 + # is installed. + from pydantic import PyObject + + # prevent unbound variable warnings + BaseModelV2 = BaseModelV1 + UndefinedV2 = Undefined +except ImportError: + # pydantic v2 + + # v2 specific imports + from pydantic import BaseModel as BaseModelV2 + from pydantic_core import PydanticUndefined as UndefinedV2 + from pydantic_core import to_json + + from pydantic.v1 import ( # v1 compat imports + UUID1, + UUID3, + UUID4, + UUID5, + AmqpDsn, + AnyHttpUrl, + AnyUrl, + DirectoryPath, + FilePath, + HttpUrl, + KafkaDsn, + PostgresDsn, + PyObject, + RedisDsn, + ) + from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment] + from pydantic.v1.color import Color # type: ignore[assignment] + from pydantic.v1.fields import DeferredType, ModelField, Undefined + + +if TYPE_CHECKING: + from random import Random + from typing import Callable, Sequence + + from typing_extensions import NotRequired, TypeGuard + +T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") + +_IS_PYDANTIC_V1 = VERSION.startswith("1") + + +class PydanticConstraints(Constraints): + """Metadata regarding a Pydantic type constraints, if any""" + + json: NotRequired[bool] + + +class PydanticFieldMeta(FieldMeta): + """Field meta subclass capable of handling pydantic ModelFields""" + + def __init__( + self, + *, + name: str, + annotation: type, + random: Random | None = None, + default: Any = ..., + children: list[FieldMeta] | None = None, + constraints: PydanticConstraints | None = None, + ) -> None: + super().__init__( + name=name, + annotation=annotation, + random=random, + default=default, + children=children, + constraints=constraints, + ) + + @classmethod + def from_field_info( + cls, + field_name: str, + field_info: FieldInfo, + use_alias: bool, + random: Random | None, + randomize_collection_length: bool | None = None, + min_collection_length: int | None = None, + max_collection_length: int | None = None, + ) -> PydanticFieldMeta: + """Create an instance from a pydantic field info. + + :param field_name: The name of the field. + :param field_info: A pydantic FieldInfo instance. + :param use_alias: Whether to use the field alias. + :param random: A random.Random instance. + :param randomize_collection_length: Whether to randomize collection length. + :param min_collection_length: Minimum collection length. + :param max_collection_length: Maximum collection length. + + :returns: A PydanticFieldMeta instance. + """ + check_for_deprecated_parameters( + "2.11.0", + parameters=( + ("randomize_collection_length", randomize_collection_length), + ("min_collection_length", min_collection_length), + ("max_collection_length", max_collection_length), + ), + ) + if callable(field_info.default_factory): + default_value = field_info.default_factory() + else: + default_value = field_info.default if field_info.default is not UndefinedV2 else Null + + annotation = unwrap_new_type(field_info.annotation) + children: list[FieldMeta,] | None = None + name = field_info.alias if field_info.alias and use_alias else field_name + + constraints: PydanticConstraints + # pydantic v2 does not always propagate metadata for Union types + if is_union(annotation): + constraints = {} + children = [] + for arg in get_args(annotation): + if arg is NoneType: + continue + child_field_info = FieldInfo.from_annotation(arg) + merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info) + children.append( + cls.from_field_info( + field_name="", + field_info=merged_field_info, + use_alias=use_alias, + random=random, + ), + ) + else: + metadata, is_json = [], False + for m in field_info.metadata: + if not is_json and isinstance(m, Json): # type: ignore[misc] + is_json = True + elif m is not None: + metadata.append(m) + + constraints = cast( + PydanticConstraints, + cls.parse_constraints(metadata=metadata) if metadata else {}, + ) + + if "url" in constraints: + # pydantic uses a sentinel value for url constraints + annotation = str + + if is_json: + constraints["json"] = True + + return PydanticFieldMeta.from_type( + annotation=annotation, + children=children, + constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None, + default=default_value, + name=name, + random=random or DEFAULT_RANDOM, + ) + + @classmethod + def from_model_field( # pragma: no cover + cls, + model_field: ModelField, # pyright: ignore[reportGeneralTypeIssues] + use_alias: bool, + randomize_collection_length: bool | None = None, + min_collection_length: int | None = None, + max_collection_length: int | None = None, + random: Random = DEFAULT_RANDOM, + ) -> PydanticFieldMeta: + """Create an instance from a pydantic model field. + :param model_field: A pydantic ModelField. + :param use_alias: Whether to use the field alias. + :param randomize_collection_length: A boolean flag whether to randomize collections lengths + :param min_collection_length: Minimum number of elements in randomized collection + :param max_collection_length: Maximum number of elements in randomized collection + :param random: An instance of random.Random. + + :returns: A PydanticFieldMeta instance. + + """ + check_for_deprecated_parameters( + "2.11.0", + parameters=( + ("randomize_collection_length", randomize_collection_length), + ("min_collection_length", min_collection_length), + ("max_collection_length", max_collection_length), + ), + ) + + if model_field.default is not Undefined: + default_value = model_field.default + elif callable(model_field.default_factory): + default_value = model_field.default_factory() + else: + default_value = model_field.default if model_field.default is not Undefined else Null + + name = model_field.alias if model_field.alias and use_alias else model_field.name + + outer_type = unwrap_new_type(model_field.outer_type_) + annotation = ( + model_field.outer_type_ + if isinstance(model_field.annotation, (DeferredType, ForwardRef)) + else unwrap_new_type(model_field.annotation) + ) + + constraints = cast( + "Constraints", + { + "ge": getattr(outer_type, "ge", model_field.field_info.ge), + "gt": getattr(outer_type, "gt", model_field.field_info.gt), + "le": getattr(outer_type, "le", model_field.field_info.le), + "lt": getattr(outer_type, "lt", model_field.field_info.lt), + "min_length": ( + getattr(outer_type, "min_length", model_field.field_info.min_length) + or getattr(outer_type, "min_items", model_field.field_info.min_items) + ), + "max_length": ( + getattr(outer_type, "max_length", model_field.field_info.max_length) + or getattr(outer_type, "max_items", model_field.field_info.max_items) + ), + "pattern": getattr(outer_type, "regex", model_field.field_info.regex), + "unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items), + "decimal_places": getattr(outer_type, "decimal_places", None), + "max_digits": getattr(outer_type, "max_digits", None), + "multiple_of": getattr(outer_type, "multiple_of", None), + "upper_case": getattr(outer_type, "to_upper", None), + "lower_case": getattr(outer_type, "to_lower", None), + "item_type": getattr(outer_type, "item_type", None), + }, + ) + + # pydantic v1 has constraints set for these values, but we generate them using faker + if unwrap_optional(annotation) in ( + AnyUrl, + HttpUrl, + KafkaDsn, + PostgresDsn, + RedisDsn, + AmqpDsn, + AnyHttpUrl, + ): + constraints = {} + + if model_field.field_info.const and ( + default_value is None or isinstance(default_value, (int, bool, str, bytes)) + ): + annotation = Literal[default_value] # pyright: ignore # noqa: PGH003 + + children: list[FieldMeta] = [] + + # Refer #412. + args = get_args(model_field.annotation) + if is_optional(model_field.annotation) and len(args) == 2: # noqa: PLR2004 + child_annotation = args[0] if args[0] is not NoneType else args[1] + children.append(PydanticFieldMeta.from_type(child_annotation)) + elif model_field.key_field or model_field.sub_fields: + fields_to_iterate = ( + ([model_field.key_field, *model_field.sub_fields]) + if model_field.key_field is not None + else model_field.sub_fields + ) + type_args = tuple( + ( + sub_field.outer_type_ + if isinstance(sub_field.annotation, DeferredType) + else unwrap_new_type(sub_field.annotation) + ) + for sub_field in fields_to_iterate + ) + type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate)) + if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis: + # pydantic removes ellipses from Tuples in sub_fields + type_args += (...,) + extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1) + children.extend( + PydanticFieldMeta.from_model_field( + model_field=type_arg_to_sub_field[arg], + use_alias=use_alias, + random=random, + ) + for arg in extended_type_args + ) + + return PydanticFieldMeta( + name=name, + random=random or DEFAULT_RANDOM, + annotation=annotation, + children=children or None, + default=default_value, + constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None, + ) + + if not _IS_PYDANTIC_V1: + + @classmethod + def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: + metadata = [] + for m in super().get_constraints_metadata(annotation): + if isinstance(m, FieldInfo): + metadata.extend(m.metadata) + else: + metadata.append(m) + + return metadata + + +class ModelFactory(Generic[T], BaseFactory[T]): + """Base factory for pydantic models""" + + __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} + __is_base_factory__ = True + + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: + super().__init_subclass__(*args, **kwargs) + + if ( + getattr(cls, "__model__", None) + and _is_pydantic_v1_model(cls.__model__) + and hasattr(cls.__model__, "update_forward_refs") + ): + with suppress(NameError): # pragma: no cover + cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + + return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) + + @classmethod + def get_model_fields(cls) -> list["FieldMeta"]: + """Retrieve a list of fields from the factory's model. + + + :returns: A list of field MetaData instances. + + """ + if "_fields_metadata" not in cls.__dict__: + if _is_pydantic_v1_model(cls.__model__): + cls._fields_metadata = [ + PydanticFieldMeta.from_model_field( + field, + use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined] + random=cls.__random__, + ) + for field in cls.__model__.__fields__.values() + ] + else: + cls._fields_metadata = [ + PydanticFieldMeta.from_field_info( + field_info=field_info, + field_name=field_name, + random=cls.__random__, + use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues] + "populate_by_name", + False, + ), + ) + for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues] + ] + return cls._fields_metadata + + @classmethod + def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: + constraints = cast(PydanticConstraints, field_meta.constraints) + if constraints.pop("json", None): + value = cls.get_field_value(field_meta) + return to_json(value) # pyright: ignore[reportUnboundVariable] + + return super().get_constrained_field_value(annotation, field_meta) + + @classmethod + def build( + cls, + factory_use_construct: bool = False, + **kwargs: Any, + ) -> T: + """Build an instance of the factory's __model__ + + :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the + model. This is supported only for pydantic models. + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: An instance of type T. + + """ + processed_kwargs = cls.process_kwargs(**kwargs) + + if factory_use_construct: + if _is_pydantic_v1_model(cls.__model__): + return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value] + return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value] + + return cls.__model__(**processed_kwargs) # type: ignore[return-value] + + @classmethod + def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: + """Determine whether the field is a custom root field. + + :param field_meta: FieldMeta instance. + + :returns: A boolean determining whether the field is a custom root. + + """ + return field_meta.name == "__root__" + + @classmethod + def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: + """Determine whether to set a value for a given field_name. + This is an override of BaseFactory.should_set_field_value. + + :param field_meta: FieldMeta instance. + :param kwargs: Any kwargs passed to the factory. + + :returns: A boolean determining whether a value should be set for the given field_meta. + + """ + return field_meta.name not in kwargs and ( + not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta) + ) + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + mapping = { + pydantic.ByteSize: cls.__faker__.pyint, + pydantic.PositiveInt: cls.__faker__.pyint, + pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), + pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, + pydantic.PositiveFloat: cls.__faker__.pyint, + pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), + pydantic.NonNegativeInt: cls.__faker__.pyint, + pydantic.StrictInt: cls.__faker__.pyint, + pydantic.StrictBool: cls.__faker__.pybool, + pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), + pydantic.StrictFloat: cls.__faker__.pyfloat, + pydantic.StrictStr: cls.__faker__.pystr, + pydantic.EmailStr: cls.__faker__.free_email, + pydantic.NameEmail: cls.__faker__.free_email, + pydantic.Json: cls.__faker__.json, + pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, + pydantic.AnyUrl: cls.__faker__.url, + pydantic.AnyHttpUrl: cls.__faker__.url, + pydantic.HttpUrl: cls.__faker__.url, + pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), + pydantic.SecretStr: cls.__faker__.pystr, + pydantic.IPvAnyAddress: cls.__faker__.ipv4, + pydantic.IPvAnyInterface: cls.__faker__.ipv4, + pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), + pydantic.PastDate: cls.__faker__.past_date, + pydantic.FutureDate: cls.__faker__.future_date, + } + + # v1 only values + mapping.update( + { + PyObject: lambda: "decimal.Decimal", + AmqpDsn: lambda: "amqps://example.com", + KafkaDsn: lambda: "kafka://localhost:9092", + PostgresDsn: lambda: "postgresql://user:secret@localhost", + RedisDsn: lambda: "redis://localhost:6379/0", + FilePath: lambda: Path(realpath(__file__)), + DirectoryPath: lambda: Path(realpath(__file__)).parent, + UUID1: uuid1, + UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), + UUID4: cls.__faker__.uuid4, + UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), + Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] + }, + ) + + if not _IS_PYDANTIC_V1: + mapping.update( + { + # pydantic v2 specific types + pydantic.PastDatetime: cls.__faker__.past_datetime, + pydantic.FutureDatetime: cls.__faker__.future_datetime, + pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), + pydantic.NaiveDatetime: cls.__faker__.date_time, + }, + ) + + mapping.update(super().get_provider_map()) + return mapping + + +def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: + return is_safe_subclass(model, BaseModelV1) + + +def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: + return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py new file mode 100644 index 0000000..ad8873f --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta +from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + +try: + from sqlalchemy import Column, inspect, types + from sqlalchemy.dialects import mysql, postgresql + from sqlalchemy.exc import NoInspectionAvailable + from sqlalchemy.orm import InstanceState, Mapper +except ImportError as e: + msg = "sqlalchemy is not installed" + raise MissingDependencyException(msg) from e + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + from typing_extensions import TypeGuard + + +T = TypeVar("T") + + +class SQLASyncPersistence(SyncPersistenceProtocol[T]): + def __init__(self, session: Session) -> None: + """Sync persistence handler for SQLAFactory.""" + self.session = session + + def save(self, data: T) -> T: + self.session.add(data) + self.session.commit() + return data + + def save_many(self, data: list[T]) -> list[T]: + self.session.add_all(data) + self.session.commit() + return data + + +class SQLAASyncPersistence(AsyncPersistenceProtocol[T]): + def __init__(self, session: AsyncSession) -> None: + """Async persistence handler for SQLAFactory.""" + self.session = session + + async def save(self, data: T) -> T: + self.session.add(data) + await self.session.commit() + return data + + async def save_many(self, data: list[T]) -> list[T]: + self.session.add_all(data) + await self.session.commit() + return data + + +class SQLAlchemyFactory(Generic[T], BaseFactory[T]): + """Base factory for SQLAlchemy models.""" + + __is_base_factory__ = True + + __set_primary_key__: ClassVar[bool] = True + """Configuration to consider primary key columns as a field or not.""" + __set_foreign_keys__: ClassVar[bool] = True + """Configuration to consider columns with foreign keys as a field or not.""" + __set_relationships__: ClassVar[bool] = False + """Configuration to consider relationships property as a model field or not.""" + + __session__: ClassVar[Session | Callable[[], Session] | None] = None + __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None + + __config_keys__ = ( + *BaseFactory.__config_keys__, + "__set_primary_key__", + "__set_foreign_keys__", + "__set_relationships__", + ) + + @classmethod + def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: + """Get mapping of types where column type.""" + return { + types.TupleType: cls.__faker__.pytuple, + mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), + postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False), + postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011 + postgresql.INET: lambda: cls.__faker__.ipv4(network=True), + postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), + postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), + postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), + postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), + postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 + postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 + } + + @classmethod + def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: + providers_map = super().get_provider_map() + providers_map.update(cls.get_sqlalchemy_types()) + return providers_map + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: + try: + inspected = inspect(value) + except NoInspectionAvailable: + return False + return isinstance(inspected, (Mapper, InstanceState)) + + @classmethod + def should_column_be_set(cls, column: Column) -> bool: + if not cls.__set_primary_key__ and column.primary_key: + return False + + return bool(cls.__set_foreign_keys__ or not column.foreign_keys) + + @classmethod + def get_type_from_column(cls, column: Column) -> type: + column_type = type(column.type) + if column_type in cls.get_sqlalchemy_types(): + annotation = column_type + elif issubclass(column_type, types.ARRAY): + annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] + else: + annotation = ( + column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues] + if hasattr(column.type, "impl") + else column.type.python_type + ) + + if column.nullable: + annotation = Union[annotation, None] # type: ignore[assignment] + + return annotation + + @classmethod + def get_model_fields(cls) -> list[FieldMeta]: + fields_meta: list[FieldMeta] = [] + + table: Mapper = inspect(cls.__model__) # type: ignore[assignment] + fields_meta.extend( + FieldMeta.from_type( + annotation=cls.get_type_from_column(column), + name=name, + random=cls.__random__, + ) + for name, column in table.columns.items() + if cls.should_column_be_set(column) + ) + if cls.__set_relationships__: + for name, relationship in table.relationships.items(): + class_ = relationship.entity.class_ + annotation = class_ if not relationship.uselist else List[class_] # type: ignore[valid-type] + fields_meta.append( + FieldMeta.from_type( + name=name, + annotation=annotation, + random=cls.__random__, + ), + ) + + return fields_meta + + @classmethod + def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: + if cls.__session__ is not None: + return ( + SQLASyncPersistence(cls.__session__()) + if callable(cls.__session__) + else SQLASyncPersistence(cls.__session__) + ) + return super()._get_sync_persistence() + + @classmethod + def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: + if cls.__async_session__ is not None: + return ( + SQLAASyncPersistence(cls.__async_session__()) + if callable(cls.__async_session__) + else SQLAASyncPersistence(cls.__async_session__) + ) + return super()._get_async_persistence() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py new file mode 100644 index 0000000..2a3ea1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar, get_args + +from typing_extensions import ( # type: ignore[attr-defined] + NotRequired, + Required, + TypeGuard, + _TypedDictMeta, # pyright: ignore[reportGeneralTypeIssues] + get_origin, + get_type_hints, + is_typeddict, +) + +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta) + + +class TypedDictFactory(Generic[TypedDictT], BaseFactory[TypedDictT]): + """TypedDict base factory""" + + __is_base_factory__ = True + + @classmethod + def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]: + """Determine whether the given value is supported by the factory. + + :param value: An arbitrary value. + :returns: A typeguard + """ + return is_typeddict(value) + + @classmethod + def get_model_fields(cls) -> list["FieldMeta"]: + """Retrieve a list of fields from the factory's model. + + + :returns: A list of field MetaData instances. + + """ + model_type_hints = get_type_hints(cls.__model__, include_extras=True) + + field_metas: list[FieldMeta] = [] + for field_name, annotation in model_type_hints.items(): + origin = get_origin(annotation) + if origin in (Required, NotRequired): + annotation = get_args(annotation)[0] # noqa: PLW2901 + + field_metas.append( + FieldMeta.from_type( + annotation=annotation, + random=DEFAULT_RANDOM, + name=field_name, + default=getattr(cls.__model__, field_name, Null), + ), + ) + + return field_metas diff --git a/venv/lib/python3.11/site-packages/polyfactory/field_meta.py b/venv/lib/python3.11/site-packages/polyfactory/field_meta.py new file mode 100644 index 0000000..d6288fd --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/field_meta.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from dataclasses import asdict, is_dataclass +from typing import TYPE_CHECKING, Any, Literal, Mapping, Pattern, TypedDict, cast + +from typing_extensions import get_args, get_origin + +from polyfactory.collection_extender import CollectionExtender +from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING +from polyfactory.utils.deprecation import check_for_deprecated_parameters +from polyfactory.utils.helpers import ( + get_annotation_metadata, + unwrap_annotated, + unwrap_new_type, +) +from polyfactory.utils.predicates import is_annotated +from polyfactory.utils.types import NoneType + +if TYPE_CHECKING: + import datetime + from decimal import Decimal + from random import Random + from typing import Sequence + + from typing_extensions import NotRequired, Self + + +class Null: + """Sentinel class for empty values""" + + +class UrlConstraints(TypedDict): + max_length: NotRequired[int] + allowed_schemes: NotRequired[list[str]] + host_required: NotRequired[bool] + default_host: NotRequired[str] + default_port: NotRequired[int] + default_path: NotRequired[str] + + +class Constraints(TypedDict): + """Metadata regarding a type constraints, if any""" + + allow_inf_nan: NotRequired[bool] + decimal_places: NotRequired[int] + ge: NotRequired[int | float | Decimal] + gt: NotRequired[int | float | Decimal] + item_type: NotRequired[Any] + le: NotRequired[int | float | Decimal] + lower_case: NotRequired[bool] + lt: NotRequired[int | float | Decimal] + max_digits: NotRequired[int] + max_length: NotRequired[int] + min_length: NotRequired[int] + multiple_of: NotRequired[int | float | Decimal] + path_type: NotRequired[Literal["file", "dir", "new"]] + pattern: NotRequired[str | Pattern] + tz: NotRequired[datetime.tzinfo] + unique_items: NotRequired[bool] + upper_case: NotRequired[bool] + url: NotRequired[UrlConstraints] + uuid_version: NotRequired[Literal[1, 3, 4, 5]] + + +class FieldMeta: + """Factory field metadata container. This class is used to store the data about a field of a factory's model.""" + + __slots__ = ("name", "annotation", "random", "children", "default", "constraints") + + annotation: Any + random: Random + children: list[FieldMeta] | None + default: Any + name: str + constraints: Constraints | None + + def __init__( + self, + *, + name: str, + annotation: type, + random: Random | None = None, + default: Any = Null, + children: list[FieldMeta] | None = None, + constraints: Constraints | None = None, + ) -> None: + """Create a factory field metadata instance.""" + self.annotation = annotation + self.random = random or DEFAULT_RANDOM + self.children = children + self.default = default + self.name = name + self.constraints = constraints + + @property + def type_args(self) -> tuple[Any, ...]: + """Return the normalized type args of the annotation, if any. + + :returns: a tuple of types. + """ + return tuple(TYPE_MAPPING[arg] if arg in TYPE_MAPPING else arg for arg in get_args(self.annotation)) + + @classmethod + def from_type( + cls, + annotation: Any, + random: Random = DEFAULT_RANDOM, + name: str = "", + default: Any = Null, + constraints: Constraints | None = None, + randomize_collection_length: bool | None = None, + min_collection_length: int | None = None, + max_collection_length: int | None = None, + children: list[FieldMeta] | None = None, + ) -> Self: + """Builder method to create a FieldMeta from a type annotation. + + :param annotation: A type annotation. + :param random: An instance of random.Random. + :param name: Field name + :param default: Default value, if any. + :param constraints: A dictionary of constraints, if any. + :param randomize_collection_length: A boolean flag whether to randomize collections lengths + :param min_collection_length: Minimum number of elements in randomized collection + :param max_collection_length: Maximum number of elements in randomized collection + + :returns: A field meta instance. + """ + check_for_deprecated_parameters( + "2.11.0", + parameters=( + ("randomize_collection_length", randomize_collection_length), + ("min_collection_length", min_collection_length), + ("max_collection_length", max_collection_length), + ), + ) + + annotated = is_annotated(annotation) + if not constraints and annotated: + metadata = cls.get_constraints_metadata(annotation) + constraints = cls.parse_constraints(metadata) + + if annotated: + annotation = get_args(annotation)[0] + elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING: # pragma: no cover + container = TYPE_MAPPING[origin] + annotation = container[get_args(annotation)] # type: ignore[index] + + field = cls( + annotation=annotation, + random=random, + name=name, + default=default, + children=children, + constraints=constraints, + ) + + if field.type_args and not field.children: + number_of_args = 1 + extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args) + field.children = [ + cls.from_type( + annotation=unwrap_new_type(arg), + random=random, + ) + for arg in extended_type_args + if arg is not NoneType + ] + return field + + @classmethod + def parse_constraints(cls, metadata: Sequence[Any]) -> "Constraints": + constraints = {} + + for value in metadata: + if is_annotated(value): + _, inner_metadata = unwrap_annotated(value, random=DEFAULT_RANDOM) + constraints.update(cast("dict[str, Any]", cls.parse_constraints(metadata=inner_metadata))) + elif func := getattr(value, "func", None): + if func is str.islower: + constraints["lower_case"] = True + elif func is str.isupper: + constraints["upper_case"] = True + elif func is str.isascii: + constraints["pattern"] = "[[:ascii:]]" + elif func is str.isdigit: + constraints["pattern"] = "[[:digit:]]" + elif is_dataclass(value) and (value_dict := asdict(value)) and ("allowed_schemes" in value_dict): + constraints["url"] = {k: v for k, v in value_dict.items() if v is not None} + # This is to support `Constraints`, but we can't do a isinstance with `Constraints` since isinstance + # checks with `TypedDict` is not supported. + elif isinstance(value, Mapping): + constraints.update(value) + else: + constraints.update( + { + k: v + for k, v in { + "allow_inf_nan": getattr(value, "allow_inf_nan", None), + "decimal_places": getattr(value, "decimal_places", None), + "ge": getattr(value, "ge", None), + "gt": getattr(value, "gt", None), + "item_type": getattr(value, "item_type", None), + "le": getattr(value, "le", None), + "lower_case": getattr(value, "to_lower", None), + "lt": getattr(value, "lt", None), + "max_digits": getattr(value, "max_digits", None), + "max_length": getattr(value, "max_length", getattr(value, "max_length", None)), + "min_length": getattr(value, "min_length", getattr(value, "min_items", None)), + "multiple_of": getattr(value, "multiple_of", None), + "path_type": getattr(value, "path_type", None), + "pattern": getattr(value, "regex", getattr(value, "pattern", None)), + "tz": getattr(value, "tz", None), + "unique_items": getattr(value, "unique_items", None), + "upper_case": getattr(value, "to_upper", None), + "uuid_version": getattr(value, "uuid_version", None), + }.items() + if v is not None + }, + ) + return cast("Constraints", constraints) + + @classmethod + def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: + """Get the metadatas of the constraints from the given annotation. + + :param annotation: A type annotation. + :param random: An instance of random.Random. + + :returns: A list of the metadata in the annotation. + """ + + return get_annotation_metadata(annotation) diff --git a/venv/lib/python3.11/site-packages/polyfactory/fields.py b/venv/lib/python3.11/site-packages/polyfactory/fields.py new file mode 100644 index 0000000..a3b19ef --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/fields.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Callable, Generic, TypedDict, TypeVar, cast + +from typing_extensions import ParamSpec + +from polyfactory.exceptions import ParameterException + +T = TypeVar("T") +P = ParamSpec("P") + + +class WrappedCallable(TypedDict): + """A ref storing a callable. This class is a utility meant to prevent binding of methods.""" + + value: Callable + + +class Require: + """A factory field that marks an attribute as a required build-time kwarg.""" + + +class Ignore: + """A factory field that marks an attribute as ignored.""" + + +class Use(Generic[P, T]): + """Factory field used to wrap a callable. + + The callable will be invoked whenever building the given factory attribute. + + + """ + + __slots__ = ("fn", "kwargs", "args") + + def __init__(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + """Wrap a callable. + + :param fn: A callable to wrap. + :param args: Any args to pass to the callable. + :param kwargs: Any kwargs to pass to the callable. + """ + self.fn: WrappedCallable = {"value": fn} + self.kwargs = kwargs + self.args = args + + def to_value(self) -> T: + """Invoke the callable. + + :returns: The output of the callable. + + + """ + return cast("T", self.fn["value"](*self.args, **self.kwargs)) + + +class PostGenerated: + """Factory field that allows generating values after other fields are generated by the factory.""" + + __slots__ = ("fn", "kwargs", "args") + + def __init__(self, fn: Callable, *args: Any, **kwargs: Any) -> None: + """Designate field as post-generated. + + :param fn: A callable. + :param args: Args for the callable. + :param kwargs: Kwargs for the callable. + """ + self.fn: WrappedCallable = {"value": fn} + self.kwargs = kwargs + self.args = args + + def to_value(self, name: str, values: dict[str, Any]) -> Any: + """Invoke the post-generation callback passing to it the build results. + + :param name: Field name. + :param values: Generated values. + + :returns: An arbitrary value. + """ + return self.fn["value"](name, values, *self.args, **self.kwargs) + + +class Fixture: + """Factory field to create a pytest fixture from a factory.""" + + __slots__ = ("ref", "size", "kwargs") + + def __init__(self, fixture: Callable, size: int | None = None, **kwargs: Any) -> None: + """Create a fixture from a factory. + + :param fixture: A factory that was registered as a fixture. + :param size: Optional batch size. + :param kwargs: Any build kwargs. + """ + self.ref: WrappedCallable = {"value": fixture} + self.size = size + self.kwargs = kwargs + + def to_value(self) -> Any: + """Call the factory's build or batch method. + + :raises: ParameterException + + :returns: The build result. + """ + from polyfactory.pytest_plugin import FactoryFixture + + if factory := FactoryFixture.factory_class_map.get(self.ref["value"]): + if self.size is not None: + return factory.batch(self.size, **self.kwargs) + return factory.build(**self.kwargs) + + msg = "fixture has not been registered using the register_factory decorator" + raise ParameterException(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/persistence.py b/venv/lib/python3.11/site-packages/polyfactory/persistence.py new file mode 100644 index 0000000..7aa510b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/persistence.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") + + +@runtime_checkable +class SyncPersistenceProtocol(Protocol[T]): + """Protocol for sync persistence""" + + def save(self, data: T) -> T: + """Persist a single instance synchronously. + + :param data: A single instance to persist. + + :returns: The persisted result. + + """ + ... + + def save_many(self, data: list[T]) -> list[T]: + """Persist multiple instances synchronously. + + :param data: A list of instances to persist. + + :returns: The persisted result + + """ + ... + + +@runtime_checkable +class AsyncPersistenceProtocol(Protocol[T]): + """Protocol for async persistence""" + + async def save(self, data: T) -> T: + """Persist a single instance asynchronously. + + :param data: A single instance to persist. + + :returns: The persisted result. + """ + ... + + async def save_many(self, data: list[T]) -> list[T]: + """Persist multiple instances asynchronously. + + :param data: A list of instances to persist. + + :returns: The persisted result + """ + ... diff --git a/venv/lib/python3.11/site-packages/polyfactory/py.typed b/venv/lib/python3.11/site-packages/polyfactory/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/py.typed diff --git a/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py b/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py new file mode 100644 index 0000000..0ca9196 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import re +from typing import ( + Any, + Callable, + ClassVar, + Literal, + Union, +) + +from pytest import Config, fixture # noqa: PT013 + +from polyfactory.exceptions import ParameterException +from polyfactory.factories.base import BaseFactory +from polyfactory.utils.predicates import is_safe_subclass + +Scope = Union[ + Literal["session", "package", "module", "class", "function"], + Callable[[str, Config], Literal["session", "package", "module", "class", "function"]], +] + + +split_pattern_1 = re.compile(r"([A-Z]+)([A-Z][a-z])") +split_pattern_2 = re.compile(r"([a-z\d])([A-Z])") + + +def _get_fixture_name(name: str) -> str: + """From inflection.underscore. + + :param name: str: A name. + + :returns: Normalized fixture name. + + """ + name = re.sub(split_pattern_1, r"\1_\2", name) + name = re.sub(split_pattern_2, r"\1_\2", name) + name = name.replace("-", "_") + return name.lower() + + +class FactoryFixture: + """Decorator that creates a pytest fixture from a factory""" + + __slots__ = ("scope", "autouse", "name") + + factory_class_map: ClassVar[dict[Callable, type[BaseFactory[Any]]]] = {} + + def __init__( + self, + scope: Scope = "function", + autouse: bool = False, + name: str | None = None, + ) -> None: + """Create a factory fixture decorator + + :param scope: Fixture scope + :param autouse: Autouse the fixture + :param name: Fixture name + """ + self.scope = scope + self.autouse = autouse + self.name = name + + def __call__(self, factory: type[BaseFactory[Any]]) -> Any: + if not is_safe_subclass(factory, BaseFactory): + msg = f"{factory.__name__} is not a BaseFactory subclass." + raise ParameterException(msg) + + fixture_name = self.name or _get_fixture_name(factory.__name__) + fixture_register = fixture( + scope=self.scope, # pyright: ignore[reportGeneralTypeIssues] + name=fixture_name, + autouse=self.autouse, + ) + + def _factory_fixture() -> type[BaseFactory[Any]]: + """The wrapped factory""" + return factory + + _factory_fixture.__doc__ = factory.__doc__ + marker = fixture_register(_factory_fixture) + self.factory_class_map[marker] = factory + return marker + + +def register_fixture( + factory: type[BaseFactory[Any]] | None = None, + *, + scope: Scope = "function", + autouse: bool = False, + name: str | None = None, +) -> Any: + """A decorator that allows registering model factories as fixtures. + + :param factory: An optional factory class to decorate. + :param scope: Pytest scope. + :param autouse: Auto use fixture. + :param name: Fixture name. + + :returns: A fixture factory instance. + """ + factory_fixture = FactoryFixture(scope=scope, autouse=autouse, name=name) + return factory_fixture(factory) if factory else factory_fixture diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6297806 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8f43f83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..982e068 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..22c475b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c1908b6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..59b2b6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py b/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py new file mode 100644 index 0000000..576c4ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Any, Callable, Literal, TypeVar +from warnings import warn + +from typing_extensions import ParamSpec + +__all__ = ("deprecated", "warn_deprecation", "check_for_deprecated_parameters") + + +T = TypeVar("T") +P = ParamSpec("P") +DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"] + + +def warn_deprecation( + version: str, + deprecated_name: str, + kind: DeprecatedKind, + *, + removal_in: str | None = None, + alternative: str | None = None, + info: str | None = None, + pending: bool = False, +) -> None: + """Warn about a call to a (soon to be) deprecated function. + + Args: + version: Polyfactory version where the deprecation will occur. + deprecated_name: Name of the deprecated function. + removal_in: Polyfactory version where the deprecated function will be removed. + alternative: Name of a function that should be used instead. + info: Additional information. + pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. + kind: Type of the deprecated thing. + """ + parts = [] + + if kind == "import": + access_type = "Import of" + elif kind in {"function", "method"}: + access_type = "Call to" + else: + access_type = "Use of" + + if pending: + parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") + else: + parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") + + parts.extend( + ( + f"Deprecated in polyfactory {version}", + f"This {kind} will be removed in {removal_in or 'the next major version'}", + ) # noqa: COM812 + ) + if alternative: + parts.append(f"Use {alternative!r} instead") + + if info: + parts.append(info) + + text = ". ".join(parts) + warning_class = PendingDeprecationWarning if pending else DeprecationWarning + + warn(text, warning_class, stacklevel=2) + + +def deprecated( + version: str, + *, + removal_in: str | None = None, + alternative: str | None = None, + info: str | None = None, + pending: bool = False, + kind: Literal["function", "method", "classmethod", "property"] | None = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Create a decorator wrapping a function, method or property with a warning call about a (pending) deprecation. + + Args: + version: Polyfactory version where the deprecation will occur. + removal_in: Polyfactory version where the deprecated function will be removed. + alternative: Name of a function that should be used instead. + info: Additional information. + pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. + kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure + out if it's a function or method. + + Returns: + A decorator wrapping the function call with a warning + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecation( + version=version, + deprecated_name=func.__name__, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind=kind or ("method" if inspect.ismethod(func) else "function"), + ) + return func(*args, **kwargs) + + return wrapped + + return decorator + + +def check_for_deprecated_parameters( + version: str, + *, + parameters: tuple[tuple[str, Any], ...], + default_value: Any = None, + removal_in: str | None = None, + alternative: str | None = None, + info: str | None = None, + pending: bool = False, +) -> None: + """Warn about a call to a (soon to be) deprecated argument to a function. + + Args: + version: Polyfactory version where the deprecation will occur. + parameters: Parameters to trigger warning if used. + default_value: Default value for parameter to detect if supplied or not. + removal_in: Polyfactory version where the deprecated function will be removed. + alternative: Name of a function that should be used instead. + info: Additional information. + pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. + kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure + out if it's a function or method. + """ + for parameter_name, value in parameters: + if value == default_value: + continue + + warn_deprecation( + version=version, + deprecated_name=parameter_name, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind="parameter", + ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py b/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py new file mode 100644 index 0000000..f9924bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Mapping + +from typing_extensions import TypeAliasType, get_args, get_origin + +from polyfactory.constants import TYPE_MAPPING +from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType + +if TYPE_CHECKING: + from random import Random + from typing import Sequence + + +def unwrap_new_type(annotation: Any) -> Any: + """Return base type if given annotation is a type derived with NewType, otherwise annotation. + + :param annotation: A type annotation, possibly one created using 'types.NewType' + + :returns: The unwrapped annotation. + """ + while is_new_type(annotation): + annotation = annotation.__supertype__ + + return annotation + + +def unwrap_union(annotation: Any, random: Random) -> Any: + """Unwraps union types - recursively. + + :param annotation: A type annotation, possibly a type union. + :param random: An instance of random.Random. + :returns: A type annotation + """ + while is_union(annotation): + args = list(get_args(annotation)) + annotation = random.choice(args) + return annotation + + +def unwrap_optional(annotation: Any) -> Any: + """Unwraps optional union types - recursively. + + :param annotation: A type annotation, possibly an optional union. + + :returns: A type annotation + """ + while is_optional(annotation): + annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None)) + return annotation + + +def unwrap_annotation(annotation: Any, random: Random) -> Any: + """Unwraps an annotation. + + :param annotation: A type annotation. + :param random: An instance of random.Random. + + :returns: The unwrapped annotation. + + """ + while ( + is_optional(annotation) + or is_union(annotation) + or is_new_type(annotation) + or is_annotated(annotation) + or isinstance(annotation, TypeAliasType) + ): + if is_new_type(annotation): + annotation = unwrap_new_type(annotation) + elif is_optional(annotation): + annotation = unwrap_optional(annotation) + elif is_annotated(annotation): + annotation = unwrap_annotated(annotation, random=random)[0] + elif isinstance(annotation, TypeAliasType): + annotation = annotation.__value__ + else: + annotation = unwrap_union(annotation, random=random) + + return annotation + + +def flatten_annotation(annotation: Any) -> list[Any]: + """Flattens an annotation. + + :param annotation: A type annotation. + + :returns: The flattened annotations. + """ + flat = [] + if is_new_type(annotation): + flat.extend(flatten_annotation(unwrap_new_type(annotation))) + elif is_optional(annotation): + for a in get_args(annotation): + flat.extend(flatten_annotation(a)) + elif is_annotated(annotation): + flat.extend(flatten_annotation(get_args(annotation)[0])) + elif is_union(annotation): + for a in get_args(annotation): + flat.extend(flatten_annotation(a)) + else: + flat.append(annotation) + + return flat + + +def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]: + """Unwrap the annotation and return any type args. + + :param annotation: A type annotation + :param random: An instance of random.Random. + + :returns: A tuple of type args. + + """ + + return get_args(unwrap_annotation(annotation=annotation, random=random)) + + +def unwrap_annotated(annotation: Any, random: Random) -> tuple[Any, list[Any]]: + """Unwrap an annotated type and return a tuple of type args and optional metadata. + + :param annotation: An annotated type annotation + :param random: An instance of random.Random. + + :returns: A tuple of type args. + + """ + args = [arg for arg in get_args(annotation) if arg is not None] + return unwrap_annotation(args[0], random=random), args[1:] + + +def normalize_annotation(annotation: Any, random: Random) -> Any: + """Normalize an annotation. + + :param annotation: A type annotation. + + :returns: A normalized type annotation. + + """ + if is_new_type(annotation): + annotation = unwrap_new_type(annotation) + + if is_annotated(annotation): + annotation = unwrap_annotated(annotation, random=random)[0] + + # we have to maintain compatibility with the older non-subscriptable typings. + if sys.version_info <= (3, 9): # pragma: no cover + return annotation + + origin = get_origin(annotation) or annotation + + if origin in TYPE_MAPPING: + origin = TYPE_MAPPING[origin] + + if args := get_args(annotation): + args = tuple(normalize_annotation(arg, random=random) for arg in args) + return origin[args] if origin is not type else annotation + + return origin + + +def get_annotation_metadata(annotation: Any) -> Sequence[Any]: + """Get the metadata in the annotation. + + :param annotation: A type annotation. + + :returns: The metadata. + """ + + return get_args(annotation)[1:] + + +def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]: + """Get the collection type from the annotation. + + :param annotation: A type annotation. + + :returns: The collection type. + """ + + if is_safe_subclass(annotation, list): + return list + if is_safe_subclass(annotation, Mapping): + return dict + if is_safe_subclass(annotation, tuple): + return tuple + if is_safe_subclass(annotation, set): + return set + if is_safe_subclass(annotation, frozenset): + return frozenset + + msg = f"Unknown collection type - {annotation}" + raise ValueError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py b/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py new file mode 100644 index 0000000..6fc3971 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence +from typing import AbstractSet, Any, Generic, Set, TypeVar, cast + +from typing_extensions import ParamSpec + +from polyfactory.exceptions import ParameterException + + +class CoverageContainerBase(ABC): + """Base class for coverage container implementations. + + A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and + track a "done" state to indicate that all coverage examples have been generated. + """ + + @abstractmethod + def next_value(self) -> Any: + """Provide the next value""" + ... + + @abstractmethod + def is_done(self) -> bool: + """Indicate if this container has provided every coverage example it has""" + ... + + +T = TypeVar("T") + + +class CoverageContainer(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a collection of values. + + When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate + examples to be returned (wraps around). + + If there are any coverage containers within the given collection, the values from those containers are essentially merged + into the parent container. + """ + + def __init__(self, instances: Iterable[T]) -> None: + self._pos = 0 + self._instances = list(instances) + if not self._instances: + msg = "CoverageContainer must have at least one instance" + raise ValueError(msg) + + def next_value(self) -> T: + value = self._instances[self._pos % len(self._instances)] + if isinstance(value, CoverageContainerBase): + result = value.next_value() + if value.is_done(): + # Only move onto the next instance if the sub-container is done + self._pos += 1 + return cast(T, result) + + self._pos += 1 + return value + + def is_done(self) -> bool: + return self._pos >= len(self._instances) + + def __repr__(self) -> str: + return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})" + + +P = ParamSpec("P") + + +class CoverageContainerCallable(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a callable. + + When calling ``next_value()`` the wrapped callable is called to provide a value. + """ + + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + self._func = func + self._args = args + self._kwargs = kwargs + + def next_value(self) -> T: + try: + return self._func(*self._args, **self._kwargs) + except Exception as e: # noqa: BLE001 + msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException(msg) from e + + def is_done(self) -> bool: + return True + + +def _resolve_next(unresolved: Any) -> tuple[Any, bool]: # noqa: C901 + if isinstance(unresolved, CoverageContainerBase): + result, done = _resolve_next(unresolved.next_value()) + return result, unresolved.is_done() and done + + if isinstance(unresolved, Mapping): + result = {} + done_status = True + for key, value in unresolved.items(): + val_resolved, val_done = _resolve_next(value) + key_resolved, key_done = _resolve_next(key) + result[key_resolved] = val_resolved + done_status = done_status and val_done and key_done + return result, done_status + + if isinstance(unresolved, (tuple, MutableSequence)): + result = [] + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.append(resolved) + done_status = done_status and done + if isinstance(unresolved, tuple): + result = tuple(result) + return result, done_status + + if isinstance(unresolved, Set): + result = type(unresolved)() + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.add(resolved) + done_status = done_status and done + return result, done_status + + if issubclass(type(unresolved), AbstractSet): + result = type(unresolved)() + done_status = True + resolved_values = [] + for value in unresolved: + resolved, done = _resolve_next(value) + resolved_values.append(resolved) + done_status = done_status and done + return result.union(resolved_values), done_status + + return unresolved, True + + +def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]: + done = False + while not done: + resolved, done = _resolve_next(kwargs) + yield resolved diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py b/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py new file mode 100644 index 0000000..895e380 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from inspect import isclass +from typing import Any, Literal, NewType, Optional, TypeVar, get_args + +from typing_extensions import Annotated, NotRequired, ParamSpec, Required, TypeGuard, _AnnotatedAlias, get_origin + +from polyfactory.constants import TYPE_MAPPING +from polyfactory.utils.types import UNION_TYPES, NoneType + +P = ParamSpec("P") +T = TypeVar("T") + + +def is_safe_subclass(annotation: Any, class_or_tuple: type[T] | tuple[type[T], ...]) -> "TypeGuard[type[T]]": + """Determine whether a given annotation is a subclass of a give type + + :param annotation: A type annotation. + :param class_or_tuple: A potential super class or classes. + + :returns: A typeguard + """ + origin = get_type_origin(annotation) + if not origin and not isclass(annotation): + return False + try: + return issubclass(origin or annotation, class_or_tuple) + except TypeError: # pragma: no cover + return False + + +def is_any(annotation: Any) -> "TypeGuard[Any]": + """Determine whether a given annotation is 'typing.Any'. + + :param annotation: A type annotation. + + :returns: A typeguard. + """ + return ( + annotation is Any + or getattr(annotation, "_name", "") == "typing.Any" + or (get_origin(annotation) in UNION_TYPES and Any in get_args(annotation)) + ) + + +def is_dict_key_or_value_type(annotation: Any) -> "TypeGuard[Any]": + """Determine whether a given annotation is a valid dict key or value type: + ``typing.KT`` or ``typing.VT``. + + :returns: A typeguard. + """ + return str(annotation) in {"~KT", "~VT"} + + +def is_union(annotation: Any) -> "TypeGuard[Any]": + """Determine whether a given annotation is 'typing.Union'. + + :param annotation: A type annotation. + + :returns: A typeguard. + """ + return get_type_origin(annotation) in UNION_TYPES + + +def is_optional(annotation: Any) -> "TypeGuard[Any | None]": + """Determine whether a given annotation is 'typing.Optional'. + + :param annotation: A type annotation. + + :returns: A typeguard. + """ + origin = get_type_origin(annotation) + return origin is Optional or (get_origin(annotation) in UNION_TYPES and NoneType in get_args(annotation)) + + +def is_literal(annotation: Any) -> bool: + """Determine whether a given annotation is 'typing.Literal'. + + :param annotation: A type annotation. + + :returns: A boolean. + """ + return ( + get_type_origin(annotation) is Literal + or repr(annotation).startswith("typing.Literal") + or repr(annotation).startswith("typing_extensions.Literal") + ) + + +def is_new_type(annotation: Any) -> "TypeGuard[type[NewType]]": + """Determine whether a given annotation is 'typing.NewType'. + + :param annotation: A type annotation. + + :returns: A typeguard. + """ + return hasattr(annotation, "__supertype__") + + +def is_annotated(annotation: Any) -> bool: + """Determine whether a given annotation is 'typing.Annotated'. + + :param annotation: A type annotation. + + :returns: A boolean. + """ + return get_origin(annotation) is Annotated or ( + isinstance(annotation, _AnnotatedAlias) and getattr(annotation, "__args__", None) is not None + ) + + +def is_any_annotated(annotation: Any) -> bool: + """Determine whether any of the types in the given annotation is + `typing.Annotated`. + + :param annotation: A type annotation. + + :returns: A boolean + """ + + return any(is_annotated(arg) or hasattr(arg, "__args__") and is_any_annotated(arg) for arg in get_args(annotation)) + + +def get_type_origin(annotation: Any) -> Any: + """Get the type origin of an annotation - safely. + + :param annotation: A type annotation. + + :returns: A type annotation. + """ + origin = get_origin(annotation) + if origin in (Annotated, Required, NotRequired): + origin = get_args(annotation)[0] + return mapped_type if (mapped_type := TYPE_MAPPING.get(origin)) else origin diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/types.py b/venv/lib/python3.11/site-packages/polyfactory/utils/types.py new file mode 100644 index 0000000..413f5dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/types.py @@ -0,0 +1,12 @@ +from typing import Union + +try: + from types import NoneType, UnionType + + UNION_TYPES = {UnionType, Union} +except ImportError: + UNION_TYPES = {Union} + + NoneType = type(None) # type: ignore[misc,assignment] + +__all__ = ("NoneType", "UNION_TYPES") diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c95bde0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..b99a526 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bb33c8a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f649c36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..926ca52 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5808e81 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..53efa2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..4dd4ad5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e653a72 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..97787ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..319510e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py new file mode 100644 index 0000000..2706891 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, AbstractSet, Any, Iterable, MutableMapping, MutableSequence, Set, Tuple, cast + +from typing_extensions import is_typeddict + +from polyfactory.constants import INSTANTIABLE_TYPE_MAPPING, PY_38 +from polyfactory.field_meta import FieldMeta +from polyfactory.utils.model_coverage import CoverageContainer + +if TYPE_CHECKING: + from polyfactory.factories.base import BaseFactory + + +def handle_collection_type(field_meta: FieldMeta, container_type: type, factory: type[BaseFactory[Any]]) -> Any: + """Handle generation of container types recursively. + + :param container_type: A type that can accept type arguments. + :param factory: A factory. + :param field_meta: A field meta instance. + + :returns: A built result. + """ + + if PY_38 and container_type in INSTANTIABLE_TYPE_MAPPING: + container_type = INSTANTIABLE_TYPE_MAPPING[container_type] # type: ignore[assignment] + + container = container_type() + if not field_meta.children: + return container + + if issubclass(container_type, MutableMapping) or is_typeddict(container_type): + for key_field_meta, value_field_meta in cast( + Iterable[Tuple[FieldMeta, FieldMeta]], + zip(field_meta.children[::2], field_meta.children[1::2]), + ): + key = factory.get_field_value(key_field_meta) + value = factory.get_field_value(value_field_meta) + container[key] = value + return container + + if issubclass(container_type, MutableSequence): + container.extend([factory.get_field_value(subfield_meta) for subfield_meta in field_meta.children]) + return container + + if issubclass(container_type, Set): + for subfield_meta in field_meta.children: + container.add(factory.get_field_value(subfield_meta)) + return container + + if issubclass(container_type, AbstractSet): + return container.union(handle_collection_type(field_meta, set, factory)) + + if issubclass(container_type, tuple): + return container_type(map(factory.get_field_value, field_meta.children)) + + msg = f"Unsupported container type: {container_type}" + raise NotImplementedError(msg) + + +def handle_collection_type_coverage( + field_meta: FieldMeta, + container_type: type, + factory: type[BaseFactory[Any]], +) -> Any: + """Handle coverage generation of container types recursively. + + :param container_type: A type that can accept type arguments. + :param factory: A factory. + :param field_meta: A field meta instance. + + :returns: An unresolved built result. + """ + container = container_type() + if not field_meta.children: + return container + + if issubclass(container_type, MutableMapping) or is_typeddict(container_type): + for key_field_meta, value_field_meta in cast( + Iterable[Tuple[FieldMeta, FieldMeta]], + zip(field_meta.children[::2], field_meta.children[1::2]), + ): + key = CoverageContainer(factory.get_field_value_coverage(key_field_meta)) + value = CoverageContainer(factory.get_field_value_coverage(value_field_meta)) + container[key] = value + return container + + if issubclass(container_type, MutableSequence): + container_instance = container_type() + for subfield_meta in field_meta.children: + container_instance.extend(factory.get_field_value_coverage(subfield_meta)) + + return container_instance + + if issubclass(container_type, Set): + set_instance = container_type() + for subfield_meta in field_meta.children: + set_instance = set_instance.union(factory.get_field_value_coverage(subfield_meta)) + + return set_instance + + if issubclass(container_type, AbstractSet): + return container.union(handle_collection_type_coverage(field_meta, set, factory)) + + if issubclass(container_type, tuple): + return container_type( + CoverageContainer(factory.get_field_value_coverage(subfield_meta)) for subfield_meta in field_meta.children + ) + + msg = f"Unsupported container type: {container_type}" + raise NotImplementedError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py new file mode 100644 index 0000000..405d02d --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.field_meta import FieldMeta + +if TYPE_CHECKING: + from polyfactory.factories.base import BaseFactory + +T = TypeVar("T", list, set, frozenset) + + +def handle_constrained_collection( + collection_type: Callable[..., T], + factory: type[BaseFactory[Any]], + field_meta: FieldMeta, + item_type: Any, + max_items: int | None = None, + min_items: int | None = None, + unique_items: bool = False, +) -> T: + """Generate a constrained list or set. + + :param collection_type: A type that can accept type arguments. + :param factory: A factory. + :param field_meta: A field meta instance. + :param item_type: Type of the collection items. + :param max_items: Maximal number of items. + :param min_items: Minimal number of items. + :param unique_items: Whether the items should be unique. + + :returns: A collection value. + """ + min_items = abs(min_items if min_items is not None else (max_items or 0)) + max_items = abs(max_items if max_items is not None else min_items + 1) + + if max_items < min_items: + msg = "max_items must be larger or equal to min_items" + raise ParameterException(msg) + + collection: set[T] | list[T] = set() if (collection_type in (frozenset, set) or unique_items) else [] + + try: + length = factory.__random__.randint(min_items, max_items) or 1 + while len(collection) < length: + value = factory.get_field_value(field_meta) + if isinstance(collection, set): + collection.add(value) + else: + collection.append(value) + return collection_type(collection) + except TypeError as e: + msg = f"cannot generate a constrained collection of type: {item_type}" + raise ParameterException(msg) from e + + +def handle_constrained_mapping( + factory: type[BaseFactory[Any]], + field_meta: FieldMeta, + max_items: int | None = None, + min_items: int | None = None, +) -> Mapping[Any, Any]: + """Generate a constrained mapping. + + :param factory: A factory. + :param field_meta: A field meta instance. + :param max_items: Maximal number of items. + :param min_items: Minimal number of items. + + :returns: A mapping instance. + """ + min_items = abs(min_items if min_items is not None else (max_items or 0)) + max_items = abs(max_items if max_items is not None else min_items + 1) + + if max_items < min_items: + msg = "max_items must be larger or equal to min_items" + raise ParameterException(msg) + + length = factory.__random__.randint(min_items, max_items) or 1 + + collection: dict[Any, Any] = {} + + children = cast(List[FieldMeta], field_meta.children) + key_field_meta = children[0] + value_field_meta = children[1] + while len(collection) < length: + key = factory.get_field_value(key_field_meta) + value = factory.get_field_value(value_field_meta) + collection[key] = value + + return collection diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py new file mode 100644 index 0000000..4e92601 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta, timezone, tzinfo +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from faker import Faker + + +def handle_constrained_date( + faker: Faker, + ge: date | None = None, + gt: date | None = None, + le: date | None = None, + lt: date | None = None, + tz: tzinfo = timezone.utc, +) -> date: + """Generates a date value fulfilling the expected constraints. + + :param faker: An instance of faker. + :param lt: Less than value. + :param le: Less than or equal value. + :param gt: Greater than value. + :param ge: Greater than or equal value. + :param tz: A timezone. + + :returns: A date instance. + """ + start_date = datetime.now(tz=tz).date() - timedelta(days=100) + if ge: + start_date = ge + elif gt: + start_date = gt + timedelta(days=1) + + end_date = datetime.now(tz=timezone.utc).date() + timedelta(days=100) + if le: + end_date = le + elif lt: + end_date = lt - timedelta(days=1) + + return cast("date", faker.date_between(start_date=start_date, end_date=end_date)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py new file mode 100644 index 0000000..23516ce --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +from decimal import Decimal +from sys import float_info +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer + +if TYPE_CHECKING: + from random import Random + +T = TypeVar("T", Decimal, int, float) + + +class NumberGeneratorProtocol(Protocol[T]): + """Protocol for custom callables used to generate numerical values""" + + def __call__(self, random: "Random", minimum: T | None = None, maximum: T | None = None) -> T: + """Signature of the callable. + + :param random: An instance of random. + :param minimum: A minimum value. + :param maximum: A maximum value. + :return: The generated numeric value. + """ + ... + + +def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: + """Return True if two floats are almost equal + + :param value_1: A float value. + :param value_2: A float value. + :param delta: A minimal delta. + + :returns: Boolean dictating whether the floats can be considered equal - given python's problematic comparison of floats. + """ + return abs(value_1 - value_2) <= delta + + +def is_multiply_of_multiple_of_in_range( + minimum: T, + maximum: T, + multiple_of: T, +) -> bool: + """Determine if at least one multiply of `multiple_of` lies in the given range. + + :param minimum: T: A minimum value. + :param maximum: T: A maximum value. + :param multiple_of: T: A value to use as a base for multiplication. + + :returns: Boolean dictating whether at least one multiply of `multiple_of` lies in the given range between minimum and maximum. + """ + + # if the range has infinity on one of its ends then infinite number of multipliers + # can be found within the range + + # if we were given floats and multiple_of is really close to zero then it doesn't make sense + # to continue trying to check the range + if ( + isinstance(minimum, float) + and isinstance(multiple_of, float) + and minimum / multiple_of in [float("+inf"), float("-inf")] + ): + return False + + multiplier = round(minimum / multiple_of) + step = 1 if multiple_of > 0 else -1 + + # since rounding can go either up or down we may end up in a situation when + # minimum is less or equal to `multiplier * multiple_of` + # or when it is greater than `multiplier * multiple_of` + # (in this case minimum is less than `(multiplier + 1)* multiple_of`). So we need to check + # that any of two values is inside the given range. ASCII graphic below explain this + # + # minimum + # -----------------+-------+-----------------------------------+---------------------------- + # + # + # minimum + # -------------------------+--------+--------------------------+---------------------------- + # + # since `multiple_of` can be a negative number adding +1 to `multiplier` drives `(multiplier + 1) * multiple_of`` + # away from `minimum` to the -infinity. It looks like this: + # minimum + # -----------------------+--------------------------------+------------------------+-------- + # + # so for negative `multiple_of` we want to subtract 1 from multiplier + for multiply in [multiplier * multiple_of, (multiplier + step) * multiple_of]: + multiply_float = float(multiply) + if ( + almost_equal_floats(multiply_float, float(minimum)) + or almost_equal_floats(multiply_float, float(maximum)) + or minimum < multiply < maximum + ): + return True + + return False + + +def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool: + """Determine whether a given value passes the pydantic multiple_of validation. + + :param value: A numeric value. + :param multiple_of: Another numeric value. + + :returns: Boolean dictating whether value is a multiple of value. + + """ + if multiple_of == 0: + return True + mod = float(value) / float(multiple_of) % 1 + return almost_equal_floats(mod, 0.0) or almost_equal_floats(mod, 1.0) + + +def get_increment(t_type: type[T]) -> T: + """Get a small increment base to add to constrained values, i.e. lt/gt entries. + + :param t_type: A value of type T. + + :returns: An increment T. + """ + values: dict[Any, Any] = { + int: 1, + float: float_info.epsilon, + Decimal: Decimal("0.001"), + } + return cast("T", values[t_type]) + + +def get_value_or_none( + t_type: type[T], + lt: T | None = None, + le: T | None = None, + gt: T | None = None, + ge: T | None = None, +) -> tuple[T | None, T | None]: + """Return an optional value. + + :param equal_value: An GE/LE value. + :param constrained: An GT/LT value. + :param increment: increment + + :returns: Optional T. + """ + if ge is not None: + minimum_value = ge + elif gt is not None: + minimum_value = gt + get_increment(t_type) + else: + minimum_value = None + + if le is not None: + maximum_value = le + elif lt is not None: + maximum_value = lt - get_increment(t_type) + else: + maximum_value = None + return minimum_value, maximum_value + + +def get_constrained_number_range( + t_type: type[T], + random: Random, + lt: T | None = None, + le: T | None = None, + gt: T | None = None, + ge: T | None = None, + multiple_of: T | None = None, +) -> tuple[T | None, T | None]: + """Return the minimum and maximum values given a field_meta's constraints. + + :param t_type: A primitive constructor - int, float or Decimal. + :param random: An instance of Random. + :param lt: Less than value. + :param le: Less than or equal value. + :param gt: Greater than value. + :param ge: Greater than or equal value. + :param multiple_of: Multiple of value. + + :returns: a tuple of optional minimum and maximum values. + """ + seed = t_type(random.random() * 10) + minimum, maximum = get_value_or_none(lt=lt, le=le, gt=gt, ge=ge, t_type=t_type) + + if minimum is not None and maximum is not None and maximum < minimum: + msg = "maximum value must be greater than minimum value" + raise ParameterException(msg) + + if multiple_of is None: + if minimum is not None and maximum is None: + return ( + (minimum, seed) if minimum == 0 else (minimum, minimum + seed) + ) # pyright: ignore[reportGeneralTypeIssues] + if maximum is not None and minimum is None: + return maximum - seed, maximum + else: + if multiple_of == 0.0: # TODO: investigate @guacs # noqa: PLR2004, FIX002 + msg = "multiple_of can not be zero" + raise ParameterException(msg) + if ( + minimum is not None + and maximum is not None + and not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of) + ): + msg = "given range should include at least one multiply of multiple_of" + raise ParameterException(msg) + + return minimum, maximum + + +def generate_constrained_number( + random: Random, + minimum: T | None, + maximum: T | None, + multiple_of: T | None, + method: "NumberGeneratorProtocol[T]", +) -> T: + """Generate a constrained number, output depends on the passed in callbacks. + + :param random: An instance of random. + :param minimum: A minimum value. + :param maximum: A maximum value. + :param multiple_of: A multiple of value. + :param method: A function that generates numbers of type T. + + :returns: A value of type T. + """ + if minimum is None or maximum is None: + return multiple_of if multiple_of is not None else method(random=random) + if multiple_of is None: + return method(random=random, minimum=minimum, maximum=maximum) + if multiple_of >= minimum: + return multiple_of + result = minimum + while not passes_pydantic_multiple_validator(result, multiple_of): + result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of + return result + + +def handle_constrained_int( + random: Random, + multiple_of: int | None = None, + gt: int | None = None, + ge: int | None = None, + lt: int | None = None, + le: int | None = None, +) -> int: + """Handle constrained integers. + + :param random: An instance of Random. + :param lt: Less than value. + :param le: Less than or equal value. + :param gt: Greater than value. + :param ge: Greater than or equal value. + :param multiple_of: Multiple of value. + + :returns: An integer. + + """ + + minimum, maximum = get_constrained_number_range( + gt=gt, + ge=ge, + lt=lt, + le=le, + t_type=int, + multiple_of=multiple_of, + random=random, + ) + return generate_constrained_number( + random=random, + minimum=minimum, + maximum=maximum, + multiple_of=multiple_of, + method=create_random_integer, + ) + + +def handle_constrained_float( + random: Random, + multiple_of: float | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, +) -> float: + """Handle constrained floats. + + :param random: An instance of Random. + :param lt: Less than value. + :param le: Less than or equal value. + :param gt: Greater than value. + :param ge: Greater than or equal value. + :param multiple_of: Multiple of value. + + :returns: A float. + """ + + minimum, maximum = get_constrained_number_range( + gt=gt, + ge=ge, + lt=lt, + le=le, + t_type=float, + multiple_of=multiple_of, + random=random, + ) + + return generate_constrained_number( + random=random, + minimum=minimum, + maximum=maximum, + multiple_of=multiple_of, + method=create_random_float, + ) + + +def validate_max_digits( + max_digits: int, + minimum: Decimal | None, + decimal_places: int | None, +) -> None: + """Validate that max digits is greater than minimum and decimal places. + + :param max_digits: The maximal number of digits for the decimal. + :param minimum: Minimal value. + :param decimal_places: Number of decimal places + + :returns: 'None' + + """ + if max_digits <= 0: + msg = "max_digits must be greater than 0" + raise ParameterException(msg) + + if minimum is not None: + min_str = str(minimum).split(".")[1] if "." in str(minimum) else str(minimum) + + if max_digits <= len(min_str): + msg = "minimum is greater than max_digits" + raise ParameterException(msg) + + if decimal_places is not None and max_digits <= decimal_places: + msg = "max_digits must be greater than decimal places" + raise ParameterException(msg) + + +def handle_decimal_length( + generated_decimal: Decimal, + decimal_places: int | None, + max_digits: int | None, +) -> Decimal: + """Handle the length of the decimal. + + :param generated_decimal: A decimal value. + :param decimal_places: Number of decimal places. + :param max_digits: Maximal number of digits. + + """ + string_number = str(generated_decimal) + sign = "-" if "-" in string_number else "+" + string_number = string_number.replace("-", "") + whole_numbers, decimals = string_number.split(".") + + if ( + max_digits is not None + and decimal_places is not None + and len(whole_numbers) + decimal_places > max_digits + or (max_digits is None or decimal_places is None) + and max_digits is not None + ): + max_decimals = max_digits - len(whole_numbers) + elif max_digits is not None: + max_decimals = decimal_places # type: ignore[assignment] + else: + max_decimals = cast("int", decimal_places) + + if max_decimals < 0: # pyright: ignore[reportOptionalOperand] + return Decimal(sign + whole_numbers[:max_decimals]) + + decimals = decimals[:max_decimals] + return Decimal(sign + whole_numbers + "." + decimals[:decimal_places]) + + +def handle_constrained_decimal( + random: Random, + multiple_of: Decimal | None = None, + decimal_places: int | None = None, + max_digits: int | None = None, + gt: Decimal | None = None, + ge: Decimal | None = None, + lt: Decimal | None = None, + le: Decimal | None = None, +) -> Decimal: + """Handle a constrained decimal. + + :param random: An instance of Random. + :param multiple_of: Multiple of value. + :param decimal_places: Number of decimal places. + :param max_digits: Maximal number of digits. + :param lt: Less than value. + :param le: Less than or equal value. + :param gt: Greater than value. + :param ge: Greater than or equal value. + + :returns: A decimal. + + """ + + minimum, maximum = get_constrained_number_range( + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + t_type=Decimal, + random=random, + ) + + if max_digits is not None: + validate_max_digits(max_digits=max_digits, minimum=minimum, decimal_places=decimal_places) + + generated_decimal = generate_constrained_number( + random=random, + minimum=minimum, + maximum=maximum, + multiple_of=multiple_of, + method=create_random_decimal, + ) + + if max_digits is not None or decimal_places is not None: + return handle_decimal_length( + generated_decimal=generated_decimal, + max_digits=max_digits, + decimal_places=decimal_places, + ) + + return generated_decimal diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py new file mode 100644 index 0000000..debaf86 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py @@ -0,0 +1,13 @@ +from os.path import realpath +from pathlib import Path +from typing import Literal, cast + +from faker import Faker + + +def handle_constrained_path(constraint: Literal["file", "dir", "new"], faker: Faker) -> Path: + if constraint == "new": + return cast("Path", faker.file_path(depth=1, category=None, extension=None)) + if constraint == "file": + return Path(realpath(__file__)) + return Path(realpath(__file__)).parent diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py new file mode 100644 index 0000000..c7da72b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Pattern, TypeVar, Union, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.value_generators.primitives import create_random_bytes, create_random_string +from polyfactory.value_generators.regex import RegexFactory + +T = TypeVar("T", bound=Union[bytes, str]) + +if TYPE_CHECKING: + from random import Random + + +def _validate_length( + min_length: int | None = None, + max_length: int | None = None, +) -> None: + """Validate the length parameters make sense. + + :param min_length: Minimum length. + :param max_length: Maximum length. + + :raises: ParameterException. + + :returns: None. + """ + if min_length is not None and min_length < 0: + msg = "min_length must be greater or equal to 0" + raise ParameterException(msg) + + if max_length is not None and max_length < 0: + msg = "max_length must be greater or equal to 0" + raise ParameterException(msg) + + if max_length is not None and min_length is not None and max_length < min_length: + msg = "max_length must be greater than min_length" + raise ParameterException(msg) + + +def _generate_pattern( + random: Random, + pattern: str | Pattern, + lower_case: bool = False, + upper_case: bool = False, + min_length: int | None = None, + max_length: int | None = None, +) -> str: + """Generate a regex. + + :param random: An instance of random. + :param pattern: A regex or string pattern. + :param lower_case: Whether to lowercase the result. + :param upper_case: Whether to uppercase the result. + :param min_length: A minimum length. + :param max_length: A maximum length. + + :returns: A string matching the given pattern. + """ + regex_factory = RegexFactory(random=random) + result = regex_factory(pattern) + if min_length: + while len(result) < min_length: + result += regex_factory(pattern) + + if max_length is not None and len(result) > max_length: + result = result[:max_length] + + if lower_case: + result = result.lower() + + if upper_case: + result = result.upper() + + return result + + +def handle_constrained_string_or_bytes( + random: Random, + t_type: Callable[[], T], + lower_case: bool = False, + upper_case: bool = False, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | Pattern | None = None, +) -> T: + """Handle constrained string or bytes, for example - pydantic `constr` or `conbytes`. + + :param random: An instance of random. + :param t_type: A type (str or bytes) + :param lower_case: Whether to lowercase the result. + :param upper_case: Whether to uppercase the result. + :param min_length: A minimum length. + :param max_length: A maximum length. + :param pattern: A regex or string pattern. + + :returns: A value of type T. + """ + _validate_length(min_length=min_length, max_length=max_length) + + if max_length == 0: + return t_type() + + if pattern: + return cast( + "T", + _generate_pattern( + random=random, + pattern=pattern, + lower_case=lower_case, + upper_case=upper_case, + min_length=min_length, + max_length=max_length, + ), + ) + + if t_type is str: + return cast( + "T", + create_random_string( + min_length=min_length, + max_length=max_length, + lower_case=lower_case, + upper_case=upper_case, + random=random, + ), + ) + + return cast( + "T", + create_random_bytes( + min_length=min_length, + max_length=max_length, + lower_case=lower_case, + upper_case=upper_case, + random=random, + ), + ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py new file mode 100644 index 0000000..d29555e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py @@ -0,0 +1,10 @@ +from polyfactory.field_meta import UrlConstraints + + +def handle_constrained_url(constraints: UrlConstraints) -> str: + schema = (constraints.get("allowed_schemes") or ["http", "https"])[0] + default_host = constraints.get("default_host") or "localhost" + default_port = constraints.get("default_port") or 80 + default_path = constraints.get("default_path") or "" + + return f"{schema}://{default_host}:{default_port}{default_path}" diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py new file mode 100644 index 0000000..053f047 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py @@ -0,0 +1,31 @@ +from typing import Literal, cast +from uuid import NAMESPACE_DNS, UUID, uuid1, uuid3, uuid5 + +from faker import Faker + +UUID_VERSION_1 = 1 +UUID_VERSION_3 = 3 +UUID_VERSION_4 = 4 +UUID_VERSION_5 = 5 + + +def handle_constrained_uuid(uuid_version: Literal[1, 3, 4, 5], faker: Faker) -> UUID: + """Generate a UUID based on the version specified. + + Args: + uuid_version: The version of the UUID to generate. + faker: The Faker instance to use. + + Returns: + The generated UUID. + """ + if uuid_version == UUID_VERSION_1: + return uuid1() + if uuid_version == UUID_VERSION_3: + return uuid3(NAMESPACE_DNS, faker.pystr()) + if uuid_version == UUID_VERSION_4: + return cast("UUID", faker.uuid4()) + if uuid_version == UUID_VERSION_5: + return uuid5(NAMESPACE_DNS, faker.pystr()) + msg = f"Unknown UUID version: {uuid_version}" + raise ValueError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py new file mode 100644 index 0000000..2cf6b41 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from binascii import hexlify +from decimal import Decimal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from random import Random + + +def create_random_float( + random: Random, + minimum: Decimal | float | None = None, + maximum: Decimal | float | None = None, +) -> float: + """Generate a random float given the constraints. + + :param random: An instance of random. + :param minimum: A minimum value + :param maximum: A maximum value. + + :returns: A random float. + """ + if minimum is None: + minimum = float(random.randint(0, 100)) if maximum is None else float(maximum) - 100.0 + if maximum is None: + maximum = float(minimum) + 1.0 * 2.0 if minimum >= 0 else float(minimum) + 1.0 / 2.0 + return random.uniform(float(minimum), float(maximum)) + + +def create_random_integer(random: Random, minimum: int | None = None, maximum: int | None = None) -> int: + """Generate a random int given the constraints. + + :param random: An instance of random. + :param minimum: A minimum value + :param maximum: A maximum value. + + :returns: A random integer. + """ + return round(create_random_float(random=random, minimum=minimum, maximum=maximum)) + + +def create_random_decimal( + random: Random, + minimum: Decimal | None = None, + maximum: Decimal | None = None, +) -> Decimal: + """Generate a random Decimal given the constraints. + + :param random: An instance of random. + :param minimum: A minimum value + :param maximum: A maximum value. + + :returns: A random decimal. + """ + return Decimal(str(create_random_float(random=random, minimum=minimum, maximum=maximum))) + + +def create_random_bytes( + random: Random, + min_length: int | None = None, + max_length: int | None = None, + lower_case: bool = False, + upper_case: bool = False, +) -> bytes: + """Generate a random bytes given the constraints. + + :param random: An instance of random. + :param min_length: A minimum length. + :param max_length: A maximum length. + :param lower_case: Whether to lowercase the result. + :param upper_case: Whether to uppercase the result. + + :returns: A random byte-string. + """ + if min_length is None: + min_length = 0 + if max_length is None: + max_length = min_length + 1 * 2 + + length = random.randint(min_length, max_length) + result = b"" if length == 0 else hexlify(random.getrandbits(length * 8).to_bytes(length, "little")) + + if lower_case: + result = result.lower() + elif upper_case: + result = result.upper() + + if max_length and len(result) > max_length: + end = random.randint(min_length or 0, max_length) + return result[:end] + + return result + + +def create_random_string( + random: Random, + min_length: int | None = None, + max_length: int | None = None, + lower_case: bool = False, + upper_case: bool = False, +) -> str: + """Generate a random string given the constraints. + + :param random: An instance of random. + :param min_length: A minimum length. + :param max_length: A maximum length. + :param lower_case: Whether to lowercase the result. + :param upper_case: Whether to uppercase the result. + + :returns: A random string. + """ + return create_random_bytes( + random=random, + min_length=min_length, + max_length=max_length, + lower_case=lower_case, + upper_case=upper_case, + ).decode("utf-8") + + +def create_random_boolean(random: Random) -> bool: + """Generate a random boolean value. + + :param random: An instance of random. + + :returns: A random boolean. + """ + return bool(random.getrandbits(1)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py new file mode 100644 index 0000000..eab9bd0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py @@ -0,0 +1,150 @@ +"""The code in this files is adapted from https://github.com/crdoconnor/xeger/blob/master/xeger/xeger.py.Which in turn +adapted it from https://bitbucket.org/leapfrogdevelopment/rstr/. + +Copyright (C) 2015, Colm O'Connor +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the Leapfrog Direct Response, LLC, including + its subsidiaries and affiliates nor the names of its + contributors, may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL LEAPFROG DIRECT +RESPONSE, LLC, INCLUDING ITS SUBSIDIARIES AND AFFILIATES, BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from __future__ import annotations + +from itertools import chain +from string import ( + ascii_letters, + ascii_lowercase, + ascii_uppercase, + digits, + printable, + punctuation, + whitespace, +) +from typing import TYPE_CHECKING, Any, Pattern + +try: # >=3.11 + from re._parser import SubPattern, parse +except ImportError: # < 3.11 + from sre_parse import SubPattern, parse # pylint: disable=deprecated-module + +if TYPE_CHECKING: + from random import Random + +_alphabets = { + "printable": printable, + "letters": ascii_letters, + "uppercase": ascii_uppercase, + "lowercase": ascii_lowercase, + "digits": digits, + "punctuation": punctuation, + "nondigits": ascii_letters + punctuation, + "nonletters": digits + punctuation, + "whitespace": whitespace, + "nonwhitespace": printable.strip(), + "normal": ascii_letters + digits + " ", + "word": ascii_letters + digits + "_", + "nonword": "".join(set(printable).difference(ascii_letters + digits + "_")), + "postalsafe": ascii_letters + digits + " .-#/", + "urlsafe": ascii_letters + digits + "-._~", + "domainsafe": ascii_letters + digits + "-", +} + +_categories = { + "category_digit": _alphabets["digits"], + "category_not_digit": _alphabets["nondigits"], + "category_space": _alphabets["whitespace"], + "category_not_space": _alphabets["nonwhitespace"], + "category_word": _alphabets["word"], + "category_not_word": _alphabets["nonword"], +} + + +class RegexFactory: + """Factory for regexes.""" + + def __init__(self, random: Random, limit: int = 10) -> None: + """Create a RegexFactory""" + self._limit = limit + self._cache: dict[str, Any] = {} + self._random = random + + self._cases = { + "literal": chr, + "not_literal": lambda x: self._random.choice(printable.replace(chr(x), "")), + "at": lambda x: "", + "in": self._handle_in, + "any": lambda x: self._random.choice(printable.replace("\n", "")), + "range": lambda x: [chr(i) for i in range(x[0], x[1] + 1)], + "category": lambda x: _categories[str(x).lower()], + "branch": lambda x: "".join(self._handle_state(i) for i in self._random.choice(x[1])), + "subpattern": self._handle_group, + "assert": lambda x: "".join(self._handle_state(i) for i in x[1]), + "assert_not": lambda x: "", + "groupref": lambda x: self._cache[x], + "min_repeat": lambda x: self._handle_repeat(*x), + "max_repeat": lambda x: self._handle_repeat(*x), + "negate": lambda x: [False], + } + + def __call__(self, string_or_regex: str | Pattern) -> str: + """Generate a string matching a regex. + + :param string_or_regex: A string or pattern. + + :return: The generated string. + """ + pattern = string_or_regex.pattern if isinstance(string_or_regex, Pattern) else string_or_regex + parsed = parse(pattern) + result = self._build_string(parsed) # pyright: ignore[reportGeneralTypeIssues] + self._cache.clear() + return result + + def _build_string(self, parsed: SubPattern) -> str: + return "".join([self._handle_state(state) for state in parsed]) # pyright:ignore[reportGeneralTypeIssues] + + def _handle_state(self, state: tuple[SubPattern, tuple[Any, ...]]) -> Any: + opcode, value = state + return self._cases[str(opcode).lower()](value) # type: ignore[no-untyped-call] + + def _handle_group(self, value: tuple[Any, ...]) -> str: + result = "".join(self._handle_state(i) for i in value[3]) + if value[0]: + self._cache[value[0]] = result + return result + + def _handle_in(self, value: tuple[Any, ...]) -> Any: + candidates = list(chain(*(self._handle_state(i) for i in value))) + if candidates and candidates[0] is False: + candidates = list(set(printable).difference(candidates[1:])) + return self._random.choice(candidates) + return self._random.choice(candidates) + + def _handle_repeat(self, start_range: int, end_range: Any, value: SubPattern) -> str: + end_range = min(end_range, self._limit) + + result = [ + "".join(self._handle_state(v) for v in list(value)) # pyright: ignore[reportGeneralTypeIssues] + for _ in range(self._random.randint(start_range, max(start_range, end_range))) + ] + return "".join(result) |