diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:17:55 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:17:55 -0400 |
commit | 12cf076118570eebbff08c6b3090e0d4798447a1 (patch) | |
tree | 3ba25e17e3c3a5e82316558ba3864b955919ff72 /venv/lib/python3.11/site-packages/polyfactory/factories | |
parent | c45662ff3923b34614ddcc8feb9195541166dcc5 (diff) |
no venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/polyfactory/factories')
20 files changed, 0 insertions, 2292 deletions
diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py b/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py deleted file mode 100644 index c8a9b92..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from polyfactory.factories.base import BaseFactory -from polyfactory.factories.dataclass_factory import DataclassFactory -from polyfactory.factories.typed_dict_factory import TypedDictFactory - -__all__ = ("BaseFactory", "TypedDictFactory", "DataclassFactory") diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc Binary files differdeleted file mode 100644 index 0ebad18..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/__init__.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index b5ca945..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/attrs_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc Binary files differdeleted file mode 100644 index 2ee4cdb..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/base.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index 0fc27ab..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/beanie_odm_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index 6e59c83..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/dataclass_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index 5e528d3..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/msgspec_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index 6f45693..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/odmantic_odm_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index ca70ca8..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/pydantic_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index 31d93d5..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/sqlalchemy_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc b/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc Binary files differdeleted file mode 100644 index a1ec583..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/__pycache__/typed_dict_factory.cpython-311.pyc +++ /dev/null diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py deleted file mode 100644 index 00ffa03..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/attrs_factory.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -from inspect import isclass -from typing import TYPE_CHECKING, Generic, TypeVar - -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory -from polyfactory.field_meta import FieldMeta, Null - -if TYPE_CHECKING: - from typing import Any, TypeGuard - - -try: - import attrs - from attr._make import Factory - from attrs import AttrsInstance -except ImportError as ex: - msg = "attrs is not installed" - raise MissingDependencyException(msg) from ex - - -T = TypeVar("T", bound=AttrsInstance) - - -class AttrsFactory(Generic[T], BaseFactory[T]): - """Base factory for attrs classes.""" - - __model__: type[T] - - __is_base_factory__ = True - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - return isclass(value) and hasattr(value, "__attrs_attrs__") - - @classmethod - def get_model_fields(cls) -> list[FieldMeta]: - field_metas: list[FieldMeta] = [] - none_type = type(None) - - cls.resolve_types(cls.__model__) - fields = attrs.fields(cls.__model__) - - for field in fields: - if not field.init: - continue - - annotation = none_type if field.type is None else field.type - - default = field.default - if isinstance(default, Factory): - # The default value is not currently being used when generating - # the field values. When that is implemented, this would need - # to be handled differently since the `default.factory` could - # take a `self` argument. - default_value = default.factory - elif default is None: - default_value = Null - else: - default_value = default - - field_metas.append( - FieldMeta.from_type( - annotation=annotation, - name=field.alias, - default=default_value, - random=cls.__random__, - ), - ) - - return field_metas - - @classmethod - def resolve_types(cls, model: type[T], **kwargs: Any) -> None: - """Resolve any strings and forward annotations in type annotations. - - :param model: The model to resolve the type annotations for. - :param kwargs: Any parameters that need to be passed to `attrs.resolve_types`. - """ - - attrs.resolve_types(model, **kwargs) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py b/venv/lib/python3.11/site-packages/polyfactory/factories/base.py deleted file mode 100644 index 60fe7a7..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/base.py +++ /dev/null @@ -1,1127 +0,0 @@ -from __future__ import annotations - -import copy -from abc import ABC, abstractmethod -from collections import Counter, abc, deque -from contextlib import suppress -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import EnumMeta -from functools import partial -from importlib import import_module -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, - ip_address, - ip_interface, - ip_network, -) -from os.path import realpath -from pathlib import Path -from random import Random -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Collection, - Generic, - Iterable, - Mapping, - Sequence, - Type, - TypedDict, - TypeVar, - cast, -) -from uuid import UUID - -from faker import Faker -from typing_extensions import get_args, get_origin, get_original_bases - -from polyfactory.constants import ( - DEFAULT_RANDOM, - MAX_COLLECTION_LENGTH, - MIN_COLLECTION_LENGTH, - RANDOMIZE_COLLECTION_LENGTH, -) -from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException -from polyfactory.field_meta import Null -from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use -from polyfactory.utils.helpers import ( - flatten_annotation, - get_collection_type, - unwrap_annotation, - unwrap_args, - unwrap_optional, -) -from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage -from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union -from polyfactory.utils.types import NoneType -from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage -from polyfactory.value_generators.constrained_collections import ( - handle_constrained_collection, - handle_constrained_mapping, -) -from polyfactory.value_generators.constrained_dates import handle_constrained_date -from polyfactory.value_generators.constrained_numbers import ( - handle_constrained_decimal, - handle_constrained_float, - handle_constrained_int, -) -from polyfactory.value_generators.constrained_path import handle_constrained_path -from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes -from polyfactory.value_generators.constrained_url import handle_constrained_url -from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid -from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string - -if TYPE_CHECKING: - from typing_extensions import TypeGuard - - from polyfactory.field_meta import Constraints, FieldMeta - from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol - - -T = TypeVar("T") -F = TypeVar("F", bound="BaseFactory[Any]") - - -class BuildContext(TypedDict): - seen_models: set[type] - - -def _get_build_context(build_context: BuildContext | None) -> BuildContext: - if build_context is None: - return {"seen_models": set()} - - return copy.deepcopy(build_context) - - -class BaseFactory(ABC, Generic[T]): - """Base Factory class - this class holds the main logic of the library""" - - # configuration attributes - __model__: type[T] - """ - The model for the factory. - This attribute is required for non-base factories and an exception will be raised if it's not set. Can be automatically inferred from the factory generic argument. - """ - __check_model__: bool = False - """ - Flag dictating whether to check if fields defined on the factory exists on the model or not. - If 'True', checks will be done against Use, PostGenerated, Ignore, Require constructs fields only. - """ - __allow_none_optionals__: ClassVar[bool] = True - """ - Flag dictating whether to allow 'None' for optional values. - If 'True', 'None' will be randomly generated as a value for optional model fields - """ - __sync_persistence__: type[SyncPersistenceProtocol[T]] | SyncPersistenceProtocol[T] | None = None - """A sync persistence handler. Can be a class or a class instance.""" - __async_persistence__: type[AsyncPersistenceProtocol[T]] | AsyncPersistenceProtocol[T] | None = None - """An async persistence handler. Can be a class or a class instance.""" - __set_as_default_factory_for_type__ = False - """ - Flag dictating whether to set as the default factory for the given type. - If 'True' the factory will be used instead of dynamically generating a factory for the type. - """ - __is_base_factory__: bool = False - """ - Flag dictating whether the factory is a 'base' factory. Base factories are registered globally as handlers for types. - For example, the 'DataclassFactory', 'TypedDictFactory' and 'ModelFactory' are all base factories. - """ - __base_factory_overrides__: dict[Any, type[BaseFactory[Any]]] | None = None - """ - A base factory to override with this factory. If this value is set, the given factory will replace the given base factory. - - Note: this value can only be set when '__is_base_factory__' is 'True'. - """ - __faker__: ClassVar["Faker"] = Faker() - """ - A faker instance to use. Can be a user provided value. - """ - __random__: ClassVar["Random"] = DEFAULT_RANDOM - """ - An instance of 'random.Random' to use. - """ - __random_seed__: ClassVar[int] - """ - An integer to seed the factory's Faker and Random instances with. - This attribute can be used to control random generation. - """ - __randomize_collection_length__: ClassVar[bool] = RANDOMIZE_COLLECTION_LENGTH - """ - Flag dictating whether to randomize collections lengths. - """ - __min_collection_length__: ClassVar[int] = MIN_COLLECTION_LENGTH - """ - An integer value that defines minimum length of a collection. - """ - __max_collection_length__: ClassVar[int] = MAX_COLLECTION_LENGTH - """ - An integer value that defines maximum length of a collection. - """ - __use_defaults__: ClassVar[bool] = False - """ - Flag indicating whether to use the default value on a specific field, if provided. - """ - - __config_keys__: tuple[str, ...] = ( - "__check_model__", - "__allow_none_optionals__", - "__set_as_default_factory_for_type__", - "__faker__", - "__random__", - "__randomize_collection_length__", - "__min_collection_length__", - "__max_collection_length__", - "__use_defaults__", - ) - """Keys to be considered as config values to pass on to dynamically created factories.""" - - # cached attributes - _fields_metadata: list[FieldMeta] - # BaseFactory only attributes - _factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]] - _base_factories: ClassVar[list[type[BaseFactory[Any]]]] - - # Non-public attributes - _extra_providers: dict[Any, Callable[[], Any]] | None = None - - def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901 - super().__init_subclass__(*args, **kwargs) - - if not hasattr(BaseFactory, "_base_factories"): - BaseFactory._base_factories = [] - - if not hasattr(BaseFactory, "_factory_type_mapping"): - BaseFactory._factory_type_mapping = {} - - if cls.__min_collection_length__ > cls.__max_collection_length__: - msg = "Minimum collection length shouldn't be greater than maximum collection length" - raise ConfigurationException( - msg, - ) - - if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__: - model: type[T] | None = getattr(cls, "__model__", None) or cls._infer_model_type() - if not model: - msg = f"required configuration attribute '__model__' is not set on {cls.__name__}" - raise ConfigurationException( - msg, - ) - cls.__model__ = model - if not cls.is_supported_type(model): - for factory in BaseFactory._base_factories: - if factory.is_supported_type(model): - msg = f"{cls.__name__} does not support {model.__name__}, but this type is supported by the {factory.__name__} base factory class. To resolve this error, subclass the factory from {factory.__name__} instead of {cls.__name__}" - raise ConfigurationException( - msg, - ) - msg = f"Model type {model.__name__} is not supported. To support it, register an appropriate base factory and subclass it for your factory." - raise ConfigurationException( - msg, - ) - if cls.__check_model__: - cls._check_declared_fields_exist_in_model() - else: - BaseFactory._base_factories.append(cls) - - random_seed = getattr(cls, "__random_seed__", None) - if random_seed is not None: - cls.seed_random(random_seed) - - if cls.__set_as_default_factory_for_type__ and hasattr(cls, "__model__"): - BaseFactory._factory_type_mapping[cls.__model__] = cls - - @classmethod - def _infer_model_type(cls: type[F]) -> type[T] | None: - """Return model type inferred from class declaration. - class Foo(ModelFactory[MyModel]): # <<< MyModel - ... - - If more than one base class and/or generic arguments specified return None. - - :returns: Inferred model type or None - """ - - factory_bases: Iterable[type[BaseFactory[T]]] = ( - b for b in get_original_bases(cls) if get_origin(b) and issubclass(get_origin(b), BaseFactory) - ) - generic_args: Sequence[type[T]] = [ - arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar) - ] - if len(generic_args) != 1: - return None - - return generic_args[0] - - @classmethod - def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: - """Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. - - :raises: ConfigurationException - :returns: SyncPersistenceHandler - """ - if cls.__sync_persistence__: - return cls.__sync_persistence__() if callable(cls.__sync_persistence__) else cls.__sync_persistence__ - msg = "A '__sync_persistence__' handler must be defined in the factory to use this method" - raise ConfigurationException( - msg, - ) - - @classmethod - def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: - """Return a AsyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException. - - :raises: ConfigurationException - :returns: AsyncPersistenceHandler - """ - if cls.__async_persistence__: - return cls.__async_persistence__() if callable(cls.__async_persistence__) else cls.__async_persistence__ - msg = "An '__async_persistence__' handler must be defined in the factory to use this method" - raise ConfigurationException( - msg, - ) - - @classmethod - def _handle_factory_field( - cls, - field_value: Any, - build_context: BuildContext, - field_build_parameters: Any | None = None, - ) -> Any: - """Handle a value defined on the factory class itself. - - :param field_value: A value defined as an attribute on the factory class. - :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - - :returns: An arbitrary value correlating with the given field_meta value. - """ - if is_safe_subclass(field_value, BaseFactory): - if isinstance(field_build_parameters, Mapping): - return field_value.build(_build_context=build_context, **field_build_parameters) - - if isinstance(field_build_parameters, Sequence): - return [ - field_value.build(_build_context=build_context, **parameter) for parameter in field_build_parameters - ] - - return field_value.build(_build_context=build_context) - - if isinstance(field_value, Use): - return field_value.to_value() - - if isinstance(field_value, Fixture): - return field_value.to_value() - - return field_value() if callable(field_value) else field_value - - @classmethod - def _handle_factory_field_coverage( - cls, - field_value: Any, - field_build_parameters: Any | None = None, - build_context: BuildContext | None = None, - ) -> Any: - """Handle a value defined on the factory class itself. - - :param field_value: A value defined as an attribute on the factory class. - :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - - :returns: An arbitrary value correlating with the given field_meta value. - """ - if is_safe_subclass(field_value, BaseFactory): - if isinstance(field_build_parameters, Mapping): - return CoverageContainer(field_value.coverage(_build_context=build_context, **field_build_parameters)) - - if isinstance(field_build_parameters, Sequence): - return [ - CoverageContainer(field_value.coverage(_build_context=build_context, **parameter)) - for parameter in field_build_parameters - ] - - return CoverageContainer(field_value.coverage()) - - if isinstance(field_value, Use): - return field_value.to_value() - - if isinstance(field_value, Fixture): - return CoverageContainerCallable(field_value.to_value) - - return CoverageContainerCallable(field_value) if callable(field_value) else field_value - - @classmethod - def _get_config(cls) -> dict[str, Any]: - return { - **{key: getattr(cls, key) for key in cls.__config_keys__}, - "_extra_providers": cls.get_provider_map(), - } - - @classmethod - def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: - """Get a factory from registered factories or generate a factory dynamically. - - :param model: A model type. - :returns: A Factory sub-class. - - """ - if factory := BaseFactory._factory_type_mapping.get(model): - return factory - - config = cls._get_config() - - if cls.__base_factory_overrides__: - for model_ancestor in model.mro(): - if factory := cls.__base_factory_overrides__.get(model_ancestor): - return factory.create_factory(model, **config) - - for factory in reversed(BaseFactory._base_factories): - if factory.is_supported_type(model): - return factory.create_factory(model, **config) - - msg = f"unsupported model type {model.__name__}" - raise ParameterException(msg) # pragma: no cover - - # Public Methods - - @classmethod - def is_factory_type(cls, annotation: Any) -> bool: - """Determine whether a given field is annotated with a type that is supported by a base factory. - - :param annotation: A type annotation. - :returns: Boolean dictating whether the annotation is a factory type - """ - return any(factory.is_supported_type(annotation) for factory in BaseFactory._base_factories) - - @classmethod - def is_batch_factory_type(cls, annotation: Any) -> bool: - """Determine whether a given field is annotated with a sequence of supported factory types. - - :param annotation: A type annotation. - :returns: Boolean dictating whether the annotation is a batch factory type - """ - origin = get_type_origin(annotation) or annotation - if is_safe_subclass(origin, Sequence) and (args := unwrap_args(annotation, random=cls.__random__)): - return len(args) == 1 and BaseFactory.is_factory_type(annotation=args[0]) - return False - - @classmethod - def extract_field_build_parameters(cls, field_meta: FieldMeta, build_args: dict[str, Any]) -> Any: - """Extract from the build kwargs any build parameters passed for a given field meta - if it is a factory type. - - :param field_meta: A field meta instance. - :param build_args: Any kwargs passed to the factory. - :returns: Any values - """ - if build_arg := build_args.get(field_meta.name): - annotation = unwrap_optional(field_meta.annotation) - if ( - BaseFactory.is_factory_type(annotation=annotation) - and isinstance(build_arg, Mapping) - and not BaseFactory.is_factory_type(annotation=type(build_arg)) - ): - return build_args.pop(field_meta.name) - - if ( - BaseFactory.is_batch_factory_type(annotation=annotation) - and isinstance(build_arg, Sequence) - and not any(BaseFactory.is_factory_type(annotation=type(value)) for value in build_arg) - ): - return build_args.pop(field_meta.name) - return None - - @classmethod - @abstractmethod - def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": # pragma: no cover - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - raise NotImplementedError - - @classmethod - def seed_random(cls, seed: int) -> None: - """Seed faker and random with the given integer. - - :param seed: An integer to set as seed. - :returns: 'None' - - """ - cls.__random__ = Random(seed) - cls.__faker__.seed_instance(seed) - - @classmethod - def is_ignored_type(cls, value: Any) -> bool: - """Check whether a given value is an ignored type. - - :param value: An arbitrary value. - - :notes: - - This method is meant to be overwritten by extension factories and other subclasses - - :returns: A boolean determining whether the value should be ignored. - - """ - return value is None - - @classmethod - def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: - """Map types to callables. - - :notes: - - This method is distinct to allow overriding. - - - :returns: a dictionary mapping types to callables. - - """ - - def _create_generic_fn() -> Callable: - """Return a generic lambda""" - return lambda *args: None - - return { - Any: lambda: None, - # primitives - object: object, - float: cls.__faker__.pyfloat, - int: cls.__faker__.pyint, - bool: cls.__faker__.pybool, - str: cls.__faker__.pystr, - bytes: partial(create_random_bytes, cls.__random__), - # built-in objects - dict: cls.__faker__.pydict, - tuple: cls.__faker__.pytuple, - list: cls.__faker__.pylist, - set: cls.__faker__.pyset, - frozenset: lambda: frozenset(cls.__faker__.pylist()), - deque: lambda: deque(cls.__faker__.pylist()), - # standard library objects - Path: lambda: Path(realpath(__file__)), - Decimal: cls.__faker__.pydecimal, - UUID: lambda: UUID(cls.__faker__.uuid4()), - # datetime - datetime: cls.__faker__.date_time_between, - date: cls.__faker__.date_this_decade, - time: cls.__faker__.time_object, - timedelta: cls.__faker__.time_delta, - # ip addresses - IPv4Address: lambda: ip_address(cls.__faker__.ipv4()), - IPv4Interface: lambda: ip_interface(cls.__faker__.ipv4()), - IPv4Network: lambda: ip_network(cls.__faker__.ipv4(network=True)), - IPv6Address: lambda: ip_address(cls.__faker__.ipv6()), - IPv6Interface: lambda: ip_interface(cls.__faker__.ipv6()), - IPv6Network: lambda: ip_network(cls.__faker__.ipv6(network=True)), - # types - Callable: _create_generic_fn, - abc.Callable: _create_generic_fn, - Counter: lambda: Counter(cls.__faker__.pystr()), - **(cls._extra_providers or {}), - } - - @classmethod - def create_factory( - cls: type[F], - model: type[T] | None = None, - bases: tuple[type[BaseFactory[Any]], ...] | None = None, - **kwargs: Any, - ) -> type[F]: - """Generate a factory for the given type dynamically. - - :param model: A type to model. Defaults to current factory __model__ if any. - Otherwise, raise an error - :param bases: Base classes to use when generating the new class. - :param kwargs: Any kwargs. - - :returns: A 'ModelFactory' subclass. - - """ - if model is None: - try: - model = cls.__model__ - except AttributeError as ex: - msg = "A 'model' argument is required when creating a new factory from a base one" - raise TypeError(msg) from ex - return cast( - "Type[F]", - type( - f"{model.__name__}Factory", # pyright: ignore[reportOptionalMemberAccess] - (*(bases or ()), cls), - {"__model__": model, **kwargs}, - ), - ) - - @classmethod - def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: # noqa: C901, PLR0911, PLR0912 - try: - constraints = cast("Constraints", field_meta.constraints) - if is_safe_subclass(annotation, float): - return handle_constrained_float( - random=cls.__random__, - multiple_of=cast("Any", constraints.get("multiple_of")), - gt=cast("Any", constraints.get("gt")), - ge=cast("Any", constraints.get("ge")), - lt=cast("Any", constraints.get("lt")), - le=cast("Any", constraints.get("le")), - ) - - if is_safe_subclass(annotation, int): - return handle_constrained_int( - random=cls.__random__, - multiple_of=cast("Any", constraints.get("multiple_of")), - gt=cast("Any", constraints.get("gt")), - ge=cast("Any", constraints.get("ge")), - lt=cast("Any", constraints.get("lt")), - le=cast("Any", constraints.get("le")), - ) - - if is_safe_subclass(annotation, Decimal): - return handle_constrained_decimal( - random=cls.__random__, - decimal_places=cast("Any", constraints.get("decimal_places")), - max_digits=cast("Any", constraints.get("max_digits")), - multiple_of=cast("Any", constraints.get("multiple_of")), - gt=cast("Any", constraints.get("gt")), - ge=cast("Any", constraints.get("ge")), - lt=cast("Any", constraints.get("lt")), - le=cast("Any", constraints.get("le")), - ) - - if url_constraints := constraints.get("url"): - return handle_constrained_url(constraints=url_constraints) - - if is_safe_subclass(annotation, str) or is_safe_subclass(annotation, bytes): - return handle_constrained_string_or_bytes( - random=cls.__random__, - t_type=str if is_safe_subclass(annotation, str) else bytes, - lower_case=constraints.get("lower_case") or False, - upper_case=constraints.get("upper_case") or False, - min_length=constraints.get("min_length"), - max_length=constraints.get("max_length"), - pattern=constraints.get("pattern"), - ) - - try: - collection_type = get_collection_type(annotation) - except ValueError: - collection_type = None - if collection_type is not None: - if collection_type == dict: - return handle_constrained_mapping( - factory=cls, - field_meta=field_meta, - min_items=constraints.get("min_length"), - max_items=constraints.get("max_length"), - ) - return handle_constrained_collection( - collection_type=collection_type, # type: ignore[type-var] - factory=cls, - field_meta=field_meta.children[0] if field_meta.children else field_meta, - item_type=constraints.get("item_type"), - max_items=constraints.get("max_length"), - min_items=constraints.get("min_length"), - unique_items=constraints.get("unique_items", False), - ) - - if is_safe_subclass(annotation, date): - return handle_constrained_date( - faker=cls.__faker__, - ge=cast("Any", constraints.get("ge")), - gt=cast("Any", constraints.get("gt")), - le=cast("Any", constraints.get("le")), - lt=cast("Any", constraints.get("lt")), - tz=cast("Any", constraints.get("tz")), - ) - - if is_safe_subclass(annotation, UUID) and (uuid_version := constraints.get("uuid_version")): - return handle_constrained_uuid( - uuid_version=uuid_version, - faker=cls.__faker__, - ) - - if is_safe_subclass(annotation, Path) and (path_constraint := constraints.get("path_type")): - return handle_constrained_path(constraint=path_constraint, faker=cls.__faker__) - except TypeError as e: - raise ParameterException from e - - msg = f"received constraints for unsupported type {annotation}" - raise ParameterException(msg) - - @classmethod - def get_field_value( # noqa: C901, PLR0911, PLR0912 - cls, - field_meta: FieldMeta, - field_build_parameters: Any | None = None, - build_context: BuildContext | None = None, - ) -> Any: - """Return a field value on the subclass if existing, otherwise returns a mock value. - - :param field_meta: FieldMeta instance. - :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - :param build_context: BuildContext data for current build. - - :returns: An arbitrary value. - - """ - build_context = _get_build_context(build_context) - if cls.is_ignored_type(field_meta.annotation): - return None - - if field_build_parameters is None and cls.should_set_none_value(field_meta=field_meta): - return None - - unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=cls.__random__) - - if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): - return cls.__random__.choice(literal_args) - - if isinstance(unwrapped_annotation, EnumMeta): - return cls.__random__.choice(list(unwrapped_annotation)) - - if field_meta.constraints: - return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta) - - if is_union(field_meta.annotation) and field_meta.children: - seen_models = build_context["seen_models"] - children = [child for child in field_meta.children if child.annotation not in seen_models] - - # `None` is removed from the children when creating FieldMeta so when `children` - # is empty, it must mean that the field meta is an optional type. - if children: - return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context) - - if BaseFactory.is_factory_type(annotation=unwrapped_annotation): - if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]: - return None if is_optional(field_meta.annotation) else Null - - return cls._get_or_create_factory(model=unwrapped_annotation).build( - _build_context=build_context, - **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), - ) - - if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation): - factory = cls._get_or_create_factory(model=field_meta.type_args[0]) - if isinstance(field_build_parameters, Sequence): - return [ - factory.build(_build_context=build_context, **field_parameters) - for field_parameters in field_build_parameters - ] - - if field_meta.type_args[0] in build_context["seen_models"]: - return [] - - if not cls.__randomize_collection_length__: - return [factory.build(_build_context=build_context)] - - batch_size = cls.__random__.randint(cls.__min_collection_length__, cls.__max_collection_length__) - return factory.batch(size=batch_size, _build_context=build_context) - - if (origin := get_type_origin(unwrapped_annotation)) and is_safe_subclass(origin, Collection): - if cls.__randomize_collection_length__: - collection_type = get_collection_type(unwrapped_annotation) - if collection_type != dict: - return handle_constrained_collection( - collection_type=collection_type, # type: ignore[type-var] - factory=cls, - item_type=Any, - field_meta=field_meta.children[0] if field_meta.children else field_meta, - min_items=cls.__min_collection_length__, - max_items=cls.__max_collection_length__, - ) - return handle_constrained_mapping( - factory=cls, - field_meta=field_meta, - min_items=cls.__min_collection_length__, - max_items=cls.__max_collection_length__, - ) - - return handle_collection_type(field_meta, origin, cls) - - if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): - return create_random_string(cls.__random__, min_length=1, max_length=10) - - if provider := cls.get_provider_map().get(unwrapped_annotation): - return provider() - - if callable(unwrapped_annotation): - # if value is a callable we can try to naively call it. - # this will work for callables that do not require any parameters passed - with suppress(Exception): - return unwrapped_annotation() - - msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." - raise ParameterException( - msg, - ) - - @classmethod - def get_field_value_coverage( # noqa: C901 - cls, - field_meta: FieldMeta, - field_build_parameters: Any | None = None, - build_context: BuildContext | None = None, - ) -> Iterable[Any]: - """Return a field value on the subclass if existing, otherwise returns a mock value. - - :param field_meta: FieldMeta instance. - :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - :param build_context: BuildContext data for current build. - - :returns: An iterable of values. - - """ - if cls.is_ignored_type(field_meta.annotation): - return [None] - - for unwrapped_annotation in flatten_annotation(field_meta.annotation): - if unwrapped_annotation in (None, NoneType): - yield None - - elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): - yield CoverageContainer(literal_args) - - elif isinstance(unwrapped_annotation, EnumMeta): - yield CoverageContainer(list(unwrapped_annotation)) - - elif field_meta.constraints: - yield CoverageContainerCallable( - cls.get_constrained_field_value, - annotation=unwrapped_annotation, - field_meta=field_meta, - ) - - elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): - yield CoverageContainer( - cls._get_or_create_factory(model=unwrapped_annotation).coverage( - _build_context=build_context, - **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), - ), - ) - - elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): - yield handle_collection_type_coverage(field_meta, origin, cls) - - elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): - yield create_random_string(cls.__random__, min_length=1, max_length=10) - - elif provider := cls.get_provider_map().get(unwrapped_annotation): - yield CoverageContainerCallable(provider) - - elif callable(unwrapped_annotation): - # if value is a callable we can try to naively call it. - # this will work for callables that do not require any parameters passed - yield CoverageContainerCallable(unwrapped_annotation) - else: - msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." - raise ParameterException( - msg, - ) - - @classmethod - def should_set_none_value(cls, field_meta: FieldMeta) -> bool: - """Determine whether a given model field_meta should be set to None. - - :param field_meta: Field metadata. - - :notes: - - This method is distinct to allow overriding. - - :returns: A boolean determining whether 'None' should be set for the given field_meta. - - """ - return ( - cls.__allow_none_optionals__ - and is_optional(field_meta.annotation) - and create_random_boolean(random=cls.__random__) - ) - - @classmethod - def should_use_default_value(cls, field_meta: FieldMeta) -> bool: - """Determine whether to use the default value for the given field. - - :param field_meta: FieldMeta instance. - - :notes: - - This method is distinct to allow overriding. - - :returns: A boolean determining whether the default value should be used for the given field_meta. - - """ - return cls.__use_defaults__ and field_meta.default is not Null - - @classmethod - def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: - """Determine whether to set a value for a given field_name. - - :param field_meta: FieldMeta instance. - :param kwargs: Any kwargs passed to the factory. - - :notes: - - This method is distinct to allow overriding. - - :returns: A boolean determining whether a value should be set for the given field_meta. - - """ - return not field_meta.name.startswith("_") and field_meta.name not in kwargs - - @classmethod - @abstractmethod - def get_model_fields(cls) -> list[FieldMeta]: # pragma: no cover - """Retrieve a list of fields from the factory's model. - - - :returns: A list of field MetaData instances. - - """ - raise NotImplementedError - - @classmethod - def get_factory_fields(cls) -> list[tuple[str, Any]]: - """Retrieve a list of fields from the factory. - - Trying to be smart about what should be considered a field on the model, - ignoring dunder methods and some parent class attributes. - - :returns: A list of tuples made of field name and field definition - """ - factory_fields = cls.__dict__.items() - return [ - (field_name, field_value) - for field_name, field_value in factory_fields - if not (field_name.startswith("__") or field_name == "_abc_impl") - ] - - @classmethod - def _check_declared_fields_exist_in_model(cls) -> None: - model_fields_names = {field_meta.name for field_meta in cls.get_model_fields()} - factory_fields = cls.get_factory_fields() - - for field_name, field_value in factory_fields: - if field_name in model_fields_names: - continue - - error_message = ( - f"{field_name} is declared on the factory {cls.__name__}" - f" but it is not part of the model {cls.__model__.__name__}" - ) - if isinstance(field_value, (Use, PostGenerated, Ignore, Require)): - raise ConfigurationException(error_message) - - @classmethod - def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: - """Process the given kwargs and generate values for the factory's model. - - :param kwargs: Any build kwargs. - - :returns: A dictionary of build results. - - """ - _build_context = _get_build_context(kwargs.pop("_build_context", None)) - _build_context["seen_models"].add(cls.__model__) - - result: dict[str, Any] = {**kwargs} - generate_post: dict[str, PostGenerated] = {} - - for field_meta in cls.get_model_fields(): - field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) - if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta): - if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): - field_value = getattr(cls, field_meta.name) - if isinstance(field_value, Ignore): - continue - - if isinstance(field_value, Require) and field_meta.name not in kwargs: - msg = f"Require kwarg {field_meta.name} is missing" - raise MissingBuildKwargException(msg) - - if isinstance(field_value, PostGenerated): - generate_post[field_meta.name] = field_value - continue - - result[field_meta.name] = cls._handle_factory_field( - field_value=field_value, - field_build_parameters=field_build_parameters, - build_context=_build_context, - ) - continue - - field_result = cls.get_field_value( - field_meta, - field_build_parameters=field_build_parameters, - build_context=_build_context, - ) - if field_result is Null: - continue - - result[field_meta.name] = field_result - - for field_name, post_generator in generate_post.items(): - result[field_name] = post_generator.to_value(field_name, result) - - return result - - @classmethod - def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: - """Process the given kwargs and generate values for the factory's model. - - :param kwargs: Any build kwargs. - :param build_context: BuildContext data for current build. - - :returns: A dictionary of build results. - - """ - _build_context = _get_build_context(kwargs.pop("_build_context", None)) - _build_context["seen_models"].add(cls.__model__) - - result: dict[str, Any] = {**kwargs} - generate_post: dict[str, PostGenerated] = {} - - for field_meta in cls.get_model_fields(): - field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) - - if cls.should_set_field_value(field_meta, **kwargs): - if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): - field_value = getattr(cls, field_meta.name) - if isinstance(field_value, Ignore): - continue - - if isinstance(field_value, Require) and field_meta.name not in kwargs: - msg = f"Require kwarg {field_meta.name} is missing" - raise MissingBuildKwargException(msg) - - if isinstance(field_value, PostGenerated): - generate_post[field_meta.name] = field_value - continue - - result[field_meta.name] = cls._handle_factory_field_coverage( - field_value=field_value, - field_build_parameters=field_build_parameters, - build_context=_build_context, - ) - continue - - result[field_meta.name] = CoverageContainer( - cls.get_field_value_coverage( - field_meta, - field_build_parameters=field_build_parameters, - build_context=_build_context, - ), - ) - - for resolved in resolve_kwargs_coverage(result): - for field_name, post_generator in generate_post.items(): - resolved[field_name] = post_generator.to_value(field_name, resolved) - yield resolved - - @classmethod - def build(cls, **kwargs: Any) -> T: - """Build an instance of the factory's __model__ - - :param kwargs: Any kwargs. If field names are set in kwargs, their values will be used. - - :returns: An instance of type T. - - """ - return cast("T", cls.__model__(**cls.process_kwargs(**kwargs))) - - @classmethod - def batch(cls, size: int, **kwargs: Any) -> list[T]: - """Build a batch of size n of the factory's Meta.model. - - :param size: Size of the batch. - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: A list of instances of type T. - - """ - return [cls.build(**kwargs) for _ in range(size)] - - @classmethod - def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: - """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. - - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: A iterator of instances of type T. - - """ - for data in cls.process_kwargs_coverage(**kwargs): - instance = cls.__model__(**data) - yield cast("T", instance) - - @classmethod - def create_sync(cls, **kwargs: Any) -> T: - """Build and persists synchronously a single model instance. - - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: An instance of type T. - - """ - return cls._get_sync_persistence().save(data=cls.build(**kwargs)) - - @classmethod - def create_batch_sync(cls, size: int, **kwargs: Any) -> list[T]: - """Build and persists synchronously a batch of n size model instances. - - :param size: Size of the batch. - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: A list of instances of type T. - - """ - return cls._get_sync_persistence().save_many(data=cls.batch(size, **kwargs)) - - @classmethod - async def create_async(cls, **kwargs: Any) -> T: - """Build and persists asynchronously a single model instance. - - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: An instance of type T. - """ - return await cls._get_async_persistence().save(data=cls.build(**kwargs)) - - @classmethod - async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]: - """Build and persists asynchronously a batch of n size model instances. - - - :param size: Size of the batch. - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: A list of instances of type T. - """ - return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs)) - - -def _register_builtin_factories() -> None: - """This function is used to register the base factories, if present. - - :returns: None - """ - import polyfactory.factories.dataclass_factory - import polyfactory.factories.typed_dict_factory # noqa: F401 - - for module in [ - "polyfactory.factories.pydantic_factory", - "polyfactory.factories.beanie_odm_factory", - "polyfactory.factories.odmantic_odm_factory", - "polyfactory.factories.msgspec_factory", - # `AttrsFactory` is not being registered by default since not all versions of `attrs` are supported. - # Issue: https://github.com/litestar-org/polyfactory/issues/356 - # "polyfactory.factories.attrs_factory", - ]: - try: - import_module(module) - except ImportError: # noqa: PERF203 - continue - - -_register_builtin_factories() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py deleted file mode 100644 index ddd3169..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/beanie_odm_factory.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from typing_extensions import get_args - -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.persistence import AsyncPersistenceProtocol -from polyfactory.utils.predicates import is_safe_subclass - -if TYPE_CHECKING: - from typing_extensions import TypeGuard - - from polyfactory.factories.base import BuildContext - from polyfactory.field_meta import FieldMeta - -try: - from beanie import Document -except ImportError as e: - msg = "beanie is not installed" - raise MissingDependencyException(msg) from e - -T = TypeVar("T", bound=Document) - - -class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]): - """Persistence Handler using beanie logic""" - - async def save(self, data: T) -> T: - """Persist a single instance in mongoDB.""" - return await data.insert() # type: ignore[no-any-return] - - async def save_many(self, data: list[T]) -> list[T]: - """Persist multiple instances in mongoDB. - - .. note:: we cannot use the ``.insert_many`` method from Beanie here because it doesn't - return the created instances - """ - return [await doc.insert() for doc in data] # pyright: ignore[reportGeneralTypeIssues] - - -class BeanieDocumentFactory(Generic[T], ModelFactory[T]): - """Base factory for Beanie Documents""" - - __async_persistence__ = BeaniePersistenceHandler - __is_base_factory__ = True - - @classmethod - def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - return is_safe_subclass(value, Document) - - @classmethod - def get_field_value( - cls, - field_meta: "FieldMeta", - field_build_parameters: Any | None = None, - build_context: BuildContext | None = None, - ) -> Any: - """Return a field value on the subclass if existing, otherwise returns a mock value. - - :param field_meta: FieldMeta instance. - :param field_build_parameters: Any build parameters passed to the factory as kwarg values. - - :returns: An arbitrary value. - - """ - if hasattr(field_meta.annotation, "__name__"): - if "Indexed " in field_meta.annotation.__name__: - base_type = field_meta.annotation.__bases__[0] - field_meta.annotation = base_type - - if "Link" in field_meta.annotation.__name__: - link_class = get_args(field_meta.annotation)[0] - field_meta.annotation = link_class - field_meta.annotation = link_class - - return super().get_field_value( - field_meta=field_meta, - field_build_parameters=field_build_parameters, - build_context=build_context, - ) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py deleted file mode 100644 index 01cfbe7..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/dataclass_factory.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from dataclasses import MISSING, fields, is_dataclass -from typing import Any, Generic - -from typing_extensions import TypeGuard, get_type_hints - -from polyfactory.factories.base import BaseFactory, T -from polyfactory.field_meta import FieldMeta, Null - - -class DataclassFactory(Generic[T], BaseFactory[T]): - """Dataclass base factory""" - - __is_base_factory__ = True - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - return bool(is_dataclass(value)) - - @classmethod - def get_model_fields(cls) -> list["FieldMeta"]: - """Retrieve a list of fields from the factory's model. - - - :returns: A list of field MetaData instances. - - """ - fields_meta: list["FieldMeta"] = [] - - model_type_hints = get_type_hints(cls.__model__, include_extras=True) - - for field in fields(cls.__model__): # type: ignore[arg-type] - if not field.init: - continue - - if field.default_factory and field.default_factory is not MISSING: - default_value = field.default_factory() - elif field.default is not MISSING: - default_value = field.default - else: - default_value = Null - - fields_meta.append( - FieldMeta.from_type( - annotation=model_type_hints[field.name], - name=field.name, - default=default_value, - random=cls.__random__, - ), - ) - - return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py deleted file mode 100644 index 1b579ae..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/msgspec_factory.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from inspect import isclass -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar - -from typing_extensions import get_type_hints - -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory -from polyfactory.field_meta import FieldMeta, Null -from polyfactory.value_generators.constrained_numbers import handle_constrained_int -from polyfactory.value_generators.primitives import create_random_bytes - -if TYPE_CHECKING: - from typing_extensions import TypeGuard - -try: - import msgspec - from msgspec.structs import fields -except ImportError as e: - msg = "msgspec is not installed" - raise MissingDependencyException(msg) from e - -T = TypeVar("T", bound=msgspec.Struct) - - -class MsgspecFactory(Generic[T], BaseFactory[T]): - """Base factory for msgspec Structs.""" - - __is_base_factory__ = True - - @classmethod - def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: - def get_msgpack_ext() -> msgspec.msgpack.Ext: - code = handle_constrained_int(cls.__random__, ge=-128, le=127) - data = create_random_bytes(cls.__random__) - return msgspec.msgpack.Ext(code, data) - - msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext} - - provider_map = super().get_provider_map() - provider_map.update(msgspec_provider_map) - - return provider_map - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - return isclass(value) and hasattr(value, "__struct_fields__") - - @classmethod - def get_model_fields(cls) -> list[FieldMeta]: - fields_meta: list[FieldMeta] = [] - - type_hints = get_type_hints(cls.__model__, include_extras=True) - for field in fields(cls.__model__): - annotation = type_hints[field.name] - if field.default is not msgspec.NODEFAULT: - default_value = field.default - elif field.default_factory is not msgspec.NODEFAULT: - default_value = field.default_factory() - else: - default_value = Null - - fields_meta.append( - FieldMeta.from_type( - annotation=annotation, - name=field.name, - default=default_value, - random=cls.__random__, - ), - ) - return fields_meta diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py deleted file mode 100644 index 1b3367a..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/odmantic_odm_factory.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import decimal -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union - -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.utils.predicates import is_safe_subclass -from polyfactory.value_generators.primitives import create_random_bytes - -try: - from bson.decimal128 import Decimal128, create_decimal128_context - from odmantic import EmbeddedModel, Model - from odmantic import bson as odbson - -except ImportError as e: - msg = "odmantic is not installed" - raise MissingDependencyException(msg) from e - -T = TypeVar("T", bound=Union[Model, EmbeddedModel]) - -if TYPE_CHECKING: - from typing_extensions import TypeGuard - - -class OdmanticModelFactory(Generic[T], ModelFactory[T]): - """Base factory for odmantic models""" - - __is_base_factory__ = True - - @classmethod - def is_supported_type(cls, value: Any) -> "TypeGuard[type[T]]": - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - return is_safe_subclass(value, (Model, EmbeddedModel)) - - @classmethod - def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: - provider_map = super().get_provider_map() - provider_map.update( - { - odbson.Int64: lambda: odbson.Int64.validate(cls.__faker__.pyint()), - odbson.Decimal128: lambda: _to_decimal128(cls.__faker__.pydecimal()), - odbson.Binary: lambda: odbson.Binary.validate(create_random_bytes(cls.__random__)), - odbson._datetime: lambda: odbson._datetime.validate(cls.__faker__.date_time_between()), - # bson.Regex and bson._Pattern not supported as there is no way to generate - # a random regular expression with Faker - # bson.Regex: - # bson._Pattern: - }, - ) - return provider_map - - -def _to_decimal128(value: decimal.Decimal) -> Decimal128: - with decimal.localcontext(create_decimal128_context()) as ctx: - return Decimal128(ctx.create_decimal(value)) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py deleted file mode 100644 index a6028b1..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/pydantic_factory.py +++ /dev/null @@ -1,554 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from datetime import timezone -from functools import partial -from os.path import realpath -from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, ForwardRef, Generic, Mapping, Tuple, TypeVar, cast -from uuid import NAMESPACE_DNS, uuid1, uuid3, uuid5 - -from typing_extensions import Literal, get_args, get_origin - -from polyfactory.collection_extender import CollectionExtender -from polyfactory.constants import DEFAULT_RANDOM -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory -from polyfactory.field_meta import Constraints, FieldMeta, Null -from polyfactory.utils.deprecation import check_for_deprecated_parameters -from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional -from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union -from polyfactory.utils.types import NoneType -from polyfactory.value_generators.primitives import create_random_bytes - -try: - import pydantic - from pydantic import VERSION, Json - from pydantic.fields import FieldInfo -except ImportError as e: - msg = "pydantic is not installed" - raise MissingDependencyException(msg) from e - -try: - # pydantic v1 - from pydantic import ( # noqa: I001 - UUID1, - UUID3, - UUID4, - UUID5, - AmqpDsn, - AnyHttpUrl, - AnyUrl, - DirectoryPath, - FilePath, - HttpUrl, - KafkaDsn, - PostgresDsn, - RedisDsn, - ) - from pydantic import BaseModel as BaseModelV1 - from pydantic.color import Color - from pydantic.fields import ( # type: ignore[attr-defined] - DeferredType, # pyright: ignore[reportGeneralTypeIssues] - ModelField, # pyright: ignore[reportGeneralTypeIssues] - Undefined, # pyright: ignore[reportGeneralTypeIssues] - ) - - # Keep this import last to prevent warnings from pydantic if pydantic v2 - # is installed. - from pydantic import PyObject - - # prevent unbound variable warnings - BaseModelV2 = BaseModelV1 - UndefinedV2 = Undefined -except ImportError: - # pydantic v2 - - # v2 specific imports - from pydantic import BaseModel as BaseModelV2 - from pydantic_core import PydanticUndefined as UndefinedV2 - from pydantic_core import to_json - - from pydantic.v1 import ( # v1 compat imports - UUID1, - UUID3, - UUID4, - UUID5, - AmqpDsn, - AnyHttpUrl, - AnyUrl, - DirectoryPath, - FilePath, - HttpUrl, - KafkaDsn, - PostgresDsn, - PyObject, - RedisDsn, - ) - from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment] - from pydantic.v1.color import Color # type: ignore[assignment] - from pydantic.v1.fields import DeferredType, ModelField, Undefined - - -if TYPE_CHECKING: - from random import Random - from typing import Callable, Sequence - - from typing_extensions import NotRequired, TypeGuard - -T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") - -_IS_PYDANTIC_V1 = VERSION.startswith("1") - - -class PydanticConstraints(Constraints): - """Metadata regarding a Pydantic type constraints, if any""" - - json: NotRequired[bool] - - -class PydanticFieldMeta(FieldMeta): - """Field meta subclass capable of handling pydantic ModelFields""" - - def __init__( - self, - *, - name: str, - annotation: type, - random: Random | None = None, - default: Any = ..., - children: list[FieldMeta] | None = None, - constraints: PydanticConstraints | None = None, - ) -> None: - super().__init__( - name=name, - annotation=annotation, - random=random, - default=default, - children=children, - constraints=constraints, - ) - - @classmethod - def from_field_info( - cls, - field_name: str, - field_info: FieldInfo, - use_alias: bool, - random: Random | None, - randomize_collection_length: bool | None = None, - min_collection_length: int | None = None, - max_collection_length: int | None = None, - ) -> PydanticFieldMeta: - """Create an instance from a pydantic field info. - - :param field_name: The name of the field. - :param field_info: A pydantic FieldInfo instance. - :param use_alias: Whether to use the field alias. - :param random: A random.Random instance. - :param randomize_collection_length: Whether to randomize collection length. - :param min_collection_length: Minimum collection length. - :param max_collection_length: Maximum collection length. - - :returns: A PydanticFieldMeta instance. - """ - check_for_deprecated_parameters( - "2.11.0", - parameters=( - ("randomize_collection_length", randomize_collection_length), - ("min_collection_length", min_collection_length), - ("max_collection_length", max_collection_length), - ), - ) - if callable(field_info.default_factory): - default_value = field_info.default_factory() - else: - default_value = field_info.default if field_info.default is not UndefinedV2 else Null - - annotation = unwrap_new_type(field_info.annotation) - children: list[FieldMeta,] | None = None - name = field_info.alias if field_info.alias and use_alias else field_name - - constraints: PydanticConstraints - # pydantic v2 does not always propagate metadata for Union types - if is_union(annotation): - constraints = {} - children = [] - for arg in get_args(annotation): - if arg is NoneType: - continue - child_field_info = FieldInfo.from_annotation(arg) - merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info) - children.append( - cls.from_field_info( - field_name="", - field_info=merged_field_info, - use_alias=use_alias, - random=random, - ), - ) - else: - metadata, is_json = [], False - for m in field_info.metadata: - if not is_json and isinstance(m, Json): # type: ignore[misc] - is_json = True - elif m is not None: - metadata.append(m) - - constraints = cast( - PydanticConstraints, - cls.parse_constraints(metadata=metadata) if metadata else {}, - ) - - if "url" in constraints: - # pydantic uses a sentinel value for url constraints - annotation = str - - if is_json: - constraints["json"] = True - - return PydanticFieldMeta.from_type( - annotation=annotation, - children=children, - constraints=cast("Constraints", {k: v for k, v in constraints.items() if v is not None}) or None, - default=default_value, - name=name, - random=random or DEFAULT_RANDOM, - ) - - @classmethod - def from_model_field( # pragma: no cover - cls, - model_field: ModelField, # pyright: ignore[reportGeneralTypeIssues] - use_alias: bool, - randomize_collection_length: bool | None = None, - min_collection_length: int | None = None, - max_collection_length: int | None = None, - random: Random = DEFAULT_RANDOM, - ) -> PydanticFieldMeta: - """Create an instance from a pydantic model field. - :param model_field: A pydantic ModelField. - :param use_alias: Whether to use the field alias. - :param randomize_collection_length: A boolean flag whether to randomize collections lengths - :param min_collection_length: Minimum number of elements in randomized collection - :param max_collection_length: Maximum number of elements in randomized collection - :param random: An instance of random.Random. - - :returns: A PydanticFieldMeta instance. - - """ - check_for_deprecated_parameters( - "2.11.0", - parameters=( - ("randomize_collection_length", randomize_collection_length), - ("min_collection_length", min_collection_length), - ("max_collection_length", max_collection_length), - ), - ) - - if model_field.default is not Undefined: - default_value = model_field.default - elif callable(model_field.default_factory): - default_value = model_field.default_factory() - else: - default_value = model_field.default if model_field.default is not Undefined else Null - - name = model_field.alias if model_field.alias and use_alias else model_field.name - - outer_type = unwrap_new_type(model_field.outer_type_) - annotation = ( - model_field.outer_type_ - if isinstance(model_field.annotation, (DeferredType, ForwardRef)) - else unwrap_new_type(model_field.annotation) - ) - - constraints = cast( - "Constraints", - { - "ge": getattr(outer_type, "ge", model_field.field_info.ge), - "gt": getattr(outer_type, "gt", model_field.field_info.gt), - "le": getattr(outer_type, "le", model_field.field_info.le), - "lt": getattr(outer_type, "lt", model_field.field_info.lt), - "min_length": ( - getattr(outer_type, "min_length", model_field.field_info.min_length) - or getattr(outer_type, "min_items", model_field.field_info.min_items) - ), - "max_length": ( - getattr(outer_type, "max_length", model_field.field_info.max_length) - or getattr(outer_type, "max_items", model_field.field_info.max_items) - ), - "pattern": getattr(outer_type, "regex", model_field.field_info.regex), - "unique_items": getattr(outer_type, "unique_items", model_field.field_info.unique_items), - "decimal_places": getattr(outer_type, "decimal_places", None), - "max_digits": getattr(outer_type, "max_digits", None), - "multiple_of": getattr(outer_type, "multiple_of", None), - "upper_case": getattr(outer_type, "to_upper", None), - "lower_case": getattr(outer_type, "to_lower", None), - "item_type": getattr(outer_type, "item_type", None), - }, - ) - - # pydantic v1 has constraints set for these values, but we generate them using faker - if unwrap_optional(annotation) in ( - AnyUrl, - HttpUrl, - KafkaDsn, - PostgresDsn, - RedisDsn, - AmqpDsn, - AnyHttpUrl, - ): - constraints = {} - - if model_field.field_info.const and ( - default_value is None or isinstance(default_value, (int, bool, str, bytes)) - ): - annotation = Literal[default_value] # pyright: ignore # noqa: PGH003 - - children: list[FieldMeta] = [] - - # Refer #412. - args = get_args(model_field.annotation) - if is_optional(model_field.annotation) and len(args) == 2: # noqa: PLR2004 - child_annotation = args[0] if args[0] is not NoneType else args[1] - children.append(PydanticFieldMeta.from_type(child_annotation)) - elif model_field.key_field or model_field.sub_fields: - fields_to_iterate = ( - ([model_field.key_field, *model_field.sub_fields]) - if model_field.key_field is not None - else model_field.sub_fields - ) - type_args = tuple( - ( - sub_field.outer_type_ - if isinstance(sub_field.annotation, DeferredType) - else unwrap_new_type(sub_field.annotation) - ) - for sub_field in fields_to_iterate - ) - type_arg_to_sub_field = dict(zip(type_args, fields_to_iterate)) - if get_origin(outer_type) in (tuple, Tuple) and get_args(outer_type)[-1] == Ellipsis: - # pydantic removes ellipses from Tuples in sub_fields - type_args += (...,) - extended_type_args = CollectionExtender.extend_type_args(annotation, type_args, 1) - children.extend( - PydanticFieldMeta.from_model_field( - model_field=type_arg_to_sub_field[arg], - use_alias=use_alias, - random=random, - ) - for arg in extended_type_args - ) - - return PydanticFieldMeta( - name=name, - random=random or DEFAULT_RANDOM, - annotation=annotation, - children=children or None, - default=default_value, - constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None, - ) - - if not _IS_PYDANTIC_V1: - - @classmethod - def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: - metadata = [] - for m in super().get_constraints_metadata(annotation): - if isinstance(m, FieldInfo): - metadata.extend(m.metadata) - else: - metadata.append(m) - - return metadata - - -class ModelFactory(Generic[T], BaseFactory[T]): - """Base factory for pydantic models""" - - __forward_ref_resolution_type_mapping__: ClassVar[Mapping[str, type]] = {} - __is_base_factory__ = True - - def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: - super().__init_subclass__(*args, **kwargs) - - if ( - getattr(cls, "__model__", None) - and _is_pydantic_v1_model(cls.__model__) - and hasattr(cls.__model__, "update_forward_refs") - ): - with suppress(NameError): # pragma: no cover - cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - - return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) - - @classmethod - def get_model_fields(cls) -> list["FieldMeta"]: - """Retrieve a list of fields from the factory's model. - - - :returns: A list of field MetaData instances. - - """ - if "_fields_metadata" not in cls.__dict__: - if _is_pydantic_v1_model(cls.__model__): - cls._fields_metadata = [ - PydanticFieldMeta.from_model_field( - field, - use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined] - random=cls.__random__, - ) - for field in cls.__model__.__fields__.values() - ] - else: - cls._fields_metadata = [ - PydanticFieldMeta.from_field_info( - field_info=field_info, - field_name=field_name, - random=cls.__random__, - use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues] - "populate_by_name", - False, - ), - ) - for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues] - ] - return cls._fields_metadata - - @classmethod - def get_constrained_field_value(cls, annotation: Any, field_meta: FieldMeta) -> Any: - constraints = cast(PydanticConstraints, field_meta.constraints) - if constraints.pop("json", None): - value = cls.get_field_value(field_meta) - return to_json(value) # pyright: ignore[reportUnboundVariable] - - return super().get_constrained_field_value(annotation, field_meta) - - @classmethod - def build( - cls, - factory_use_construct: bool = False, - **kwargs: Any, - ) -> T: - """Build an instance of the factory's __model__ - - :param factory_use_construct: A boolean that determines whether validations will be made when instantiating the - model. This is supported only for pydantic models. - :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. - - :returns: An instance of type T. - - """ - processed_kwargs = cls.process_kwargs(**kwargs) - - if factory_use_construct: - if _is_pydantic_v1_model(cls.__model__): - return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value] - return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value] - - return cls.__model__(**processed_kwargs) # type: ignore[return-value] - - @classmethod - def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: - """Determine whether the field is a custom root field. - - :param field_meta: FieldMeta instance. - - :returns: A boolean determining whether the field is a custom root. - - """ - return field_meta.name == "__root__" - - @classmethod - def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool: - """Determine whether to set a value for a given field_name. - This is an override of BaseFactory.should_set_field_value. - - :param field_meta: FieldMeta instance. - :param kwargs: Any kwargs passed to the factory. - - :returns: A boolean determining whether a value should be set for the given field_meta. - - """ - return field_meta.name not in kwargs and ( - not field_meta.name.startswith("_") or cls.is_custom_root_field(field_meta) - ) - - @classmethod - def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: - mapping = { - pydantic.ByteSize: cls.__faker__.pyint, - pydantic.PositiveInt: cls.__faker__.pyint, - pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1), - pydantic.NegativeInt: lambda: cls.__faker__.pyint() * -1, - pydantic.PositiveFloat: cls.__faker__.pyint, - pydantic.NonPositiveFloat: lambda: cls.__random__.uniform(-100, 0), - pydantic.NonNegativeInt: cls.__faker__.pyint, - pydantic.StrictInt: cls.__faker__.pyint, - pydantic.StrictBool: cls.__faker__.pybool, - pydantic.StrictBytes: partial(create_random_bytes, cls.__random__), - pydantic.StrictFloat: cls.__faker__.pyfloat, - pydantic.StrictStr: cls.__faker__.pystr, - pydantic.EmailStr: cls.__faker__.free_email, - pydantic.NameEmail: cls.__faker__.free_email, - pydantic.Json: cls.__faker__.json, - pydantic.PaymentCardNumber: cls.__faker__.credit_card_number, - pydantic.AnyUrl: cls.__faker__.url, - pydantic.AnyHttpUrl: cls.__faker__.url, - pydantic.HttpUrl: cls.__faker__.url, - pydantic.SecretBytes: partial(create_random_bytes, cls.__random__), - pydantic.SecretStr: cls.__faker__.pystr, - pydantic.IPvAnyAddress: cls.__faker__.ipv4, - pydantic.IPvAnyInterface: cls.__faker__.ipv4, - pydantic.IPvAnyNetwork: lambda: cls.__faker__.ipv4(network=True), - pydantic.PastDate: cls.__faker__.past_date, - pydantic.FutureDate: cls.__faker__.future_date, - } - - # v1 only values - mapping.update( - { - PyObject: lambda: "decimal.Decimal", - AmqpDsn: lambda: "amqps://example.com", - KafkaDsn: lambda: "kafka://localhost:9092", - PostgresDsn: lambda: "postgresql://user:secret@localhost", - RedisDsn: lambda: "redis://localhost:6379/0", - FilePath: lambda: Path(realpath(__file__)), - DirectoryPath: lambda: Path(realpath(__file__)).parent, - UUID1: uuid1, - UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), - UUID4: cls.__faker__.uuid4, - UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), - Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] - }, - ) - - if not _IS_PYDANTIC_V1: - mapping.update( - { - # pydantic v2 specific types - pydantic.PastDatetime: cls.__faker__.past_datetime, - pydantic.FutureDatetime: cls.__faker__.future_datetime, - pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), - pydantic.NaiveDatetime: cls.__faker__.date_time, - }, - ) - - mapping.update(super().get_provider_map()) - return mapping - - -def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: - return is_safe_subclass(model, BaseModelV1) - - -def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: - return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py deleted file mode 100644 index ad8873f..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/sqlalchemy_factory.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import annotations - -from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, List, TypeVar, Union - -from polyfactory.exceptions import MissingDependencyException -from polyfactory.factories.base import BaseFactory -from polyfactory.field_meta import FieldMeta -from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol - -try: - from sqlalchemy import Column, inspect, types - from sqlalchemy.dialects import mysql, postgresql - from sqlalchemy.exc import NoInspectionAvailable - from sqlalchemy.orm import InstanceState, Mapper -except ImportError as e: - msg = "sqlalchemy is not installed" - raise MissingDependencyException(msg) from e - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.orm import Session - from typing_extensions import TypeGuard - - -T = TypeVar("T") - - -class SQLASyncPersistence(SyncPersistenceProtocol[T]): - def __init__(self, session: Session) -> None: - """Sync persistence handler for SQLAFactory.""" - self.session = session - - def save(self, data: T) -> T: - self.session.add(data) - self.session.commit() - return data - - def save_many(self, data: list[T]) -> list[T]: - self.session.add_all(data) - self.session.commit() - return data - - -class SQLAASyncPersistence(AsyncPersistenceProtocol[T]): - def __init__(self, session: AsyncSession) -> None: - """Async persistence handler for SQLAFactory.""" - self.session = session - - async def save(self, data: T) -> T: - self.session.add(data) - await self.session.commit() - return data - - async def save_many(self, data: list[T]) -> list[T]: - self.session.add_all(data) - await self.session.commit() - return data - - -class SQLAlchemyFactory(Generic[T], BaseFactory[T]): - """Base factory for SQLAlchemy models.""" - - __is_base_factory__ = True - - __set_primary_key__: ClassVar[bool] = True - """Configuration to consider primary key columns as a field or not.""" - __set_foreign_keys__: ClassVar[bool] = True - """Configuration to consider columns with foreign keys as a field or not.""" - __set_relationships__: ClassVar[bool] = False - """Configuration to consider relationships property as a model field or not.""" - - __session__: ClassVar[Session | Callable[[], Session] | None] = None - __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None - - __config_keys__ = ( - *BaseFactory.__config_keys__, - "__set_primary_key__", - "__set_foreign_keys__", - "__set_relationships__", - ) - - @classmethod - def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: - """Get mapping of types where column type.""" - return { - types.TupleType: cls.__faker__.pytuple, - mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), - postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False), - postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011 - postgresql.INET: lambda: cls.__faker__.ipv4(network=True), - postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), - postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), - postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), - postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), - postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 - postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 - } - - @classmethod - def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: - providers_map = super().get_provider_map() - providers_map.update(cls.get_sqlalchemy_types()) - return providers_map - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: - try: - inspected = inspect(value) - except NoInspectionAvailable: - return False - return isinstance(inspected, (Mapper, InstanceState)) - - @classmethod - def should_column_be_set(cls, column: Column) -> bool: - if not cls.__set_primary_key__ and column.primary_key: - return False - - return bool(cls.__set_foreign_keys__ or not column.foreign_keys) - - @classmethod - def get_type_from_column(cls, column: Column) -> type: - column_type = type(column.type) - if column_type in cls.get_sqlalchemy_types(): - annotation = column_type - elif issubclass(column_type, types.ARRAY): - annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] - else: - annotation = ( - column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues] - if hasattr(column.type, "impl") - else column.type.python_type - ) - - if column.nullable: - annotation = Union[annotation, None] # type: ignore[assignment] - - return annotation - - @classmethod - def get_model_fields(cls) -> list[FieldMeta]: - fields_meta: list[FieldMeta] = [] - - table: Mapper = inspect(cls.__model__) # type: ignore[assignment] - fields_meta.extend( - FieldMeta.from_type( - annotation=cls.get_type_from_column(column), - name=name, - random=cls.__random__, - ) - for name, column in table.columns.items() - if cls.should_column_be_set(column) - ) - if cls.__set_relationships__: - for name, relationship in table.relationships.items(): - class_ = relationship.entity.class_ - annotation = class_ if not relationship.uselist else List[class_] # type: ignore[valid-type] - fields_meta.append( - FieldMeta.from_type( - name=name, - annotation=annotation, - random=cls.__random__, - ), - ) - - return fields_meta - - @classmethod - def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]: - if cls.__session__ is not None: - return ( - SQLASyncPersistence(cls.__session__()) - if callable(cls.__session__) - else SQLASyncPersistence(cls.__session__) - ) - return super()._get_sync_persistence() - - @classmethod - def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]: - if cls.__async_session__ is not None: - return ( - SQLAASyncPersistence(cls.__async_session__()) - if callable(cls.__async_session__) - else SQLAASyncPersistence(cls.__async_session__) - ) - return super()._get_async_persistence() diff --git a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py b/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py deleted file mode 100644 index 2a3ea1b..0000000 --- a/venv/lib/python3.11/site-packages/polyfactory/factories/typed_dict_factory.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from typing import Any, Generic, TypeVar, get_args - -from typing_extensions import ( # type: ignore[attr-defined] - NotRequired, - Required, - TypeGuard, - _TypedDictMeta, # pyright: ignore[reportGeneralTypeIssues] - get_origin, - get_type_hints, - is_typeddict, -) - -from polyfactory.constants import DEFAULT_RANDOM -from polyfactory.factories.base import BaseFactory -from polyfactory.field_meta import FieldMeta, Null - -TypedDictT = TypeVar("TypedDictT", bound=_TypedDictMeta) - - -class TypedDictFactory(Generic[TypedDictT], BaseFactory[TypedDictT]): - """TypedDict base factory""" - - __is_base_factory__ = True - - @classmethod - def is_supported_type(cls, value: Any) -> TypeGuard[type[TypedDictT]]: - """Determine whether the given value is supported by the factory. - - :param value: An arbitrary value. - :returns: A typeguard - """ - return is_typeddict(value) - - @classmethod - def get_model_fields(cls) -> list["FieldMeta"]: - """Retrieve a list of fields from the factory's model. - - - :returns: A list of field MetaData instances. - - """ - model_type_hints = get_type_hints(cls.__model__, include_extras=True) - - field_metas: list[FieldMeta] = [] - for field_name, annotation in model_type_hints.items(): - origin = get_origin(annotation) - if origin in (Required, NotRequired): - annotation = get_args(annotation)[0] # noqa: PLW2901 - - field_metas.append( - FieldMeta.from_type( - annotation=annotation, - random=DEFAULT_RANDOM, - name=field_name, - default=getattr(cls.__model__, field_name, Null), - ), - ) - - return field_metas |