diff options
| author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
|---|---|---|
| committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
| commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
| tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/polyfactory | |
| parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) | |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory')
73 files changed, 4816 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/polyfactory/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/__init__.py new file mode 100644 index 0000000..0a2269e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__init__.py @@ -0,0 +1,16 @@ +from .exceptions import ConfigurationException +from .factories import BaseFactory +from .fields import Fixture, Ignore, PostGenerated, Require, Use +from .persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + +__all__ = ( +    "AsyncPersistenceProtocol", +    "BaseFactory", +    "ConfigurationException", +    "Fixture", +    "Ignore", +    "PostGenerated", +    "Require", +    "SyncPersistenceProtocol", +    "Use", +) diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..415b3cc --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pycBinary files differ new file mode 100644 index 0000000..5421b30 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/collection_extender.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pycBinary files differ new file mode 100644 index 0000000..c170008 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/constants.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pycBinary files differ new file mode 100644 index 0000000..e0c5378 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/decorators.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pycBinary files differ new file mode 100644 index 0000000..2d049c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/exceptions.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pycBinary files differ new file mode 100644 index 0000000..8575103 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/field_meta.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pycBinary files differ new file mode 100644 index 0000000..0c7fc3a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/fields.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6d40b2e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/persistence.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pycBinary files differ new file mode 100644 index 0000000..63cd9ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/__pycache__/pytest_plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py new file mode 100644 index 0000000..6377125 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/collection_extender.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import deque +from typing import Any + +from polyfactory.utils.predicates import is_safe_subclass + + +class CollectionExtender(ABC): +    __types__: tuple[type, ...] + +    @staticmethod +    @abstractmethod +    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: +        raise NotImplementedError + +    @classmethod +    def _subclass_for_type(cls, annotation_alias: Any) -> type[CollectionExtender]: +        return next( +            ( +                subclass +                for subclass in cls.__subclasses__() +                if any(is_safe_subclass(annotation_alias, t) for t in subclass.__types__) +            ), +            FallbackExtender, +        ) + +    @classmethod +    def extend_type_args( +        cls, +        annotation_alias: Any, +        type_args: tuple[Any, ...], +        number_of_args: int, +    ) -> tuple[Any, ...]: +        return cls._subclass_for_type(annotation_alias)._extend_type_args(type_args, number_of_args) + + +class TupleExtender(CollectionExtender): +    __types__ = (tuple,) + +    @staticmethod +    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: +        if not type_args: +            return type_args +        if type_args[-1] is not ...: +            return type_args +        type_to_extend = type_args[-2] +        return type_args[:-2] + (type_to_extend,) * number_of_args + + +class ListLikeExtender(CollectionExtender): +    __types__ = (list, deque) + +    @staticmethod +    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: +        if not type_args: +            return type_args +        return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class SetExtender(CollectionExtender): +    __types__ = (set, frozenset) + +    @staticmethod +    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: +        if not type_args: +            return type_args +        return tuple(random.choice(type_args) for _ in range(number_of_args)) + + +class DictExtender(CollectionExtender): +    __types__ = (dict,) + +    @staticmethod +    def _extend_type_args(type_args: tuple[Any, ...], number_of_args: int) -> tuple[Any, ...]: +        return type_args * number_of_args + + +class FallbackExtender(CollectionExtender): +    __types__ = () + +    @staticmethod +    def _extend_type_args( +        type_args: tuple[Any, ...], +        number_of_args: int,  # noqa: ARG004 +    ) -> tuple[Any, ...]:  # - investigate @guacs +        return type_args diff --git a/venv/lib/python3.11/site-packages/polyfactory/constants.py b/venv/lib/python3.11/site-packages/polyfactory/constants.py new file mode 100644 index 0000000..c0e6d50 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/constants.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import sys +from collections import abc, defaultdict, deque +from random import Random +from typing import ( +    DefaultDict, +    Deque, +    Dict, +    FrozenSet, +    Iterable, +    List, +    Mapping, +    Sequence, +    Set, +    Tuple, +    Union, +) + +try: +    from types import UnionType +except ImportError: +    UnionType = Union  # type: ignore[misc,assignment] + +PY_38 = sys.version_info.major == 3 and sys.version_info.minor == 8  # noqa: PLR2004 + +# Mapping of type annotations into concrete types. This is used to normalize python <= 3.9 annotations. +INSTANTIABLE_TYPE_MAPPING = { +    DefaultDict: defaultdict, +    Deque: deque, +    Dict: dict, +    FrozenSet: frozenset, +    Iterable: list, +    List: list, +    Mapping: dict, +    Sequence: list, +    Set: set, +    Tuple: tuple, +    abc.Iterable: list, +    abc.Mapping: dict, +    abc.Sequence: list, +    abc.Set: set, +    UnionType: Union, +} + + +if not PY_38: +    TYPE_MAPPING = INSTANTIABLE_TYPE_MAPPING +else: +    # For 3.8, we have to keep the types from typing since dict[str] syntax is not supported in 3.8. +    TYPE_MAPPING = { +        DefaultDict: DefaultDict, +        Deque: Deque, +        Dict: Dict, +        FrozenSet: FrozenSet, +        Iterable: List, +        List: List, +        Mapping: Dict, +        Sequence: List, +        Set: Set, +        Tuple: Tuple, +        abc.Iterable: List, +        abc.Mapping: Dict, +        abc.Sequence: List, +        abc.Set: Set, +    } + + +DEFAULT_RANDOM = Random() +RANDOMIZE_COLLECTION_LENGTH = False +MIN_COLLECTION_LENGTH = 0 +MAX_COLLECTION_LENGTH = 5 diff --git a/venv/lib/python3.11/site-packages/polyfactory/decorators.py b/venv/lib/python3.11/site-packages/polyfactory/decorators.py new file mode 100644 index 0000000..88c1021 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/decorators.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import contextlib +import inspect +from typing import Any, Callable + +from polyfactory import PostGenerated + + +class post_generated:  # noqa: N801 +    """Descriptor class for wrapping a classmethod into a ``PostGenerated`` field.""" + +    __slots__ = ("method", "cache") + +    def __init__(self, method: Callable | classmethod) -> None: +        if not isinstance(method, classmethod): +            msg = "post_generated decorator can only be used on classmethods" +            raise TypeError(msg) +        self.method = method +        self.cache: dict[type, PostGenerated] = {} + +    def __get__(self, obj: Any, objtype: type) -> PostGenerated: +        with contextlib.suppress(KeyError): +            return self.cache[objtype] +        fn = self.method.__func__  # pyright: ignore[reportFunctionMemberAccess] +        fn_args = inspect.getfullargspec(fn).args[1:] + +        def new_fn(name: str, values: dict[str, Any]) -> Any:  # noqa: ARG001  - investigate @guacs +            return fn(objtype, **{arg: values[arg] for arg in fn_args if arg in values}) + +        return self.cache.setdefault(objtype, PostGenerated(new_fn)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/exceptions.py b/venv/lib/python3.11/site-packages/polyfactory/exceptions.py new file mode 100644 index 0000000..53f1271 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/exceptions.py @@ -0,0 +1,18 @@ +class FactoryException(Exception): +    """Base Factory error class""" + + +class ConfigurationException(FactoryException): +    """Configuration Error class - used for misconfiguration""" + + +class ParameterException(FactoryException): +    """Parameter exception - used when wrong parameters are used""" + + +class MissingBuildKwargException(FactoryException): +    """Missing Build Kwarg exception - used when a required build kwarg is not provided""" + + +class MissingDependencyException(FactoryException, ImportError): +    """Missing dependency exception - used when a dependency is not installed""" diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py new file mode 100644 index 0000000..c8a9b92 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py @@ -0,0 +1,5 @@ +from polyfactory.factories.base import BaseFactory +from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.typed_dict_factory import TypedDictFactory + +__all__ = ("BaseFactory", "TypedDictFactory", "DataclassFactory") diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..0ebad18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..b5ca945 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pycBinary files differ new file mode 100644 index 0000000..2ee4cdb --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..0fc27ab --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6e59c83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..5e528d3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6f45693 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..ca70ca8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..31d93d5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pycBinary files differ new file mode 100644 index 0000000..a1ec583 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py new file mode 100644 index 0000000..00ffa03 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +if TYPE_CHECKING: +    from typing import Any, TypeGuard + + +try: +    import attrs +    from attr._make import Factory +    from attrs import AttrsInstance +except ImportError as ex: +    msg = "attrs is not installed" +    raise MissingDependencyException(msg) from ex + + +T = TypeVar("T", bound=AttrsInstance) + + +class AttrsFactory(Generic[T], BaseFactory[T]): +    """Base factory for attrs classes.""" + +    __model__: type[T] + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        return isclass(value) and hasattr(value, "__attrs_attrs__") + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        field_metas: list[FieldMeta] = [] +        none_type = type(None) + +        cls.resolve_types(cls.__model__) +        fields = attrs.fields(cls.__model__) + +        for field in fields: +            if not field.init: +                continue + +            annotation = none_type if field.type is None else field.type + +            default = field.default +            if isinstance(default, Factory): +                # The default value is not currently being used when generating +                # the field values. When that is implemented, this would need +                # to be handled differently since the `default.factory` could +                # take a `self` argument. +                default_value = default.factory +            elif default is None: +                default_value = Null +            else: +                default_value = default + +            field_metas.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    name=field.alias, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) + +        return field_metas + +    @classmethod +    def resolve_types(cls, model: type[T], **kwargs: Any) -> None: +        """Resolve any strings and forward annotations in type annotations. + +        :param model: The model to resolve the type annotations for. +        :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`. +        """ + +        attrs.resolve_types(model, **kwargs) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py new file mode 100644 index 0000000..60fe7a7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py @@ -0,0 +1,1127 @@ +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections import Counter, abc, deque +from contextlib import suppress +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import EnumMeta +from functools import partial +from importlib import import_module +from ipaddress import ( +    IPv4Address, +    IPv4Interface, +    IPv4Network, +    IPv6Address, +    IPv6Interface, +    IPv6Network, +    ip_address, +    ip_interface, +    ip_network, +) +from os.path import realpath +from pathlib import Path +from random import Random +from typing import ( +    TYPE_CHECKING, +    Any, +    Callable, +    ClassVar, +    Collection, +    Generic, +    Iterable, +    Mapping, +    Sequence, +    Type, +    TypedDict, +    TypeVar, +    cast, +) +from uuid import UUID + +from faker import Faker +from typing_extensions import get_args, get_origin, get_original_bases + +from polyfactory.constants import ( +    DEFAULT_RANDOM, +    MAX_COLLECTION_LENGTH, +    MIN_COLLECTION_LENGTH, +    RANDOMIZE_COLLECTION_LENGTH, +) +from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException +from polyfactory.field_meta import Null +from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use +from polyfactory.utils.helpers import ( +    flatten_annotation, +    get_collection_type, +    unwrap_annotation, +    unwrap_args, +    unwrap_optional, +) +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage +from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage +from polyfactory.value_generators.constrained_collections import ( +    handle_constrained_collection, +    handle_constrained_mapping, +) +from polyfactory.value_generators.constrained_dates import handle_constrained_date +from polyfactory.value_generators.constrained_numbers import ( +    handle_constrained_decimal, +    handle_constrained_float, +    handle_constrained_int, +) +from polyfactory.value_generators.constrained_path import handle_constrained_path +from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes +from polyfactory.value_generators.constrained_url import handle_constrained_url +from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid +from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +    from polyfactory.field_meta import Constraints, FieldMeta +    from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + + +T = TypeVar("T") +F = TypeVar("F", bound="BaseFactory[Any]") + + +class BuildContext(TypedDict): +    seen_models: set[type] + + +def _get_build_context(build_context: BuildContext | None) -> BuildContext: +    if build_context is None: +        return {"seen_models": set()} + +    return copy.deepcopy(build_context) + + +class BaseFactory(ABC, Generic[T]): +    """Base Factory class - this class holds the main logic of the library""" + +    # configuration attributes +    __model__: type[T] +    """ +    The model for the factory. +    This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument. +    """ +    __check_model__: bool = False +    """ +    Flag dictating whether to check if fields defined on the factory exists on the model or not. +    If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only. +    """ +    __allow_none_optionals__: ClassVar[bool] = True +    """ +    Flag dictating whether to allow 'None' for optional values. +    If 'True', 'None' will be randomly generated as a value for optional model fields +    """ +    __sync_persistence__: type[SyncPersistenceProtocol[T]] | SyncPersistenceProtocol[T] | None = None +    """A sync persistence handler. Can be a class or a class instance.""" +    __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None +    """An async persistence handler. Can be a class or a class instance.""" +    __set_as_default_factory_for_type__ = False +    """ +    Flag dictating whether to set as the default factory for the given type. +    If 'True' the factory will be used instead of dynamically generating a factory for the type. +    """ +    __is_base_factory__: bool = False +    """ +    Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types. +    For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories. +    """ +    __base_factory_overrides__: dict[Any, type[BaseFactory[Any]]] | None = None +    """ +    A base factory to override with this factory. If this value is set, the given factory will replace the given base factory. + +    Note: this value can only be set when '__is_base_factory__' is 'True'. +    """ +    __faker__: ClassVar["Faker"] = Faker() +    """ +    A faker instance to use. Can be a user provided value. +    """ +    __random__: ClassVar["Random"] = DEFAULT_RANDOM +    """ +    An instance of 'random.Random' to use. +    """ +    __random_seed__: ClassVar[int] +    """ +    An integer to seed the factory's Faker and Random instances with. +    This attribute can be used to control random generation. +    """ +    __randomize_collection_length__: ClassVar[bool] = RANDOMIZE_COLLECTION_LENGTH +    """ +    Flag dictating whether to randomize collections lengths. +    """ +    __min_collection_length__: ClassVar[int] = MIN_COLLECTION_LENGTH +    """ +    An integer value that defines minimum length of a collection. +    """ +    __max_collection_length__: ClassVar[int] = MAX_COLLECTION_LENGTH +    """ +    An integer value that defines maximum length of a collection. +    """ +    __use_defaults__: ClassVar[bool] = False +    """ +    Flag indicating whether to use the default value on a specific field, if provided. +    """ + +    __config_keys__: tuple[str, ...] = ( +        "__check_model__", +        "__allow_none_optionals__", +        "__set_as_default_factory_for_type__", +        "__faker__", +        "__random__", +        "__randomize_collection_length__", +        "__min_collection_length__", +        "__max_collection_length__", +        "__use_defaults__", +    ) +    """Keys to be considered as config values to pass on to dynamically created factories.""" + +    # cached attributes +    _fields_metadata: list[FieldMeta] +    # BaseFactory only attributes +    _factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]] +    _base_factories: ClassVar[list[type[BaseFactory[Any]]]] + +    # Non-public attributes +    _extra_providers: dict[Any, Callable[[], Any]] | None = None + +    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:  # noqa: C901 +        super().__init_subclass__(*args, **kwargs) + +        if not hasattr(BaseFactory, "_base_factories"): +            BaseFactory._base_factories = [] + +        if not hasattr(BaseFactory, "_factory_type_mapping"): +            BaseFactory._factory_type_mapping = {} + +        if cls.__min_collection_length__ > cls.__max_collection_length__: +            msg = "Minimum collection length shouldn't be greater than maximum collection length" +            raise ConfigurationException( +                msg, +            ) + +        if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: +            model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type() +            if not model: +                msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" +                raise ConfigurationException( +                    msg, +                ) +            cls.__model__ = model +            if not cls.is_supported_type(model): +                for factory in BaseFactory._base_factories: +                    if factory.is_supported_type(model): +                        msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}" +                        raise ConfigurationException( +                            msg, +                        ) +                    msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory." +                    raise ConfigurationException( +                        msg, +                    ) +            if cls.__check_model__: +                cls._check_declared_fields_exist_in_model() +        else: +            BaseFactory._base_factories.append(cls) + +        random_seed = getattr(cls, "__random_seed__", None) +        if random_seed is not None: +            cls.seed_random(random_seed) + +        if cls.__set_as_default_factory_for_type__ and hasattr(cls, "__model__"): +            BaseFactory._factory_type_mapping[cls.__model__] = cls + +    @classmethod +    def _infer_model_type(cls: type[F]) -> type[T] | None: +        """Return model type inferred from class declaration. +        class Foo(ModelFactory[MyModel]):  # <<< MyModel +            ... + +        If more than one base class and/or generic arguments specified return None. + +        :returns: Inferred model type or None +        """ + +        factory_bases: Iterable[type[BaseFactory[T]]] = ( +            b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory) +        ) +        generic_args: Sequence[type[T]] = [ +            arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) +        ] +        if len(generic_args) != 1: +            return None + +        return generic_args[0] + +    @classmethod +    def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: +        """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + +        :raises: ConfigurationException +        :returns: SyncPersistenceHandler +        """ +        if cls.__sync_persistence__: +            return cls.__sync_persistence__() if callable(cls.__sync_persistence__) else cls.__sync_persistence__ +        msg = "A '__sync_persistence__' handler must be defined in the factory to use this method" +        raise ConfigurationException( +            msg, +        ) + +    @classmethod +    def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: +        """Return a AsyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. + +        :raises: ConfigurationException +        :returns: AsyncPersistenceHandler +        """ +        if cls.__async_persistence__: +            return cls.__async_persistence__() if callable(cls.__async_persistence__) else cls.__async_persistence__ +        msg = "An '__async_persistence__' handler must be defined in the factory to use this method" +        raise ConfigurationException( +            msg, +        ) + +    @classmethod +    def _handle_factory_field( +        cls, +        field_value: Any, +        build_context: BuildContext, +        field_build_parameters: Any | None = None, +    ) -> Any: +        """Handle a value defined on the factory class itself. + +        :param field_value: A value defined as an attribute on the factory class. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value correlating with the given field_meta value. +        """ +        if is_safe_subclass(field_value, BaseFactory): +            if isinstance(field_build_parameters, Mapping): +                return field_value.build(_build_context=build_context, **field_build_parameters) + +            if isinstance(field_build_parameters, Sequence): +                return [ +                    field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters +                ] + +            return field_value.build(_build_context=build_context) + +        if isinstance(field_value, Use): +            return field_value.to_value() + +        if isinstance(field_value, Fixture): +            return field_value.to_value() + +        return field_value() if callable(field_value) else field_value + +    @classmethod +    def _handle_factory_field_coverage( +        cls, +        field_value: Any, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Handle a value defined on the factory class itself. + +        :param field_value: A value defined as an attribute on the factory class. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value correlating with the given field_meta value. +        """ +        if is_safe_subclass(field_value, BaseFactory): +            if isinstance(field_build_parameters, Mapping): +                return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) + +            if isinstance(field_build_parameters, Sequence): +                return [ +                    CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) +                    for parameter in field_build_parameters +                ] + +            return CoverageContainer(field_value.coverage()) + +        if isinstance(field_value, Use): +            return field_value.to_value() + +        if isinstance(field_value, Fixture): +            return CoverageContainerCallable(field_value.to_value) + +        return CoverageContainerCallable(field_value) if callable(field_value) else field_value + +    @classmethod +    def _get_config(cls) -> dict[str, Any]: +        return { +            **{key: getattr(cls, key) for key in cls.__config_keys__}, +            "_extra_providers": cls.get_provider_map(), +        } + +    @classmethod +    def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: +        """Get a factory from registered factories or generate a factory dynamically. + +        :param model: A model type. +        :returns: A Factory sub-class. + +        """ +        if factory := BaseFactory._factory_type_mapping.get(model): +            return factory + +        config = cls._get_config() + +        if cls.__base_factory_overrides__: +            for model_ancestor in model.mro(): +                if factory := cls.__base_factory_overrides__.get(model_ancestor): +                    return factory.create_factory(model, **config) + +        for factory in reversed(BaseFactory._base_factories): +            if factory.is_supported_type(model): +                return factory.create_factory(model, **config) + +        msg = f"unsupported model type {model.__name__}" +        raise ParameterException(msg)  # pragma: no cover + +    # Public Methods + +    @classmethod +    def is_factory_type(cls, annotation: Any) -> bool: +        """Determine whether a given field is annotated with a type that is supported by a base factory. + +        :param annotation: A type annotation. +        :returns: Boolean dictating whether the annotation is a factory type +        """ +        return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories) + +    @classmethod +    def is_batch_factory_type(cls, annotation: Any) -> bool: +        """Determine whether a given field is annotated with a sequence of supported factory types. + +        :param annotation: A type annotation. +        :returns: Boolean dictating whether the annotation is a batch factory type +        """ +        origin = get_type_origin(annotation) or annotation +        if is_safe_subclass(origin, Sequence) and (args := unwrap_args(annotation, random=cls.__random__)): +            return len(args) == 1 and BaseFactory.is_factory_type(annotation=args[0]) +        return False + +    @classmethod +    def extract_field_build_parameters(cls, field_meta: FieldMeta, build_args: dict[str, Any]) -> Any: +        """Extract from the build kwargs any build parameters passed for a given field meta - if it is a factory type. + +        :param field_meta: A field meta instance. +        :param build_args: Any kwargs passed to the factory. +        :returns: Any values +        """ +        if build_arg := build_args.get(field_meta.name): +            annotation = unwrap_optional(field_meta.annotation) +            if ( +                BaseFactory.is_factory_type(annotation=annotation) +                and isinstance(build_arg, Mapping) +                and not BaseFactory.is_factory_type(annotation=type(build_arg)) +            ): +                return build_args.pop(field_meta.name) + +            if ( +                BaseFactory.is_batch_factory_type(annotation=annotation) +                and isinstance(build_arg, Sequence) +                and not any(BaseFactory.is_factory_type(annotation=type(value)) for value in build_arg) +            ): +                return build_args.pop(field_meta.name) +        return None + +    @classmethod +    @abstractmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":  # pragma: no cover +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        raise NotImplementedError + +    @classmethod +    def seed_random(cls, seed: int) -> None: +        """Seed faker and random with the given integer. + +        :param seed: An integer to set as seed. +        :returns: 'None' + +        """ +        cls.__random__ = Random(seed) +        cls.__faker__.seed_instance(seed) + +    @classmethod +    def is_ignored_type(cls, value: Any) -> bool: +        """Check whether a given value is an ignored type. + +        :param value: An arbitrary value. + +        :notes: +            - This method is meant to be overwritten by extension factories and other subclasses + +        :returns: A boolean determining whether the value should be ignored. + +        """ +        return value is None + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        """Map types to callables. + +        :notes: +            - This method is distinct to allow overriding. + + +        :returns: a dictionary mapping types to callables. + +        """ + +        def _create_generic_fn() -> Callable: +            """Return a generic lambda""" +            return lambda *args: None + +        return { +            Any: lambda: None, +            # primitives +            object: object, +            float: cls.__faker__.pyfloat, +            int: cls.__faker__.pyint, +            bool: cls.__faker__.pybool, +            str: cls.__faker__.pystr, +            bytes: partial(create_random_bytes, cls.__random__), +            # built-in objects +            dict: cls.__faker__.pydict, +            tuple: cls.__faker__.pytuple, +            list: cls.__faker__.pylist, +            set: cls.__faker__.pyset, +            frozenset: lambda: frozenset(cls.__faker__.pylist()), +            deque: lambda: deque(cls.__faker__.pylist()), +            # standard library objects +            Path: lambda: Path(realpath(__file__)), +            Decimal: cls.__faker__.pydecimal, +            UUID: lambda: UUID(cls.__faker__.uuid4()), +            # datetime +            datetime: cls.__faker__.date_time_between, +            date: cls.__faker__.date_this_decade, +            time: cls.__faker__.time_object, +            timedelta: cls.__faker__.time_delta, +            # ip addresses +            IPv4Address: lambda: ip_address(cls.__faker__.ipv4()), +            IPv4Interface: lambda: ip_interface(cls.__faker__.ipv4()), +            IPv4Network: lambda: ip_network(cls.__faker__.ipv4(network=True)), +            IPv6Address: lambda: ip_address(cls.__faker__.ipv6()), +            IPv6Interface: lambda: ip_interface(cls.__faker__.ipv6()), +            IPv6Network: lambda: ip_network(cls.__faker__.ipv6(network=True)), +            # types +            Callable: _create_generic_fn, +            abc.Callable: _create_generic_fn, +            Counter: lambda: Counter(cls.__faker__.pystr()), +            **(cls._extra_providers or {}), +        } + +    @classmethod +    def create_factory( +        cls: type[F], +        model: type[T] | None = None, +        bases: tuple[type[BaseFactory[Any]], ...] | None = None, +        **kwargs: Any, +    ) -> type[F]: +        """Generate a factory for the given type dynamically. + +        :param model: A type to model. Defaults to current factory __model__ if any. +            Otherwise, raise an error +        :param bases: Base classes to use when generating the new class. +        :param kwargs: Any kwargs. + +        :returns: A 'ModelFactory' subclass. + +        """ +        if model is None: +            try: +                model = cls.__model__ +            except AttributeError as ex: +                msg = "A 'model' argument is required when creating a new factory from a base one" +                raise TypeError(msg) from ex +        return cast( +            "Type[F]", +            type( +                f"{model.__name__}Factory",  # pyright: ignore[reportOptionalMemberAccess] +                (*(bases or ()), cls), +                {"__model__": model, **kwargs}, +            ), +        ) + +    @classmethod +    def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any:  # noqa: C901, PLR0911, PLR0912 +        try: +            constraints = cast("Constraints", field_meta.constraints) +            if is_safe_subclass(annotation, float): +                return handle_constrained_float( +                    random=cls.__random__, +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if is_safe_subclass(annotation, int): +                return handle_constrained_int( +                    random=cls.__random__, +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if is_safe_subclass(annotation, Decimal): +                return handle_constrained_decimal( +                    random=cls.__random__, +                    decimal_places=cast("Any", constraints.get("decimal_places")), +                    max_digits=cast("Any", constraints.get("max_digits")), +                    multiple_of=cast("Any", constraints.get("multiple_of")), +                    gt=cast("Any", constraints.get("gt")), +                    ge=cast("Any", constraints.get("ge")), +                    lt=cast("Any", constraints.get("lt")), +                    le=cast("Any", constraints.get("le")), +                ) + +            if url_constraints := constraints.get("url"): +                return handle_constrained_url(constraints=url_constraints) + +            if is_safe_subclass(annotation, str) or is_safe_subclass(annotation, bytes): +                return handle_constrained_string_or_bytes( +                    random=cls.__random__, +                    t_type=str if is_safe_subclass(annotation, str) else bytes, +                    lower_case=constraints.get("lower_case") or False, +                    upper_case=constraints.get("upper_case") or False, +                    min_length=constraints.get("min_length"), +                    max_length=constraints.get("max_length"), +                    pattern=constraints.get("pattern"), +                ) + +            try: +                collection_type = get_collection_type(annotation) +            except ValueError: +                collection_type = None +            if collection_type is not None: +                if collection_type == dict: +                    return handle_constrained_mapping( +                        factory=cls, +                        field_meta=field_meta, +                        min_items=constraints.get("min_length"), +                        max_items=constraints.get("max_length"), +                    ) +                return handle_constrained_collection( +                    collection_type=collection_type,  # type: ignore[type-var] +                    factory=cls, +                    field_meta=field_meta.children[0] if field_meta.children else field_meta, +                    item_type=constraints.get("item_type"), +                    max_items=constraints.get("max_length"), +                    min_items=constraints.get("min_length"), +                    unique_items=constraints.get("unique_items", False), +                ) + +            if is_safe_subclass(annotation, date): +                return handle_constrained_date( +                    faker=cls.__faker__, +                    ge=cast("Any", constraints.get("ge")), +                    gt=cast("Any", constraints.get("gt")), +                    le=cast("Any", constraints.get("le")), +                    lt=cast("Any", constraints.get("lt")), +                    tz=cast("Any", constraints.get("tz")), +                ) + +            if is_safe_subclass(annotation, UUID) and (uuid_version := constraints.get("uuid_version")): +                return handle_constrained_uuid( +                    uuid_version=uuid_version, +                    faker=cls.__faker__, +                ) + +            if is_safe_subclass(annotation, Path) and (path_constraint := constraints.get("path_type")): +                return handle_constrained_path(constraint=path_constraint, faker=cls.__faker__) +        except TypeError as e: +            raise ParameterException from e + +        msg = f"received constraints for unsupported type {annotation}" +        raise ParameterException(msg) + +    @classmethod +    def get_field_value(  # noqa: C901, PLR0911, PLR0912 +        cls, +        field_meta: FieldMeta, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. +        :param build_context: BuildContext data for current build. + +        :returns: An arbitrary value. + +        """ +        build_context = _get_build_context(build_context) +        if cls.is_ignored_type(field_meta.annotation): +            return None + +        if field_build_parameters is None and cls.should_set_none_value(field_meta=field_meta): +            return None + +        unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=cls.__random__) + +        if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): +            return cls.__random__.choice(literal_args) + +        if isinstance(unwrapped_annotation, EnumMeta): +            return cls.__random__.choice(list(unwrapped_annotation)) + +        if field_meta.constraints: +            return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) + +        if is_union(field_meta.annotation) and field_meta.children: +            seen_models = build_context["seen_models"] +            children = [child for child in field_meta.children if child.annotation not in seen_models] + +            # `None` is removed from the children when creating FieldMeta so when `children` +            # is empty, it must mean that the field meta is an optional type. +            if children: +                return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context) + +        if BaseFactory.is_factory_type(annotation=unwrapped_annotation): +            if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: +                return None if is_optional(field_meta.annotation) else Null + +            return cls._get_or_create_factory(model=unwrapped_annotation).build( +                _build_context=build_context, +                **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), +            ) + +        if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): +            factory = cls._get_or_create_factory(model=field_meta.type_args[0]) +            if isinstance(field_build_parameters, Sequence): +                return [ +                    factory.build(_build_context=build_context, **field_parameters) +                    for field_parameters in field_build_parameters +                ] + +            if field_meta.type_args[0] in build_context["seen_models"]: +                return [] + +            if not cls.__randomize_collection_length__: +                return [factory.build(_build_context=build_context)] + +            batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) +            return factory.batch(size=batch_size, _build_context=build_context) + +        if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): +            if cls.__randomize_collection_length__: +                collection_type = get_collection_type(unwrapped_annotation) +                if collection_type != dict: +                    return handle_constrained_collection( +                        collection_type=collection_type,  # type: ignore[type-var] +                        factory=cls, +                        item_type=Any, +                        field_meta=field_meta.children[0] if field_meta.children else field_meta, +                        min_items=cls.__min_collection_length__, +                        max_items=cls.__max_collection_length__, +                    ) +                return handle_constrained_mapping( +                    factory=cls, +                    field_meta=field_meta, +                    min_items=cls.__min_collection_length__, +                    max_items=cls.__max_collection_length__, +                ) + +            return handle_collection_type(field_meta, origin, cls) + +        if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): +            return create_random_string(cls.__random__, min_length=1, max_length=10) + +        if provider := cls.get_provider_map().get(unwrapped_annotation): +            return provider() + +        if callable(unwrapped_annotation): +            # if value is a callable we can try to naively call it. +            # this will work for callables that do not require any parameters passed +            with suppress(Exception): +                return unwrapped_annotation() + +        msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." +        raise ParameterException( +            msg, +        ) + +    @classmethod +    def get_field_value_coverage(  # noqa: C901 +        cls, +        field_meta: FieldMeta, +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Iterable[Any]: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. +        :param build_context: BuildContext data for current build. + +        :returns: An iterable of values. + +        """ +        if cls.is_ignored_type(field_meta.annotation): +            return [None] + +        for unwrapped_annotation in flatten_annotation(field_meta.annotation): +            if unwrapped_annotation in (None, NoneType): +                yield None + +            elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): +                yield CoverageContainer(literal_args) + +            elif isinstance(unwrapped_annotation, EnumMeta): +                yield CoverageContainer(list(unwrapped_annotation)) + +            elif field_meta.constraints: +                yield CoverageContainerCallable( +                    cls.get_constrained_field_value, +                    annotation=unwrapped_annotation, +                    field_meta=field_meta, +                ) + +            elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): +                yield CoverageContainer( +                    cls._get_or_create_factory(model=unwrapped_annotation).coverage( +                        _build_context=build_context, +                        **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), +                    ), +                ) + +            elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): +                yield handle_collection_type_coverage(field_meta, origin, cls) + +            elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): +                yield create_random_string(cls.__random__, min_length=1, max_length=10) + +            elif provider := cls.get_provider_map().get(unwrapped_annotation): +                yield CoverageContainerCallable(provider) + +            elif callable(unwrapped_annotation): +                # if value is a callable we can try to naively call it. +                # this will work for callables that do not require any parameters passed +                yield CoverageContainerCallable(unwrapped_annotation) +            else: +                msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." +                raise ParameterException( +                    msg, +                ) + +    @classmethod +    def should_set_none_value(cls, field_meta: FieldMeta) -> bool: +        """Determine whether a given model field_meta should be set to None. + +        :param field_meta: Field metadata. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether 'None' should be set for the given field_meta. + +        """ +        return ( +            cls.__allow_none_optionals__ +            and is_optional(field_meta.annotation) +            and create_random_boolean(random=cls.__random__) +        ) + +    @classmethod +    def should_use_default_value(cls, field_meta: FieldMeta) -> bool: +        """Determine whether to use the default value for the given field. + +        :param field_meta: FieldMeta instance. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether the default value should be used for the given field_meta. + +        """ +        return cls.__use_defaults__ and field_meta.default is not Null + +    @classmethod +    def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: +        """Determine whether to set a value for a given field_name. + +        :param field_meta: FieldMeta instance. +        :param kwargs: Any kwargs passed to the factory. + +        :notes: +            - This method is distinct to allow overriding. + +        :returns: A boolean determining whether a value should be set for the given field_meta. + +        """ +        return not field_meta.name.startswith("_") and field_meta.name not in kwargs + +    @classmethod +    @abstractmethod +    def get_model_fields(cls) -> list[FieldMeta]:  # pragma: no cover +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        raise NotImplementedError + +    @classmethod +    def get_factory_fields(cls) -> list[tuple[str, Any]]: +        """Retrieve a list of fields from the factory. + +        Trying to be smart about what should be considered a field on the model, +        ignoring dunder methods and some parent class attributes. + +        :returns: A list of tuples made of field name and field definition +        """ +        factory_fields = cls.__dict__.items() +        return [ +            (field_name, field_value) +            for field_name, field_value in factory_fields +            if not (field_name.startswith("__") or field_name == "_abc_impl") +        ] + +    @classmethod +    def _check_declared_fields_exist_in_model(cls) -> None: +        model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} +        factory_fields = cls.get_factory_fields() + +        for field_name, field_value in factory_fields: +            if field_name in model_fields_names: +                continue + +            error_message = ( +                f"{field_name} is declared on the factory {cls.__name__}" +                f" but it is not part of the model {cls.__model__.__name__}" +            ) +            if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): +                raise ConfigurationException(error_message) + +    @classmethod +    def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: +        """Process the given kwargs and generate values for the factory's model. + +        :param kwargs: Any build kwargs. + +        :returns: A dictionary of build results. + +        """ +        _build_context = _get_build_context(kwargs.pop("_build_context", None)) +        _build_context["seen_models"].add(cls.__model__) + +        result: dict[str, Any] = {**kwargs} +        generate_post: dict[str, PostGenerated] = {} + +        for field_meta in cls.get_model_fields(): +            field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) +            if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): +                if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): +                    field_value = getattr(cls, field_meta.name) +                    if isinstance(field_value, Ignore): +                        continue + +                    if isinstance(field_value, Require) and field_meta.name not in kwargs: +                        msg = f"Require kwarg {field_meta.name} is missing" +                        raise MissingBuildKwargException(msg) + +                    if isinstance(field_value, PostGenerated): +                        generate_post[field_meta.name] = field_value +                        continue + +                    result[field_meta.name] = cls._handle_factory_field( +                        field_value=field_value, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ) +                    continue + +                field_result = cls.get_field_value( +                    field_meta, +                    field_build_parameters=field_build_parameters, +                    build_context=_build_context, +                ) +                if field_result is Null: +                    continue + +                result[field_meta.name] = field_result + +        for field_name, post_generator in generate_post.items(): +            result[field_name] = post_generator.to_value(field_name, result) + +        return result + +    @classmethod +    def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: +        """Process the given kwargs and generate values for the factory's model. + +        :param kwargs: Any build kwargs. +        :param build_context: BuildContext data for current build. + +        :returns: A dictionary of build results. + +        """ +        _build_context = _get_build_context(kwargs.pop("_build_context", None)) +        _build_context["seen_models"].add(cls.__model__) + +        result: dict[str, Any] = {**kwargs} +        generate_post: dict[str, PostGenerated] = {} + +        for field_meta in cls.get_model_fields(): +            field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + +            if cls.should_set_field_value(field_meta, **kwargs): +                if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): +                    field_value = getattr(cls, field_meta.name) +                    if isinstance(field_value, Ignore): +                        continue + +                    if isinstance(field_value, Require) and field_meta.name not in kwargs: +                        msg = f"Require kwarg {field_meta.name} is missing" +                        raise MissingBuildKwargException(msg) + +                    if isinstance(field_value, PostGenerated): +                        generate_post[field_meta.name] = field_value +                        continue + +                    result[field_meta.name] = cls._handle_factory_field_coverage( +                        field_value=field_value, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ) +                    continue + +                result[field_meta.name] = CoverageContainer( +                    cls.get_field_value_coverage( +                        field_meta, +                        field_build_parameters=field_build_parameters, +                        build_context=_build_context, +                    ), +                ) + +        for resolved in resolve_kwargs_coverage(result): +            for field_name, post_generator in generate_post.items(): +                resolved[field_name] = post_generator.to_value(field_name, resolved) +            yield resolved + +    @classmethod +    def build(cls, **kwargs: Any) -> T: +        """Build an instance of the factory's __model__ + +        :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) + +    @classmethod +    def batch(cls, size: int, **kwargs: Any) -> list[T]: +        """Build a batch of size n of the factory's Meta.model. + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. + +        """ +        return [cls.build(**kwargs) for _ in range(size)] + +    @classmethod +    def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: +        """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A iterator of instances of type T. + +        """ +        for data in cls.process_kwargs_coverage(**kwargs): +            instance = cls.__model__(**data) +            yield cast("T", instance) + +    @classmethod +    def create_sync(cls, **kwargs: Any) -> T: +        """Build and persists synchronously a single model instance. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        return cls._get_sync_persistence().save(data=cls.build(**kwargs)) + +    @classmethod +    def create_batch_sync(cls, size: int, **kwargs: Any) -> list[T]: +        """Build and persists synchronously a batch of n size model instances. + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. + +        """ +        return cls._get_sync_persistence().save_many(data=cls.batch(size, **kwargs)) + +    @classmethod +    async def create_async(cls, **kwargs: Any) -> T: +        """Build and persists asynchronously a single model instance. + +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. +        """ +        return await cls._get_async_persistence().save(data=cls.build(**kwargs)) + +    @classmethod +    async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: +        """Build and persists asynchronously a batch of n size model instances. + + +        :param size: Size of the batch. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: A list of instances of type T. +        """ +        return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) + + +def _register_builtin_factories() -> None: +    """This function is used to register the base factories, if present. + +    :returns: None +    """ +    import polyfactory.factories.dataclass_factory +    import polyfactory.factories.typed_dict_factory  # noqa: F401 + +    for module in [ +        "polyfactory.factories.pydantic_factory", +        "polyfactory.factories.beanie_odm_factory", +        "polyfactory.factories.odmantic_odm_factory", +        "polyfactory.factories.msgspec_factory", +        # `AttrsFactory` is not being registered by default since not all versions of `attrs` are supported. +        # Issue: https://github.com/litestar-org/polyfactory/issues/356 +        # "polyfactory.factories.attrs_factory", +    ]: +        try: +            import_module(module) +        except ImportError:  # noqa: PERF203 +            continue + + +_register_builtin_factories() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py new file mode 100644 index 0000000..ddd3169 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from typing_extensions import get_args + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.persistence import AsyncPersistenceProtocol +from polyfactory.utils.predicates import is_safe_subclass + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +    from polyfactory.factories.base import BuildContext +    from polyfactory.field_meta import FieldMeta + +try: +    from beanie import Document +except ImportError as e: +    msg = "beanie is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Document) + + +class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]): +    """Persistence Handler using beanie logic""" + +    async def save(self, data: T) -> T: +        """Persist a single instance in mongoDB.""" +        return await data.insert()  # type: ignore[no-any-return] + +    async def save_many(self, data: list[T]) -> list[T]: +        """Persist multiple instances in mongoDB. + +        .. note:: we cannot use the ``.insert_many`` method from Beanie here because it doesn't +            return the created instances +        """ +        return [await doc.insert() for doc in data]  # pyright: ignore[reportGeneralTypeIssues] + + +class BeanieDocumentFactory(Generic[T], ModelFactory[T]): +    """Base factory for Beanie Documents""" + +    __async_persistence__ = BeaniePersistenceHandler +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_safe_subclass(value, Document) + +    @classmethod +    def get_field_value( +        cls, +        field_meta: "FieldMeta", +        field_build_parameters: Any | None = None, +        build_context: BuildContext | None = None, +    ) -> Any: +        """Return a field value on the subclass if existing, otherwise returns a mock value. + +        :param field_meta: FieldMeta instance. +        :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + +        :returns: An arbitrary value. + +        """ +        if hasattr(field_meta.annotation, "__name__"): +            if "Indexed " in field_meta.annotation.__name__: +                base_type = field_meta.annotation.__bases__[0] +                field_meta.annotation = base_type + +            if "Link" in field_meta.annotation.__name__: +                link_class = get_args(field_meta.annotation)[0] +                field_meta.annotation = link_class +                field_meta.annotation = link_class + +        return super().get_field_value( +            field_meta=field_meta, +            field_build_parameters=field_build_parameters, +            build_context=build_context, +        ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py new file mode 100644 index 0000000..01cfbe7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields, is_dataclass +from typing import Any, Generic + +from typing_extensions import TypeGuard, get_type_hints + +from polyfactory.factories.base import BaseFactory, T +from polyfactory.field_meta import FieldMeta, Null + + +class DataclassFactory(Generic[T], BaseFactory[T]): +    """Dataclass base factory""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return bool(is_dataclass(value)) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        fields_meta: list["FieldMeta"] = [] + +        model_type_hints = get_type_hints(cls.__model__, include_extras=True) + +        for field in fields(cls.__model__):  # type: ignore[arg-type] +            if not field.init: +                continue + +            if field.default_factory and field.default_factory is not MISSING: +                default_value = field.default_factory() +            elif field.default is not MISSING: +                default_value = field.default +            else: +                default_value = Null + +            fields_meta.append( +                FieldMeta.from_type( +                    annotation=model_type_hints[field.name], +                    name=field.name, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) + +        return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py new file mode 100644 index 0000000..1b579ae --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from inspect import isclass +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +from typing_extensions import get_type_hints + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null +from polyfactory.value_generators.constrained_numbers import handle_constrained_int +from polyfactory.value_generators.primitives import create_random_bytes + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + +try: +    import msgspec +    from msgspec.structs import fields +except ImportError as e: +    msg = "msgspec is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=msgspec.Struct) + + +class MsgspecFactory(Generic[T], BaseFactory[T]): +    """Base factory for msgspec Structs.""" + +    __is_base_factory__ = True + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        def get_msgpack_ext() -> msgspec.msgpack.Ext: +            code = handle_constrained_int(cls.__random__, ge=-128, le=127) +            data = create_random_bytes(cls.__random__) +            return msgspec.msgpack.Ext(code, data) + +        msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext} + +        provider_map = super().get_provider_map() +        provider_map.update(msgspec_provider_map) + +        return provider_map + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        return isclass(value) and hasattr(value, "__struct_fields__") + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        fields_meta: list[FieldMeta] = [] + +        type_hints = get_type_hints(cls.__model__, include_extras=True) +        for field in fields(cls.__model__): +            annotation = type_hints[field.name] +            if field.default is not msgspec.NODEFAULT: +                default_value = field.default +            elif field.default_factory is not msgspec.NODEFAULT: +                default_value = field.default_factory() +            else: +                default_value = Null + +            fields_meta.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    name=field.name, +                    default=default_value, +                    random=cls.__random__, +                ), +            ) +        return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py new file mode 100644 index 0000000..1b3367a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import decimal +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.utils.predicates import is_safe_subclass +from polyfactory.value_generators.primitives import create_random_bytes + +try: +    from bson.decimal128 import Decimal128, create_decimal128_context +    from odmantic import EmbeddedModel, Model +    from odmantic import bson as odbson + +except ImportError as e: +    msg = "odmantic is not installed" +    raise MissingDependencyException(msg) from e + +T = TypeVar("T", bound=Union[Model, EmbeddedModel]) + +if TYPE_CHECKING: +    from typing_extensions import TypeGuard + + +class OdmanticModelFactory(Generic[T], ModelFactory[T]): +    """Base factory for odmantic models""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_safe_subclass(value, (Model, EmbeddedModel)) + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        provider_map = super().get_provider_map() +        provider_map.update( +            { +                odbson.Int64: lambda: odbson.Int64.validate(cls.__faker__.pyint()), +                odbson.Decimal128: lambda: _to_decimal128(cls.__faker__.pydecimal()), +                odbson.Binary: lambda: odbson.Binary.validate(create_random_bytes(cls.__random__)), +                odbson._datetime: lambda: odbson._datetime.validate(cls.__faker__.date_time_between()), +                # bson.Regex and bson._Pattern not supported as there is no way to generate +                # a random regular expression with Faker +                # bson.Regex: +                # bson._Pattern: +            }, +        ) +        return provider_map + + +def _to_decimal128(value: decimal.Decimal) -> Decimal128: +    with decimal.localcontext(create_decimal128_context()) as ctx: +        return Decimal128(ctx.create_decimal(value)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py new file mode 100644 index 0000000..a6028b1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +from contextlib import suppress +from datetime import timezone +from functools import partial +from os.path import realpath +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast +from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5 + +from typing_extensions import Literal, get_args, get_origin + +from polyfactory.collection_extender import CollectionExtender +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import Constraints, FieldMeta, Null +from polyfactory.utils.deprecation import check_for_deprecated_parameters +from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional +from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType +from polyfactory.value_generators.primitives import create_random_bytes + +try: +    import pydantic +    from pydantic import VERSION, Json +    from pydantic.fields import FieldInfo +except ImportError as e: +    msg = "pydantic is not installed" +    raise MissingDependencyException(msg) from e + +try: +    # pydantic v1 +    from pydantic import (  # noqa: I001 +        UUID1, +        UUID3, +        UUID4, +        UUID5, +        AmqpDsn, +        AnyHttpUrl, +        AnyUrl, +        DirectoryPath, +        FilePath, +        HttpUrl, +        KafkaDsn, +        PostgresDsn, +        RedisDsn, +    ) +    from pydantic import BaseModel as BaseModelV1 +    from pydantic.color import Color +    from pydantic.fields import (  # type: ignore[attr-defined] +        DeferredType,  # pyright: ignore[reportGeneralTypeIssues] +        ModelField,  # pyright: ignore[reportGeneralTypeIssues] +        Undefined,  # pyright: ignore[reportGeneralTypeIssues] +    ) + +    # Keep this import last to prevent warnings from pydantic if pydantic v2 +    # is installed. +    from pydantic import PyObject + +    # prevent unbound variable warnings +    BaseModelV2 = BaseModelV1 +    UndefinedV2 = Undefined +except ImportError: +    # pydantic v2 + +    # v2 specific imports +    from pydantic import BaseModel as BaseModelV2 +    from pydantic_core import PydanticUndefined as UndefinedV2 +    from pydantic_core import to_json + +    from pydantic.v1 import (  # v1 compat imports +        UUID1, +        UUID3, +        UUID4, +        UUID5, +        AmqpDsn, +        AnyHttpUrl, +        AnyUrl, +        DirectoryPath, +        FilePath, +        HttpUrl, +        KafkaDsn, +        PostgresDsn, +        PyObject, +        RedisDsn, +    ) +    from pydantic.v1 import BaseModel as BaseModelV1  # type: ignore[assignment] +    from pydantic.v1.color import Color  # type: ignore[assignment] +    from pydantic.v1.fields import DeferredType, ModelField, Undefined + + +if TYPE_CHECKING: +    from random import Random +    from typing import Callable, Sequence + +    from typing_extensions import NotRequired, TypeGuard + +T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") + +_IS_PYDANTIC_V1 = VERSION.startswith("1") + + +class PydanticConstraints(Constraints): +    """Metadata regarding a Pydantic type constraints, if any""" + +    json: NotRequired[bool] + + +class PydanticFieldMeta(FieldMeta): +    """Field meta subclass capable of handling pydantic ModelFields""" + +    def __init__( +        self, +        *, +        name: str, +        annotation: type, +        random: Random | None = None, +        default: Any = ..., +        children: list[FieldMeta] | None = None, +        constraints: PydanticConstraints | None = None, +    ) -> None: +        super().__init__( +            name=name, +            annotation=annotation, +            random=random, +            default=default, +            children=children, +            constraints=constraints, +        ) + +    @classmethod +    def from_field_info( +        cls, +        field_name: str, +        field_info: FieldInfo, +        use_alias: bool, +        random: Random | None, +        randomize_collection_length: bool | None = None, +        min_collection_length: int | None = None, +        max_collection_length: int | None = None, +    ) -> PydanticFieldMeta: +        """Create an instance from a pydantic field info. + +        :param field_name: The name of the field. +        :param field_info: A pydantic FieldInfo instance. +        :param use_alias: Whether to use the field alias. +        :param random: A random.Random instance. +        :param randomize_collection_length: Whether to randomize collection length. +        :param min_collection_length: Minimum collection length. +        :param max_collection_length: Maximum collection length. + +        :returns: A PydanticFieldMeta instance. +        """ +        check_for_deprecated_parameters( +            "2.11.0", +            parameters=( +                ("randomize_collection_length", randomize_collection_length), +                ("min_collection_length", min_collection_length), +                ("max_collection_length", max_collection_length), +            ), +        ) +        if callable(field_info.default_factory): +            default_value = field_info.default_factory() +        else: +            default_value = field_info.default if field_info.default is not UndefinedV2 else Null + +        annotation = unwrap_new_type(field_info.annotation) +        children: list[FieldMeta,] | None = None +        name = field_info.alias if field_info.alias and use_alias else field_name + +        constraints: PydanticConstraints +        # pydantic v2 does not always propagate metadata for Union types +        if is_union(annotation): +            constraints = {} +            children = [] +            for arg in get_args(annotation): +                if arg is NoneType: +                    continue +                child_field_info = FieldInfo.from_annotation(arg) +                merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info) +                children.append( +                    cls.from_field_info( +                        field_name="", +                        field_info=merged_field_info, +                        use_alias=use_alias, +                        random=random, +                    ), +                ) +        else: +            metadata, is_json = [], False +            for m in field_info.metadata: +                if not is_json and isinstance(m, Json):  # type: ignore[misc] +                    is_json = True +                elif m is not None: +                    metadata.append(m) + +            constraints = cast( +                PydanticConstraints, +                cls.parse_constraints(metadata=metadata) if metadata else {}, +            ) + +            if "url" in constraints: +                # pydantic uses a sentinel value for url constraints +                annotation = str + +            if is_json: +                constraints["json"] = True + +        return PydanticFieldMeta.from_type( +            annotation=annotation, +            children=children, +            constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None, +            default=default_value, +            name=name, +            random=random or DEFAULT_RANDOM, +        ) + +    @classmethod +    def from_model_field(  # pragma: no cover +        cls, +        model_field: ModelField,  # pyright: ignore[reportGeneralTypeIssues] +        use_alias: bool, +        randomize_collection_length: bool | None = None, +        min_collection_length: int | None = None, +        max_collection_length: int | None = None, +        random: Random = DEFAULT_RANDOM, +    ) -> PydanticFieldMeta: +        """Create an instance from a pydantic model field. +        :param model_field: A pydantic ModelField. +        :param use_alias: Whether to use the field alias. +        :param randomize_collection_length: A boolean flag whether to randomize collections lengths +        :param min_collection_length: Minimum number of elements in randomized collection +        :param max_collection_length: Maximum number of elements in randomized collection +        :param random: An instance of random.Random. + +        :returns: A PydanticFieldMeta instance. + +        """ +        check_for_deprecated_parameters( +            "2.11.0", +            parameters=( +                ("randomize_collection_length", randomize_collection_length), +                ("min_collection_length", min_collection_length), +                ("max_collection_length", max_collection_length), +            ), +        ) + +        if model_field.default is not Undefined: +            default_value = model_field.default +        elif callable(model_field.default_factory): +            default_value = model_field.default_factory() +        else: +            default_value = model_field.default if model_field.default is not Undefined else Null + +        name = model_field.alias if model_field.alias and use_alias else model_field.name + +        outer_type = unwrap_new_type(model_field.outer_type_) +        annotation = ( +            model_field.outer_type_ +            if isinstance(model_field.annotation, (DeferredType, ForwardRef)) +            else unwrap_new_type(model_field.annotation) +        ) + +        constraints = cast( +            "Constraints", +            { +                "ge": getattr(outer_type, "ge", model_field.field_info.ge), +                "gt": getattr(outer_type, "gt", model_field.field_info.gt), +                "le": getattr(outer_type, "le", model_field.field_info.le), +                "lt": getattr(outer_type, "lt", model_field.field_info.lt), +                "min_length": ( +                    getattr(outer_type, "min_length", model_field.field_info.min_length) +                    or getattr(outer_type, "min_items", model_field.field_info.min_items) +                ), +                "max_length": ( +                    getattr(outer_type, "max_length", model_field.field_info.max_length) +                    or getattr(outer_type, "max_items", model_field.field_info.max_items) +                ), +                "pattern": getattr(outer_type, "regex", model_field.field_info.regex), +                "unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items), +                "decimal_places": getattr(outer_type, "decimal_places", None), +                "max_digits": getattr(outer_type, "max_digits", None), +                "multiple_of": getattr(outer_type, "multiple_of", None), +                "upper_case": getattr(outer_type, "to_upper", None), +                "lower_case": getattr(outer_type, "to_lower", None), +                "item_type": getattr(outer_type, "item_type", None), +            }, +        ) + +        # pydantic v1 has constraints set for these values, but we generate them using faker +        if unwrap_optional(annotation) in ( +            AnyUrl, +            HttpUrl, +            KafkaDsn, +            PostgresDsn, +            RedisDsn, +            AmqpDsn, +            AnyHttpUrl, +        ): +            constraints = {} + +        if model_field.field_info.const and ( +            default_value is None or isinstance(default_value, (int, bool, str, bytes)) +        ): +            annotation = Literal[default_value]  # pyright: ignore  # noqa: PGH003 + +        children: list[FieldMeta] = [] + +        # Refer #412. +        args = get_args(model_field.annotation) +        if is_optional(model_field.annotation) and len(args) == 2:  # noqa: PLR2004 +            child_annotation = args[0] if args[0] is not NoneType else args[1] +            children.append(PydanticFieldMeta.from_type(child_annotation)) +        elif model_field.key_field or model_field.sub_fields: +            fields_to_iterate = ( +                ([model_field.key_field, *model_field.sub_fields]) +                if model_field.key_field is not None +                else model_field.sub_fields +            ) +            type_args = tuple( +                ( +                    sub_field.outer_type_ +                    if isinstance(sub_field.annotation, DeferredType) +                    else unwrap_new_type(sub_field.annotation) +                ) +                for sub_field in fields_to_iterate +            ) +            type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate)) +            if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis: +                # pydantic removes ellipses from Tuples in sub_fields +                type_args += (...,) +            extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1) +            children.extend( +                PydanticFieldMeta.from_model_field( +                    model_field=type_arg_to_sub_field[arg], +                    use_alias=use_alias, +                    random=random, +                ) +                for arg in extended_type_args +            ) + +        return PydanticFieldMeta( +            name=name, +            random=random or DEFAULT_RANDOM, +            annotation=annotation, +            children=children or None, +            default=default_value, +            constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None, +        ) + +    if not _IS_PYDANTIC_V1: + +        @classmethod +        def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: +            metadata = [] +            for m in super().get_constraints_metadata(annotation): +                if isinstance(m, FieldInfo): +                    metadata.extend(m.metadata) +                else: +                    metadata.append(m) + +            return metadata + + +class ModelFactory(Generic[T], BaseFactory[T]): +    """Base factory for pydantic models""" + +    __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} +    __is_base_factory__ = True + +    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: +        super().__init_subclass__(*args, **kwargs) + +        if ( +            getattr(cls, "__model__", None) +            and _is_pydantic_v1_model(cls.__model__) +            and hasattr(cls.__model__, "update_forward_refs") +        ): +            with suppress(NameError):  # pragma: no cover +                cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__)  # type: ignore[attr-defined] + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ + +        return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        if "_fields_metadata" not in cls.__dict__: +            if _is_pydantic_v1_model(cls.__model__): +                cls._fields_metadata = [ +                    PydanticFieldMeta.from_model_field( +                        field, +                        use_alias=not cls.__model__.__config__.allow_population_by_field_name,  # type: ignore[attr-defined] +                        random=cls.__random__, +                    ) +                    for field in cls.__model__.__fields__.values() +                ] +            else: +                cls._fields_metadata = [ +                    PydanticFieldMeta.from_field_info( +                        field_info=field_info, +                        field_name=field_name, +                        random=cls.__random__, +                        use_alias=not cls.__model__.model_config.get(  # pyright: ignore[reportGeneralTypeIssues] +                            "populate_by_name", +                            False, +                        ), +                    ) +                    for field_name, field_info in cls.__model__.model_fields.items()  # pyright: ignore[reportGeneralTypeIssues] +                ] +        return cls._fields_metadata + +    @classmethod +    def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: +        constraints = cast(PydanticConstraints, field_meta.constraints) +        if constraints.pop("json", None): +            value = cls.get_field_value(field_meta) +            return to_json(value)  # pyright: ignore[reportUnboundVariable] + +        return super().get_constrained_field_value(annotation, field_meta) + +    @classmethod +    def build( +        cls, +        factory_use_construct: bool = False, +        **kwargs: Any, +    ) -> T: +        """Build an instance of the factory's __model__ + +        :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the +                model. This is supported only for pydantic models. +        :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + +        :returns: An instance of type T. + +        """ +        processed_kwargs = cls.process_kwargs(**kwargs) + +        if factory_use_construct: +            if _is_pydantic_v1_model(cls.__model__): +                return cls.__model__.construct(**processed_kwargs)  # type: ignore[return-value] +            return cls.__model__.model_construct(**processed_kwargs)  # type: ignore[return-value] + +        return cls.__model__(**processed_kwargs)  # type: ignore[return-value] + +    @classmethod +    def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: +        """Determine whether the field is a custom root field. + +        :param field_meta: FieldMeta instance. + +        :returns: A boolean determining whether the field is a custom root. + +        """ +        return field_meta.name == "__root__" + +    @classmethod +    def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: +        """Determine whether to set a value for a given field_name. +        This is an override of BaseFactory.should_set_field_value. + +        :param field_meta: FieldMeta instance. +        :param kwargs: Any kwargs passed to the factory. + +        :returns: A boolean determining whether a value should be set for the given field_meta. + +        """ +        return field_meta.name not in kwargs and ( +            not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta) +        ) + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        mapping = { +            pydantic.ByteSize: cls.__faker__.pyint, +            pydantic.PositiveInt: cls.__faker__.pyint, +            pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), +            pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, +            pydantic.PositiveFloat: cls.__faker__.pyint, +            pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), +            pydantic.NonNegativeInt: cls.__faker__.pyint, +            pydantic.StrictInt: cls.__faker__.pyint, +            pydantic.StrictBool: cls.__faker__.pybool, +            pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), +            pydantic.StrictFloat: cls.__faker__.pyfloat, +            pydantic.StrictStr: cls.__faker__.pystr, +            pydantic.EmailStr: cls.__faker__.free_email, +            pydantic.NameEmail: cls.__faker__.free_email, +            pydantic.Json: cls.__faker__.json, +            pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, +            pydantic.AnyUrl: cls.__faker__.url, +            pydantic.AnyHttpUrl: cls.__faker__.url, +            pydantic.HttpUrl: cls.__faker__.url, +            pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), +            pydantic.SecretStr: cls.__faker__.pystr, +            pydantic.IPvAnyAddress: cls.__faker__.ipv4, +            pydantic.IPvAnyInterface: cls.__faker__.ipv4, +            pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), +            pydantic.PastDate: cls.__faker__.past_date, +            pydantic.FutureDate: cls.__faker__.future_date, +        } + +        # v1 only values +        mapping.update( +            { +                PyObject: lambda: "decimal.Decimal", +                AmqpDsn: lambda: "amqps://example.com", +                KafkaDsn: lambda: "kafka://localhost:9092", +                PostgresDsn: lambda: "postgresql://user:secret@localhost", +                RedisDsn: lambda: "redis://localhost:6379/0", +                FilePath: lambda: Path(realpath(__file__)), +                DirectoryPath: lambda: Path(realpath(__file__)).parent, +                UUID1: uuid1, +                UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), +                UUID4: cls.__faker__.uuid4, +                UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), +                Color: cls.__faker__.hex_color,  # pyright: ignore[reportGeneralTypeIssues] +            }, +        ) + +        if not _IS_PYDANTIC_V1: +            mapping.update( +                { +                    # pydantic v2 specific types +                    pydantic.PastDatetime: cls.__faker__.past_datetime, +                    pydantic.FutureDatetime: cls.__faker__.future_datetime, +                    pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), +                    pydantic.NaiveDatetime: cls.__faker__.date_time, +                }, +            ) + +        mapping.update(super().get_provider_map()) +        return mapping + + +def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: +    return is_safe_subclass(model, BaseModelV1) + + +def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: +    return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py new file mode 100644 index 0000000..ad8873f --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union + +from polyfactory.exceptions import MissingDependencyException +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta +from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol + +try: +    from sqlalchemy import Column, inspect, types +    from sqlalchemy.dialects import mysql, postgresql +    from sqlalchemy.exc import NoInspectionAvailable +    from sqlalchemy.orm import InstanceState, Mapper +except ImportError as e: +    msg = "sqlalchemy is not installed" +    raise MissingDependencyException(msg) from e + +if TYPE_CHECKING: +    from sqlalchemy.ext.asyncio import AsyncSession +    from sqlalchemy.orm import Session +    from typing_extensions import TypeGuard + + +T = TypeVar("T") + + +class SQLASyncPersistence(SyncPersistenceProtocol[T]): +    def __init__(self, session: Session) -> None: +        """Sync persistence handler for SQLAFactory.""" +        self.session = session + +    def save(self, data: T) -> T: +        self.session.add(data) +        self.session.commit() +        return data + +    def save_many(self, data: list[T]) -> list[T]: +        self.session.add_all(data) +        self.session.commit() +        return data + + +class SQLAASyncPersistence(AsyncPersistenceProtocol[T]): +    def __init__(self, session: AsyncSession) -> None: +        """Async persistence handler for SQLAFactory.""" +        self.session = session + +    async def save(self, data: T) -> T: +        self.session.add(data) +        await self.session.commit() +        return data + +    async def save_many(self, data: list[T]) -> list[T]: +        self.session.add_all(data) +        await self.session.commit() +        return data + + +class SQLAlchemyFactory(Generic[T], BaseFactory[T]): +    """Base factory for SQLAlchemy models.""" + +    __is_base_factory__ = True + +    __set_primary_key__: ClassVar[bool] = True +    """Configuration to consider primary key columns as a field or not.""" +    __set_foreign_keys__: ClassVar[bool] = True +    """Configuration to consider columns with foreign keys as a field or not.""" +    __set_relationships__: ClassVar[bool] = False +    """Configuration to consider relationships property as a model field or not.""" + +    __session__: ClassVar[Session | Callable[[], Session] | None] = None +    __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None + +    __config_keys__ = ( +        *BaseFactory.__config_keys__, +        "__set_primary_key__", +        "__set_foreign_keys__", +        "__set_relationships__", +    ) + +    @classmethod +    def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: +        """Get mapping of types where column type.""" +        return { +            types.TupleType: cls.__faker__.pytuple, +            mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), +            postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False), +            postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()),  # noqa: DTZ011 +            postgresql.INET: lambda: cls.__faker__.ipv4(network=True), +            postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), +            postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), +            postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()),  # noqa: DTZ005 +            postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()),  # noqa: DTZ005 +        } + +    @classmethod +    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: +        providers_map = super().get_provider_map() +        providers_map.update(cls.get_sqlalchemy_types()) +        return providers_map + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: +        try: +            inspected = inspect(value) +        except NoInspectionAvailable: +            return False +        return isinstance(inspected, (Mapper, InstanceState)) + +    @classmethod +    def should_column_be_set(cls, column: Column) -> bool: +        if not cls.__set_primary_key__ and column.primary_key: +            return False + +        return bool(cls.__set_foreign_keys__ or not column.foreign_keys) + +    @classmethod +    def get_type_from_column(cls, column: Column) -> type: +        column_type = type(column.type) +        if column_type in cls.get_sqlalchemy_types(): +            annotation = column_type +        elif issubclass(column_type, types.ARRAY): +            annotation = List[column.type.item_type.python_type]  # type: ignore[assignment,name-defined] +        else: +            annotation = ( +                column.type.impl.python_type  # pyright: ignore[reportGeneralTypeIssues] +                if hasattr(column.type, "impl") +                else column.type.python_type +            ) + +        if column.nullable: +            annotation = Union[annotation, None]  # type: ignore[assignment] + +        return annotation + +    @classmethod +    def get_model_fields(cls) -> list[FieldMeta]: +        fields_meta: list[FieldMeta] = [] + +        table: Mapper = inspect(cls.__model__)  # type: ignore[assignment] +        fields_meta.extend( +            FieldMeta.from_type( +                annotation=cls.get_type_from_column(column), +                name=name, +                random=cls.__random__, +            ) +            for name, column in table.columns.items() +            if cls.should_column_be_set(column) +        ) +        if cls.__set_relationships__: +            for name, relationship in table.relationships.items(): +                class_ = relationship.entity.class_ +                annotation = class_ if not relationship.uselist else List[class_]  # type: ignore[valid-type] +                fields_meta.append( +                    FieldMeta.from_type( +                        name=name, +                        annotation=annotation, +                        random=cls.__random__, +                    ), +                ) + +        return fields_meta + +    @classmethod +    def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: +        if cls.__session__ is not None: +            return ( +                SQLASyncPersistence(cls.__session__()) +                if callable(cls.__session__) +                else SQLASyncPersistence(cls.__session__) +            ) +        return super()._get_sync_persistence() + +    @classmethod +    def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: +        if cls.__async_session__ is not None: +            return ( +                SQLAASyncPersistence(cls.__async_session__()) +                if callable(cls.__async_session__) +                else SQLAASyncPersistence(cls.__async_session__) +            ) +        return super()._get_async_persistence() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py new file mode 100644 index 0000000..2a3ea1b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Any, Generic, TypeVar, get_args + +from typing_extensions import (  # type: ignore[attr-defined] +    NotRequired, +    Required, +    TypeGuard, +    _TypedDictMeta,  # pyright: ignore[reportGeneralTypeIssues] +    get_origin, +    get_type_hints, +    is_typeddict, +) + +from polyfactory.constants import DEFAULT_RANDOM +from polyfactory.factories.base import BaseFactory +from polyfactory.field_meta import FieldMeta, Null + +TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta) + + +class TypedDictFactory(Generic[TypedDictT], BaseFactory[TypedDictT]): +    """TypedDict base factory""" + +    __is_base_factory__ = True + +    @classmethod +    def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]: +        """Determine whether the given value is supported by the factory. + +        :param value: An arbitrary value. +        :returns: A typeguard +        """ +        return is_typeddict(value) + +    @classmethod +    def get_model_fields(cls) -> list["FieldMeta"]: +        """Retrieve a list of fields from the factory's model. + + +        :returns: A list of field MetaData instances. + +        """ +        model_type_hints = get_type_hints(cls.__model__, include_extras=True) + +        field_metas: list[FieldMeta] = [] +        for field_name, annotation in model_type_hints.items(): +            origin = get_origin(annotation) +            if origin in (Required, NotRequired): +                annotation = get_args(annotation)[0]  # noqa: PLW2901 + +            field_metas.append( +                FieldMeta.from_type( +                    annotation=annotation, +                    random=DEFAULT_RANDOM, +                    name=field_name, +                    default=getattr(cls.__model__, field_name, Null), +                ), +            ) + +        return field_metas diff --git a/venv/lib/python3.11/site-packages/polyfactory/field_meta.py b/venv/lib/python3.11/site-packages/polyfactory/field_meta.py new file mode 100644 index 0000000..d6288fd --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/field_meta.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from dataclasses import asdict, is_dataclass +from typing import TYPE_CHECKING, Any, Literal, Mapping, Pattern, TypedDict, cast + +from typing_extensions import get_args, get_origin + +from polyfactory.collection_extender import CollectionExtender +from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING +from polyfactory.utils.deprecation import check_for_deprecated_parameters +from polyfactory.utils.helpers import ( +    get_annotation_metadata, +    unwrap_annotated, +    unwrap_new_type, +) +from polyfactory.utils.predicates import is_annotated +from polyfactory.utils.types import NoneType + +if TYPE_CHECKING: +    import datetime +    from decimal import Decimal +    from random import Random +    from typing import Sequence + +    from typing_extensions import NotRequired, Self + + +class Null: +    """Sentinel class for empty values""" + + +class UrlConstraints(TypedDict): +    max_length: NotRequired[int] +    allowed_schemes: NotRequired[list[str]] +    host_required: NotRequired[bool] +    default_host: NotRequired[str] +    default_port: NotRequired[int] +    default_path: NotRequired[str] + + +class Constraints(TypedDict): +    """Metadata regarding a type constraints, if any""" + +    allow_inf_nan: NotRequired[bool] +    decimal_places: NotRequired[int] +    ge: NotRequired[int | float | Decimal] +    gt: NotRequired[int | float | Decimal] +    item_type: NotRequired[Any] +    le: NotRequired[int | float | Decimal] +    lower_case: NotRequired[bool] +    lt: NotRequired[int | float | Decimal] +    max_digits: NotRequired[int] +    max_length: NotRequired[int] +    min_length: NotRequired[int] +    multiple_of: NotRequired[int | float | Decimal] +    path_type: NotRequired[Literal["file", "dir", "new"]] +    pattern: NotRequired[str | Pattern] +    tz: NotRequired[datetime.tzinfo] +    unique_items: NotRequired[bool] +    upper_case: NotRequired[bool] +    url: NotRequired[UrlConstraints] +    uuid_version: NotRequired[Literal[1, 3, 4, 5]] + + +class FieldMeta: +    """Factory field metadata container. This class is used to store the data about a field of a factory's model.""" + +    __slots__ = ("name", "annotation", "random", "children", "default", "constraints") + +    annotation: Any +    random: Random +    children: list[FieldMeta] | None +    default: Any +    name: str +    constraints: Constraints | None + +    def __init__( +        self, +        *, +        name: str, +        annotation: type, +        random: Random | None = None, +        default: Any = Null, +        children: list[FieldMeta] | None = None, +        constraints: Constraints | None = None, +    ) -> None: +        """Create a factory field metadata instance.""" +        self.annotation = annotation +        self.random = random or DEFAULT_RANDOM +        self.children = children +        self.default = default +        self.name = name +        self.constraints = constraints + +    @property +    def type_args(self) -> tuple[Any, ...]: +        """Return the normalized type args of the annotation, if any. + +        :returns: a tuple of types. +        """ +        return tuple(TYPE_MAPPING[arg] if arg in TYPE_MAPPING else arg for arg in get_args(self.annotation)) + +    @classmethod +    def from_type( +        cls, +        annotation: Any, +        random: Random = DEFAULT_RANDOM, +        name: str = "", +        default: Any = Null, +        constraints: Constraints | None = None, +        randomize_collection_length: bool | None = None, +        min_collection_length: int | None = None, +        max_collection_length: int | None = None, +        children: list[FieldMeta] | None = None, +    ) -> Self: +        """Builder method to create a FieldMeta from a type annotation. + +        :param annotation: A type annotation. +        :param random: An instance of random.Random. +        :param name: Field name +        :param default: Default value, if any. +        :param constraints: A dictionary of constraints, if any. +        :param randomize_collection_length: A boolean flag whether to randomize collections lengths +        :param min_collection_length: Minimum number of elements in randomized collection +        :param max_collection_length: Maximum number of elements in randomized collection + +        :returns: A field meta instance. +        """ +        check_for_deprecated_parameters( +            "2.11.0", +            parameters=( +                ("randomize_collection_length", randomize_collection_length), +                ("min_collection_length", min_collection_length), +                ("max_collection_length", max_collection_length), +            ), +        ) + +        annotated = is_annotated(annotation) +        if not constraints and annotated: +            metadata = cls.get_constraints_metadata(annotation) +            constraints = cls.parse_constraints(metadata) + +        if annotated: +            annotation = get_args(annotation)[0] +        elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING:  # pragma: no cover +            container = TYPE_MAPPING[origin] +            annotation = container[get_args(annotation)]  # type: ignore[index] + +        field = cls( +            annotation=annotation, +            random=random, +            name=name, +            default=default, +            children=children, +            constraints=constraints, +        ) + +        if field.type_args and not field.children: +            number_of_args = 1 +            extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args) +            field.children = [ +                cls.from_type( +                    annotation=unwrap_new_type(arg), +                    random=random, +                ) +                for arg in extended_type_args +                if arg is not NoneType +            ] +        return field + +    @classmethod +    def parse_constraints(cls, metadata: Sequence[Any]) -> "Constraints": +        constraints = {} + +        for value in metadata: +            if is_annotated(value): +                _, inner_metadata = unwrap_annotated(value, random=DEFAULT_RANDOM) +                constraints.update(cast("dict[str, Any]", cls.parse_constraints(metadata=inner_metadata))) +            elif func := getattr(value, "func", None): +                if func is str.islower: +                    constraints["lower_case"] = True +                elif func is str.isupper: +                    constraints["upper_case"] = True +                elif func is str.isascii: +                    constraints["pattern"] = "[[:ascii:]]" +                elif func is str.isdigit: +                    constraints["pattern"] = "[[:digit:]]" +            elif is_dataclass(value) and (value_dict := asdict(value)) and ("allowed_schemes" in value_dict): +                constraints["url"] = {k: v for k, v in value_dict.items() if v is not None} +            # This is to support `Constraints`, but we can't do a isinstance with `Constraints` since isinstance +            # checks with `TypedDict` is not supported. +            elif isinstance(value, Mapping): +                constraints.update(value) +            else: +                constraints.update( +                    { +                        k: v +                        for k, v in { +                            "allow_inf_nan": getattr(value, "allow_inf_nan", None), +                            "decimal_places": getattr(value, "decimal_places", None), +                            "ge": getattr(value, "ge", None), +                            "gt": getattr(value, "gt", None), +                            "item_type": getattr(value, "item_type", None), +                            "le": getattr(value, "le", None), +                            "lower_case": getattr(value, "to_lower", None), +                            "lt": getattr(value, "lt", None), +                            "max_digits": getattr(value, "max_digits", None), +                            "max_length": getattr(value, "max_length", getattr(value, "max_length", None)), +                            "min_length": getattr(value, "min_length", getattr(value, "min_items", None)), +                            "multiple_of": getattr(value, "multiple_of", None), +                            "path_type": getattr(value, "path_type", None), +                            "pattern": getattr(value, "regex", getattr(value, "pattern", None)), +                            "tz": getattr(value, "tz", None), +                            "unique_items": getattr(value, "unique_items", None), +                            "upper_case": getattr(value, "to_upper", None), +                            "uuid_version": getattr(value, "uuid_version", None), +                        }.items() +                        if v is not None +                    }, +                ) +        return cast("Constraints", constraints) + +    @classmethod +    def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: +        """Get the metadatas of the constraints from the given annotation. + +        :param annotation: A type annotation. +        :param random: An instance of random.Random. + +        :returns: A list of the metadata in the annotation. +        """ + +        return get_annotation_metadata(annotation) diff --git a/venv/lib/python3.11/site-packages/polyfactory/fields.py b/venv/lib/python3.11/site-packages/polyfactory/fields.py new file mode 100644 index 0000000..a3b19ef --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/fields.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Callable, Generic, TypedDict, TypeVar, cast + +from typing_extensions import ParamSpec + +from polyfactory.exceptions import ParameterException + +T = TypeVar("T") +P = ParamSpec("P") + + +class WrappedCallable(TypedDict): +    """A ref storing a callable. This class is a utility meant to prevent binding of methods.""" + +    value: Callable + + +class Require: +    """A factory field that marks an attribute as a required build-time kwarg.""" + + +class Ignore: +    """A factory field that marks an attribute as ignored.""" + + +class Use(Generic[P, T]): +    """Factory field used to wrap a callable. + +    The callable will be invoked whenever building the given factory attribute. + + +    """ + +    __slots__ = ("fn", "kwargs", "args") + +    def __init__(self, fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: +        """Wrap a callable. + +        :param fn: A callable to wrap. +        :param args: Any args to pass to the callable. +        :param kwargs: Any kwargs to pass to the callable. +        """ +        self.fn: WrappedCallable = {"value": fn} +        self.kwargs = kwargs +        self.args = args + +    def to_value(self) -> T: +        """Invoke the callable. + +        :returns: The output of the callable. + + +        """ +        return cast("T", self.fn["value"](*self.args, **self.kwargs)) + + +class PostGenerated: +    """Factory field that allows generating values after other fields are generated by the factory.""" + +    __slots__ = ("fn", "kwargs", "args") + +    def __init__(self, fn: Callable, *args: Any, **kwargs: Any) -> None: +        """Designate field as post-generated. + +        :param fn: A callable. +        :param args: Args for the callable. +        :param kwargs: Kwargs for the callable. +        """ +        self.fn: WrappedCallable = {"value": fn} +        self.kwargs = kwargs +        self.args = args + +    def to_value(self, name: str, values: dict[str, Any]) -> Any: +        """Invoke the post-generation callback passing to it the build results. + +        :param name: Field name. +        :param values: Generated values. + +        :returns: An arbitrary value. +        """ +        return self.fn["value"](name, values, *self.args, **self.kwargs) + + +class Fixture: +    """Factory field to create a pytest fixture from a factory.""" + +    __slots__ = ("ref", "size", "kwargs") + +    def __init__(self, fixture: Callable, size: int | None = None, **kwargs: Any) -> None: +        """Create a fixture from a factory. + +        :param fixture: A factory that was registered as a fixture. +        :param size: Optional batch size. +        :param kwargs: Any build kwargs. +        """ +        self.ref: WrappedCallable = {"value": fixture} +        self.size = size +        self.kwargs = kwargs + +    def to_value(self) -> Any: +        """Call the factory's build or batch method. + +        :raises: ParameterException + +        :returns: The build result. +        """ +        from polyfactory.pytest_plugin import FactoryFixture + +        if factory := FactoryFixture.factory_class_map.get(self.ref["value"]): +            if self.size is not None: +                return factory.batch(self.size, **self.kwargs) +            return factory.build(**self.kwargs) + +        msg = "fixture has not been registered using the register_factory decorator" +        raise ParameterException(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/persistence.py b/venv/lib/python3.11/site-packages/polyfactory/persistence.py new file mode 100644 index 0000000..7aa510b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/persistence.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") + + +@runtime_checkable +class SyncPersistenceProtocol(Protocol[T]): +    """Protocol for sync persistence""" + +    def save(self, data: T) -> T: +        """Persist a single instance synchronously. + +        :param data: A single instance to persist. + +        :returns: The persisted result. + +        """ +        ... + +    def save_many(self, data: list[T]) -> list[T]: +        """Persist multiple instances synchronously. + +        :param data: A list of instances to persist. + +        :returns: The persisted result + +        """ +        ... + + +@runtime_checkable +class AsyncPersistenceProtocol(Protocol[T]): +    """Protocol for async persistence""" + +    async def save(self, data: T) -> T: +        """Persist a single instance asynchronously. + +        :param data: A single instance to persist. + +        :returns: The persisted result. +        """ +        ... + +    async def save_many(self, data: list[T]) -> list[T]: +        """Persist multiple instances asynchronously. + +        :param data: A list of instances to persist. + +        :returns: The persisted result +        """ +        ... diff --git a/venv/lib/python3.11/site-packages/polyfactory/py.typed b/venv/lib/python3.11/site-packages/polyfactory/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/py.typed diff --git a/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py b/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py new file mode 100644 index 0000000..0ca9196 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/pytest_plugin.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import re +from typing import ( +    Any, +    Callable, +    ClassVar, +    Literal, +    Union, +) + +from pytest import Config, fixture  # noqa: PT013 + +from polyfactory.exceptions import ParameterException +from polyfactory.factories.base import BaseFactory +from polyfactory.utils.predicates import is_safe_subclass + +Scope = Union[ +    Literal["session", "package", "module", "class", "function"], +    Callable[[str, Config], Literal["session", "package", "module", "class", "function"]], +] + + +split_pattern_1 = re.compile(r"([A-Z]+)([A-Z][a-z])") +split_pattern_2 = re.compile(r"([a-z\d])([A-Z])") + + +def _get_fixture_name(name: str) -> str: +    """From inflection.underscore. + +    :param name: str: A name. + +    :returns: Normalized fixture name. + +    """ +    name = re.sub(split_pattern_1, r"\1_\2", name) +    name = re.sub(split_pattern_2, r"\1_\2", name) +    name = name.replace("-", "_") +    return name.lower() + + +class FactoryFixture: +    """Decorator that creates a pytest fixture from a factory""" + +    __slots__ = ("scope", "autouse", "name") + +    factory_class_map: ClassVar[dict[Callable, type[BaseFactory[Any]]]] = {} + +    def __init__( +        self, +        scope: Scope = "function", +        autouse: bool = False, +        name: str | None = None, +    ) -> None: +        """Create a factory fixture decorator + +        :param scope: Fixture scope +        :param autouse: Autouse the fixture +        :param name: Fixture name +        """ +        self.scope = scope +        self.autouse = autouse +        self.name = name + +    def __call__(self, factory: type[BaseFactory[Any]]) -> Any: +        if not is_safe_subclass(factory, BaseFactory): +            msg = f"{factory.__name__} is not a BaseFactory subclass." +            raise ParameterException(msg) + +        fixture_name = self.name or _get_fixture_name(factory.__name__) +        fixture_register = fixture( +            scope=self.scope,  # pyright: ignore[reportGeneralTypeIssues] +            name=fixture_name, +            autouse=self.autouse, +        ) + +        def _factory_fixture() -> type[BaseFactory[Any]]: +            """The wrapped factory""" +            return factory + +        _factory_fixture.__doc__ = factory.__doc__ +        marker = fixture_register(_factory_fixture) +        self.factory_class_map[marker] = factory +        return marker + + +def register_fixture( +    factory: type[BaseFactory[Any]] | None = None, +    *, +    scope: Scope = "function", +    autouse: bool = False, +    name: str | None = None, +) -> Any: +    """A decorator that allows registering model factories as fixtures. + +    :param factory: An optional factory class to decorate. +    :param scope: Pytest scope. +    :param autouse: Auto use fixture. +    :param name: Fixture name. + +    :returns: A fixture factory instance. +    """ +    factory_fixture = FactoryFixture(scope=scope, autouse=autouse, name=name) +    return factory_fixture(factory) if factory else factory_fixture diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__init__.py diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..6297806 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pycBinary files differ new file mode 100644 index 0000000..8f43f83 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/deprecation.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pycBinary files differ new file mode 100644 index 0000000..982e068 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/helpers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pycBinary files differ new file mode 100644 index 0000000..22c475b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/model_coverage.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pycBinary files differ new file mode 100644 index 0000000..c1908b6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/predicates.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pycBinary files differ new file mode 100644 index 0000000..59b2b6a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/__pycache__/types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py b/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py new file mode 100644 index 0000000..576c4ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/deprecation.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Any, Callable, Literal, TypeVar +from warnings import warn + +from typing_extensions import ParamSpec + +__all__ = ("deprecated", "warn_deprecation", "check_for_deprecated_parameters") + + +T = TypeVar("T") +P = ParamSpec("P") +DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"] + + +def warn_deprecation( +    version: str, +    deprecated_name: str, +    kind: DeprecatedKind, +    *, +    removal_in: str | None = None, +    alternative: str | None = None, +    info: str | None = None, +    pending: bool = False, +) -> None: +    """Warn about a call to a (soon to be) deprecated function. + +    Args: +        version: Polyfactory version where the deprecation will occur. +        deprecated_name: Name of the deprecated function. +        removal_in: Polyfactory version where the deprecated function will be removed. +        alternative: Name of a function that should be used instead. +        info: Additional information. +        pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. +        kind: Type of the deprecated thing. +    """ +    parts = [] + +    if kind == "import": +        access_type = "Import of" +    elif kind in {"function", "method"}: +        access_type = "Call to" +    else: +        access_type = "Use of" + +    if pending: +        parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") +    else: +        parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") + +    parts.extend( +        ( +            f"Deprecated in polyfactory {version}", +            f"This {kind} will be removed in {removal_in or 'the next major version'}", +        )  # noqa: COM812 +    ) +    if alternative: +        parts.append(f"Use {alternative!r} instead") + +    if info: +        parts.append(info) + +    text = ". ".join(parts) +    warning_class = PendingDeprecationWarning if pending else DeprecationWarning + +    warn(text, warning_class, stacklevel=2) + + +def deprecated( +    version: str, +    *, +    removal_in: str | None = None, +    alternative: str | None = None, +    info: str | None = None, +    pending: bool = False, +    kind: Literal["function", "method", "classmethod", "property"] | None = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: +    """Create a decorator wrapping a function, method or property with a warning call about a (pending) deprecation. + +    Args: +        version: Polyfactory version where the deprecation will occur. +        removal_in: Polyfactory version where the deprecated function will be removed. +        alternative: Name of a function that should be used instead. +        info: Additional information. +        pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. +        kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure +            out if it's a function or method. + +    Returns: +        A decorator wrapping the function call with a warning +    """ + +    def decorator(func: Callable[P, T]) -> Callable[P, T]: +        @wraps(func) +        def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: +            warn_deprecation( +                version=version, +                deprecated_name=func.__name__, +                info=info, +                alternative=alternative, +                pending=pending, +                removal_in=removal_in, +                kind=kind or ("method" if inspect.ismethod(func) else "function"), +            ) +            return func(*args, **kwargs) + +        return wrapped + +    return decorator + + +def check_for_deprecated_parameters( +    version: str, +    *, +    parameters: tuple[tuple[str, Any], ...], +    default_value: Any = None, +    removal_in: str | None = None, +    alternative: str | None = None, +    info: str | None = None, +    pending: bool = False, +) -> None: +    """Warn about a call to a (soon to be) deprecated argument to a function. + +    Args: +        version: Polyfactory version where the deprecation will occur. +        parameters: Parameters to trigger warning if used. +        default_value: Default value for parameter to detect if supplied or not. +        removal_in: Polyfactory version where the deprecated function will be removed. +        alternative: Name of a function that should be used instead. +        info: Additional information. +        pending: Use ``PendingDeprecationWarning`` instead of ``DeprecationWarning``. +        kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure +            out if it's a function or method. +    """ +    for parameter_name, value in parameters: +        if value == default_value: +            continue + +        warn_deprecation( +            version=version, +            deprecated_name=parameter_name, +            info=info, +            alternative=alternative, +            pending=pending, +            removal_in=removal_in, +            kind="parameter", +        ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py b/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py new file mode 100644 index 0000000..f9924bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/helpers.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Mapping + +from typing_extensions import TypeAliasType, get_args, get_origin + +from polyfactory.constants import TYPE_MAPPING +from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_safe_subclass, is_union +from polyfactory.utils.types import NoneType + +if TYPE_CHECKING: +    from random import Random +    from typing import Sequence + + +def unwrap_new_type(annotation: Any) -> Any: +    """Return base type if given annotation is a type derived with NewType, otherwise annotation. + +    :param annotation: A type annotation, possibly one created using 'types.NewType' + +    :returns: The unwrapped annotation. +    """ +    while is_new_type(annotation): +        annotation = annotation.__supertype__ + +    return annotation + + +def unwrap_union(annotation: Any, random: Random) -> Any: +    """Unwraps union types - recursively. + +    :param annotation: A type annotation, possibly a type union. +    :param random: An instance of random.Random. +    :returns: A type annotation +    """ +    while is_union(annotation): +        args = list(get_args(annotation)) +        annotation = random.choice(args) +    return annotation + + +def unwrap_optional(annotation: Any) -> Any: +    """Unwraps optional union types - recursively. + +    :param annotation: A type annotation, possibly an optional union. + +    :returns: A type annotation +    """ +    while is_optional(annotation): +        annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None)) +    return annotation + + +def unwrap_annotation(annotation: Any, random: Random) -> Any: +    """Unwraps an annotation. + +    :param annotation: A type annotation. +    :param random: An instance of random.Random. + +    :returns: The unwrapped annotation. + +    """ +    while ( +        is_optional(annotation) +        or is_union(annotation) +        or is_new_type(annotation) +        or is_annotated(annotation) +        or isinstance(annotation, TypeAliasType) +    ): +        if is_new_type(annotation): +            annotation = unwrap_new_type(annotation) +        elif is_optional(annotation): +            annotation = unwrap_optional(annotation) +        elif is_annotated(annotation): +            annotation = unwrap_annotated(annotation, random=random)[0] +        elif isinstance(annotation, TypeAliasType): +            annotation = annotation.__value__ +        else: +            annotation = unwrap_union(annotation, random=random) + +    return annotation + + +def flatten_annotation(annotation: Any) -> list[Any]: +    """Flattens an annotation. + +    :param annotation: A type annotation. + +    :returns: The flattened annotations. +    """ +    flat = [] +    if is_new_type(annotation): +        flat.extend(flatten_annotation(unwrap_new_type(annotation))) +    elif is_optional(annotation): +        for a in get_args(annotation): +            flat.extend(flatten_annotation(a)) +    elif is_annotated(annotation): +        flat.extend(flatten_annotation(get_args(annotation)[0])) +    elif is_union(annotation): +        for a in get_args(annotation): +            flat.extend(flatten_annotation(a)) +    else: +        flat.append(annotation) + +    return flat + + +def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]: +    """Unwrap the annotation and return any type args. + +    :param annotation: A type annotation +    :param random: An instance of random.Random. + +    :returns: A tuple of type args. + +    """ + +    return get_args(unwrap_annotation(annotation=annotation, random=random)) + + +def unwrap_annotated(annotation: Any, random: Random) -> tuple[Any, list[Any]]: +    """Unwrap an annotated type and return a tuple of type args and optional metadata. + +    :param annotation: An annotated type annotation +    :param random: An instance of random.Random. + +    :returns: A tuple of type args. + +    """ +    args = [arg for arg in get_args(annotation) if arg is not None] +    return unwrap_annotation(args[0], random=random), args[1:] + + +def normalize_annotation(annotation: Any, random: Random) -> Any: +    """Normalize an annotation. + +    :param annotation: A type annotation. + +    :returns: A normalized type annotation. + +    """ +    if is_new_type(annotation): +        annotation = unwrap_new_type(annotation) + +    if is_annotated(annotation): +        annotation = unwrap_annotated(annotation, random=random)[0] + +    # we have to maintain compatibility with the older non-subscriptable typings. +    if sys.version_info <= (3, 9):  # pragma: no cover +        return annotation + +    origin = get_origin(annotation) or annotation + +    if origin in TYPE_MAPPING: +        origin = TYPE_MAPPING[origin] + +    if args := get_args(annotation): +        args = tuple(normalize_annotation(arg, random=random) for arg in args) +        return origin[args] if origin is not type else annotation + +    return origin + + +def get_annotation_metadata(annotation: Any) -> Sequence[Any]: +    """Get the metadata in the annotation. + +    :param annotation: A type annotation. + +    :returns: The metadata. +    """ + +    return get_args(annotation)[1:] + + +def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]: +    """Get the collection type from the annotation. + +    :param annotation: A type annotation. + +    :returns: The collection type. +    """ + +    if is_safe_subclass(annotation, list): +        return list +    if is_safe_subclass(annotation, Mapping): +        return dict +    if is_safe_subclass(annotation, tuple): +        return tuple +    if is_safe_subclass(annotation, set): +        return set +    if is_safe_subclass(annotation, frozenset): +        return frozenset + +    msg = f"Unknown collection type - {annotation}" +    raise ValueError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py b/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py new file mode 100644 index 0000000..6fc3971 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/model_coverage.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence +from typing import AbstractSet, Any, Generic, Set, TypeVar, cast + +from typing_extensions import ParamSpec + +from polyfactory.exceptions import ParameterException + + +class CoverageContainerBase(ABC): +    """Base class for coverage container implementations. + +    A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and +    track a "done" state to indicate that all coverage examples have been generated. +    """ + +    @abstractmethod +    def next_value(self) -> Any: +        """Provide the next value""" +        ... + +    @abstractmethod +    def is_done(self) -> bool: +        """Indicate if this container has provided every coverage example it has""" +        ... + + +T = TypeVar("T") + + +class CoverageContainer(CoverageContainerBase, Generic[T]): +    """A coverage container that wraps a collection of values. + +    When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate +    examples to be returned (wraps around). + +    If there are any coverage containers within the given collection, the values from those containers are essentially merged +    into the parent container. +    """ + +    def __init__(self, instances: Iterable[T]) -> None: +        self._pos = 0 +        self._instances = list(instances) +        if not self._instances: +            msg = "CoverageContainer must have at least one instance" +            raise ValueError(msg) + +    def next_value(self) -> T: +        value = self._instances[self._pos % len(self._instances)] +        if isinstance(value, CoverageContainerBase): +            result = value.next_value() +            if value.is_done(): +                # Only move onto the next instance if the sub-container is done +                self._pos += 1 +            return cast(T, result) + +        self._pos += 1 +        return value + +    def is_done(self) -> bool: +        return self._pos >= len(self._instances) + +    def __repr__(self) -> str: +        return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})" + + +P = ParamSpec("P") + + +class CoverageContainerCallable(CoverageContainerBase, Generic[T]): +    """A coverage container that wraps a callable. + +    When calling ``next_value()`` the wrapped callable is called to provide a value. +    """ + +    def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: +        self._func = func +        self._args = args +        self._kwargs = kwargs + +    def next_value(self) -> T: +        try: +            return self._func(*self._args, **self._kwargs) +        except Exception as e:  # noqa: BLE001 +            msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type." +            raise ParameterException(msg) from e + +    def is_done(self) -> bool: +        return True + + +def _resolve_next(unresolved: Any) -> tuple[Any, bool]:  # noqa: C901 +    if isinstance(unresolved, CoverageContainerBase): +        result, done = _resolve_next(unresolved.next_value()) +        return result, unresolved.is_done() and done + +    if isinstance(unresolved, Mapping): +        result = {} +        done_status = True +        for key, value in unresolved.items(): +            val_resolved, val_done = _resolve_next(value) +            key_resolved, key_done = _resolve_next(key) +            result[key_resolved] = val_resolved +            done_status = done_status and val_done and key_done +        return result, done_status + +    if isinstance(unresolved, (tuple, MutableSequence)): +        result = [] +        done_status = True +        for value in unresolved: +            resolved, done = _resolve_next(value) +            result.append(resolved) +            done_status = done_status and done +        if isinstance(unresolved, tuple): +            result = tuple(result) +        return result, done_status + +    if isinstance(unresolved, Set): +        result = type(unresolved)() +        done_status = True +        for value in unresolved: +            resolved, done = _resolve_next(value) +            result.add(resolved) +            done_status = done_status and done +        return result, done_status + +    if issubclass(type(unresolved), AbstractSet): +        result = type(unresolved)() +        done_status = True +        resolved_values = [] +        for value in unresolved: +            resolved, done = _resolve_next(value) +            resolved_values.append(resolved) +            done_status = done_status and done +        return result.union(resolved_values), done_status + +    return unresolved, True + + +def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]: +    done = False +    while not done: +        resolved, done = _resolve_next(kwargs) +        yield resolved diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py b/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py new file mode 100644 index 0000000..895e380 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/predicates.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from inspect import isclass +from typing import Any, Literal, NewType, Optional, TypeVar, get_args + +from typing_extensions import Annotated, NotRequired, ParamSpec, Required, TypeGuard, _AnnotatedAlias, get_origin + +from polyfactory.constants import TYPE_MAPPING +from polyfactory.utils.types import UNION_TYPES, NoneType + +P = ParamSpec("P") +T = TypeVar("T") + + +def is_safe_subclass(annotation: Any, class_or_tuple: type[T] | tuple[type[T], ...]) -> "TypeGuard[type[T]]": +    """Determine whether a given annotation is a subclass of a give type + +    :param annotation: A type annotation. +    :param class_or_tuple: A potential super class or classes. + +    :returns: A typeguard +    """ +    origin = get_type_origin(annotation) +    if not origin and not isclass(annotation): +        return False +    try: +        return issubclass(origin or annotation, class_or_tuple) +    except TypeError:  # pragma: no cover +        return False + + +def is_any(annotation: Any) -> "TypeGuard[Any]": +    """Determine whether a given annotation is 'typing.Any'. + +    :param annotation: A type annotation. + +    :returns: A typeguard. +    """ +    return ( +        annotation is Any +        or getattr(annotation, "_name", "") == "typing.Any" +        or (get_origin(annotation) in UNION_TYPES and Any in get_args(annotation)) +    ) + + +def is_dict_key_or_value_type(annotation: Any) -> "TypeGuard[Any]": +    """Determine whether a given annotation is a valid dict key or value type: +    ``typing.KT`` or ``typing.VT``. + +    :returns: A typeguard. +    """ +    return str(annotation) in {"~KT", "~VT"} + + +def is_union(annotation: Any) -> "TypeGuard[Any]": +    """Determine whether a given annotation is 'typing.Union'. + +    :param annotation: A type annotation. + +    :returns: A typeguard. +    """ +    return get_type_origin(annotation) in UNION_TYPES + + +def is_optional(annotation: Any) -> "TypeGuard[Any | None]": +    """Determine whether a given annotation is 'typing.Optional'. + +    :param annotation: A type annotation. + +    :returns: A typeguard. +    """ +    origin = get_type_origin(annotation) +    return origin is Optional or (get_origin(annotation) in UNION_TYPES and NoneType in get_args(annotation)) + + +def is_literal(annotation: Any) -> bool: +    """Determine whether a given annotation is 'typing.Literal'. + +    :param annotation: A type annotation. + +    :returns: A boolean. +    """ +    return ( +        get_type_origin(annotation) is Literal +        or repr(annotation).startswith("typing.Literal") +        or repr(annotation).startswith("typing_extensions.Literal") +    ) + + +def is_new_type(annotation: Any) -> "TypeGuard[type[NewType]]": +    """Determine whether a given annotation is 'typing.NewType'. + +    :param annotation: A type annotation. + +    :returns: A typeguard. +    """ +    return hasattr(annotation, "__supertype__") + + +def is_annotated(annotation: Any) -> bool: +    """Determine whether a given annotation is 'typing.Annotated'. + +    :param annotation: A type annotation. + +    :returns: A boolean. +    """ +    return get_origin(annotation) is Annotated or ( +        isinstance(annotation, _AnnotatedAlias) and getattr(annotation, "__args__", None) is not None +    ) + + +def is_any_annotated(annotation: Any) -> bool: +    """Determine whether any of the types in the given annotation is +    `typing.Annotated`. + +    :param annotation: A type annotation. + +    :returns: A boolean +    """ + +    return any(is_annotated(arg) or hasattr(arg, "__args__") and is_any_annotated(arg) for arg in get_args(annotation)) + + +def get_type_origin(annotation: Any) -> Any: +    """Get the type origin of an annotation - safely. + +    :param annotation: A type annotation. + +    :returns: A type annotation. +    """ +    origin = get_origin(annotation) +    if origin in (Annotated, Required, NotRequired): +        origin = get_args(annotation)[0] +    return mapped_type if (mapped_type := TYPE_MAPPING.get(origin)) else origin diff --git a/venv/lib/python3.11/site-packages/polyfactory/utils/types.py b/venv/lib/python3.11/site-packages/polyfactory/utils/types.py new file mode 100644 index 0000000..413f5dd --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/utils/types.py @@ -0,0 +1,12 @@ +from typing import Union + +try: +    from types import NoneType, UnionType + +    UNION_TYPES = {UnionType, Union} +except ImportError: +    UNION_TYPES = {Union} + +    NoneType = type(None)  # type: ignore[misc,assignment] + +__all__ = ("NoneType", "UNION_TYPES") diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__init__.py diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pycBinary files differ new file mode 100644 index 0000000..c95bde0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pycBinary files differ new file mode 100644 index 0000000..b99a526 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/complex_types.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pycBinary files differ new file mode 100644 index 0000000..bb33c8a --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_collections.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pycBinary files differ new file mode 100644 index 0000000..f649c36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_dates.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pycBinary files differ new file mode 100644 index 0000000..926ca52 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_numbers.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pycBinary files differ new file mode 100644 index 0000000..5808e81 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_path.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pycBinary files differ new file mode 100644 index 0000000..53efa2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_strings.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pycBinary files differ new file mode 100644 index 0000000..4dd4ad5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_url.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pycBinary files differ new file mode 100644 index 0000000..e653a72 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/constrained_uuid.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pycBinary files differ new file mode 100644 index 0000000..97787ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/primitives.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pycBinary files differ new file mode 100644 index 0000000..319510e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/__pycache__/regex.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py new file mode 100644 index 0000000..2706891 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/complex_types.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, AbstractSet, Any, Iterable, MutableMapping, MutableSequence, Set, Tuple, cast + +from typing_extensions import is_typeddict + +from polyfactory.constants import INSTANTIABLE_TYPE_MAPPING, PY_38 +from polyfactory.field_meta import FieldMeta +from polyfactory.utils.model_coverage import CoverageContainer + +if TYPE_CHECKING: +    from polyfactory.factories.base import BaseFactory + + +def handle_collection_type(field_meta: FieldMeta, container_type: type, factory: type[BaseFactory[Any]]) -> Any: +    """Handle generation of container types recursively. + +    :param container_type: A type that can accept type arguments. +    :param factory: A factory. +    :param field_meta: A field meta instance. + +    :returns: A built result. +    """ + +    if PY_38 and container_type in INSTANTIABLE_TYPE_MAPPING: +        container_type = INSTANTIABLE_TYPE_MAPPING[container_type]  # type: ignore[assignment] + +    container = container_type() +    if not field_meta.children: +        return container + +    if issubclass(container_type, MutableMapping) or is_typeddict(container_type): +        for key_field_meta, value_field_meta in cast( +            Iterable[Tuple[FieldMeta, FieldMeta]], +            zip(field_meta.children[::2], field_meta.children[1::2]), +        ): +            key = factory.get_field_value(key_field_meta) +            value = factory.get_field_value(value_field_meta) +            container[key] = value +        return container + +    if issubclass(container_type, MutableSequence): +        container.extend([factory.get_field_value(subfield_meta) for subfield_meta in field_meta.children]) +        return container + +    if issubclass(container_type, Set): +        for subfield_meta in field_meta.children: +            container.add(factory.get_field_value(subfield_meta)) +        return container + +    if issubclass(container_type, AbstractSet): +        return container.union(handle_collection_type(field_meta, set, factory)) + +    if issubclass(container_type, tuple): +        return container_type(map(factory.get_field_value, field_meta.children)) + +    msg = f"Unsupported container type: {container_type}" +    raise NotImplementedError(msg) + + +def handle_collection_type_coverage( +    field_meta: FieldMeta, +    container_type: type, +    factory: type[BaseFactory[Any]], +) -> Any: +    """Handle coverage generation of container types recursively. + +    :param container_type: A type that can accept type arguments. +    :param factory: A factory. +    :param field_meta: A field meta instance. + +    :returns: An unresolved built result. +    """ +    container = container_type() +    if not field_meta.children: +        return container + +    if issubclass(container_type, MutableMapping) or is_typeddict(container_type): +        for key_field_meta, value_field_meta in cast( +            Iterable[Tuple[FieldMeta, FieldMeta]], +            zip(field_meta.children[::2], field_meta.children[1::2]), +        ): +            key = CoverageContainer(factory.get_field_value_coverage(key_field_meta)) +            value = CoverageContainer(factory.get_field_value_coverage(value_field_meta)) +            container[key] = value +        return container + +    if issubclass(container_type, MutableSequence): +        container_instance = container_type() +        for subfield_meta in field_meta.children: +            container_instance.extend(factory.get_field_value_coverage(subfield_meta)) + +        return container_instance + +    if issubclass(container_type, Set): +        set_instance = container_type() +        for subfield_meta in field_meta.children: +            set_instance = set_instance.union(factory.get_field_value_coverage(subfield_meta)) + +        return set_instance + +    if issubclass(container_type, AbstractSet): +        return container.union(handle_collection_type_coverage(field_meta, set, factory)) + +    if issubclass(container_type, tuple): +        return container_type( +            CoverageContainer(factory.get_field_value_coverage(subfield_meta)) for subfield_meta in field_meta.children +        ) + +    msg = f"Unsupported container type: {container_type}" +    raise NotImplementedError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py new file mode 100644 index 0000000..405d02d --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_collections.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.field_meta import FieldMeta + +if TYPE_CHECKING: +    from polyfactory.factories.base import BaseFactory + +T = TypeVar("T", list, set, frozenset) + + +def handle_constrained_collection( +    collection_type: Callable[..., T], +    factory: type[BaseFactory[Any]], +    field_meta: FieldMeta, +    item_type: Any, +    max_items: int | None = None, +    min_items: int | None = None, +    unique_items: bool = False, +) -> T: +    """Generate a constrained list or set. + +    :param collection_type: A type that can accept type arguments. +    :param factory: A factory. +    :param field_meta: A field meta instance. +    :param item_type: Type of the collection items. +    :param max_items: Maximal number of items. +    :param min_items: Minimal number of items. +    :param unique_items: Whether the items should be unique. + +    :returns: A collection value. +    """ +    min_items = abs(min_items if min_items is not None else (max_items or 0)) +    max_items = abs(max_items if max_items is not None else min_items + 1) + +    if max_items < min_items: +        msg = "max_items must be larger or equal to min_items" +        raise ParameterException(msg) + +    collection: set[T] | list[T] = set() if (collection_type in (frozenset, set) or unique_items) else [] + +    try: +        length = factory.__random__.randint(min_items, max_items) or 1 +        while len(collection) < length: +            value = factory.get_field_value(field_meta) +            if isinstance(collection, set): +                collection.add(value) +            else: +                collection.append(value) +        return collection_type(collection) +    except TypeError as e: +        msg = f"cannot generate a constrained collection of type: {item_type}" +        raise ParameterException(msg) from e + + +def handle_constrained_mapping( +    factory: type[BaseFactory[Any]], +    field_meta: FieldMeta, +    max_items: int | None = None, +    min_items: int | None = None, +) -> Mapping[Any, Any]: +    """Generate a constrained mapping. + +    :param factory: A factory. +    :param field_meta: A field meta instance. +    :param max_items: Maximal number of items. +    :param min_items: Minimal number of items. + +    :returns: A mapping instance. +    """ +    min_items = abs(min_items if min_items is not None else (max_items or 0)) +    max_items = abs(max_items if max_items is not None else min_items + 1) + +    if max_items < min_items: +        msg = "max_items must be larger or equal to min_items" +        raise ParameterException(msg) + +    length = factory.__random__.randint(min_items, max_items) or 1 + +    collection: dict[Any, Any] = {} + +    children = cast(List[FieldMeta], field_meta.children) +    key_field_meta = children[0] +    value_field_meta = children[1] +    while len(collection) < length: +        key = factory.get_field_value(key_field_meta) +        value = factory.get_field_value(value_field_meta) +        collection[key] = value + +    return collection diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py new file mode 100644 index 0000000..4e92601 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_dates.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta, timezone, tzinfo +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: +    from faker import Faker + + +def handle_constrained_date( +    faker: Faker, +    ge: date | None = None, +    gt: date | None = None, +    le: date | None = None, +    lt: date | None = None, +    tz: tzinfo = timezone.utc, +) -> date: +    """Generates a date value fulfilling the expected constraints. + +    :param faker: An instance of faker. +    :param lt: Less than value. +    :param le: Less than or equal value. +    :param gt: Greater than value. +    :param ge: Greater than or equal value. +    :param tz: A timezone. + +    :returns: A date instance. +    """ +    start_date = datetime.now(tz=tz).date() - timedelta(days=100) +    if ge: +        start_date = ge +    elif gt: +        start_date = gt + timedelta(days=1) + +    end_date = datetime.now(tz=timezone.utc).date() + timedelta(days=100) +    if le: +        end_date = le +    elif lt: +        end_date = lt - timedelta(days=1) + +    return cast("date", faker.date_between(start_date=start_date, end_date=end_date)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py new file mode 100644 index 0000000..23516ce --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_numbers.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +from decimal import Decimal +from sys import float_info +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer + +if TYPE_CHECKING: +    from random import Random + +T = TypeVar("T", Decimal, int, float) + + +class NumberGeneratorProtocol(Protocol[T]): +    """Protocol for custom callables used to generate numerical values""" + +    def __call__(self, random: "Random", minimum: T | None = None, maximum: T | None = None) -> T: +        """Signature of the callable. + +        :param random: An instance of random. +        :param minimum: A minimum value. +        :param maximum: A maximum value. +        :return: The generated numeric value. +        """ +        ... + + +def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: +    """Return True if two floats are almost equal + +    :param value_1: A float value. +    :param value_2: A float value. +    :param delta: A minimal delta. + +    :returns: Boolean dictating whether the floats can be considered equal - given python's problematic comparison of floats. +    """ +    return abs(value_1 - value_2) <= delta + + +def is_multiply_of_multiple_of_in_range( +    minimum: T, +    maximum: T, +    multiple_of: T, +) -> bool: +    """Determine if at least one multiply of `multiple_of` lies in the given range. + +    :param minimum: T: A minimum value. +    :param maximum: T: A maximum value. +    :param multiple_of: T: A value to use as a base for multiplication. + +    :returns: Boolean dictating whether at least one multiply of `multiple_of` lies in the given range between minimum and maximum. +    """ + +    # if the range has infinity on one of its ends then infinite number of multipliers +    # can be found within the range + +    # if we were given floats and multiple_of is really close to zero then it doesn't make sense +    # to continue trying to check the range +    if ( +        isinstance(minimum, float) +        and isinstance(multiple_of, float) +        and minimum / multiple_of in [float("+inf"), float("-inf")] +    ): +        return False + +    multiplier = round(minimum / multiple_of) +    step = 1 if multiple_of > 0 else -1 + +    # since rounding can go either up or down we may end up in a situation when +    # minimum is less or equal to `multiplier * multiple_of` +    # or when it is greater than `multiplier * multiple_of` +    # (in this case minimum is less than `(multiplier + 1)* multiple_of`). So we need to check +    # that any of two values is inside the given range. ASCII graphic below explain this +    # +    #                minimum +    # -----------------+-------+-----------------------------------+---------------------------- +    # +    # +    #                                minimum +    # -------------------------+--------+--------------------------+---------------------------- +    # +    # since `multiple_of` can be a negative number adding +1 to `multiplier` drives `(multiplier + 1) * multiple_of`` +    # away from `minimum` to the -infinity. It looks like this: +    #                                                                               minimum +    # -----------------------+--------------------------------+------------------------+-------- +    # +    # so for negative `multiple_of` we want to subtract 1 from multiplier +    for multiply in [multiplier * multiple_of, (multiplier + step) * multiple_of]: +        multiply_float = float(multiply) +        if ( +            almost_equal_floats(multiply_float, float(minimum)) +            or almost_equal_floats(multiply_float, float(maximum)) +            or minimum < multiply < maximum +        ): +            return True + +    return False + + +def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool: +    """Determine whether a given value passes the pydantic multiple_of validation. + +    :param value: A numeric value. +    :param multiple_of: Another numeric value. + +    :returns: Boolean dictating whether value is a multiple of value. + +    """ +    if multiple_of == 0: +        return True +    mod = float(value) / float(multiple_of) % 1 +    return almost_equal_floats(mod, 0.0) or almost_equal_floats(mod, 1.0) + + +def get_increment(t_type: type[T]) -> T: +    """Get a small increment base to add to constrained values, i.e. lt/gt entries. + +    :param t_type: A value of type T. + +    :returns: An increment T. +    """ +    values: dict[Any, Any] = { +        int: 1, +        float: float_info.epsilon, +        Decimal: Decimal("0.001"), +    } +    return cast("T", values[t_type]) + + +def get_value_or_none( +    t_type: type[T], +    lt: T | None = None, +    le: T | None = None, +    gt: T | None = None, +    ge: T | None = None, +) -> tuple[T | None, T | None]: +    """Return an optional value. + +    :param equal_value: An GE/LE value. +    :param constrained: An GT/LT value. +    :param increment: increment + +    :returns: Optional T. +    """ +    if ge is not None: +        minimum_value = ge +    elif gt is not None: +        minimum_value = gt + get_increment(t_type) +    else: +        minimum_value = None + +    if le is not None: +        maximum_value = le +    elif lt is not None: +        maximum_value = lt - get_increment(t_type) +    else: +        maximum_value = None +    return minimum_value, maximum_value + + +def get_constrained_number_range( +    t_type: type[T], +    random: Random, +    lt: T | None = None, +    le: T | None = None, +    gt: T | None = None, +    ge: T | None = None, +    multiple_of: T | None = None, +) -> tuple[T | None, T | None]: +    """Return the minimum and maximum values given a field_meta's constraints. + +    :param t_type: A primitive constructor - int, float or Decimal. +    :param random: An instance of Random. +    :param lt: Less than value. +    :param le: Less than or equal value. +    :param gt: Greater than value. +    :param ge: Greater than or equal value. +    :param multiple_of: Multiple of value. + +    :returns: a tuple of optional minimum and maximum values. +    """ +    seed = t_type(random.random() * 10) +    minimum, maximum = get_value_or_none(lt=lt, le=le, gt=gt, ge=ge, t_type=t_type) + +    if minimum is not None and maximum is not None and maximum < minimum: +        msg = "maximum value must be greater than minimum value" +        raise ParameterException(msg) + +    if multiple_of is None: +        if minimum is not None and maximum is None: +            return ( +                (minimum, seed) if minimum == 0 else (minimum, minimum + seed) +            )  # pyright: ignore[reportGeneralTypeIssues] +        if maximum is not None and minimum is None: +            return maximum - seed, maximum +    else: +        if multiple_of == 0.0:  # TODO: investigate @guacs # noqa: PLR2004, FIX002 +            msg = "multiple_of can not be zero" +            raise ParameterException(msg) +        if ( +            minimum is not None +            and maximum is not None +            and not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of) +        ): +            msg = "given range should include at least one multiply of multiple_of" +            raise ParameterException(msg) + +    return minimum, maximum + + +def generate_constrained_number( +    random: Random, +    minimum: T | None, +    maximum: T | None, +    multiple_of: T | None, +    method: "NumberGeneratorProtocol[T]", +) -> T: +    """Generate a constrained number, output depends on the passed in callbacks. + +    :param random: An instance of random. +    :param minimum: A minimum value. +    :param maximum: A maximum value. +    :param multiple_of: A multiple of value. +    :param method: A function that generates numbers of type T. + +    :returns: A value of type T. +    """ +    if minimum is None or maximum is None: +        return multiple_of if multiple_of is not None else method(random=random) +    if multiple_of is None: +        return method(random=random, minimum=minimum, maximum=maximum) +    if multiple_of >= minimum: +        return multiple_of +    result = minimum +    while not passes_pydantic_multiple_validator(result, multiple_of): +        result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of +    return result + + +def handle_constrained_int( +    random: Random, +    multiple_of: int | None = None, +    gt: int | None = None, +    ge: int | None = None, +    lt: int | None = None, +    le: int | None = None, +) -> int: +    """Handle constrained integers. + +    :param random: An instance of Random. +    :param lt: Less than value. +    :param le: Less than or equal value. +    :param gt: Greater than value. +    :param ge: Greater than or equal value. +    :param multiple_of: Multiple of value. + +    :returns: An integer. + +    """ + +    minimum, maximum = get_constrained_number_range( +        gt=gt, +        ge=ge, +        lt=lt, +        le=le, +        t_type=int, +        multiple_of=multiple_of, +        random=random, +    ) +    return generate_constrained_number( +        random=random, +        minimum=minimum, +        maximum=maximum, +        multiple_of=multiple_of, +        method=create_random_integer, +    ) + + +def handle_constrained_float( +    random: Random, +    multiple_of: float | None = None, +    gt: float | None = None, +    ge: float | None = None, +    lt: float | None = None, +    le: float | None = None, +) -> float: +    """Handle constrained floats. + +    :param random: An instance of Random. +    :param lt: Less than value. +    :param le: Less than or equal value. +    :param gt: Greater than value. +    :param ge: Greater than or equal value. +    :param multiple_of: Multiple of value. + +    :returns: A float. +    """ + +    minimum, maximum = get_constrained_number_range( +        gt=gt, +        ge=ge, +        lt=lt, +        le=le, +        t_type=float, +        multiple_of=multiple_of, +        random=random, +    ) + +    return generate_constrained_number( +        random=random, +        minimum=minimum, +        maximum=maximum, +        multiple_of=multiple_of, +        method=create_random_float, +    ) + + +def validate_max_digits( +    max_digits: int, +    minimum: Decimal | None, +    decimal_places: int | None, +) -> None: +    """Validate that max digits is greater than minimum and decimal places. + +    :param max_digits: The maximal number of digits for the decimal. +    :param minimum: Minimal value. +    :param decimal_places: Number of decimal places + +    :returns: 'None' + +    """ +    if max_digits <= 0: +        msg = "max_digits must be greater than 0" +        raise ParameterException(msg) + +    if minimum is not None: +        min_str = str(minimum).split(".")[1] if "." in str(minimum) else str(minimum) + +        if max_digits <= len(min_str): +            msg = "minimum is greater than max_digits" +            raise ParameterException(msg) + +    if decimal_places is not None and max_digits <= decimal_places: +        msg = "max_digits must be greater than decimal places" +        raise ParameterException(msg) + + +def handle_decimal_length( +    generated_decimal: Decimal, +    decimal_places: int | None, +    max_digits: int | None, +) -> Decimal: +    """Handle the length of the decimal. + +    :param generated_decimal: A decimal value. +    :param decimal_places: Number of decimal places. +    :param max_digits: Maximal number of digits. + +    """ +    string_number = str(generated_decimal) +    sign = "-" if "-" in string_number else "+" +    string_number = string_number.replace("-", "") +    whole_numbers, decimals = string_number.split(".") + +    if ( +        max_digits is not None +        and decimal_places is not None +        and len(whole_numbers) + decimal_places > max_digits +        or (max_digits is None or decimal_places is None) +        and max_digits is not None +    ): +        max_decimals = max_digits - len(whole_numbers) +    elif max_digits is not None: +        max_decimals = decimal_places  # type: ignore[assignment] +    else: +        max_decimals = cast("int", decimal_places) + +    if max_decimals < 0:  # pyright: ignore[reportOptionalOperand] +        return Decimal(sign + whole_numbers[:max_decimals]) + +    decimals = decimals[:max_decimals] +    return Decimal(sign + whole_numbers + "." + decimals[:decimal_places]) + + +def handle_constrained_decimal( +    random: Random, +    multiple_of: Decimal | None = None, +    decimal_places: int | None = None, +    max_digits: int | None = None, +    gt: Decimal | None = None, +    ge: Decimal | None = None, +    lt: Decimal | None = None, +    le: Decimal | None = None, +) -> Decimal: +    """Handle a constrained decimal. + +    :param random: An instance of Random. +    :param multiple_of: Multiple of value. +    :param decimal_places: Number of decimal places. +    :param max_digits: Maximal number of digits. +    :param lt: Less than value. +    :param le: Less than or equal value. +    :param gt: Greater than value. +    :param ge: Greater than or equal value. + +    :returns: A decimal. + +    """ + +    minimum, maximum = get_constrained_number_range( +        gt=gt, +        ge=ge, +        lt=lt, +        le=le, +        multiple_of=multiple_of, +        t_type=Decimal, +        random=random, +    ) + +    if max_digits is not None: +        validate_max_digits(max_digits=max_digits, minimum=minimum, decimal_places=decimal_places) + +    generated_decimal = generate_constrained_number( +        random=random, +        minimum=minimum, +        maximum=maximum, +        multiple_of=multiple_of, +        method=create_random_decimal, +    ) + +    if max_digits is not None or decimal_places is not None: +        return handle_decimal_length( +            generated_decimal=generated_decimal, +            max_digits=max_digits, +            decimal_places=decimal_places, +        ) + +    return generated_decimal diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py new file mode 100644 index 0000000..debaf86 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_path.py @@ -0,0 +1,13 @@ +from os.path import realpath +from pathlib import Path +from typing import Literal, cast + +from faker import Faker + + +def handle_constrained_path(constraint: Literal["file", "dir", "new"], faker: Faker) -> Path: +    if constraint == "new": +        return cast("Path", faker.file_path(depth=1, category=None, extension=None)) +    if constraint == "file": +        return Path(realpath(__file__)) +    return Path(realpath(__file__)).parent diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py new file mode 100644 index 0000000..c7da72b --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_strings.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Pattern, TypeVar, Union, cast + +from polyfactory.exceptions import ParameterException +from polyfactory.value_generators.primitives import create_random_bytes, create_random_string +from polyfactory.value_generators.regex import RegexFactory + +T = TypeVar("T", bound=Union[bytes, str]) + +if TYPE_CHECKING: +    from random import Random + + +def _validate_length( +    min_length: int | None = None, +    max_length: int | None = None, +) -> None: +    """Validate the length parameters make sense. + +    :param min_length: Minimum length. +    :param max_length: Maximum length. + +    :raises: ParameterException. + +    :returns: None. +    """ +    if min_length is not None and min_length < 0: +        msg = "min_length must be greater or equal to 0" +        raise ParameterException(msg) + +    if max_length is not None and max_length < 0: +        msg = "max_length must be greater or equal to 0" +        raise ParameterException(msg) + +    if max_length is not None and min_length is not None and max_length < min_length: +        msg = "max_length must be greater than min_length" +        raise ParameterException(msg) + + +def _generate_pattern( +    random: Random, +    pattern: str | Pattern, +    lower_case: bool = False, +    upper_case: bool = False, +    min_length: int | None = None, +    max_length: int | None = None, +) -> str: +    """Generate a regex. + +    :param random: An instance of random. +    :param pattern: A regex or string pattern. +    :param lower_case: Whether to lowercase the result. +    :param upper_case: Whether to uppercase the result. +    :param min_length: A minimum length. +    :param max_length: A maximum length. + +    :returns: A string matching the given pattern. +    """ +    regex_factory = RegexFactory(random=random) +    result = regex_factory(pattern) +    if min_length: +        while len(result) < min_length: +            result += regex_factory(pattern) + +    if max_length is not None and len(result) > max_length: +        result = result[:max_length] + +    if lower_case: +        result = result.lower() + +    if upper_case: +        result = result.upper() + +    return result + + +def handle_constrained_string_or_bytes( +    random: Random, +    t_type: Callable[[], T], +    lower_case: bool = False, +    upper_case: bool = False, +    min_length: int | None = None, +    max_length: int | None = None, +    pattern: str | Pattern | None = None, +) -> T: +    """Handle constrained string or bytes, for example - pydantic `constr` or `conbytes`. + +    :param random: An instance of random. +    :param t_type: A type (str or bytes) +    :param lower_case: Whether to lowercase the result. +    :param upper_case: Whether to uppercase the result. +    :param min_length: A minimum length. +    :param max_length: A maximum length. +    :param pattern: A regex or string pattern. + +    :returns: A value of type T. +    """ +    _validate_length(min_length=min_length, max_length=max_length) + +    if max_length == 0: +        return t_type() + +    if pattern: +        return cast( +            "T", +            _generate_pattern( +                random=random, +                pattern=pattern, +                lower_case=lower_case, +                upper_case=upper_case, +                min_length=min_length, +                max_length=max_length, +            ), +        ) + +    if t_type is str: +        return cast( +            "T", +            create_random_string( +                min_length=min_length, +                max_length=max_length, +                lower_case=lower_case, +                upper_case=upper_case, +                random=random, +            ), +        ) + +    return cast( +        "T", +        create_random_bytes( +            min_length=min_length, +            max_length=max_length, +            lower_case=lower_case, +            upper_case=upper_case, +            random=random, +        ), +    ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py new file mode 100644 index 0000000..d29555e --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_url.py @@ -0,0 +1,10 @@ +from polyfactory.field_meta import UrlConstraints + + +def handle_constrained_url(constraints: UrlConstraints) -> str: +    schema = (constraints.get("allowed_schemes") or ["http", "https"])[0] +    default_host = constraints.get("default_host") or "localhost" +    default_port = constraints.get("default_port") or 80 +    default_path = constraints.get("default_path") or "" + +    return f"{schema}://{default_host}:{default_port}{default_path}" diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py new file mode 100644 index 0000000..053f047 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/constrained_uuid.py @@ -0,0 +1,31 @@ +from typing import Literal, cast +from uuid import NAMESPACE_DNS, UUID, uuid1, uuid3, uuid5 + +from faker import Faker + +UUID_VERSION_1 = 1 +UUID_VERSION_3 = 3 +UUID_VERSION_4 = 4 +UUID_VERSION_5 = 5 + + +def handle_constrained_uuid(uuid_version: Literal[1, 3, 4, 5], faker: Faker) -> UUID: +    """Generate a UUID based on the version specified. + +    Args: +        uuid_version: The version of the UUID to generate. +        faker: The Faker instance to use. + +    Returns: +        The generated UUID. +    """ +    if uuid_version == UUID_VERSION_1: +        return uuid1() +    if uuid_version == UUID_VERSION_3: +        return uuid3(NAMESPACE_DNS, faker.pystr()) +    if uuid_version == UUID_VERSION_4: +        return cast("UUID", faker.uuid4()) +    if uuid_version == UUID_VERSION_5: +        return uuid5(NAMESPACE_DNS, faker.pystr()) +    msg = f"Unknown UUID version: {uuid_version}" +    raise ValueError(msg) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py new file mode 100644 index 0000000..2cf6b41 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/primitives.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from binascii import hexlify +from decimal import Decimal +from typing import TYPE_CHECKING + +if TYPE_CHECKING: +    from random import Random + + +def create_random_float( +    random: Random, +    minimum: Decimal | float | None = None, +    maximum: Decimal | float | None = None, +) -> float: +    """Generate a random float given the constraints. + +    :param random: An instance of random. +    :param minimum: A minimum value +    :param maximum: A maximum value. + +    :returns: A random float. +    """ +    if minimum is None: +        minimum = float(random.randint(0, 100)) if maximum is None else float(maximum) - 100.0 +    if maximum is None: +        maximum = float(minimum) + 1.0 * 2.0 if minimum >= 0 else float(minimum) + 1.0 / 2.0 +    return random.uniform(float(minimum), float(maximum)) + + +def create_random_integer(random: Random, minimum: int | None = None, maximum: int | None = None) -> int: +    """Generate a random int given the constraints. + +    :param random: An instance of random. +    :param minimum: A minimum value +    :param maximum: A maximum value. + +    :returns: A random integer. +    """ +    return round(create_random_float(random=random, minimum=minimum, maximum=maximum)) + + +def create_random_decimal( +    random: Random, +    minimum: Decimal | None = None, +    maximum: Decimal | None = None, +) -> Decimal: +    """Generate a random Decimal given the constraints. + +    :param random: An instance of random. +    :param minimum: A minimum value +    :param maximum: A maximum value. + +    :returns: A random decimal. +    """ +    return Decimal(str(create_random_float(random=random, minimum=minimum, maximum=maximum))) + + +def create_random_bytes( +    random: Random, +    min_length: int | None = None, +    max_length: int | None = None, +    lower_case: bool = False, +    upper_case: bool = False, +) -> bytes: +    """Generate a random bytes given the constraints. + +    :param random: An instance of random. +    :param min_length: A minimum length. +    :param max_length: A maximum length. +    :param lower_case: Whether to lowercase the result. +    :param upper_case: Whether to uppercase the result. + +    :returns: A random byte-string. +    """ +    if min_length is None: +        min_length = 0 +    if max_length is None: +        max_length = min_length + 1 * 2 + +    length = random.randint(min_length, max_length) +    result = b"" if length == 0 else hexlify(random.getrandbits(length * 8).to_bytes(length, "little")) + +    if lower_case: +        result = result.lower() +    elif upper_case: +        result = result.upper() + +    if max_length and len(result) > max_length: +        end = random.randint(min_length or 0, max_length) +        return result[:end] + +    return result + + +def create_random_string( +    random: Random, +    min_length: int | None = None, +    max_length: int | None = None, +    lower_case: bool = False, +    upper_case: bool = False, +) -> str: +    """Generate a random string given the constraints. + +    :param random: An instance of random. +    :param min_length: A minimum length. +    :param max_length: A maximum length. +    :param lower_case: Whether to lowercase the result. +    :param upper_case: Whether to uppercase the result. + +    :returns: A random string. +    """ +    return create_random_bytes( +        random=random, +        min_length=min_length, +        max_length=max_length, +        lower_case=lower_case, +        upper_case=upper_case, +    ).decode("utf-8") + + +def create_random_boolean(random: Random) -> bool: +    """Generate a random boolean value. + +    :param random: An instance of random. + +    :returns: A random boolean. +    """ +    return bool(random.getrandbits(1)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py b/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py new file mode 100644 index 0000000..eab9bd0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/polyfactory/value_generators/regex.py @@ -0,0 +1,150 @@ +"""The code in this files is adapted from https://github.com/crdoconnor/xeger/blob/master/xeger/xeger.py.Which in turn +adapted it from https://bitbucket.org/leapfrogdevelopment/rstr/. + +Copyright (C) 2015, Colm O'Connor +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: +    * Redistributions of source code must retain the above copyright +      notice, this list of conditions and the following disclaimer. +    * Redistributions in binary form must reproduce the above copyright +      notice, this list of conditions and the following disclaimer in the +      documentation and/or other materials provided with the distribution. +    * Neither the name of the Leapfrog Direct Response, LLC, including +      its subsidiaries and affiliates nor the names of its +      contributors, may be used to endorse or promote products derived +      from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL LEAPFROG DIRECT +RESPONSE, LLC, INCLUDING ITS SUBSIDIARIES AND AFFILIATES, BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from __future__ import annotations + +from itertools import chain +from string import ( +    ascii_letters, +    ascii_lowercase, +    ascii_uppercase, +    digits, +    printable, +    punctuation, +    whitespace, +) +from typing import TYPE_CHECKING, Any, Pattern + +try:  # >=3.11 +    from re._parser import SubPattern, parse +except ImportError:  # < 3.11 +    from sre_parse import SubPattern, parse  # pylint: disable=deprecated-module + +if TYPE_CHECKING: +    from random import Random + +_alphabets = { +    "printable": printable, +    "letters": ascii_letters, +    "uppercase": ascii_uppercase, +    "lowercase": ascii_lowercase, +    "digits": digits, +    "punctuation": punctuation, +    "nondigits": ascii_letters + punctuation, +    "nonletters": digits + punctuation, +    "whitespace": whitespace, +    "nonwhitespace": printable.strip(), +    "normal": ascii_letters + digits + " ", +    "word": ascii_letters + digits + "_", +    "nonword": "".join(set(printable).difference(ascii_letters + digits + "_")), +    "postalsafe": ascii_letters + digits + " .-#/", +    "urlsafe": ascii_letters + digits + "-._~", +    "domainsafe": ascii_letters + digits + "-", +} + +_categories = { +    "category_digit": _alphabets["digits"], +    "category_not_digit": _alphabets["nondigits"], +    "category_space": _alphabets["whitespace"], +    "category_not_space": _alphabets["nonwhitespace"], +    "category_word": _alphabets["word"], +    "category_not_word": _alphabets["nonword"], +} + + +class RegexFactory: +    """Factory for regexes.""" + +    def __init__(self, random: Random, limit: int = 10) -> None: +        """Create a RegexFactory""" +        self._limit = limit +        self._cache: dict[str, Any] = {} +        self._random = random + +        self._cases = { +            "literal": chr, +            "not_literal": lambda x: self._random.choice(printable.replace(chr(x), "")), +            "at": lambda x: "", +            "in": self._handle_in, +            "any": lambda x: self._random.choice(printable.replace("\n", "")), +            "range": lambda x: [chr(i) for i in range(x[0], x[1] + 1)], +            "category": lambda x: _categories[str(x).lower()], +            "branch": lambda x: "".join(self._handle_state(i) for i in self._random.choice(x[1])), +            "subpattern": self._handle_group, +            "assert": lambda x: "".join(self._handle_state(i) for i in x[1]), +            "assert_not": lambda x: "", +            "groupref": lambda x: self._cache[x], +            "min_repeat": lambda x: self._handle_repeat(*x), +            "max_repeat": lambda x: self._handle_repeat(*x), +            "negate": lambda x: [False], +        } + +    def __call__(self, string_or_regex: str | Pattern) -> str: +        """Generate a string matching a regex. + +        :param string_or_regex: A string or pattern. + +        :return: The generated string. +        """ +        pattern = string_or_regex.pattern if isinstance(string_or_regex, Pattern) else string_or_regex +        parsed = parse(pattern) +        result = self._build_string(parsed)  # pyright: ignore[reportGeneralTypeIssues] +        self._cache.clear() +        return result + +    def _build_string(self, parsed: SubPattern) -> str: +        return "".join([self._handle_state(state) for state in parsed])  # pyright:ignore[reportGeneralTypeIssues] + +    def _handle_state(self, state: tuple[SubPattern, tuple[Any, ...]]) -> Any: +        opcode, value = state +        return self._cases[str(opcode).lower()](value)  # type: ignore[no-untyped-call] + +    def _handle_group(self, value: tuple[Any, ...]) -> str: +        result = "".join(self._handle_state(i) for i in value[3]) +        if value[0]: +            self._cache[value[0]] = result +        return result + +    def _handle_in(self, value: tuple[Any, ...]) -> Any: +        candidates = list(chain(*(self._handle_state(i) for i in value))) +        if candidates and candidates[0] is False: +            candidates = list(set(printable).difference(candidates[1:])) +            return self._random.choice(candidates) +        return self._random.choice(candidates) + +    def _handle_repeat(self, start_range: int, end_range: Any, value: SubPattern) -> str: +        end_range = min(end_range, self._limit) + +        result = [ +            "".join(self._handle_state(v) for v in list(value))  # pyright: ignore[reportGeneralTypeIssues] +            for _ in range(self._random.randint(start_range, max(start_range, end_range))) +        ] +        return "".join(result) | 
