summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py')
-rw-r--r--venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py435
1 files changed, 435 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py b/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py
new file mode 100644
index 0000000..addc4b7
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py
@@ -0,0 +1,435 @@
+# testing/exclusions.py
+# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+import contextlib
+import operator
+import re
+import sys
+
+from . import config
+from .. import util
+from ..util import decorator
+from ..util.compat import inspect_getfullargspec
+
+
+def skip_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.skips.add(pred)
+ return rule
+
+
+def fails_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.fails.add(pred)
+ return rule
+
+
+class compound:
+ def __init__(self):
+ self.fails = set()
+ self.skips = set()
+
+ def __add__(self, other):
+ return self.add(other)
+
+ def as_skips(self):
+ rule = compound()
+ rule.skips.update(self.skips)
+ rule.skips.update(self.fails)
+ return rule
+
+ def add(self, *others):
+ copy = compound()
+ copy.fails.update(self.fails)
+ copy.skips.update(self.skips)
+
+ for other in others:
+ copy.fails.update(other.fails)
+ copy.skips.update(other.skips)
+ return copy
+
+ def not_(self):
+ copy = compound()
+ copy.fails.update(NotPredicate(fail) for fail in self.fails)
+ copy.skips.update(NotPredicate(skip) for skip in self.skips)
+ return copy
+
+ @property
+ def enabled(self):
+ return self.enabled_for_config(config._current)
+
+ def enabled_for_config(self, config):
+ for predicate in self.skips.union(self.fails):
+ if predicate(config):
+ return False
+ else:
+ return True
+
+ def matching_config_reasons(self, config):
+ return [
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
+ if predicate(config)
+ ]
+
+ def _extend(self, other):
+ self.skips.update(other.skips)
+ self.fails.update(other.fails)
+
+ def __call__(self, fn):
+ if hasattr(fn, "_sa_exclusion_extend"):
+ fn._sa_exclusion_extend._extend(self)
+ return fn
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ return self._do(config._current, fn, *args, **kw)
+
+ decorated = decorate(fn)
+ decorated._sa_exclusion_extend = self
+ return decorated
+
+ @contextlib.contextmanager
+ def fail_if(self):
+ all_fails = compound()
+ all_fails.fails.update(self.skips.union(self.fails))
+
+ try:
+ yield
+ except Exception as ex:
+ all_fails._expect_failure(config._current, ex)
+ else:
+ all_fails._expect_success(config._current)
+
+ def _do(self, cfg, fn, *args, **kw):
+ for skip in self.skips:
+ if skip(cfg):
+ msg = "'%s' : %s" % (
+ config.get_current_test_name(),
+ skip._as_string(cfg),
+ )
+ config.skip_test(msg)
+
+ try:
+ return_value = fn(*args, **kw)
+ except Exception as ex:
+ self._expect_failure(cfg, ex, name=fn.__name__)
+ else:
+ self._expect_success(cfg, name=fn.__name__)
+ return return_value
+
+ def _expect_failure(self, config, ex, name="block"):
+ for fail in self.fails:
+ if fail(config):
+ print(
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), ex)
+ )
+ break
+ else:
+ raise ex.with_traceback(sys.exc_info()[2])
+
+ def _expect_success(self, config, name="block"):
+ if not self.fails:
+ return
+
+ for fail in self.fails:
+ if fail(config):
+ raise AssertionError(
+ "Unexpected success for '%s' (%s)"
+ % (
+ name,
+ " and ".join(
+ fail._as_string(config) for fail in self.fails
+ ),
+ )
+ )
+
+
+def only_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return skip_if(NotPredicate(predicate), reason)
+
+
+def succeeds_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return fails_if(NotPredicate(predicate), reason)
+
+
+class Predicate:
+ @classmethod
+ def as_predicate(cls, predicate, description=None):
+ if isinstance(predicate, compound):
+ return cls.as_predicate(predicate.enabled_for_config, description)
+ elif isinstance(predicate, Predicate):
+ if description and predicate.description is None:
+ predicate.description = description
+ return predicate
+ elif isinstance(predicate, (list, set)):
+ return OrPredicate(
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
+ elif isinstance(predicate, tuple):
+ return SpecPredicate(*predicate)
+ elif isinstance(predicate, str):
+ tokens = re.match(
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
+ if not tokens:
+ raise ValueError(
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
+ db = tokens.group(1)
+ op = tokens.group(2)
+ spec = (
+ tuple(int(d) for d in tokens.group(3).split("."))
+ if tokens.group(3)
+ else None
+ )
+
+ return SpecPredicate(db, op, spec, description=description)
+ elif callable(predicate):
+ return LambdaPredicate(predicate, description)
+ else:
+ assert False, "unknown predicate type: %s" % predicate
+
+ def _format_description(self, config, negate=False):
+ bool_ = self(config)
+ if negate:
+ bool_ = not negate
+ return self.description % {
+ "driver": (
+ config.db.url.get_driver_name() if config else "<no driver>"
+ ),
+ "database": (
+ config.db.url.get_backend_name() if config else "<no database>"
+ ),
+ "doesnt_support": "doesn't support" if bool_ else "does support",
+ "does_support": "does support" if bool_ else "doesn't support",
+ }
+
+ def _as_string(self, config=None, negate=False):
+ raise NotImplementedError()
+
+
+class BooleanPredicate(Predicate):
+ def __init__(self, value, description=None):
+ self.value = value
+ self.description = description or "boolean %s" % value
+
+ def __call__(self, config):
+ return self.value
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config, negate=negate)
+
+
+class SpecPredicate(Predicate):
+ def __init__(self, db, op=None, spec=None, description=None):
+ self.db = db
+ self.op = op
+ self.spec = spec
+ self.description = description
+
+ _ops = {
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+ def __call__(self, config):
+ if config is None:
+ return False
+
+ engine = config.db
+
+ if "+" in self.db:
+ dialect, driver = self.db.split("+")
+ else:
+ dialect, driver = self.db, None
+
+ if dialect and engine.name != dialect:
+ return False
+ if driver is not None and engine.driver != driver:
+ return False
+
+ if self.op is not None:
+ assert driver is None, "DBAPI version specs not supported yet"
+
+ version = _server_version(engine)
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
+ return oper(version, self.spec)
+ else:
+ return True
+
+ def _as_string(self, config, negate=False):
+ if self.description is not None:
+ return self._format_description(config)
+ elif self.op is None:
+ if negate:
+ return "not %s" % self.db
+ else:
+ return "%s" % self.db
+ else:
+ if negate:
+ return "not %s %s %s" % (self.db, self.op, self.spec)
+ else:
+ return "%s %s %s" % (self.db, self.op, self.spec)
+
+
+class LambdaPredicate(Predicate):
+ def __init__(self, lambda_, description=None, args=None, kw=None):
+ spec = inspect_getfullargspec(lambda_)
+ if not spec[0]:
+ self.lambda_ = lambda db: lambda_()
+ else:
+ self.lambda_ = lambda_
+ self.args = args or ()
+ self.kw = kw or {}
+ if description:
+ self.description = description
+ elif lambda_.__doc__:
+ self.description = lambda_.__doc__
+ else:
+ self.description = "custom function"
+
+ def __call__(self, config):
+ return self.lambda_(config)
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config)
+
+
+class NotPredicate(Predicate):
+ def __init__(self, predicate, description=None):
+ self.predicate = predicate
+ self.description = description
+
+ def __call__(self, config):
+ return not self.predicate(config)
+
+ def _as_string(self, config, negate=False):
+ if self.description:
+ return self._format_description(config, not negate)
+ else:
+ return self.predicate._as_string(config, not negate)
+
+
+class OrPredicate(Predicate):
+ def __init__(self, predicates, description=None):
+ self.predicates = predicates
+ self.description = description
+
+ def __call__(self, config):
+ for pred in self.predicates:
+ if pred(config):
+ return True
+ return False
+
+ def _eval_str(self, config, negate=False):
+ if negate:
+ conjunction = " and "
+ else:
+ conjunction = " or "
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
+
+ def _negation_str(self, config):
+ if self.description is not None:
+ return "Not " + self._format_description(config)
+ else:
+ return self._eval_str(config, negate=True)
+
+ def _as_string(self, config, negate=False):
+ if negate:
+ return self._negation_str(config)
+ else:
+ if self.description is not None:
+ return self._format_description(config)
+ else:
+ return self._eval_str(config)
+
+
+_as_predicate = Predicate.as_predicate
+
+
+def _is_excluded(db, op, spec):
+ return SpecPredicate(db, op, spec)(config._current)
+
+
+def _server_version(engine):
+ """Return a server_version_info tuple."""
+
+ # force metadata to be retrieved
+ conn = engine.connect()
+ version = getattr(engine.dialect, "server_version_info", None)
+ if version is None:
+ version = ()
+ conn.close()
+ return version
+
+
+def db_spec(*dbs):
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
+
+
+def open(): # noqa
+ return skip_if(BooleanPredicate(False, "mark as execute"))
+
+
+def closed():
+ return skip_if(BooleanPredicate(True, "marked as skip"))
+
+
+def fails(reason=None):
+ return fails_if(BooleanPredicate(True, reason or "expected to fail"))
+
+
+def future():
+ return fails_if(BooleanPredicate(True, "Future feature"))
+
+
+def fails_on(db, reason=None):
+ return fails_if(db, reason)
+
+
+def fails_on_everything_except(*dbs):
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
+
+
+def skip(db, reason=None):
+ return skip_if(db, reason)
+
+
+def only_on(dbs, reason=None):
+ return only_if(
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
+ )
+
+
+def exclude(db, op, spec, reason=None):
+ return skip_if(SpecPredicate(db, op, spec), reason)
+
+
+def against(config, *queries):
+ assert queries, "no queries sent!"
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )