diff options
| author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
|---|---|---|
| committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
| commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
| tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/polyfactory/factories | |
| parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) | |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory/factories')
20 files changed, 2292 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py new file mode 100644 index 0000000..c8a9b92 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py @@ -0,0 +1,5 @@ +from polyfactory.factories.base import BaseFactory +from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.typed_dict_factory import TypedDictFactory + +__all__ = ("BaseFactory", "TypedDictFactory", "DataclassFactory") diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..0ebad18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..b5ca945 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pycBinary files differ new file mode 100644 index 0000000..2ee4cdb --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..0fc27ab --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6e59c83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..5e528d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6f45693 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..ca70ca8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..31d93d5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..a1ec583 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py new file mode 100644 index 0000000..00ffa03 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +if TYPE_CHECKING: +    from typing import Any, TypeGuard + + +try: +    import attrs +    from attr._make import Factory +    from attrs import AttrsInstance +except ImportError as ex: +    msg = "attrs is not installed" +    raise MissingDependencyException(msg) from ex + + +T = TypeVar("T", bound=AttrsInstance) + + +class AttrsFactory(Generic[T], BaseFactory[T]): +    """Base factory for attrs classes.""" + +    __model__: type[T] + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        return isclass(value) and hasattr(value, "__attrs_attrs__") + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        field_metas: list[FieldMeta] = [] +        none_type = type(None) + +        cls.resolve_types(cls.__model__) +        fields = attrs.fields(cls.__model__) + +        for field in fields: +            if not field.init: +                continue + +            annotation = none_type if field.type is None else field.type + +            default = field.default +            if isinstance(default, Factory): +                # The default value is not currently being used when generating +                # the field values. When that is implemented, this would need +                # to be handled differently since the `default.factory` could +                # take a `self` argument. +                default_value = default.factory +            elif default is None: +                default_value = Null +            else: +                default_value = default + +            field_metas.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    name=field.alias, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) + +        return field_metas + +    @classmethod +    def resolve_types(cls, model: type[T], **kwargs: Any) -> None: +        """Resolve any strings and forward annotations in type annotations. + +        :param model: The model to resolve the type annotations for. +        :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`. +        """ + +        attrs.resolve_types(model, **kwargs) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py new file mode 100644 index 0000000..60fe7a7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py @@ -0,0 +1,1127 @@ +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections import Counter, abc, deque +from contextlib import suppress +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import EnumMeta +from functools import partial +from importlib import import_module +from ipaddress import ( +    IPv4Address, +    IPv4Interface, +    IPv4Network, +    IPv6Address, +    IPv6Interface, +    IPv6Network, +    ip_address, +    ip_interface, +    ip_network, +) +from os.path import realpath +from pathlib import Path +from random import Random +from typing import ( +    TYPE_CHECKING, +    Any, +    Callable, +    ClassVar, +    Collection, +    Generic, +    Iterable, +    Mapping, +    Sequence, +    Type, +    TypedDict, +    TypeVar, +    cast, +) +from uuid import UUID + +from faker import Faker +from typing_extensions import get_args, get_origin, get_original_bases + +from polyfactory.constants import ( +    DEFAULT_RANDOM, +    MAX_COLLECTION_LENGTH, +    MIN_COLLECTION_LENGTH, +    RANDOMIZE_COLLECTION_LENGTH, +) +from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException +from polyfactory.field_meta import Null +from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.utils.helpers import ( +    flatten_annotation, +    get_collection_type, +    unwrap_annotation, +    unwrap_args, +    unwrap_optional, +) +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage +from polyfactory.value_generators.constrained_collections import ( +    handle_constrained_collection, +    handle_constrained_mapping, +) +from polyfactory.value_generators.constrained_dates import handle_constrained_date +from polyfactory.value_generators.constrained_numbers import ( +    handle_constrained_decimal, +    handle_constrained_float, +    handle_constrained_int, +) +from polyfactory.value_generators.constrained_path import handle_constrained_path +from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes +from polyfactory.value_generators.constrained_url import handle_constrained_url +from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +    from polyfactory.field_meta import Constraints, FieldMeta +    from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + + +T = TypeVar("T") +F = TypeVar("F", bound="BaseFactory[Any]") + + +class BuildContext(TypedDict): +    seen_models: set[type] + + +def _get_build_context(build_context: BuildContext | None) -> BuildContext: +    if build_context is None: +        return {"seen_models": set()} + +    return copy.deepcopy(build_context) + + +class BaseFactory(ABC, Generic[T]): +    """Base Factory class - this class holds the main logic of the library""" + +    # configuration attributes +    __model__: type[T] +    """ +    The model for the factory. +    This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument. +    """ +    __check_model__: bool = False +    """ +    Flag dictating whether to check if fields defined on the factory exists on the model or not. +    If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only. +    """ +    __allow_none_optionals__: ClassVar[bool] = True +    """ +    Flag dictating whether to allow 'None' for optional values. +    If 'True', 'None' will be randomly generated as a value for optional model fields +    """ +    __sync_persistence__: type[SyncPersistenceProtocol[T]] | SyncPersistenceProtocol[T] | None = None +    """A sync persistence handler. Can be a class or a class instance.""" +    __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None +    """An async persistence handler. Can be a class or a class instance.""" +    __set_as_default_factory_for_type__ = False +    """ +    Flag dictating whether to set as the default factory for the given type. +    If 'True' the factory will be used instead of dynamically generating a factory for the type. +    """ +    __is_base_factory__: bool = False +    """ +    Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types. +    For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories. +    """ +    __base_factory_overrides__: dict[Any, type[BaseFactory[Any]]] | None = None +    """ +    A base factory to override with this factory. If this value is set, the given factory will replace the given base factory. + +    Note: this value can only be set when '__is_base_factory__' is 'True'. +    """ +    __faker__: ClassVar["Faker"] = Faker() +    """ +    A faker instance to use. Can be a user provided value. +    """ +    __random__: ClassVar["Random"] = DEFAULT_RANDOM +    """ +    An instance of 'random.Random' to use. +    """ +    __random_seed__: ClassVar[int] +    """ +    An integer to seed the factory's Faker and Random instances with. +    This attribute can be used to control random generation. +    """ +    __randomize_collection_length__: ClassVar[bool] = RANDOMIZE_COLLECTION_LENGTH +    """ +    Flag dictating whether to randomize collections lengths. +    """ +    __min_collection_length__: ClassVar[int] = MIN_COLLECTION_LENGTH +    """ +    An integer value that defines minimum length of a collection. +    """ +    __max_collection_length__: ClassVar[int] = MAX_COLLECTION_LENGTH +    """ +    An integer value that defines maximum length of a collection. +    """ +    __use_defaults__: ClassVar[bool] = False +    """ +    Flag indicating whether to use the default value on a specific field, if provided. +    """ + +    __config_keys__: tuple[str, ...] = ( +        "__check_model__", +        "__allow_none_optionals__", +        "__set_as_default_factory_for_type__", +        "__faker__", +        "__random__", +        "__randomize_collection_length__", +        "__min_collection_length__", +        "__max_collection_length__", +        "__use_defaults__", +    ) +    """Keys to be considered as config values to pass on to dynamically created factories.""" + +    # cached attributes +    _fields_metadata: list[FieldMeta] +    # BaseFactory only attributes +    _factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]] +    _base_factories: ClassVar[list[type[BaseFactory[Any]]]] + +    # Non-public attributes +    _extra_providers: dict[Any, Callable[[], Any]] | None = None + +    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:  # noqa: C901 +        super().__init_subclass__(*args, **kwargs) + +        if not hasattr(BaseFactory, "_base_factories"): +            BaseFactory._base_factories = [] + +        if not hasattr(BaseFactory, "_factory_type_mapping"): +            BaseFactory._factory_type_mapping = {} + +        if cls.__min_collection_length__ > cls.__max_collection_length__: +            msg = "Minimum collection length shouldn't be greater than maximum collection length" +            raise ConfigurationException( +                msg, +            ) + +        if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: +            model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type() +            if not model: +                msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" +                raise ConfigurationException( +                    msg, +                ) +            cls.__model__ = model +            if not cls.is_supported_type(model): +                for factory in BaseFactory._base_factories: +                    if factory.is_supported_type(model): +                        msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}" +                        raise ConfigurationException( +                            msg, +                        ) +                    msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory." +                    raise ConfigurationException( +                        msg, +                    ) +            if cls.__check_model__: +                cls._check_declared_fields_exist_in_model() +        else: +            BaseFactory._base_factories.append(cls) + +        random_seed = getattr(cls, "__random_seed__", None) +        if random_seed is not None: +            cls.seed_random(random_seed) + +        if cls.__set_as_default_factory_for_type__ and hasattr(cls, "__model__"): +            BaseFactory._factory_type_mapping[cls.__model__] = cls + +    @classmethod +    def _infer_model_type(cls: type[F]) -> type[T] | None: +        """Return model type inferred from class declaration. +        class Foo(ModelFactory[MyModel]):  # <<< MyModel +            ... + +        If more than one base class and/or generic arguments specified return None. + +        :returns: Inferred model type or None +        """ + +        factory_bases: Iterable[type[BaseFactory[T]]] = ( +            b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory) +        ) +        generic_args: Sequence[type[T]] = [ +            arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) +        ] +        if len(generic_args) != 1: +            return None + +        return generic_args[0] + +    @classmethod +    def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: +        """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + +        :raises: ConfigurationException +        :returns: SyncPersistenceHandler +        """ +        if cls.__sync_persistence__: +            return cls.__sync_persistence__() if callable(cls.__sync_persistence__) else cls.__sync_persistence__ +        msg = "A '__sync_persistence__' handler must be defined in the factory to use this method" +        raise ConfigurationException( +            msg, +        ) + +    @classmethod +    def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: +        """Return a AsyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + +        :raises: ConfigurationException +        :returns: AsyncPersistenceHandler +        """ +        if cls.__async_persistence__: +            return cls.__async_persistence__() if callable(cls.__async_persistence__) else cls.__async_persistence__ +        msg = "An '__async_persistence__' handler must be defined in the factory to use this method" +        raise ConfigurationException( +            msg, +        ) + +    @classmethod +    def _handle_factory_field( +        cls, +        field_value: Any, +        build_context: BuildContext, +        field_build_parameters: Any | None = None, +    ) -> Any: +        """Handle a value defined on the factory class itself. + +        :param field_value: A value defined as an attribute on the factory class. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value correlating with the given field_meta value. +        """ +        if is_safe_subclass(field_value, BaseFactory): +            if isinstance(field_build_parameters, Mapping): +                return field_value.build(_build_context=build_context, **field_build_parameters) + +            if isinstance(field_build_parameters, Sequence): +                return [ +                    field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters +                ] + +            return field_value.build(_build_context=build_context) + +        if isinstance(field_value, Use): +            return field_value.to_value() + +        if isinstance(field_value, Fixture): +            return field_value.to_value() + +        return field_value() if callable(field_value) else field_value + +    @classmethod +    def _handle_factory_field_coverage( +        cls, +        field_value: Any, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Handle a value defined on the factory class itself. + +        :param field_value: A value defined as an attribute on the factory class. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value correlating with the given field_meta value. +        """ +        if is_safe_subclass(field_value, BaseFactory): +            if isinstance(field_build_parameters, Mapping): +                return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) + +            if isinstance(field_build_parameters, Sequence): +                return [ +                    CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) +                    for parameter in field_build_parameters +                ] + +            return CoverageContainer(field_value.coverage()) + +        if isinstance(field_value, Use): +            return field_value.to_value() + +        if isinstance(field_value, Fixture): +            return CoverageContainerCallable(field_value.to_value) + +        return CoverageContainerCallable(field_value) if callable(field_value) else field_value + +    @classmethod +    def _get_config(cls) -> dict[str, Any]: +        return { +            **{key: getattr(cls, key) for key in cls.__config_keys__}, +            "_extra_providers": cls.get_provider_map(), +        } + +    @classmethod +    def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: +        """Get a factory from registered factories or generate a factory dynamically. + +        :param model: A model type. +        :returns: A Factory sub-class. + +        """ +        if factory := BaseFactory._factory_type_mapping.get(model): +            return factory + +        config = cls._get_config() + +        if cls.__base_factory_overrides__: +            for model_ancestor in model.mro(): +                if factory := cls.__base_factory_overrides__.get(model_ancestor): +                    return factory.create_factory(model, **config) + +        for factory in reversed(BaseFactory._base_factories): +            if factory.is_supported_type(model): +                return factory.create_factory(model, **config) + +        msg = f"unsupported model type {model.__name__}" +        raise ParameterException(msg)  # pragma: no cover + +    # Public Methods + +    @classmethod +    def is_factory_type(cls, annotation: Any) -> bool: +        """Determine whether a given field is annotated with a type that is supported by a base factory. + +        :param annotation: A type annotation. +        :returns: Boolean dictating whether the annotation is a factory type +        """ +        return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories) + +    @classmethod +    def is_batch_factory_type(cls, annotation: Any) -> bool: +        """Determine whether a given field is annotated with a sequence of supported factory types. + +        :param annotation: A type annotation. +        :returns: Boolean dictating whether the annotation is a batch factory type +        """ +        origin = get_type_origin(annotation) or annotation +        if is_safe_subclass(origin, Sequence) and (args := unwrap_args(annotation, random=cls.__random__)): +            return len(args) == 1 and BaseFactory.is_factory_type(annotation=args[0]) +        return False + +    @classmethod +    def extract_field_build_parameters(cls, field_meta: FieldMeta, build_args: dict[str, Any]) -> Any: +        """Extract from the build kwargs any build parameters passed for a given field meta - if it is a factory type. + +        :param field_meta: A field meta instance. +        :param build_args: Any kwargs passed to the factory. +        :returns: Any values +        """ +        if build_arg := build_args.get(field_meta.name): +            annotation = unwrap_optional(field_meta.annotation) +            if ( +                BaseFactory.is_factory_type(annotation=annotation) +                and isinstance(build_arg, Mapping) +                and not BaseFactory.is_factory_type(annotation=type(build_arg)) +            ): +                return build_args.pop(field_meta.name) + +            if ( +                BaseFactory.is_batch_factory_type(annotation=annotation) +                and isinstance(build_arg, Sequence) +                and not any(BaseFactory.is_factory_type(annotation=type(value)) for value in build_arg) +            ): +                return build_args.pop(field_meta.name) +        return None + +    @classmethod +    @abstractmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":  # pragma: no cover +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        raise NotImplementedError + +    @classmethod +    def seed_random(cls, seed: int) -> None: +        """Seed faker and random with the given integer. + +        :param seed: An integer to set as seed. +        :returns: 'None' + +        """ +        cls.__random__ = Random(seed) +        cls.__faker__.seed_instance(seed) + +    @classmethod +    def is_ignored_type(cls, value: Any) -> bool: +        """Check whether a given value is an ignored type. + +        :param value: An arbitrary value. + +        :notes: +            - This method is meant to be overwritten by extension factories and other subclasses + +        :returns: A boolean determining whether the value should be ignored. + +        """ +        return value is None + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        """Map types to callables. + +        :notes: +            - This method is distinct to allow overriding. + + +        :returns: a dictionary mapping types to callables. + +        """ + +        def _create_generic_fn() -> Callable: +            """Return a generic lambda""" +            return lambda *args: None + +        return { +            Any: lambda: None, +            # primitives +            object: object, +            float: cls.__faker__.pyfloat, +            int: cls.__faker__.pyint, +            bool: cls.__faker__.pybool, +            str: cls.__faker__.pystr, +            bytes: partial(create_random_bytes, cls.__random__), +            # built-in objects +            dict: cls.__faker__.pydict, +            tuple: cls.__faker__.pytuple, +            list: cls.__faker__.pylist, +            set: cls.__faker__.pyset, +            frozenset: lambda: frozenset(cls.__faker__.pylist()), +            deque: lambda: deque(cls.__faker__.pylist()), +            # standard library objects +            Path: lambda: Path(realpath(__file__)), +            Decimal: cls.__faker__.pydecimal, +            UUID: lambda: UUID(cls.__faker__.uuid4()), +            # datetime +            datetime: cls.__faker__.date_time_between, +            date: cls.__faker__.date_this_decade, +            time: cls.__faker__.time_object, +            timedelta: cls.__faker__.time_delta, +            # ip addresses +            IPv4Address: lambda: ip_address(cls.__faker__.ipv4()), +            IPv4Interface: lambda: ip_interface(cls.__faker__.ipv4()), +            IPv4Network: lambda: ip_network(cls.__faker__.ipv4(network=True)), +            IPv6Address: lambda: ip_address(cls.__faker__.ipv6()), +            IPv6Interface: lambda: ip_interface(cls.__faker__.ipv6()), +            IPv6Network: lambda: ip_network(cls.__faker__.ipv6(network=True)), +            # types +            Callable: _create_generic_fn, +            abc.Callable: _create_generic_fn, +            Counter: lambda: Counter(cls.__faker__.pystr()), +            **(cls._extra_providers or {}), +        } + +    @classmethod +    def create_factory( +        cls: type[F], +        model: type[T] | None = None, +        bases: tuple[type[BaseFactory[Any]], ...] | None = None, +        **kwargs: Any, +    ) -> type[F]: +        """Generate a factory for the given type dynamically. + +        :param model: A type to model. Defaults to current factory __model__ if any. +            Otherwise, raise an error +        :param bases: Base classes to use when generating the new class. +        :param kwargs: Any kwargs. + +        :returns: A 'ModelFactory' subclass. + +        """ +        if model is None: +            try: +                model = cls.__model__ +            except AttributeError as ex: +                msg = "A 'model' argument is required when creating a new factory from a base one" +                raise TypeError(msg) from ex +        return cast( +            "Type[F]", +            type( +                f"{model.__name__}Factory",  # pyright: ignore[reportOptionalMemberAccess] +                (*(bases or ()), cls), +                {"__model__": model, **kwargs}, +            ), +        ) + +    @classmethod +    def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any:  # noqa: C901, PLR0911, PLR0912 +        try: +            constraints = cast("Constraints", field_meta.constraints) +            if is_safe_subclass(annotation, float): +                return handle_constrained_float( +                    random=cls.__random__, +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if is_safe_subclass(annotation, int): +                return handle_constrained_int( +                    random=cls.__random__, +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if is_safe_subclass(annotation, Decimal): +                return handle_constrained_decimal( +                    random=cls.__random__, +                    decimal_places=cast("Any", constraints.get("decimal_places")), +                    max_digits=cast("Any", constraints.get("max_digits")), +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if url_constraints := constraints.get("url"): +                return handle_constrained_url(constraints=url_constraints) + +            if is_safe_subclass(annotation, str) or is_safe_subclass(annotation, bytes): +                return handle_constrained_string_or_bytes( +                    random=cls.__random__, +                    t_type=str if is_safe_subclass(annotation, str) else bytes, +                    lower_case=constraints.get("lower_case") or False, +                    upper_case=constraints.get("upper_case") or False, +                    min_length=constraints.get("min_length"), +                    max_length=constraints.get("max_length"), +                    pattern=constraints.get("pattern"), +                ) + +            try: +                collection_type = get_collection_type(annotation) +            except ValueError: +                collection_type = None +            if collection_type is not None: +                if collection_type == dict: +                    return handle_constrained_mapping( +                        factory=cls, +                        field_meta=field_meta, +                        min_items=constraints.get("min_length"), +                        max_items=constraints.get("max_length"), +                    ) +                return handle_constrained_collection( +                    collection_type=collection_type,  # type: ignore[type-var] +                    factory=cls, +                    field_meta=field_meta.children[0] if field_meta.children else field_meta, +                    item_type=constraints.get("item_type"), +                    max_items=constraints.get("max_length"), +                    min_items=constraints.get("min_length"), +                    unique_items=constraints.get("unique_items", False), +                ) + +            if is_safe_subclass(annotation, date): +                return handle_constrained_date( +                    faker=cls.__faker__, +                    ge=cast("Any", constraints.get("ge")), +                    gt=cast("Any", constraints.get("gt")), +                    le=cast("Any", constraints.get("le")), +                    lt=cast("Any", constraints.get("lt")), +                    tz=cast("Any", constraints.get("tz")), +                ) + +            if is_safe_subclass(annotation, UUID) and (uuid_version := constraints.get("uuid_version")): +                return handle_constrained_uuid( +                    uuid_version=uuid_version, +                    faker=cls.__faker__, +                ) + +            if is_safe_subclass(annotation, Path) and (path_constraint := constraints.get("path_type")): +                return handle_constrained_path(constraint=path_constraint, faker=cls.__faker__) +        except TypeError as e: +            raise ParameterException from e + +        msg = f"received constraints for unsupported type {annotation}" +        raise ParameterException(msg) + +    @classmethod +    def get_field_value(  # noqa: C901, PLR0911, PLR0912 +        cls, +        field_meta: FieldMeta, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. +        :param build_context: BuildContext data for current build. + +        :returns: An arbitrary value. + +        """ +        build_context = _get_build_context(build_context) +        if cls.is_ignored_type(field_meta.annotation): +            return None + +        if field_build_parameters is None and cls.should_set_none_value(field_meta=field_meta): +            return None + +        unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=cls.__random__) + +        if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): +            return cls.__random__.choice(literal_args) + +        if isinstance(unwrapped_annotation, EnumMeta): +            return cls.__random__.choice(list(unwrapped_annotation)) + +        if field_meta.constraints: +            return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) + +        if is_union(field_meta.annotation) and field_meta.children: +            seen_models = build_context["seen_models"] +            children = [child for child in field_meta.children if child.annotation not in seen_models] + +            # `None` is removed from the children when creating FieldMeta so when `children` +            # is empty, it must mean that the field meta is an optional type. +            if children: +                return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context) + +        if BaseFactory.is_factory_type(annotation=unwrapped_annotation): +            if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: +                return None if is_optional(field_meta.annotation) else Null + +            return cls._get_or_create_factory(model=unwrapped_annotation).build( +                _build_context=build_context, +                **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), +            ) + +        if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): +            factory = cls._get_or_create_factory(model=field_meta.type_args[0]) +            if isinstance(field_build_parameters, Sequence): +                return [ +                    factory.build(_build_context=build_context, **field_parameters) +                    for field_parameters in field_build_parameters +                ] + +            if field_meta.type_args[0] in build_context["seen_models"]: +                return [] + +            if not cls.__randomize_collection_length__: +                return [factory.build(_build_context=build_context)] + +            batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) +            return factory.batch(size=batch_size, _build_context=build_context) + +        if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): +            if cls.__randomize_collection_length__: +                collection_type = get_collection_type(unwrapped_annotation) +                if collection_type != dict: +                    return handle_constrained_collection( +                        collection_type=collection_type,  # type: ignore[type-var] +                        factory=cls, +                        item_type=Any, +                        field_meta=field_meta.children[0] if field_meta.children else field_meta, +                        min_items=cls.__min_collection_length__, +                        max_items=cls.__max_collection_length__, +                    ) +                return handle_constrained_mapping( +                    factory=cls, +                    field_meta=field_meta, +                    min_items=cls.__min_collection_length__, +                    max_items=cls.__max_collection_length__, +                ) + +            return handle_collection_type(field_meta, origin, cls) + +        if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): +            return create_random_string(cls.__random__, min_length=1, max_length=10) + +        if provider := cls.get_provider_map().get(unwrapped_annotation): +            return provider() + +        if callable(unwrapped_annotation): +            # if value is a callable we can try to naively call it. +            # this will work for callables that do not require any parameters passed +            with suppress(Exception): +                return unwrapped_annotation() + +        msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." +        raise ParameterException( +            msg, +        ) + +    @classmethod +    def get_field_value_coverage(  # noqa: C901 +        cls, +        field_meta: FieldMeta, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Iterable[Any]: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. +        :param build_context: BuildContext data for current build. + +        :returns: An iterable of values. + +        """ +        if cls.is_ignored_type(field_meta.annotation): +            return [None] + +        for unwrapped_annotation in flatten_annotation(field_meta.annotation): +            if unwrapped_annotation in (None, NoneType): +                yield None + +            elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): +                yield CoverageContainer(literal_args) + +            elif isinstance(unwrapped_annotation, EnumMeta): +                yield CoverageContainer(list(unwrapped_annotation)) + +            elif field_meta.constraints: +                yield CoverageContainerCallable( +                    cls.get_constrained_field_value, +                    annotation=unwrapped_annotation, +                    field_meta=field_meta, +                ) + +            elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): +                yield CoverageContainer( +                    cls._get_or_create_factory(model=unwrapped_annotation).coverage( +                        _build_context=build_context, +                        **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), +                    ), +                ) + +            elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): +                yield handle_collection_type_coverage(field_meta, origin, cls) + +            elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): +                yield create_random_string(cls.__random__, min_length=1, max_length=10) + +            elif provider := cls.get_provider_map().get(unwrapped_annotation): +                yield CoverageContainerCallable(provider) + +            elif callable(unwrapped_annotation): +                # if value is a callable we can try to naively call it. +                # this will work for callables that do not require any parameters passed +                yield CoverageContainerCallable(unwrapped_annotation) +            else: +                msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." +                raise ParameterException( +                    msg, +                ) + +    @classmethod +    def should_set_none_value(cls, field_meta: FieldMeta) -> bool: +        """Determine whether a given model field_meta should be set to None. + +        :param field_meta: Field metadata. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether 'None' should be set for the given field_meta. + +        """ +        return ( +            cls.__allow_none_optionals__ +            and is_optional(field_meta.annotation) +            and create_random_boolean(random=cls.__random__) +        ) + +    @classmethod +    def should_use_default_value(cls, field_meta: FieldMeta) -> bool: +        """Determine whether to use the default value for the given field. + +        :param field_meta: FieldMeta instance. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether the default value should be used for the given field_meta. + +        """ +        return cls.__use_defaults__ and field_meta.default is not Null + +    @classmethod +    def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: +        """Determine whether to set a value for a given field_name. + +        :param field_meta: FieldMeta instance. +        :param kwargs: Any kwargs passed to the factory. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether a value should be set for the given field_meta. + +        """ +        return not field_meta.name.startswith("_") and field_meta.name not in kwargs + +    @classmethod +    @abstractmethod +    def get_model_fields(cls) -> list[FieldMeta]:  # pragma: no cover +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        raise NotImplementedError + +    @classmethod +    def get_factory_fields(cls) -> list[tuple[str, Any]]: +        """Retrieve a list of fields from the factory. + +        Trying to be smart about what should be considered a field on the model, +        ignoring dunder methods and some parent class attributes. + +        :returns: A list of tuples made of field name and field definition +        """ +        factory_fields = cls.__dict__.items() +        return [ +            (field_name, field_value) +            for field_name, field_value in factory_fields +            if not (field_name.startswith("__") or field_name == "_abc_impl") +        ] + +    @classmethod +    def _check_declared_fields_exist_in_model(cls) -> None: +        model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} +        factory_fields = cls.get_factory_fields() + +        for field_name, field_value in factory_fields: +            if field_name in model_fields_names: +                continue + +            error_message = ( +                f"{field_name} is declared on the factory {cls.__name__}" +                f" but it is not part of the model {cls.__model__.__name__}" +            ) +            if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): +                raise ConfigurationException(error_message) + +    @classmethod +    def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: +        """Process the given kwargs and generate values for the factory's model. + +        :param kwargs: Any build kwargs. + +        :returns: A dictionary of build results. + +        """ +        _build_context = _get_build_context(kwargs.pop("_build_context", None)) +        _build_context["seen_models"].add(cls.__model__) + +        result: dict[str, Any] = {**kwargs} +        generate_post: dict[str, PostGenerated] = {} + +        for field_meta in cls.get_model_fields(): +            field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) +            if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): +                if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): +                    field_value = getattr(cls, field_meta.name) +                    if isinstance(field_value, Ignore): +                        continue + +                    if isinstance(field_value, Require) and field_meta.name not in kwargs: +                        msg = f"Require kwarg {field_meta.name} is missing" +                        raise MissingBuildKwargException(msg) + +                    if isinstance(field_value, PostGenerated): +                        generate_post[field_meta.name] = field_value +                        continue + +                    result[field_meta.name] = cls._handle_factory_field( +                        field_value=field_value, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ) +                    continue + +                field_result = cls.get_field_value( +                    field_meta, +                    field_build_parameters=field_build_parameters, +                    build_context=_build_context, +                ) +                if field_result is Null: +                    continue + +                result[field_meta.name] = field_result + +        for field_name, post_generator in generate_post.items(): +            result[field_name] = post_generator.to_value(field_name, result) + +        return result + +    @classmethod +    def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: +        """Process the given kwargs and generate values for the factory's model. + +        :param kwargs: Any build kwargs. +        :param build_context: BuildContext data for current build. + +        :returns: A dictionary of build results. + +        """ +        _build_context = _get_build_context(kwargs.pop("_build_context", None)) +        _build_context["seen_models"].add(cls.__model__) + +        result: dict[str, Any] = {**kwargs} +        generate_post: dict[str, PostGenerated] = {} + +        for field_meta in cls.get_model_fields(): +            field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + +            if cls.should_set_field_value(field_meta, **kwargs): +                if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): +                    field_value = getattr(cls, field_meta.name) +                    if isinstance(field_value, Ignore): +                        continue + +                    if isinstance(field_value, Require) and field_meta.name not in kwargs: +                        msg = f"Require kwarg {field_meta.name} is missing" +                        raise MissingBuildKwargException(msg) + +                    if isinstance(field_value, PostGenerated): +                        generate_post[field_meta.name] = field_value +                        continue + +                    result[field_meta.name] = cls._handle_factory_field_coverage( +                        field_value=field_value, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ) +                    continue + +                result[field_meta.name] = CoverageContainer( +                    cls.get_field_value_coverage( +                        field_meta, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ), +                ) + +        for resolved in resolve_kwargs_coverage(result): +            for field_name, post_generator in generate_post.items(): +                resolved[field_name] = post_generator.to_value(field_name, resolved) +            yield resolved + +    @classmethod +    def build(cls, **kwargs: Any) -> T: +        """Build an instance of the factory's __model__ + +        :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + +    @classmethod +    def batch(cls, size: int, **kwargs: Any) -> list[T]: +        """Build a batch of size n of the factory's Meta.model. + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. + +        """ +        return [cls.build(**kwargs) for _ in range(size)] + +    @classmethod +    def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: +        """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A iterator of instances of type T. + +        """ +        for data in cls.process_kwargs_coverage(**kwargs): +            instance = cls.__model__(**data) +            yield cast("T", instance) + +    @classmethod +    def create_sync(cls, **kwargs: Any) -> T: +        """Build and persists synchronously a single model instance. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        return cls._get_sync_persistence().save(data=cls.build(**kwargs)) + +    @classmethod +    def create_batch_sync(cls, size: int, **kwargs: Any) -> list[T]: +        """Build and persists synchronously a batch of n size model instances. + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. + +        """ +        return cls._get_sync_persistence().save_many(data=cls.batch(size, **kwargs)) + +    @classmethod +    async def create_async(cls, **kwargs: Any) -> T: +        """Build and persists asynchronously a single model instance. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. +        """ +        return await cls._get_async_persistence().save(data=cls.build(**kwargs)) + +    @classmethod +    async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: +        """Build and persists asynchronously a batch of n size model instances. + + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. +        """ +        return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) + + +def _register_builtin_factories() -> None: +    """This function is used to register the base factories, if present. + +    :returns: None +    """ +    import polyfactory.factories.dataclass_factory +    import polyfactory.factories.typed_dict_factory  # noqa: F401 + +    for module in [ +        "polyfactory.factories.pydantic_factory", +        "polyfactory.factories.beanie_odm_factory", +        "polyfactory.factories.odmantic_odm_factory", +        "polyfactory.factories.msgspec_factory", +        # `AttrsFactory` is not being registered by default since not all versions of `attrs` are supported. +        # Issue: https://github.com/litestar-org/polyfactory/issues/356 +        # "polyfactory.factories.attrs_factory", +    ]: +        try: +            import_module(module) +        except ImportError:  # noqa: PERF203 +            continue + + +_register_builtin_factories() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py new file mode 100644 index 0000000..ddd3169 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from typing_extensions import get_args + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.persistence import AsyncPersistenceProtocol +from polyfactory.utils.predicates import is_safe_subclass + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +    from polyfactory.factories.base import BuildContext +    from polyfactory.field_meta import FieldMeta + +try: +    from beanie import Document +except ImportError as e: +    msg = "beanie is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Document) + + +class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]): +    """Persistence Handler using beanie logic""" + +    async def save(self, data: T) -> T: +        """Persist a single instance in mongoDB.""" +        return await data.insert()  # type: ignore[no-any-return] + +    async def save_many(self, data: list[T]) -> list[T]: +        """Persist multiple instances in mongoDB. + +        .. note:: we cannot use the ``.insert_many`` method from Beanie here because it doesn't +            return the created instances +        """ +        return [await doc.insert() for doc in data]  # pyright: ignore[reportGeneralTypeIssues] + + +class BeanieDocumentFactory(Generic[T], ModelFactory[T]): +    """Base factory for Beanie Documents""" + +    __async_persistence__ = BeaniePersistenceHandler +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_safe_subclass(value, Document) + +    @classmethod +    def get_field_value( +        cls, +        field_meta: "FieldMeta", +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value. + +        """ +        if hasattr(field_meta.annotation, "__name__"): +            if "Indexed " in field_meta.annotation.__name__: +                base_type = field_meta.annotation.__bases__[0] +                field_meta.annotation = base_type + +            if "Link" in field_meta.annotation.__name__: +                link_class = get_args(field_meta.annotation)[0] +                field_meta.annotation = link_class +                field_meta.annotation = link_class + +        return super().get_field_value( +            field_meta=field_meta, +            field_build_parameters=field_build_parameters, +            build_context=build_context, +        ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py new file mode 100644 index 0000000..01cfbe7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields, is_dataclass +from typing import Any, Generic + +from typing_extensions import TypeGuard, get_type_hints + +from polyfactory.factories.base import BaseFactory, T +from polyfactory.field_meta import FieldMeta, Null + + +class DataclassFactory(Generic[T], BaseFactory[T]): +    """Dataclass base factory""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return bool(is_dataclass(value)) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        fields_meta: list["FieldMeta"] = [] + +        model_type_hints = get_type_hints(cls.__model__, include_extras=True) + +        for field in fields(cls.__model__):  # type: ignore[arg-type] +            if not field.init: +                continue + +            if field.default_factory and field.default_factory is not MISSING: +                default_value = field.default_factory() +            elif field.default is not MISSING: +                default_value = field.default +            else: +                default_value = Null + +            fields_meta.append( +                FieldMeta.from_type( +                    annotation=model_type_hints[field.name], +                    name=field.name, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) + +        return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py new file mode 100644 index 0000000..1b579ae --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +from typing_extensions import get_type_hints + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null +from polyfactory.value_generators.constrained_numbers import handle_constrained_int +from polyfactory.value_generators.primitives import create_random_bytes + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +try: +    import msgspec +    from msgspec.structs import fields +except ImportError as e: +    msg = "msgspec is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=msgspec.Struct) + + +class MsgspecFactory(Generic[T], BaseFactory[T]): +    """Base factory for msgspec Structs.""" + +    __is_base_factory__ = True + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        def get_msgpack_ext() -> msgspec.msgpack.Ext: +            code = handle_constrained_int(cls.__random__, ge=-128, le=127) +            data = create_random_bytes(cls.__random__) +            return msgspec.msgpack.Ext(code, data) + +        msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext} + +        provider_map = super().get_provider_map() +        provider_map.update(msgspec_provider_map) + +        return provider_map + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        return isclass(value) and hasattr(value, "__struct_fields__") + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        fields_meta: list[FieldMeta] = [] + +        type_hints = get_type_hints(cls.__model__, include_extras=True) +        for field in fields(cls.__model__): +            annotation = type_hints[field.name] +            if field.default is not msgspec.NODEFAULT: +                default_value = field.default +            elif field.default_factory is not msgspec.NODEFAULT: +                default_value = field.default_factory() +            else: +                default_value = Null + +            fields_meta.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    name=field.name, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) +        return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py new file mode 100644 index 0000000..1b3367a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import decimal +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.utils.predicates import is_safe_subclass +from polyfactory.value_generators.primitives import create_random_bytes + +try: +    from bson.decimal128 import Decimal128, create_decimal128_context +    from odmantic import EmbeddedModel, Model +    from odmantic import bson as odbson + +except ImportError as e: +    msg = "odmantic is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Union[Model, EmbeddedModel]) + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + + +class OdmanticModelFactory(Generic[T], ModelFactory[T]): +    """Base factory for odmantic models""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_safe_subclass(value, (Model, EmbeddedModel)) + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        provider_map = super().get_provider_map() +        provider_map.update( +            { +                odbson.Int64: lambda: odbson.Int64.validate(cls.__faker__.pyint()), +                odbson.Decimal128: lambda: _to_decimal128(cls.__faker__.pydecimal()), +                odbson.Binary: lambda: odbson.Binary.validate(create_random_bytes(cls.__random__)), +                odbson._datetime: lambda: odbson._datetime.validate(cls.__faker__.date_time_between()), +                # bson.Regex and bson._Pattern not supported as there is no way to generate +                # a random regular expression with Faker +                # bson.Regex: +                # bson._Pattern: +            }, +        ) +        return provider_map + + +def _to_decimal128(value: decimal.Decimal) -> Decimal128: +    with decimal.localcontext(create_decimal128_context()) as ctx: +        return Decimal128(ctx.create_decimal(value)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py new file mode 100644 index 0000000..a6028b1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +from contextlib import suppress +from datetime import timezone +from functools import partial +from os.path import realpath +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast +from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5 + +from typing_extensions import Literal, get_args, get_origin + +from polyfactory.collection_extender import CollectionExtender +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import Constraints, FieldMeta, Null +from polyfactory.utils.deprecation import check_for_deprecated_parameters +from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional +from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.primitives import create_random_bytes + +try: +    import pydantic +    from pydantic import VERSION, Json +    from pydantic.fields import FieldInfo +except ImportError as e: +    msg = "pydantic is not installed" +    raise MissingDependencyException(msg) from e + +try: +    # pydantic v1 +    from pydantic import (  # noqa: I001 +        UUID1, +        UUID3, +        UUID4, +        UUID5, +        AmqpDsn, +        AnyHttpUrl, +        AnyUrl, +        DirectoryPath, +        FilePath, +        HttpUrl, +        KafkaDsn, +        PostgresDsn, +        RedisDsn, +    ) +    from pydantic import BaseModel as BaseModelV1 +    from pydantic.color import Color +    from pydantic.fields import (  # type: ignore[attr-defined] +        DeferredType,  # pyright: ignore[reportGeneralTypeIssues] +        ModelField,  # pyright: ignore[reportGeneralTypeIssues] +        Undefined,  # pyright: ignore[reportGeneralTypeIssues] +    ) + +    # Keep this import last to prevent warnings from pydantic if pydantic v2 +    # is installed. +    from pydantic import PyObject + +    # prevent unbound variable warnings +    BaseModelV2 = BaseModelV1 +    UndefinedV2 = Undefined +except ImportError: +    # pydantic v2 + +    # v2 specific imports +    from pydantic import BaseModel as BaseModelV2 +    from pydantic_core import PydanticUndefined as UndefinedV2 +    from pydantic_core import to_json + +    from pydantic.v1 import (  # v1 compat imports +        UUID1, +        UUID3, +        UUID4, +        UUID5, +        AmqpDsn, +        AnyHttpUrl, +        AnyUrl, +        DirectoryPath, +        FilePath, +        HttpUrl, +        KafkaDsn, +        PostgresDsn, +        PyObject, +        RedisDsn, +    ) +    from pydantic.v1 import BaseModel as BaseModelV1  # type: ignore[assignment] +    from pydantic.v1.color import Color  # type: ignore[assignment] +    from pydantic.v1.fields import DeferredType, ModelField, Undefined + + +if TYPE_CHECKING: +    from random import Random +    from typing import Callable, Sequence + +    from typing_extensions import NotRequired, TypeGuard + +T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") + +_IS_PYDANTIC_V1 = VERSION.startswith("1") + + +class PydanticConstraints(Constraints): +    """Metadata regarding a Pydantic type constraints, if any""" + +    json: NotRequired[bool] + + +class PydanticFieldMeta(FieldMeta): +    """Field meta subclass capable of handling pydantic ModelFields""" + +    def __init__( +        self, +        *, +        name: str, +        annotation: type, +        random: Random | None = None, +        default: Any = ..., +        children: list[FieldMeta] | None = None, +        constraints: PydanticConstraints | None = None, +    ) -> None: +        super().__init__( +            name=name, +            annotation=annotation, +            random=random, +            default=default, +            children=children, +            constraints=constraints, +        ) + +    @classmethod +    def from_field_info( +        cls, +        field_name: str, +        field_info: FieldInfo, +        use_alias: bool, +        random: Random | None, +        randomize_collection_length: bool | None = None, +        min_collection_length: int | None = None, +        max_collection_length: int | None = None, +    ) -> PydanticFieldMeta: +        """Create an instance from a pydantic field info. + +        :param field_name: The name of the field. +        :param field_info: A pydantic FieldInfo instance. +        :param use_alias: Whether to use the field alias. +        :param random: A random.Random instance. +        :param randomize_collection_length: Whether to randomize collection length. +        :param min_collection_length: Minimum collection length. +        :param max_collection_length: Maximum collection length. + +        :returns: A PydanticFieldMeta instance. +        """ +        check_for_deprecated_parameters( +            "2.11.0", +            parameters=( +                ("randomize_collection_length", randomize_collection_length), +                ("min_collection_length", min_collection_length), +                ("max_collection_length", max_collection_length), +            ), +        ) +        if callable(field_info.default_factory): +            default_value = field_info.default_factory() +        else: +            default_value = field_info.default if field_info.default is not UndefinedV2 else Null + +        annotation = unwrap_new_type(field_info.annotation) +        children: list[FieldMeta,] | None = None +        name = field_info.alias if field_info.alias and use_alias else field_name + +        constraints: PydanticConstraints +        # pydantic v2 does not always propagate metadata for Union types +        if is_union(annotation): +            constraints = {} +            children = [] +            for arg in get_args(annotation): +                if arg is NoneType: +                    continue +                child_field_info = FieldInfo.from_annotation(arg) +                merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info) +                children.append( +                    cls.from_field_info( +                        field_name="", +                        field_info=merged_field_info, +                        use_alias=use_alias, +                        random=random, +                    ), +                ) +        else: +            metadata, is_json = [], False +            for m in field_info.metadata: +                if not is_json and isinstance(m, Json):  # type: ignore[misc] +                    is_json = True +                elif m is not None: +                    metadata.append(m) + +            constraints = cast( +                PydanticConstraints, +                cls.parse_constraints(metadata=metadata) if metadata else {}, +            ) + +            if "url" in constraints: +                # pydantic uses a sentinel value for url constraints +                annotation = str + +            if is_json: +                constraints["json"] = True + +        return PydanticFieldMeta.from_type( +            annotation=annotation, +            children=children, +            constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None, +            default=default_value, +            name=name, +            random=random or DEFAULT_RANDOM, +        ) + +    @classmethod +    def from_model_field(  # pragma: no cover +        cls, +        model_field: ModelField,  # pyright: ignore[reportGeneralTypeIssues] +        use_alias: bool, +        randomize_collection_length: bool | None = None, +        min_collection_length: int | None = None, +        max_collection_length: int | None = None, +        random: Random = DEFAULT_RANDOM, +    ) -> PydanticFieldMeta: +        """Create an instance from a pydantic model field. +        :param model_field: A pydantic ModelField. +        :param use_alias: Whether to use the field alias. +        :param randomize_collection_length: A boolean flag whether to randomize collections lengths +        :param min_collection_length: Minimum number of elements in randomized collection +        :param max_collection_length: Maximum number of elements in randomized collection +        :param random: An instance of random.Random. + +        :returns: A PydanticFieldMeta instance. + +        """ +        check_for_deprecated_parameters( +            "2.11.0", +            parameters=( +                ("randomize_collection_length", randomize_collection_length), +                ("min_collection_length", min_collection_length), +                ("max_collection_length", max_collection_length), +            ), +        ) + +        if model_field.default is not Undefined: +            default_value = model_field.default +        elif callable(model_field.default_factory): +            default_value = model_field.default_factory() +        else: +            default_value = model_field.default if model_field.default is not Undefined else Null + +        name = model_field.alias if model_field.alias and use_alias else model_field.name + +        outer_type = unwrap_new_type(model_field.outer_type_) +        annotation = ( +            model_field.outer_type_ +            if isinstance(model_field.annotation, (DeferredType, ForwardRef)) +            else unwrap_new_type(model_field.annotation) +        ) + +        constraints = cast( +            "Constraints", +            { +                "ge": getattr(outer_type, "ge", model_field.field_info.ge), +                "gt": getattr(outer_type, "gt", model_field.field_info.gt), +                "le": getattr(outer_type, "le", model_field.field_info.le), +                "lt": getattr(outer_type, "lt", model_field.field_info.lt), +                "min_length": ( +                    getattr(outer_type, "min_length", model_field.field_info.min_length) +                    or getattr(outer_type, "min_items", model_field.field_info.min_items) +                ), +                "max_length": ( +                    getattr(outer_type, "max_length", model_field.field_info.max_length) +                    or getattr(outer_type, "max_items", model_field.field_info.max_items) +                ), +                "pattern": getattr(outer_type, "regex", model_field.field_info.regex), +                "unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items), +                "decimal_places": getattr(outer_type, "decimal_places", None), +                "max_digits": getattr(outer_type, "max_digits", None), +                "multiple_of": getattr(outer_type, "multiple_of", None), +                "upper_case": getattr(outer_type, "to_upper", None), +                "lower_case": getattr(outer_type, "to_lower", None), +                "item_type": getattr(outer_type, "item_type", None), +            }, +        ) + +        # pydantic v1 has constraints set for these values, but we generate them using faker +        if unwrap_optional(annotation) in ( +            AnyUrl, +            HttpUrl, +            KafkaDsn, +            PostgresDsn, +            RedisDsn, +            AmqpDsn, +            AnyHttpUrl, +        ): +            constraints = {} + +        if model_field.field_info.const and ( +            default_value is None or isinstance(default_value, (int, bool, str, bytes)) +        ): +            annotation = Literal[default_value]  # pyright: ignore  # noqa: PGH003 + +        children: list[FieldMeta] = [] + +        # Refer #412. +        args = get_args(model_field.annotation) +        if is_optional(model_field.annotation) and len(args) == 2:  # noqa: PLR2004 +            child_annotation = args[0] if args[0] is not NoneType else args[1] +            children.append(PydanticFieldMeta.from_type(child_annotation)) +        elif model_field.key_field or model_field.sub_fields: +            fields_to_iterate = ( +                ([model_field.key_field, *model_field.sub_fields]) +                if model_field.key_field is not None +                else model_field.sub_fields +            ) +            type_args = tuple( +                ( +                    sub_field.outer_type_ +                    if isinstance(sub_field.annotation, DeferredType) +                    else unwrap_new_type(sub_field.annotation) +                ) +                for sub_field in fields_to_iterate +            ) +            type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate)) +            if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis: +                # pydantic removes ellipses from Tuples in sub_fields +                type_args += (...,) +            extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1) +            children.extend( +                PydanticFieldMeta.from_model_field( +                    model_field=type_arg_to_sub_field[arg], +                    use_alias=use_alias, +                    random=random, +                ) +                for arg in extended_type_args +            ) + +        return PydanticFieldMeta( +            name=name, +            random=random or DEFAULT_RANDOM, +            annotation=annotation, +            children=children or None, +            default=default_value, +            constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None, +        ) + +    if not _IS_PYDANTIC_V1: + +        @classmethod +        def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: +            metadata = [] +            for m in super().get_constraints_metadata(annotation): +                if isinstance(m, FieldInfo): +                    metadata.extend(m.metadata) +                else: +                    metadata.append(m) + +            return metadata + + +class ModelFactory(Generic[T], BaseFactory[T]): +    """Base factory for pydantic models""" + +    __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} +    __is_base_factory__ = True + +    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: +        super().__init_subclass__(*args, **kwargs) + +        if ( +            getattr(cls, "__model__", None) +            and _is_pydantic_v1_model(cls.__model__) +            and hasattr(cls.__model__, "update_forward_refs") +        ): +            with suppress(NameError):  # pragma: no cover +                cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__)  # type: ignore[attr-defined] + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ + +        return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        if "_fields_metadata" not in cls.__dict__: +            if _is_pydantic_v1_model(cls.__model__): +                cls._fields_metadata = [ +                    PydanticFieldMeta.from_model_field( +                        field, +                        use_alias=not cls.__model__.__config__.allow_population_by_field_name,  # type: ignore[attr-defined] +                        random=cls.__random__, +                    ) +                    for field in cls.__model__.__fields__.values() +                ] +            else: +                cls._fields_metadata = [ +                    PydanticFieldMeta.from_field_info( +                        field_info=field_info, +                        field_name=field_name, +                        random=cls.__random__, +                        use_alias=not cls.__model__.model_config.get(  # pyright: ignore[reportGeneralTypeIssues] +                            "populate_by_name", +                            False, +                        ), +                    ) +                    for field_name, field_info in cls.__model__.model_fields.items()  # pyright: ignore[reportGeneralTypeIssues] +                ] +        return cls._fields_metadata + +    @classmethod +    def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: +        constraints = cast(PydanticConstraints, field_meta.constraints) +        if constraints.pop("json", None): +            value = cls.get_field_value(field_meta) +            return to_json(value)  # pyright: ignore[reportUnboundVariable] + +        return super().get_constrained_field_value(annotation, field_meta) + +    @classmethod +    def build( +        cls, +        factory_use_construct: bool = False, +        **kwargs: Any, +    ) -> T: +        """Build an instance of the factory's __model__ + +        :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the +                model. This is supported only for pydantic models. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        processed_kwargs = cls.process_kwargs(**kwargs) + +        if factory_use_construct: +            if _is_pydantic_v1_model(cls.__model__): +                return cls.__model__.construct(**processed_kwargs)  # type: ignore[return-value] +            return cls.__model__.model_construct(**processed_kwargs)  # type: ignore[return-value] + +        return cls.__model__(**processed_kwargs)  # type: ignore[return-value] + +    @classmethod +    def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: +        """Determine whether the field is a custom root field. + +        :param field_meta: FieldMeta instance. + +        :returns: A boolean determining whether the field is a custom root. + +        """ +        return field_meta.name == "__root__" + +    @classmethod +    def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: +        """Determine whether to set a value for a given field_name. +        This is an override of BaseFactory.should_set_field_value. + +        :param field_meta: FieldMeta instance. +        :param kwargs: Any kwargs passed to the factory. + +        :returns: A boolean determining whether a value should be set for the given field_meta. + +        """ +        return field_meta.name not in kwargs and ( +            not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta) +        ) + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        mapping = { +            pydantic.ByteSize: cls.__faker__.pyint, +            pydantic.PositiveInt: cls.__faker__.pyint, +            pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), +            pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, +            pydantic.PositiveFloat: cls.__faker__.pyint, +            pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), +            pydantic.NonNegativeInt: cls.__faker__.pyint, +            pydantic.StrictInt: cls.__faker__.pyint, +            pydantic.StrictBool: cls.__faker__.pybool, +            pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), +            pydantic.StrictFloat: cls.__faker__.pyfloat, +            pydantic.StrictStr: cls.__faker__.pystr, +            pydantic.EmailStr: cls.__faker__.free_email, +            pydantic.NameEmail: cls.__faker__.free_email, +            pydantic.Json: cls.__faker__.json, +            pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, +            pydantic.AnyUrl: cls.__faker__.url, +            pydantic.AnyHttpUrl: cls.__faker__.url, +            pydantic.HttpUrl: cls.__faker__.url, +            pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), +            pydantic.SecretStr: cls.__faker__.pystr, +            pydantic.IPvAnyAddress: cls.__faker__.ipv4, +            pydantic.IPvAnyInterface: cls.__faker__.ipv4, +            pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), +            pydantic.PastDate: cls.__faker__.past_date, +            pydantic.FutureDate: cls.__faker__.future_date, +        } + +        # v1 only values +        mapping.update( +            { +                PyObject: lambda: "decimal.Decimal", +                AmqpDsn: lambda: "amqps://example.com", +                KafkaDsn: lambda: "kafka://localhost:9092", +                PostgresDsn: lambda: "postgresql://user:secret@localhost", +                RedisDsn: lambda: "redis://localhost:6379/0", +                FilePath: lambda: Path(realpath(__file__)), +                DirectoryPath: lambda: Path(realpath(__file__)).parent, +                UUID1: uuid1, +                UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), +                UUID4: cls.__faker__.uuid4, +                UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), +                Color: cls.__faker__.hex_color,  # pyright: ignore[reportGeneralTypeIssues] +            }, +        ) + +        if not _IS_PYDANTIC_V1: +            mapping.update( +                { +                    # pydantic v2 specific types +                    pydantic.PastDatetime: cls.__faker__.past_datetime, +                    pydantic.FutureDatetime: cls.__faker__.future_datetime, +                    pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), +                    pydantic.NaiveDatetime: cls.__faker__.date_time, +                }, +            ) + +        mapping.update(super().get_provider_map()) +        return mapping + + +def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: +    return is_safe_subclass(model, BaseModelV1) + + +def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: +    return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py new file mode 100644 index 0000000..ad8873f --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta +from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + +try: +    from sqlalchemy import Column, inspect, types +    from sqlalchemy.dialects import mysql, postgresql +    from sqlalchemy.exc import NoInspectionAvailable +    from sqlalchemy.orm import InstanceState, Mapper +except ImportError as e: +    msg = "sqlalchemy is not installed" +    raise MissingDependencyException(msg) from e + +if TYPE_CHECKING: +    from sqlalchemy.ext.asyncio import AsyncSession +    from sqlalchemy.orm import Session +    from typing_extensions import TypeGuard + + +T = TypeVar("T") + + +class SQLASyncPersistence(SyncPersistenceProtocol[T]): +    def __init__(self, session: Session) -> None: +        """Sync persistence handler for SQLAFactory.""" +        self.session = session + +    def save(self, data: T) -> T: +        self.session.add(data) +        self.session.commit() +        return data + +    def save_many(self, data: list[T]) -> list[T]: +        self.session.add_all(data) +        self.session.commit() +        return data + + +class SQLAASyncPersistence(AsyncPersistenceProtocol[T]): +    def __init__(self, session: AsyncSession) -> None: +        """Async persistence handler for SQLAFactory.""" +        self.session = session + +    async def save(self, data: T) -> T: +        self.session.add(data) +        await self.session.commit() +        return data + +    async def save_many(self, data: list[T]) -> list[T]: +        self.session.add_all(data) +        await self.session.commit() +        return data + + +class SQLAlchemyFactory(Generic[T], BaseFactory[T]): +    """Base factory for SQLAlchemy models.""" + +    __is_base_factory__ = True + +    __set_primary_key__: ClassVar[bool] = True +    """Configuration to consider primary key columns as a field or not.""" +    __set_foreign_keys__: ClassVar[bool] = True +    """Configuration to consider columns with foreign keys as a field or not.""" +    __set_relationships__: ClassVar[bool] = False +    """Configuration to consider relationships property as a model field or not.""" + +    __session__: ClassVar[Session | Callable[[], Session] | None] = None +    __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None + +    __config_keys__ = ( +        *BaseFactory.__config_keys__, +        "__set_primary_key__", +        "__set_foreign_keys__", +        "__set_relationships__", +    ) + +    @classmethod +    def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: +        """Get mapping of types where column type.""" +        return { +            types.TupleType: cls.__faker__.pytuple, +            mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), +            postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False), +            postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()),  # noqa: DTZ011 +            postgresql.INET: lambda: cls.__faker__.ipv4(network=True), +            postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), +            postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()),  # noqa: DTZ005 +            postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()),  # noqa: DTZ005 +        } + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        providers_map = super().get_provider_map() +        providers_map.update(cls.get_sqlalchemy_types()) +        return providers_map + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        try: +            inspected = inspect(value) +        except NoInspectionAvailable: +            return False +        return isinstance(inspected, (Mapper, InstanceState)) + +    @classmethod +    def should_column_be_set(cls, column: Column) -> bool: +        if not cls.__set_primary_key__ and column.primary_key: +            return False + +        return bool(cls.__set_foreign_keys__ or not column.foreign_keys) + +    @classmethod +    def get_type_from_column(cls, column: Column) -> type: +        column_type = type(column.type) +        if column_type in cls.get_sqlalchemy_types(): +            annotation = column_type +        elif issubclass(column_type, types.ARRAY): +            annotation = List[column.type.item_type.python_type]  # type: ignore[assignment,name-defined] +        else: +            annotation = ( +                column.type.impl.python_type  # pyright: ignore[reportGeneralTypeIssues] +                if hasattr(column.type, "impl") +                else column.type.python_type +            ) + +        if column.nullable: +            annotation = Union[annotation, None]  # type: ignore[assignment] + +        return annotation + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        fields_meta: list[FieldMeta] = [] + +        table: Mapper = inspect(cls.__model__)  # type: ignore[assignment] +        fields_meta.extend( +            FieldMeta.from_type( +                annotation=cls.get_type_from_column(column), +                name=name, +                random=cls.__random__, +            ) +            for name, column in table.columns.items() +            if cls.should_column_be_set(column) +        ) +        if cls.__set_relationships__: +            for name, relationship in table.relationships.items(): +                class_ = relationship.entity.class_ +                annotation = class_ if not relationship.uselist else List[class_]  # type: ignore[valid-type] +                fields_meta.append( +                    FieldMeta.from_type( +                        name=name, +                        annotation=annotation, +                        random=cls.__random__, +                    ), +                ) + +        return fields_meta + +    @classmethod +    def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: +        if cls.__session__ is not None: +            return ( +                SQLASyncPersistence(cls.__session__()) +                if callable(cls.__session__) +                else SQLASyncPersistence(cls.__session__) +            ) +        return super()._get_sync_persistence() + +    @classmethod +    def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: +        if cls.__async_session__ is not None: +            return ( +                SQLAASyncPersistence(cls.__async_session__()) +                if callable(cls.__async_session__) +                else SQLAASyncPersistence(cls.__async_session__) +            ) +        return super()._get_async_persistence() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py new file mode 100644 index 0000000..2a3ea1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar, get_args + +from typing_extensions import (  # type: ignore[attr-defined] +    NotRequired, +    Required, +    TypeGuard, +    _TypedDictMeta,  # pyright: ignore[reportGeneralTypeIssues] +    get_origin, +    get_type_hints, +    is_typeddict, +) + +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta) + + +class TypedDictFactory(Generic[TypedDictT], BaseFactory[TypedDictT]): +    """TypedDict base factory""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_typeddict(value) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        model_type_hints = get_type_hints(cls.__model__, include_extras=True) + +        field_metas: list[FieldMeta] = [] +        for field_name, annotation in model_type_hints.items(): +            origin = get_origin(annotation) +            if origin in (Required, NotRequired): +                annotation = get_args(annotation)[0]  # noqa: PLW2901 + +            field_metas.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    random=DEFAULT_RANDOM, +                    name=field_name, +                    default=getattr(cls.__model__, field_name, Null), +                ), +            ) + +        return field_metas | 
