summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/polyfactory/factories
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory/factories')
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py5
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pycbin0 -> 507 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pycbin0 -> 3571 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pycbin0 -> 57046 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pycbin0 -> 4879 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pycbin0 -> 2741 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pycbin0 -> 4502 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pycbin0 -> 4800 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pycbin0 -> 29116 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pycbin0 -> 13671 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pycbin0 -> 2898 bytes
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py82
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/base.py1127
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py87
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py58
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py72
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py60
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py554
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py186
-rw-r--r--venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py61
20 files changed, 2292 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py
new file mode 100644
index 0000000..c8a9b92
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py
@@ -0,0 +1,5 @@
+from polyfactory.factories.base import BaseFactory
+from polyfactory.factories.dataclass_factory import DataclassFactory
+from polyfactory.factories.typed_dict_factory import TypedDictFactory
+
+__all__ = ("BaseFactory", "TypedDictFactory", "DataclassFactory")
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..0ebad18
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc
new file mode 100644
index 0000000..b5ca945
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000..2ee4cdb
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc
new file mode 100644
index 0000000..0fc27ab
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc
new file mode 100644
index 0000000..6e59c83
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc
new file mode 100644
index 0000000..5e528d3
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc
new file mode 100644
index 0000000..6f45693
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc
new file mode 100644
index 0000000..ca70ca8
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc
new file mode 100644
index 0000000..31d93d5
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc
new file mode 100644
index 0000000..a1ec583
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py
new file mode 100644
index 0000000..00ffa03
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from inspect import isclass
+from typing import TYPE_CHECKING, Generic, TypeVar
+
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.base import BaseFactory
+from polyfactory.field_meta import FieldMeta, Null
+
+if TYPE_CHECKING:
+ from typing import Any, TypeGuard
+
+
+try:
+ import attrs
+ from attr._make import Factory
+ from attrs import AttrsInstance
+except ImportError as ex:
+ msg = "attrs is not installed"
+ raise MissingDependencyException(msg) from ex
+
+
+T = TypeVar("T", bound=AttrsInstance)
+
+
+class AttrsFactory(Generic[T], BaseFactory[T]):
+ """Base factory for attrs classes."""
+
+ __model__: type[T]
+
+ __is_base_factory__ = True
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
+ return isclass(value) and hasattr(value, "__attrs_attrs__")
+
+ @classmethod
+ def get_model_fields(cls) -> list[FieldMeta]:
+ field_metas: list[FieldMeta] = []
+ none_type = type(None)
+
+ cls.resolve_types(cls.__model__)
+ fields = attrs.fields(cls.__model__)
+
+ for field in fields:
+ if not field.init:
+ continue
+
+ annotation = none_type if field.type is None else field.type
+
+ default = field.default
+ if isinstance(default, Factory):
+ # The default value is not currently being used when generating
+ # the field values. When that is implemented, this would need
+ # to be handled differently since the `default.factory` could
+ # take a `self` argument.
+ default_value = default.factory
+ elif default is None:
+ default_value = Null
+ else:
+ default_value = default
+
+ field_metas.append(
+ FieldMeta.from_type(
+ annotation=annotation,
+ name=field.alias,
+ default=default_value,
+ random=cls.__random__,
+ ),
+ )
+
+ return field_metas
+
+ @classmethod
+ def resolve_types(cls, model: type[T], **kwargs: Any) -> None:
+ """Resolve any strings and forward annotations in type annotations.
+
+ :param model: The model to resolve the type annotations for.
+ :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`.
+ """
+
+ attrs.resolve_types(model, **kwargs)
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py
new file mode 100644
index 0000000..60fe7a7
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py
@@ -0,0 +1,1127 @@
+from __future__ import annotations
+
+import copy
+from abc import ABC, abstractmethod
+from collections import Counter, abc, deque
+from contextlib import suppress
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from enum import EnumMeta
+from functools import partial
+from importlib import import_module
+from ipaddress import (
+ IPv4Address,
+ IPv4Interface,
+ IPv4Network,
+ IPv6Address,
+ IPv6Interface,
+ IPv6Network,
+ ip_address,
+ ip_interface,
+ ip_network,
+)
+from os.path import realpath
+from pathlib import Path
+from random import Random
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ClassVar,
+ Collection,
+ Generic,
+ Iterable,
+ Mapping,
+ Sequence,
+ Type,
+ TypedDict,
+ TypeVar,
+ cast,
+)
+from uuid import UUID
+
+from faker import Faker
+from typing_extensions import get_args, get_origin, get_original_bases
+
+from polyfactory.constants import (
+ DEFAULT_RANDOM,
+ MAX_COLLECTION_LENGTH,
+ MIN_COLLECTION_LENGTH,
+ RANDOMIZE_COLLECTION_LENGTH,
+)
+from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
+from polyfactory.field_meta import Null
+from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
+from polyfactory.utils.helpers import (
+ flatten_annotation,
+ get_collection_type,
+ unwrap_annotation,
+ unwrap_args,
+ unwrap_optional,
+)
+from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
+from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
+from polyfactory.utils.types import NoneType
+from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
+from polyfactory.value_generators.constrained_collections import (
+ handle_constrained_collection,
+ handle_constrained_mapping,
+)
+from polyfactory.value_generators.constrained_dates import handle_constrained_date
+from polyfactory.value_generators.constrained_numbers import (
+ handle_constrained_decimal,
+ handle_constrained_float,
+ handle_constrained_int,
+)
+from polyfactory.value_generators.constrained_path import handle_constrained_path
+from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes
+from polyfactory.value_generators.constrained_url import handle_constrained_url
+from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid
+from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeGuard
+
+ from polyfactory.field_meta import Constraints, FieldMeta
+ from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
+
+
+T = TypeVar("T")
+F = TypeVar("F", bound="BaseFactory[Any]")
+
+
+class BuildContext(TypedDict):
+ seen_models: set[type]
+
+
+def _get_build_context(build_context: BuildContext | None) -> BuildContext:
+ if build_context is None:
+ return {"seen_models": set()}
+
+ return copy.deepcopy(build_context)
+
+
+class BaseFactory(ABC, Generic[T]):
+ """Base Factory class - this class holds the main logic of the library"""
+
+ # configuration attributes
+ __model__: type[T]
+ """
+ The model for the factory.
+ This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument.
+ """
+ __check_model__: bool = False
+ """
+ Flag dictating whether to check if fields defined on the factory exists on the model or not.
+ If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only.
+ """
+ __allow_none_optionals__: ClassVar[bool] = True
+ """
+ Flag dictating whether to allow 'None' for optional values.
+ If 'True', 'None' will be randomly generated as a value for optional model fields
+ """
+ __sync_persistence__: type[SyncPersistenceProtocol[T]] | SyncPersistenceProtocol[T] | None = None
+ """A sync persistence handler. Can be a class or a class instance."""
+ __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None
+ """An async persistence handler. Can be a class or a class instance."""
+ __set_as_default_factory_for_type__ = False
+ """
+ Flag dictating whether to set as the default factory for the given type.
+ If 'True' the factory will be used instead of dynamically generating a factory for the type.
+ """
+ __is_base_factory__: bool = False
+ """
+ Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types.
+ For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories.
+ """
+ __base_factory_overrides__: dict[Any, type[BaseFactory[Any]]] | None = None
+ """
+ A base factory to override with this factory. If this value is set, the given factory will replace the given base factory.
+
+ Note: this value can only be set when '__is_base_factory__' is 'True'.
+ """
+ __faker__: ClassVar["Faker"] = Faker()
+ """
+ A faker instance to use. Can be a user provided value.
+ """
+ __random__: ClassVar["Random"] = DEFAULT_RANDOM
+ """
+ An instance of 'random.Random' to use.
+ """
+ __random_seed__: ClassVar[int]
+ """
+ An integer to seed the factory's Faker and Random instances with.
+ This attribute can be used to control random generation.
+ """
+ __randomize_collection_length__: ClassVar[bool] = RANDOMIZE_COLLECTION_LENGTH
+ """
+ Flag dictating whether to randomize collections lengths.
+ """
+ __min_collection_length__: ClassVar[int] = MIN_COLLECTION_LENGTH
+ """
+ An integer value that defines minimum length of a collection.
+ """
+ __max_collection_length__: ClassVar[int] = MAX_COLLECTION_LENGTH
+ """
+ An integer value that defines maximum length of a collection.
+ """
+ __use_defaults__: ClassVar[bool] = False
+ """
+ Flag indicating whether to use the default value on a specific field, if provided.
+ """
+
+ __config_keys__: tuple[str, ...] = (
+ "__check_model__",
+ "__allow_none_optionals__",
+ "__set_as_default_factory_for_type__",
+ "__faker__",
+ "__random__",
+ "__randomize_collection_length__",
+ "__min_collection_length__",
+ "__max_collection_length__",
+ "__use_defaults__",
+ )
+ """Keys to be considered as config values to pass on to dynamically created factories."""
+
+ # cached attributes
+ _fields_metadata: list[FieldMeta]
+ # BaseFactory only attributes
+ _factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]]
+ _base_factories: ClassVar[list[type[BaseFactory[Any]]]]
+
+ # Non-public attributes
+ _extra_providers: dict[Any, Callable[[], Any]] | None = None
+
+ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901
+ super().__init_subclass__(*args, **kwargs)
+
+ if not hasattr(BaseFactory, "_base_factories"):
+ BaseFactory._base_factories = []
+
+ if not hasattr(BaseFactory, "_factory_type_mapping"):
+ BaseFactory._factory_type_mapping = {}
+
+ if cls.__min_collection_length__ > cls.__max_collection_length__:
+ msg = "Minimum collection length shouldn't be greater than maximum collection length"
+ raise ConfigurationException(
+ msg,
+ )
+
+ if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__:
+ model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type()
+ if not model:
+ msg = f"required configuration attribute '__model__' is not set on {cls.__name__}"
+ raise ConfigurationException(
+ msg,
+ )
+ cls.__model__ = model
+ if not cls.is_supported_type(model):
+ for factory in BaseFactory._base_factories:
+ if factory.is_supported_type(model):
+ msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}"
+ raise ConfigurationException(
+ msg,
+ )
+ msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory."
+ raise ConfigurationException(
+ msg,
+ )
+ if cls.__check_model__:
+ cls._check_declared_fields_exist_in_model()
+ else:
+ BaseFactory._base_factories.append(cls)
+
+ random_seed = getattr(cls, "__random_seed__", None)
+ if random_seed is not None:
+ cls.seed_random(random_seed)
+
+ if cls.__set_as_default_factory_for_type__ and hasattr(cls, "__model__"):
+ BaseFactory._factory_type_mapping[cls.__model__] = cls
+
+ @classmethod
+ def _infer_model_type(cls: type[F]) -> type[T] | None:
+ """Return model type inferred from class declaration.
+ class Foo(ModelFactory[MyModel]): # <<< MyModel
+ ...
+
+ If more than one base class and/or generic arguments specified return None.
+
+ :returns: Inferred model type or None
+ """
+
+ factory_bases: Iterable[type[BaseFactory[T]]] = (
+ b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory)
+ )
+ generic_args: Sequence[type[T]] = [
+ arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar)
+ ]
+ if len(generic_args) != 1:
+ return None
+
+ return generic_args[0]
+
+ @classmethod
+ def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]:
+ """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException.
+
+ :raises: ConfigurationException
+ :returns: SyncPersistenceHandler
+ """
+ if cls.__sync_persistence__:
+ return cls.__sync_persistence__() if callable(cls.__sync_persistence__) else cls.__sync_persistence__
+ msg = "A '__sync_persistence__' handler must be defined in the factory to use this method"
+ raise ConfigurationException(
+ msg,
+ )
+
+ @classmethod
+ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
+ """Return a AsyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException.
+
+ :raises: ConfigurationException
+ :returns: AsyncPersistenceHandler
+ """
+ if cls.__async_persistence__:
+ return cls.__async_persistence__() if callable(cls.__async_persistence__) else cls.__async_persistence__
+ msg = "An '__async_persistence__' handler must be defined in the factory to use this method"
+ raise ConfigurationException(
+ msg,
+ )
+
+ @classmethod
+ def _handle_factory_field(
+ cls,
+ field_value: Any,
+ build_context: BuildContext,
+ field_build_parameters: Any | None = None,
+ ) -> Any:
+ """Handle a value defined on the factory class itself.
+
+ :param field_value: A value defined as an attribute on the factory class.
+ :param field_build_parameters: Any build parameters passed to the factory as kwarg values.
+
+ :returns: An arbitrary value correlating with the given field_meta value.
+ """
+ if is_safe_subclass(field_value, BaseFactory):
+ if isinstance(field_build_parameters, Mapping):
+ return field_value.build(_build_context=build_context, **field_build_parameters)
+
+ if isinstance(field_build_parameters, Sequence):
+ return [
+ field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters
+ ]
+
+ return field_value.build(_build_context=build_context)
+
+ if isinstance(field_value, Use):
+ return field_value.to_value()
+
+ if isinstance(field_value, Fixture):
+ return field_value.to_value()
+
+ return field_value() if callable(field_value) else field_value
+
+ @classmethod
+ def _handle_factory_field_coverage(
+ cls,
+ field_value: Any,
+ field_build_parameters: Any | None = None,
+ build_context: BuildContext | None = None,
+ ) -> Any:
+ """Handle a value defined on the factory class itself.
+
+ :param field_value: A value defined as an attribute on the factory class.
+ :param field_build_parameters: Any build parameters passed to the factory as kwarg values.
+
+ :returns: An arbitrary value correlating with the given field_meta value.
+ """
+ if is_safe_subclass(field_value, BaseFactory):
+ if isinstance(field_build_parameters, Mapping):
+ return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters))
+
+ if isinstance(field_build_parameters, Sequence):
+ return [
+ CoverageContainer(field_value.coverage(_build_context=build_context, **parameter))
+ for parameter in field_build_parameters
+ ]
+
+ return CoverageContainer(field_value.coverage())
+
+ if isinstance(field_value, Use):
+ return field_value.to_value()
+
+ if isinstance(field_value, Fixture):
+ return CoverageContainerCallable(field_value.to_value)
+
+ return CoverageContainerCallable(field_value) if callable(field_value) else field_value
+
+ @classmethod
+ def _get_config(cls) -> dict[str, Any]:
+ return {
+ **{key: getattr(cls, key) for key in cls.__config_keys__},
+ "_extra_providers": cls.get_provider_map(),
+ }
+
+ @classmethod
+ def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
+ """Get a factory from registered factories or generate a factory dynamically.
+
+ :param model: A model type.
+ :returns: A Factory sub-class.
+
+ """
+ if factory := BaseFactory._factory_type_mapping.get(model):
+ return factory
+
+ config = cls._get_config()
+
+ if cls.__base_factory_overrides__:
+ for model_ancestor in model.mro():
+ if factory := cls.__base_factory_overrides__.get(model_ancestor):
+ return factory.create_factory(model, **config)
+
+ for factory in reversed(BaseFactory._base_factories):
+ if factory.is_supported_type(model):
+ return factory.create_factory(model, **config)
+
+ msg = f"unsupported model type {model.__name__}"
+ raise ParameterException(msg) # pragma: no cover
+
+ # Public Methods
+
+ @classmethod
+ def is_factory_type(cls, annotation: Any) -> bool:
+ """Determine whether a given field is annotated with a type that is supported by a base factory.
+
+ :param annotation: A type annotation.
+ :returns: Boolean dictating whether the annotation is a factory type
+ """
+ return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories)
+
+ @classmethod
+ def is_batch_factory_type(cls, annotation: Any) -> bool:
+ """Determine whether a given field is annotated with a sequence of supported factory types.
+
+ :param annotation: A type annotation.
+ :returns: Boolean dictating whether the annotation is a batch factory type
+ """
+ origin = get_type_origin(annotation) or annotation
+ if is_safe_subclass(origin, Sequence) and (args := unwrap_args(annotation, random=cls.__random__)):
+ return len(args) == 1 and BaseFactory.is_factory_type(annotation=args[0])
+ return False
+
+ @classmethod
+ def extract_field_build_parameters(cls, field_meta: FieldMeta, build_args: dict[str, Any]) -> Any:
+ """Extract from the build kwargs any build parameters passed for a given field meta - if it is a factory type.
+
+ :param field_meta: A field meta instance.
+ :param build_args: Any kwargs passed to the factory.
+ :returns: Any values
+ """
+ if build_arg := build_args.get(field_meta.name):
+ annotation = unwrap_optional(field_meta.annotation)
+ if (
+ BaseFactory.is_factory_type(annotation=annotation)
+ and isinstance(build_arg, Mapping)
+ and not BaseFactory.is_factory_type(annotation=type(build_arg))
+ ):
+ return build_args.pop(field_meta.name)
+
+ if (
+ BaseFactory.is_batch_factory_type(annotation=annotation)
+ and isinstance(build_arg, Sequence)
+ and not any(BaseFactory.is_factory_type(annotation=type(value)) for value in build_arg)
+ ):
+ return build_args.pop(field_meta.name)
+ return None
+
+ @classmethod
+ @abstractmethod
+ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": # pragma: no cover
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def seed_random(cls, seed: int) -> None:
+ """Seed faker and random with the given integer.
+
+ :param seed: An integer to set as seed.
+ :returns: 'None'
+
+ """
+ cls.__random__ = Random(seed)
+ cls.__faker__.seed_instance(seed)
+
+ @classmethod
+ def is_ignored_type(cls, value: Any) -> bool:
+ """Check whether a given value is an ignored type.
+
+ :param value: An arbitrary value.
+
+ :notes:
+ - This method is meant to be overwritten by extension factories and other subclasses
+
+ :returns: A boolean determining whether the value should be ignored.
+
+ """
+ return value is None
+
+ @classmethod
+ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
+ """Map types to callables.
+
+ :notes:
+ - This method is distinct to allow overriding.
+
+
+ :returns: a dictionary mapping types to callables.
+
+ """
+
+ def _create_generic_fn() -> Callable:
+ """Return a generic lambda"""
+ return lambda *args: None
+
+ return {
+ Any: lambda: None,
+ # primitives
+ object: object,
+ float: cls.__faker__.pyfloat,
+ int: cls.__faker__.pyint,
+ bool: cls.__faker__.pybool,
+ str: cls.__faker__.pystr,
+ bytes: partial(create_random_bytes, cls.__random__),
+ # built-in objects
+ dict: cls.__faker__.pydict,
+ tuple: cls.__faker__.pytuple,
+ list: cls.__faker__.pylist,
+ set: cls.__faker__.pyset,
+ frozenset: lambda: frozenset(cls.__faker__.pylist()),
+ deque: lambda: deque(cls.__faker__.pylist()),
+ # standard library objects
+ Path: lambda: Path(realpath(__file__)),
+ Decimal: cls.__faker__.pydecimal,
+ UUID: lambda: UUID(cls.__faker__.uuid4()),
+ # datetime
+ datetime: cls.__faker__.date_time_between,
+ date: cls.__faker__.date_this_decade,
+ time: cls.__faker__.time_object,
+ timedelta: cls.__faker__.time_delta,
+ # ip addresses
+ IPv4Address: lambda: ip_address(cls.__faker__.ipv4()),
+ IPv4Interface: lambda: ip_interface(cls.__faker__.ipv4()),
+ IPv4Network: lambda: ip_network(cls.__faker__.ipv4(network=True)),
+ IPv6Address: lambda: ip_address(cls.__faker__.ipv6()),
+ IPv6Interface: lambda: ip_interface(cls.__faker__.ipv6()),
+ IPv6Network: lambda: ip_network(cls.__faker__.ipv6(network=True)),
+ # types
+ Callable: _create_generic_fn,
+ abc.Callable: _create_generic_fn,
+ Counter: lambda: Counter(cls.__faker__.pystr()),
+ **(cls._extra_providers or {}),
+ }
+
+ @classmethod
+ def create_factory(
+ cls: type[F],
+ model: type[T] | None = None,
+ bases: tuple[type[BaseFactory[Any]], ...] | None = None,
+ **kwargs: Any,
+ ) -> type[F]:
+ """Generate a factory for the given type dynamically.
+
+ :param model: A type to model. Defaults to current factory __model__ if any.
+ Otherwise, raise an error
+ :param bases: Base classes to use when generating the new class.
+ :param kwargs: Any kwargs.
+
+ :returns: A 'ModelFactory' subclass.
+
+ """
+ if model is None:
+ try:
+ model = cls.__model__
+ except AttributeError as ex:
+ msg = "A 'model' argument is required when creating a new factory from a base one"
+ raise TypeError(msg) from ex
+ return cast(
+ "Type[F]",
+ type(
+ f"{model.__name__}Factory", # pyright: ignore[reportOptionalMemberAccess]
+ (*(bases or ()), cls),
+ {"__model__": model, **kwargs},
+ ),
+ )
+
+ @classmethod
+ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: # noqa: C901, PLR0911, PLR0912
+ try:
+ constraints = cast("Constraints", field_meta.constraints)
+ if is_safe_subclass(annotation, float):
+ return handle_constrained_float(
+ random=cls.__random__,
+ multiple_of=cast("Any", constraints.get("multiple_of")),
+ gt=cast("Any", constraints.get("gt")),
+ ge=cast("Any", constraints.get("ge")),
+ lt=cast("Any", constraints.get("lt")),
+ le=cast("Any", constraints.get("le")),
+ )
+
+ if is_safe_subclass(annotation, int):
+ return handle_constrained_int(
+ random=cls.__random__,
+ multiple_of=cast("Any", constraints.get("multiple_of")),
+ gt=cast("Any", constraints.get("gt")),
+ ge=cast("Any", constraints.get("ge")),
+ lt=cast("Any", constraints.get("lt")),
+ le=cast("Any", constraints.get("le")),
+ )
+
+ if is_safe_subclass(annotation, Decimal):
+ return handle_constrained_decimal(
+ random=cls.__random__,
+ decimal_places=cast("Any", constraints.get("decimal_places")),
+ max_digits=cast("Any", constraints.get("max_digits")),
+ multiple_of=cast("Any", constraints.get("multiple_of")),
+ gt=cast("Any", constraints.get("gt")),
+ ge=cast("Any", constraints.get("ge")),
+ lt=cast("Any", constraints.get("lt")),
+ le=cast("Any", constraints.get("le")),
+ )
+
+ if url_constraints := constraints.get("url"):
+ return handle_constrained_url(constraints=url_constraints)
+
+ if is_safe_subclass(annotation, str) or is_safe_subclass(annotation, bytes):
+ return handle_constrained_string_or_bytes(
+ random=cls.__random__,
+ t_type=str if is_safe_subclass(annotation, str) else bytes,
+ lower_case=constraints.get("lower_case") or False,
+ upper_case=constraints.get("upper_case") or False,
+ min_length=constraints.get("min_length"),
+ max_length=constraints.get("max_length"),
+ pattern=constraints.get("pattern"),
+ )
+
+ try:
+ collection_type = get_collection_type(annotation)
+ except ValueError:
+ collection_type = None
+ if collection_type is not None:
+ if collection_type == dict:
+ return handle_constrained_mapping(
+ factory=cls,
+ field_meta=field_meta,
+ min_items=constraints.get("min_length"),
+ max_items=constraints.get("max_length"),
+ )
+ return handle_constrained_collection(
+ collection_type=collection_type, # type: ignore[type-var]
+ factory=cls,
+ field_meta=field_meta.children[0] if field_meta.children else field_meta,
+ item_type=constraints.get("item_type"),
+ max_items=constraints.get("max_length"),
+ min_items=constraints.get("min_length"),
+ unique_items=constraints.get("unique_items", False),
+ )
+
+ if is_safe_subclass(annotation, date):
+ return handle_constrained_date(
+ faker=cls.__faker__,
+ ge=cast("Any", constraints.get("ge")),
+ gt=cast("Any", constraints.get("gt")),
+ le=cast("Any", constraints.get("le")),
+ lt=cast("Any", constraints.get("lt")),
+ tz=cast("Any", constraints.get("tz")),
+ )
+
+ if is_safe_subclass(annotation, UUID) and (uuid_version := constraints.get("uuid_version")):
+ return handle_constrained_uuid(
+ uuid_version=uuid_version,
+ faker=cls.__faker__,
+ )
+
+ if is_safe_subclass(annotation, Path) and (path_constraint := constraints.get("path_type")):
+ return handle_constrained_path(constraint=path_constraint, faker=cls.__faker__)
+ except TypeError as e:
+ raise ParameterException from e
+
+ msg = f"received constraints for unsupported type {annotation}"
+ raise ParameterException(msg)
+
+ @classmethod
+ def get_field_value( # noqa: C901, PLR0911, PLR0912
+ cls,
+ field_meta: FieldMeta,
+ field_build_parameters: Any | None = None,
+ build_context: BuildContext | None = None,
+ ) -> Any:
+ """Return a field value on the subclass if existing, otherwise returns a mock value.
+
+ :param field_meta: FieldMeta instance.
+ :param field_build_parameters: Any build parameters passed to the factory as kwarg values.
+ :param build_context: BuildContext data for current build.
+
+ :returns: An arbitrary value.
+
+ """
+ build_context = _get_build_context(build_context)
+ if cls.is_ignored_type(field_meta.annotation):
+ return None
+
+ if field_build_parameters is None and cls.should_set_none_value(field_meta=field_meta):
+ return None
+
+ unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=cls.__random__)
+
+ if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
+ return cls.__random__.choice(literal_args)
+
+ if isinstance(unwrapped_annotation, EnumMeta):
+ return cls.__random__.choice(list(unwrapped_annotation))
+
+ if field_meta.constraints:
+ return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta)
+
+ if is_union(field_meta.annotation) and field_meta.children:
+ seen_models = build_context["seen_models"]
+ children = [child for child in field_meta.children if child.annotation not in seen_models]
+
+ # `None` is removed from the children when creating FieldMeta so when `children`
+ # is empty, it must mean that the field meta is an optional type.
+ if children:
+ return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context)
+
+ if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
+ if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
+ return None if is_optional(field_meta.annotation) else Null
+
+ return cls._get_or_create_factory(model=unwrapped_annotation).build(
+ _build_context=build_context,
+ **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
+ )
+
+ if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
+ factory = cls._get_or_create_factory(model=field_meta.type_args[0])
+ if isinstance(field_build_parameters, Sequence):
+ return [
+ factory.build(_build_context=build_context, **field_parameters)
+ for field_parameters in field_build_parameters
+ ]
+
+ if field_meta.type_args[0] in build_context["seen_models"]:
+ return []
+
+ if not cls.__randomize_collection_length__:
+ return [factory.build(_build_context=build_context)]
+
+ batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__)
+ return factory.batch(size=batch_size, _build_context=build_context)
+
+ if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection):
+ if cls.__randomize_collection_length__:
+ collection_type = get_collection_type(unwrapped_annotation)
+ if collection_type != dict:
+ return handle_constrained_collection(
+ collection_type=collection_type, # type: ignore[type-var]
+ factory=cls,
+ item_type=Any,
+ field_meta=field_meta.children[0] if field_meta.children else field_meta,
+ min_items=cls.__min_collection_length__,
+ max_items=cls.__max_collection_length__,
+ )
+ return handle_constrained_mapping(
+ factory=cls,
+ field_meta=field_meta,
+ min_items=cls.__min_collection_length__,
+ max_items=cls.__max_collection_length__,
+ )
+
+ return handle_collection_type(field_meta, origin, cls)
+
+ if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
+ return create_random_string(cls.__random__, min_length=1, max_length=10)
+
+ if provider := cls.get_provider_map().get(unwrapped_annotation):
+ return provider()
+
+ if callable(unwrapped_annotation):
+ # if value is a callable we can try to naively call it.
+ # this will work for callables that do not require any parameters passed
+ with suppress(Exception):
+ return unwrapped_annotation()
+
+ msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type."
+ raise ParameterException(
+ msg,
+ )
+
+ @classmethod
+ def get_field_value_coverage( # noqa: C901
+ cls,
+ field_meta: FieldMeta,
+ field_build_parameters: Any | None = None,
+ build_context: BuildContext | None = None,
+ ) -> Iterable[Any]:
+ """Return a field value on the subclass if existing, otherwise returns a mock value.
+
+ :param field_meta: FieldMeta instance.
+ :param field_build_parameters: Any build parameters passed to the factory as kwarg values.
+ :param build_context: BuildContext data for current build.
+
+ :returns: An iterable of values.
+
+ """
+ if cls.is_ignored_type(field_meta.annotation):
+ return [None]
+
+ for unwrapped_annotation in flatten_annotation(field_meta.annotation):
+ if unwrapped_annotation in (None, NoneType):
+ yield None
+
+ elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
+ yield CoverageContainer(literal_args)
+
+ elif isinstance(unwrapped_annotation, EnumMeta):
+ yield CoverageContainer(list(unwrapped_annotation))
+
+ elif field_meta.constraints:
+ yield CoverageContainerCallable(
+ cls.get_constrained_field_value,
+ annotation=unwrapped_annotation,
+ field_meta=field_meta,
+ )
+
+ elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
+ yield CoverageContainer(
+ cls._get_or_create_factory(model=unwrapped_annotation).coverage(
+ _build_context=build_context,
+ **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
+ ),
+ )
+
+ elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
+ yield handle_collection_type_coverage(field_meta, origin, cls)
+
+ elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
+ yield create_random_string(cls.__random__, min_length=1, max_length=10)
+
+ elif provider := cls.get_provider_map().get(unwrapped_annotation):
+ yield CoverageContainerCallable(provider)
+
+ elif callable(unwrapped_annotation):
+ # if value is a callable we can try to naively call it.
+ # this will work for callables that do not require any parameters passed
+ yield CoverageContainerCallable(unwrapped_annotation)
+ else:
+ msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type."
+ raise ParameterException(
+ msg,
+ )
+
+ @classmethod
+ def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
+ """Determine whether a given model field_meta should be set to None.
+
+ :param field_meta: Field metadata.
+
+ :notes:
+ - This method is distinct to allow overriding.
+
+ :returns: A boolean determining whether 'None' should be set for the given field_meta.
+
+ """
+ return (
+ cls.__allow_none_optionals__
+ and is_optional(field_meta.annotation)
+ and create_random_boolean(random=cls.__random__)
+ )
+
+ @classmethod
+ def should_use_default_value(cls, field_meta: FieldMeta) -> bool:
+ """Determine whether to use the default value for the given field.
+
+ :param field_meta: FieldMeta instance.
+
+ :notes:
+ - This method is distinct to allow overriding.
+
+ :returns: A boolean determining whether the default value should be used for the given field_meta.
+
+ """
+ return cls.__use_defaults__ and field_meta.default is not Null
+
+ @classmethod
+ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
+ """Determine whether to set a value for a given field_name.
+
+ :param field_meta: FieldMeta instance.
+ :param kwargs: Any kwargs passed to the factory.
+
+ :notes:
+ - This method is distinct to allow overriding.
+
+ :returns: A boolean determining whether a value should be set for the given field_meta.
+
+ """
+ return not field_meta.name.startswith("_") and field_meta.name not in kwargs
+
+ @classmethod
+ @abstractmethod
+ def get_model_fields(cls) -> list[FieldMeta]: # pragma: no cover
+ """Retrieve a list of fields from the factory's model.
+
+
+ :returns: A list of field MetaData instances.
+
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def get_factory_fields(cls) -> list[tuple[str, Any]]:
+ """Retrieve a list of fields from the factory.
+
+ Trying to be smart about what should be considered a field on the model,
+ ignoring dunder methods and some parent class attributes.
+
+ :returns: A list of tuples made of field name and field definition
+ """
+ factory_fields = cls.__dict__.items()
+ return [
+ (field_name, field_value)
+ for field_name, field_value in factory_fields
+ if not (field_name.startswith("__") or field_name == "_abc_impl")
+ ]
+
+ @classmethod
+ def _check_declared_fields_exist_in_model(cls) -> None:
+ model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()}
+ factory_fields = cls.get_factory_fields()
+
+ for field_name, field_value in factory_fields:
+ if field_name in model_fields_names:
+ continue
+
+ error_message = (
+ f"{field_name} is declared on the factory {cls.__name__}"
+ f" but it is not part of the model {cls.__model__.__name__}"
+ )
+ if isinstance(field_value, (Use, PostGenerated, Ignore, Require)):
+ raise ConfigurationException(error_message)
+
+ @classmethod
+ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
+ """Process the given kwargs and generate values for the factory's model.
+
+ :param kwargs: Any build kwargs.
+
+ :returns: A dictionary of build results.
+
+ """
+ _build_context = _get_build_context(kwargs.pop("_build_context", None))
+ _build_context["seen_models"].add(cls.__model__)
+
+ result: dict[str, Any] = {**kwargs}
+ generate_post: dict[str, PostGenerated] = {}
+
+ for field_meta in cls.get_model_fields():
+ field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
+ if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
+ if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
+ field_value = getattr(cls, field_meta.name)
+ if isinstance(field_value, Ignore):
+ continue
+
+ if isinstance(field_value, Require) and field_meta.name not in kwargs:
+ msg = f"Require kwarg {field_meta.name} is missing"
+ raise MissingBuildKwargException(msg)
+
+ if isinstance(field_value, PostGenerated):
+ generate_post[field_meta.name] = field_value
+ continue
+
+ result[field_meta.name] = cls._handle_factory_field(
+ field_value=field_value,
+ field_build_parameters=field_build_parameters,
+ build_context=_build_context,
+ )
+ continue
+
+ field_result = cls.get_field_value(
+ field_meta,
+ field_build_parameters=field_build_parameters,
+ build_context=_build_context,
+ )
+ if field_result is Null:
+ continue
+
+ result[field_meta.name] = field_result
+
+ for field_name, post_generator in generate_post.items():
+ result[field_name] = post_generator.to_value(field_name, result)
+
+ return result
+
+ @classmethod
+ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
+ """Process the given kwargs and generate values for the factory's model.
+
+ :param kwargs: Any build kwargs.
+ :param build_context: BuildContext data for current build.
+
+ :returns: A dictionary of build results.
+
+ """
+ _build_context = _get_build_context(kwargs.pop("_build_context", None))
+ _build_context["seen_models"].add(cls.__model__)
+
+ result: dict[str, Any] = {**kwargs}
+ generate_post: dict[str, PostGenerated] = {}
+
+ for field_meta in cls.get_model_fields():
+ field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
+
+ if cls.should_set_field_value(field_meta, **kwargs):
+ if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
+ field_value = getattr(cls, field_meta.name)
+ if isinstance(field_value, Ignore):
+ continue
+
+ if isinstance(field_value, Require) and field_meta.name not in kwargs:
+ msg = f"Require kwarg {field_meta.name} is missing"
+ raise MissingBuildKwargException(msg)
+
+ if isinstance(field_value, PostGenerated):
+ generate_post[field_meta.name] = field_value
+ continue
+
+ result[field_meta.name] = cls._handle_factory_field_coverage(
+ field_value=field_value,
+ field_build_parameters=field_build_parameters,
+ build_context=_build_context,
+ )
+ continue
+
+ result[field_meta.name] = CoverageContainer(
+ cls.get_field_value_coverage(
+ field_meta,
+ field_build_parameters=field_build_parameters,
+ build_context=_build_context,
+ ),
+ )
+
+ for resolved in resolve_kwargs_coverage(result):
+ for field_name, post_generator in generate_post.items():
+ resolved[field_name] = post_generator.to_value(field_name, resolved)
+ yield resolved
+
+ @classmethod
+ def build(cls, **kwargs: Any) -> T:
+ """Build an instance of the factory's __model__
+
+ :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.
+
+ :returns: An instance of type T.
+
+ """
+ return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
+
+ @classmethod
+ def batch(cls, size: int, **kwargs: Any) -> list[T]:
+ """Build a batch of size n of the factory's Meta.model.
+
+ :param size: Size of the batch.
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: A list of instances of type T.
+
+ """
+ return [cls.build(**kwargs) for _ in range(size)]
+
+ @classmethod
+ def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
+ """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
+
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: A iterator of instances of type T.
+
+ """
+ for data in cls.process_kwargs_coverage(**kwargs):
+ instance = cls.__model__(**data)
+ yield cast("T", instance)
+
+ @classmethod
+ def create_sync(cls, **kwargs: Any) -> T:
+ """Build and persists synchronously a single model instance.
+
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: An instance of type T.
+
+ """
+ return cls._get_sync_persistence().save(data=cls.build(**kwargs))
+
+ @classmethod
+ def create_batch_sync(cls, size: int, **kwargs: Any) -> list[T]:
+ """Build and persists synchronously a batch of n size model instances.
+
+ :param size: Size of the batch.
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: A list of instances of type T.
+
+ """
+ return cls._get_sync_persistence().save_many(data=cls.batch(size, **kwargs))
+
+ @classmethod
+ async def create_async(cls, **kwargs: Any) -> T:
+ """Build and persists asynchronously a single model instance.
+
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: An instance of type T.
+ """
+ return await cls._get_async_persistence().save(data=cls.build(**kwargs))
+
+ @classmethod
+ async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]:
+ """Build and persists asynchronously a batch of n size model instances.
+
+
+ :param size: Size of the batch.
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: A list of instances of type T.
+ """
+ return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs))
+
+
+def _register_builtin_factories() -> None:
+ """This function is used to register the base factories, if present.
+
+ :returns: None
+ """
+ import polyfactory.factories.dataclass_factory
+ import polyfactory.factories.typed_dict_factory # noqa: F401
+
+ for module in [
+ "polyfactory.factories.pydantic_factory",
+ "polyfactory.factories.beanie_odm_factory",
+ "polyfactory.factories.odmantic_odm_factory",
+ "polyfactory.factories.msgspec_factory",
+ # `AttrsFactory` is not being registered by default since not all versions of `attrs` are supported.
+ # Issue: https://github.com/litestar-org/polyfactory/issues/356
+ # "polyfactory.factories.attrs_factory",
+ ]:
+ try:
+ import_module(module)
+ except ImportError: # noqa: PERF203
+ continue
+
+
+_register_builtin_factories()
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py
new file mode 100644
index 0000000..ddd3169
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
+
+from typing_extensions import get_args
+
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.pydantic_factory import ModelFactory
+from polyfactory.persistence import AsyncPersistenceProtocol
+from polyfactory.utils.predicates import is_safe_subclass
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeGuard
+
+ from polyfactory.factories.base import BuildContext
+ from polyfactory.field_meta import FieldMeta
+
+try:
+ from beanie import Document
+except ImportError as e:
+ msg = "beanie is not installed"
+ raise MissingDependencyException(msg) from e
+
+T = TypeVar("T", bound=Document)
+
+
+class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]):
+ """Persistence Handler using beanie logic"""
+
+ async def save(self, data: T) -> T:
+ """Persist a single instance in mongoDB."""
+ return await data.insert() # type: ignore[no-any-return]
+
+ async def save_many(self, data: list[T]) -> list[T]:
+ """Persist multiple instances in mongoDB.
+
+ .. note:: we cannot use the ``.insert_many`` method from Beanie here because it doesn't
+ return the created instances
+ """
+ return [await doc.insert() for doc in data] # pyright: ignore[reportGeneralTypeIssues]
+
+
+class BeanieDocumentFactory(Generic[T], ModelFactory[T]):
+ """Base factory for Beanie Documents"""
+
+ __async_persistence__ = BeaniePersistenceHandler
+ __is_base_factory__ = True
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+ return is_safe_subclass(value, Document)
+
+ @classmethod
+ def get_field_value(
+ cls,
+ field_meta: "FieldMeta",
+ field_build_parameters: Any | None = None,
+ build_context: BuildContext | None = None,
+ ) -> Any:
+ """Return a field value on the subclass if existing, otherwise returns a mock value.
+
+ :param field_meta: FieldMeta instance.
+ :param field_build_parameters: Any build parameters passed to the factory as kwarg values.
+
+ :returns: An arbitrary value.
+
+ """
+ if hasattr(field_meta.annotation, "__name__"):
+ if "Indexed " in field_meta.annotation.__name__:
+ base_type = field_meta.annotation.__bases__[0]
+ field_meta.annotation = base_type
+
+ if "Link" in field_meta.annotation.__name__:
+ link_class = get_args(field_meta.annotation)[0]
+ field_meta.annotation = link_class
+ field_meta.annotation = link_class
+
+ return super().get_field_value(
+ field_meta=field_meta,
+ field_build_parameters=field_build_parameters,
+ build_context=build_context,
+ )
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py
new file mode 100644
index 0000000..01cfbe7
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py
@@ -0,0 +1,58 @@
+from __future__ import annotations
+
+from dataclasses import MISSING, fields, is_dataclass
+from typing import Any, Generic
+
+from typing_extensions import TypeGuard, get_type_hints
+
+from polyfactory.factories.base import BaseFactory, T
+from polyfactory.field_meta import FieldMeta, Null
+
+
+class DataclassFactory(Generic[T], BaseFactory[T]):
+ """Dataclass base factory"""
+
+ __is_base_factory__ = True
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+ return bool(is_dataclass(value))
+
+ @classmethod
+ def get_model_fields(cls) -> list["FieldMeta"]:
+ """Retrieve a list of fields from the factory's model.
+
+
+ :returns: A list of field MetaData instances.
+
+ """
+ fields_meta: list["FieldMeta"] = []
+
+ model_type_hints = get_type_hints(cls.__model__, include_extras=True)
+
+ for field in fields(cls.__model__): # type: ignore[arg-type]
+ if not field.init:
+ continue
+
+ if field.default_factory and field.default_factory is not MISSING:
+ default_value = field.default_factory()
+ elif field.default is not MISSING:
+ default_value = field.default
+ else:
+ default_value = Null
+
+ fields_meta.append(
+ FieldMeta.from_type(
+ annotation=model_type_hints[field.name],
+ name=field.name,
+ default=default_value,
+ random=cls.__random__,
+ ),
+ )
+
+ return fields_meta
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py
new file mode 100644
index 0000000..1b579ae
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+
+from inspect import isclass
+from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
+
+from typing_extensions import get_type_hints
+
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.base import BaseFactory
+from polyfactory.field_meta import FieldMeta, Null
+from polyfactory.value_generators.constrained_numbers import handle_constrained_int
+from polyfactory.value_generators.primitives import create_random_bytes
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeGuard
+
+try:
+ import msgspec
+ from msgspec.structs import fields
+except ImportError as e:
+ msg = "msgspec is not installed"
+ raise MissingDependencyException(msg) from e
+
+T = TypeVar("T", bound=msgspec.Struct)
+
+
+class MsgspecFactory(Generic[T], BaseFactory[T]):
+ """Base factory for msgspec Structs."""
+
+ __is_base_factory__ = True
+
+ @classmethod
+ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
+ def get_msgpack_ext() -> msgspec.msgpack.Ext:
+ code = handle_constrained_int(cls.__random__, ge=-128, le=127)
+ data = create_random_bytes(cls.__random__)
+ return msgspec.msgpack.Ext(code, data)
+
+ msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext}
+
+ provider_map = super().get_provider_map()
+ provider_map.update(msgspec_provider_map)
+
+ return provider_map
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
+ return isclass(value) and hasattr(value, "__struct_fields__")
+
+ @classmethod
+ def get_model_fields(cls) -> list[FieldMeta]:
+ fields_meta: list[FieldMeta] = []
+
+ type_hints = get_type_hints(cls.__model__, include_extras=True)
+ for field in fields(cls.__model__):
+ annotation = type_hints[field.name]
+ if field.default is not msgspec.NODEFAULT:
+ default_value = field.default
+ elif field.default_factory is not msgspec.NODEFAULT:
+ default_value = field.default_factory()
+ else:
+ default_value = Null
+
+ fields_meta.append(
+ FieldMeta.from_type(
+ annotation=annotation,
+ name=field.name,
+ default=default_value,
+ random=cls.__random__,
+ ),
+ )
+ return fields_meta
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py
new file mode 100644
index 0000000..1b3367a
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+import decimal
+from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
+
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.pydantic_factory import ModelFactory
+from polyfactory.utils.predicates import is_safe_subclass
+from polyfactory.value_generators.primitives import create_random_bytes
+
+try:
+ from bson.decimal128 import Decimal128, create_decimal128_context
+ from odmantic import EmbeddedModel, Model
+ from odmantic import bson as odbson
+
+except ImportError as e:
+ msg = "odmantic is not installed"
+ raise MissingDependencyException(msg) from e
+
+T = TypeVar("T", bound=Union[Model, EmbeddedModel])
+
+if TYPE_CHECKING:
+ from typing_extensions import TypeGuard
+
+
+class OdmanticModelFactory(Generic[T], ModelFactory[T]):
+ """Base factory for odmantic models"""
+
+ __is_base_factory__ = True
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]":
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+ return is_safe_subclass(value, (Model, EmbeddedModel))
+
+ @classmethod
+ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
+ provider_map = super().get_provider_map()
+ provider_map.update(
+ {
+ odbson.Int64: lambda: odbson.Int64.validate(cls.__faker__.pyint()),
+ odbson.Decimal128: lambda: _to_decimal128(cls.__faker__.pydecimal()),
+ odbson.Binary: lambda: odbson.Binary.validate(create_random_bytes(cls.__random__)),
+ odbson._datetime: lambda: odbson._datetime.validate(cls.__faker__.date_time_between()),
+ # bson.Regex and bson._Pattern not supported as there is no way to generate
+ # a random regular expression with Faker
+ # bson.Regex:
+ # bson._Pattern:
+ },
+ )
+ return provider_map
+
+
+def _to_decimal128(value: decimal.Decimal) -> Decimal128:
+ with decimal.localcontext(create_decimal128_context()) as ctx:
+ return Decimal128(ctx.create_decimal(value))
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py
new file mode 100644
index 0000000..a6028b1
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py
@@ -0,0 +1,554 @@
+from __future__ import annotations
+
+from contextlib import suppress
+from datetime import timezone
+from functools import partial
+from os.path import realpath
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast
+from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5
+
+from typing_extensions import Literal, get_args, get_origin
+
+from polyfactory.collection_extender import CollectionExtender
+from polyfactory.constants import DEFAULT_RANDOM
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.base import BaseFactory
+from polyfactory.field_meta import Constraints, FieldMeta, Null
+from polyfactory.utils.deprecation import check_for_deprecated_parameters
+from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
+from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union
+from polyfactory.utils.types import NoneType
+from polyfactory.value_generators.primitives import create_random_bytes
+
+try:
+ import pydantic
+ from pydantic import VERSION, Json
+ from pydantic.fields import FieldInfo
+except ImportError as e:
+ msg = "pydantic is not installed"
+ raise MissingDependencyException(msg) from e
+
+try:
+ # pydantic v1
+ from pydantic import ( # noqa: I001
+ UUID1,
+ UUID3,
+ UUID4,
+ UUID5,
+ AmqpDsn,
+ AnyHttpUrl,
+ AnyUrl,
+ DirectoryPath,
+ FilePath,
+ HttpUrl,
+ KafkaDsn,
+ PostgresDsn,
+ RedisDsn,
+ )
+ from pydantic import BaseModel as BaseModelV1
+ from pydantic.color import Color
+ from pydantic.fields import ( # type: ignore[attr-defined]
+ DeferredType, # pyright: ignore[reportGeneralTypeIssues]
+ ModelField, # pyright: ignore[reportGeneralTypeIssues]
+ Undefined, # pyright: ignore[reportGeneralTypeIssues]
+ )
+
+ # Keep this import last to prevent warnings from pydantic if pydantic v2
+ # is installed.
+ from pydantic import PyObject
+
+ # prevent unbound variable warnings
+ BaseModelV2 = BaseModelV1
+ UndefinedV2 = Undefined
+except ImportError:
+ # pydantic v2
+
+ # v2 specific imports
+ from pydantic import BaseModel as BaseModelV2
+ from pydantic_core import PydanticUndefined as UndefinedV2
+ from pydantic_core import to_json
+
+ from pydantic.v1 import ( # v1 compat imports
+ UUID1,
+ UUID3,
+ UUID4,
+ UUID5,
+ AmqpDsn,
+ AnyHttpUrl,
+ AnyUrl,
+ DirectoryPath,
+ FilePath,
+ HttpUrl,
+ KafkaDsn,
+ PostgresDsn,
+ PyObject,
+ RedisDsn,
+ )
+ from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
+ from pydantic.v1.color import Color # type: ignore[assignment]
+ from pydantic.v1.fields import DeferredType, ModelField, Undefined
+
+
+if TYPE_CHECKING:
+ from random import Random
+ from typing import Callable, Sequence
+
+ from typing_extensions import NotRequired, TypeGuard
+
+T = TypeVar("T", bound="BaseModelV1 | BaseModelV2")
+
+_IS_PYDANTIC_V1 = VERSION.startswith("1")
+
+
+class PydanticConstraints(Constraints):
+ """Metadata regarding a Pydantic type constraints, if any"""
+
+ json: NotRequired[bool]
+
+
+class PydanticFieldMeta(FieldMeta):
+ """Field meta subclass capable of handling pydantic ModelFields"""
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ annotation: type,
+ random: Random | None = None,
+ default: Any = ...,
+ children: list[FieldMeta] | None = None,
+ constraints: PydanticConstraints | None = None,
+ ) -> None:
+ super().__init__(
+ name=name,
+ annotation=annotation,
+ random=random,
+ default=default,
+ children=children,
+ constraints=constraints,
+ )
+
+ @classmethod
+ def from_field_info(
+ cls,
+ field_name: str,
+ field_info: FieldInfo,
+ use_alias: bool,
+ random: Random | None,
+ randomize_collection_length: bool | None = None,
+ min_collection_length: int | None = None,
+ max_collection_length: int | None = None,
+ ) -> PydanticFieldMeta:
+ """Create an instance from a pydantic field info.
+
+ :param field_name: The name of the field.
+ :param field_info: A pydantic FieldInfo instance.
+ :param use_alias: Whether to use the field alias.
+ :param random: A random.Random instance.
+ :param randomize_collection_length: Whether to randomize collection length.
+ :param min_collection_length: Minimum collection length.
+ :param max_collection_length: Maximum collection length.
+
+ :returns: A PydanticFieldMeta instance.
+ """
+ check_for_deprecated_parameters(
+ "2.11.0",
+ parameters=(
+ ("randomize_collection_length", randomize_collection_length),
+ ("min_collection_length", min_collection_length),
+ ("max_collection_length", max_collection_length),
+ ),
+ )
+ if callable(field_info.default_factory):
+ default_value = field_info.default_factory()
+ else:
+ default_value = field_info.default if field_info.default is not UndefinedV2 else Null
+
+ annotation = unwrap_new_type(field_info.annotation)
+ children: list[FieldMeta,] | None = None
+ name = field_info.alias if field_info.alias and use_alias else field_name
+
+ constraints: PydanticConstraints
+ # pydantic v2 does not always propagate metadata for Union types
+ if is_union(annotation):
+ constraints = {}
+ children = []
+ for arg in get_args(annotation):
+ if arg is NoneType:
+ continue
+ child_field_info = FieldInfo.from_annotation(arg)
+ merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)
+ children.append(
+ cls.from_field_info(
+ field_name="",
+ field_info=merged_field_info,
+ use_alias=use_alias,
+ random=random,
+ ),
+ )
+ else:
+ metadata, is_json = [], False
+ for m in field_info.metadata:
+ if not is_json and isinstance(m, Json): # type: ignore[misc]
+ is_json = True
+ elif m is not None:
+ metadata.append(m)
+
+ constraints = cast(
+ PydanticConstraints,
+ cls.parse_constraints(metadata=metadata) if metadata else {},
+ )
+
+ if "url" in constraints:
+ # pydantic uses a sentinel value for url constraints
+ annotation = str
+
+ if is_json:
+ constraints["json"] = True
+
+ return PydanticFieldMeta.from_type(
+ annotation=annotation,
+ children=children,
+ constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None,
+ default=default_value,
+ name=name,
+ random=random or DEFAULT_RANDOM,
+ )
+
+ @classmethod
+ def from_model_field( # pragma: no cover
+ cls,
+ model_field: ModelField, # pyright: ignore[reportGeneralTypeIssues]
+ use_alias: bool,
+ randomize_collection_length: bool | None = None,
+ min_collection_length: int | None = None,
+ max_collection_length: int | None = None,
+ random: Random = DEFAULT_RANDOM,
+ ) -> PydanticFieldMeta:
+ """Create an instance from a pydantic model field.
+ :param model_field: A pydantic ModelField.
+ :param use_alias: Whether to use the field alias.
+ :param randomize_collection_length: A boolean flag whether to randomize collections lengths
+ :param min_collection_length: Minimum number of elements in randomized collection
+ :param max_collection_length: Maximum number of elements in randomized collection
+ :param random: An instance of random.Random.
+
+ :returns: A PydanticFieldMeta instance.
+
+ """
+ check_for_deprecated_parameters(
+ "2.11.0",
+ parameters=(
+ ("randomize_collection_length", randomize_collection_length),
+ ("min_collection_length", min_collection_length),
+ ("max_collection_length", max_collection_length),
+ ),
+ )
+
+ if model_field.default is not Undefined:
+ default_value = model_field.default
+ elif callable(model_field.default_factory):
+ default_value = model_field.default_factory()
+ else:
+ default_value = model_field.default if model_field.default is not Undefined else Null
+
+ name = model_field.alias if model_field.alias and use_alias else model_field.name
+
+ outer_type = unwrap_new_type(model_field.outer_type_)
+ annotation = (
+ model_field.outer_type_
+ if isinstance(model_field.annotation, (DeferredType, ForwardRef))
+ else unwrap_new_type(model_field.annotation)
+ )
+
+ constraints = cast(
+ "Constraints",
+ {
+ "ge": getattr(outer_type, "ge", model_field.field_info.ge),
+ "gt": getattr(outer_type, "gt", model_field.field_info.gt),
+ "le": getattr(outer_type, "le", model_field.field_info.le),
+ "lt": getattr(outer_type, "lt", model_field.field_info.lt),
+ "min_length": (
+ getattr(outer_type, "min_length", model_field.field_info.min_length)
+ or getattr(outer_type, "min_items", model_field.field_info.min_items)
+ ),
+ "max_length": (
+ getattr(outer_type, "max_length", model_field.field_info.max_length)
+ or getattr(outer_type, "max_items", model_field.field_info.max_items)
+ ),
+ "pattern": getattr(outer_type, "regex", model_field.field_info.regex),
+ "unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items),
+ "decimal_places": getattr(outer_type, "decimal_places", None),
+ "max_digits": getattr(outer_type, "max_digits", None),
+ "multiple_of": getattr(outer_type, "multiple_of", None),
+ "upper_case": getattr(outer_type, "to_upper", None),
+ "lower_case": getattr(outer_type, "to_lower", None),
+ "item_type": getattr(outer_type, "item_type", None),
+ },
+ )
+
+ # pydantic v1 has constraints set for these values, but we generate them using faker
+ if unwrap_optional(annotation) in (
+ AnyUrl,
+ HttpUrl,
+ KafkaDsn,
+ PostgresDsn,
+ RedisDsn,
+ AmqpDsn,
+ AnyHttpUrl,
+ ):
+ constraints = {}
+
+ if model_field.field_info.const and (
+ default_value is None or isinstance(default_value, (int, bool, str, bytes))
+ ):
+ annotation = Literal[default_value] # pyright: ignore # noqa: PGH003
+
+ children: list[FieldMeta] = []
+
+ # Refer #412.
+ args = get_args(model_field.annotation)
+ if is_optional(model_field.annotation) and len(args) == 2: # noqa: PLR2004
+ child_annotation = args[0] if args[0] is not NoneType else args[1]
+ children.append(PydanticFieldMeta.from_type(child_annotation))
+ elif model_field.key_field or model_field.sub_fields:
+ fields_to_iterate = (
+ ([model_field.key_field, *model_field.sub_fields])
+ if model_field.key_field is not None
+ else model_field.sub_fields
+ )
+ type_args = tuple(
+ (
+ sub_field.outer_type_
+ if isinstance(sub_field.annotation, DeferredType)
+ else unwrap_new_type(sub_field.annotation)
+ )
+ for sub_field in fields_to_iterate
+ )
+ type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate))
+ if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis:
+ # pydantic removes ellipses from Tuples in sub_fields
+ type_args += (...,)
+ extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1)
+ children.extend(
+ PydanticFieldMeta.from_model_field(
+ model_field=type_arg_to_sub_field[arg],
+ use_alias=use_alias,
+ random=random,
+ )
+ for arg in extended_type_args
+ )
+
+ return PydanticFieldMeta(
+ name=name,
+ random=random or DEFAULT_RANDOM,
+ annotation=annotation,
+ children=children or None,
+ default=default_value,
+ constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
+ )
+
+ if not _IS_PYDANTIC_V1:
+
+ @classmethod
+ def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
+ metadata = []
+ for m in super().get_constraints_metadata(annotation):
+ if isinstance(m, FieldInfo):
+ metadata.extend(m.metadata)
+ else:
+ metadata.append(m)
+
+ return metadata
+
+
+class ModelFactory(Generic[T], BaseFactory[T]):
+ """Base factory for pydantic models"""
+
+ __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {}
+ __is_base_factory__ = True
+
+ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
+ super().__init_subclass__(*args, **kwargs)
+
+ if (
+ getattr(cls, "__model__", None)
+ and _is_pydantic_v1_model(cls.__model__)
+ and hasattr(cls.__model__, "update_forward_refs")
+ ):
+ with suppress(NameError): # pragma: no cover
+ cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+
+ return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)
+
+ @classmethod
+ def get_model_fields(cls) -> list["FieldMeta"]:
+ """Retrieve a list of fields from the factory's model.
+
+
+ :returns: A list of field MetaData instances.
+
+ """
+ if "_fields_metadata" not in cls.__dict__:
+ if _is_pydantic_v1_model(cls.__model__):
+ cls._fields_metadata = [
+ PydanticFieldMeta.from_model_field(
+ field,
+ use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
+ random=cls.__random__,
+ )
+ for field in cls.__model__.__fields__.values()
+ ]
+ else:
+ cls._fields_metadata = [
+ PydanticFieldMeta.from_field_info(
+ field_info=field_info,
+ field_name=field_name,
+ random=cls.__random__,
+ use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues]
+ "populate_by_name",
+ False,
+ ),
+ )
+ for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
+ ]
+ return cls._fields_metadata
+
+ @classmethod
+ def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any:
+ constraints = cast(PydanticConstraints, field_meta.constraints)
+ if constraints.pop("json", None):
+ value = cls.get_field_value(field_meta)
+ return to_json(value) # pyright: ignore[reportUnboundVariable]
+
+ return super().get_constrained_field_value(annotation, field_meta)
+
+ @classmethod
+ def build(
+ cls,
+ factory_use_construct: bool = False,
+ **kwargs: Any,
+ ) -> T:
+ """Build an instance of the factory's __model__
+
+ :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the
+ model. This is supported only for pydantic models.
+ :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
+
+ :returns: An instance of type T.
+
+ """
+ processed_kwargs = cls.process_kwargs(**kwargs)
+
+ if factory_use_construct:
+ if _is_pydantic_v1_model(cls.__model__):
+ return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value]
+ return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value]
+
+ return cls.__model__(**processed_kwargs) # type: ignore[return-value]
+
+ @classmethod
+ def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:
+ """Determine whether the field is a custom root field.
+
+ :param field_meta: FieldMeta instance.
+
+ :returns: A boolean determining whether the field is a custom root.
+
+ """
+ return field_meta.name == "__root__"
+
+ @classmethod
+ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
+ """Determine whether to set a value for a given field_name.
+ This is an override of BaseFactory.should_set_field_value.
+
+ :param field_meta: FieldMeta instance.
+ :param kwargs: Any kwargs passed to the factory.
+
+ :returns: A boolean determining whether a value should be set for the given field_meta.
+
+ """
+ return field_meta.name not in kwargs and (
+ not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta)
+ )
+
+ @classmethod
+ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
+ mapping = {
+ pydantic.ByteSize: cls.__faker__.pyint,
+ pydantic.PositiveInt: cls.__faker__.pyint,
+ pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
+ pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1,
+ pydantic.PositiveFloat: cls.__faker__.pyint,
+ pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0),
+ pydantic.NonNegativeInt: cls.__faker__.pyint,
+ pydantic.StrictInt: cls.__faker__.pyint,
+ pydantic.StrictBool: cls.__faker__.pybool,
+ pydantic.StrictBytes: partial(create_random_bytes, cls.__random__),
+ pydantic.StrictFloat: cls.__faker__.pyfloat,
+ pydantic.StrictStr: cls.__faker__.pystr,
+ pydantic.EmailStr: cls.__faker__.free_email,
+ pydantic.NameEmail: cls.__faker__.free_email,
+ pydantic.Json: cls.__faker__.json,
+ pydantic.PaymentCardNumber: cls.__faker__.credit_card_number,
+ pydantic.AnyUrl: cls.__faker__.url,
+ pydantic.AnyHttpUrl: cls.__faker__.url,
+ pydantic.HttpUrl: cls.__faker__.url,
+ pydantic.SecretBytes: partial(create_random_bytes, cls.__random__),
+ pydantic.SecretStr: cls.__faker__.pystr,
+ pydantic.IPvAnyAddress: cls.__faker__.ipv4,
+ pydantic.IPvAnyInterface: cls.__faker__.ipv4,
+ pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True),
+ pydantic.PastDate: cls.__faker__.past_date,
+ pydantic.FutureDate: cls.__faker__.future_date,
+ }
+
+ # v1 only values
+ mapping.update(
+ {
+ PyObject: lambda: "decimal.Decimal",
+ AmqpDsn: lambda: "amqps://example.com",
+ KafkaDsn: lambda: "kafka://localhost:9092",
+ PostgresDsn: lambda: "postgresql://user:secret@localhost",
+ RedisDsn: lambda: "redis://localhost:6379/0",
+ FilePath: lambda: Path(realpath(__file__)),
+ DirectoryPath: lambda: Path(realpath(__file__)).parent,
+ UUID1: uuid1,
+ UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
+ UUID4: cls.__faker__.uuid4,
+ UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
+ Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
+ },
+ )
+
+ if not _IS_PYDANTIC_V1:
+ mapping.update(
+ {
+ # pydantic v2 specific types
+ pydantic.PastDatetime: cls.__faker__.past_datetime,
+ pydantic.FutureDatetime: cls.__faker__.future_datetime,
+ pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc),
+ pydantic.NaiveDatetime: cls.__faker__.date_time,
+ },
+ )
+
+ mapping.update(super().get_provider_map())
+ return mapping
+
+
+def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:
+ return is_safe_subclass(model, BaseModelV1)
+
+
+def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]:
+ return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py
new file mode 100644
index 0000000..ad8873f
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py
@@ -0,0 +1,186 @@
+from __future__ import annotations
+
+from datetime import date, datetime
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union
+
+from polyfactory.exceptions import MissingDependencyException
+from polyfactory.factories.base import BaseFactory
+from polyfactory.field_meta import FieldMeta
+from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol
+
+try:
+ from sqlalchemy import Column, inspect, types
+ from sqlalchemy.dialects import mysql, postgresql
+ from sqlalchemy.exc import NoInspectionAvailable
+ from sqlalchemy.orm import InstanceState, Mapper
+except ImportError as e:
+ msg = "sqlalchemy is not installed"
+ raise MissingDependencyException(msg) from e
+
+if TYPE_CHECKING:
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.orm import Session
+ from typing_extensions import TypeGuard
+
+
+T = TypeVar("T")
+
+
+class SQLASyncPersistence(SyncPersistenceProtocol[T]):
+ def __init__(self, session: Session) -> None:
+ """Sync persistence handler for SQLAFactory."""
+ self.session = session
+
+ def save(self, data: T) -> T:
+ self.session.add(data)
+ self.session.commit()
+ return data
+
+ def save_many(self, data: list[T]) -> list[T]:
+ self.session.add_all(data)
+ self.session.commit()
+ return data
+
+
+class SQLAASyncPersistence(AsyncPersistenceProtocol[T]):
+ def __init__(self, session: AsyncSession) -> None:
+ """Async persistence handler for SQLAFactory."""
+ self.session = session
+
+ async def save(self, data: T) -> T:
+ self.session.add(data)
+ await self.session.commit()
+ return data
+
+ async def save_many(self, data: list[T]) -> list[T]:
+ self.session.add_all(data)
+ await self.session.commit()
+ return data
+
+
+class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
+ """Base factory for SQLAlchemy models."""
+
+ __is_base_factory__ = True
+
+ __set_primary_key__: ClassVar[bool] = True
+ """Configuration to consider primary key columns as a field or not."""
+ __set_foreign_keys__: ClassVar[bool] = True
+ """Configuration to consider columns with foreign keys as a field or not."""
+ __set_relationships__: ClassVar[bool] = False
+ """Configuration to consider relationships property as a model field or not."""
+
+ __session__: ClassVar[Session | Callable[[], Session] | None] = None
+ __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None
+
+ __config_keys__ = (
+ *BaseFactory.__config_keys__,
+ "__set_primary_key__",
+ "__set_foreign_keys__",
+ "__set_relationships__",
+ )
+
+ @classmethod
+ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
+ """Get mapping of types where column type."""
+ return {
+ types.TupleType: cls.__faker__.pytuple,
+ mysql.YEAR: lambda: cls.__random__.randint(1901, 2155),
+ postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False),
+ postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011
+ postgresql.INET: lambda: cls.__faker__.ipv4(network=True),
+ postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
+ postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
+ postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True),
+ postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
+ postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
+ postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
+ }
+
+ @classmethod
+ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
+ providers_map = super().get_provider_map()
+ providers_map.update(cls.get_sqlalchemy_types())
+ return providers_map
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
+ try:
+ inspected = inspect(value)
+ except NoInspectionAvailable:
+ return False
+ return isinstance(inspected, (Mapper, InstanceState))
+
+ @classmethod
+ def should_column_be_set(cls, column: Column) -> bool:
+ if not cls.__set_primary_key__ and column.primary_key:
+ return False
+
+ return bool(cls.__set_foreign_keys__ or not column.foreign_keys)
+
+ @classmethod
+ def get_type_from_column(cls, column: Column) -> type:
+ column_type = type(column.type)
+ if column_type in cls.get_sqlalchemy_types():
+ annotation = column_type
+ elif issubclass(column_type, types.ARRAY):
+ annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
+ else:
+ annotation = (
+ column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues]
+ if hasattr(column.type, "impl")
+ else column.type.python_type
+ )
+
+ if column.nullable:
+ annotation = Union[annotation, None] # type: ignore[assignment]
+
+ return annotation
+
+ @classmethod
+ def get_model_fields(cls) -> list[FieldMeta]:
+ fields_meta: list[FieldMeta] = []
+
+ table: Mapper = inspect(cls.__model__) # type: ignore[assignment]
+ fields_meta.extend(
+ FieldMeta.from_type(
+ annotation=cls.get_type_from_column(column),
+ name=name,
+ random=cls.__random__,
+ )
+ for name, column in table.columns.items()
+ if cls.should_column_be_set(column)
+ )
+ if cls.__set_relationships__:
+ for name, relationship in table.relationships.items():
+ class_ = relationship.entity.class_
+ annotation = class_ if not relationship.uselist else List[class_] # type: ignore[valid-type]
+ fields_meta.append(
+ FieldMeta.from_type(
+ name=name,
+ annotation=annotation,
+ random=cls.__random__,
+ ),
+ )
+
+ return fields_meta
+
+ @classmethod
+ def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]:
+ if cls.__session__ is not None:
+ return (
+ SQLASyncPersistence(cls.__session__())
+ if callable(cls.__session__)
+ else SQLASyncPersistence(cls.__session__)
+ )
+ return super()._get_sync_persistence()
+
+ @classmethod
+ def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
+ if cls.__async_session__ is not None:
+ return (
+ SQLAASyncPersistence(cls.__async_session__())
+ if callable(cls.__async_session__)
+ else SQLAASyncPersistence(cls.__async_session__)
+ )
+ return super()._get_async_persistence()
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py
new file mode 100644
index 0000000..2a3ea1b
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+from typing import Any, Generic, TypeVar, get_args
+
+from typing_extensions import ( # type: ignore[attr-defined]
+ NotRequired,
+ Required,
+ TypeGuard,
+ _TypedDictMeta, # pyright: ignore[reportGeneralTypeIssues]
+ get_origin,
+ get_type_hints,
+ is_typeddict,
+)
+
+from polyfactory.constants import DEFAULT_RANDOM
+from polyfactory.factories.base import BaseFactory
+from polyfactory.field_meta import FieldMeta, Null
+
+TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta)
+
+
+class TypedDictFactory(Generic[TypedDictT], BaseFactory[TypedDictT]):
+ """TypedDict base factory"""
+
+ __is_base_factory__ = True
+
+ @classmethod
+ def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]:
+ """Determine whether the given value is supported by the factory.
+
+ :param value: An arbitrary value.
+ :returns: A typeguard
+ """
+ return is_typeddict(value)
+
+ @classmethod
+ def get_model_fields(cls) -> list["FieldMeta"]:
+ """Retrieve a list of fields from the factory's model.
+
+
+ :returns: A list of field MetaData instances.
+
+ """
+ model_type_hints = get_type_hints(cls.__model__, include_extras=True)
+
+ field_metas: list[FieldMeta] = []
+ for field_name, annotation in model_type_hints.items():
+ origin = get_origin(annotation)
+ if origin in (Required, NotRequired):
+ annotation = get_args(annotation)[0] # noqa: PLW2901
+
+ field_metas.append(
+ FieldMeta.from_type(
+ annotation=annotation,
+ random=DEFAULT_RANDOM,
+ name=field_name,
+ default=getattr(cls.__model__, field_name, Null),
+ ),
+ )
+
+ return field_metas