summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/uvloop/_testbase.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/uvloop/_testbase.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/uvloop/_testbase.py')
-rw-r--r--venv/lib/python3.11/site-packages/uvloop/_testbase.py550
1 files changed, 550 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/uvloop/_testbase.py b/venv/lib/python3.11/site-packages/uvloop/_testbase.py
new file mode 100644
index 0000000..c4a7595
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvloop/_testbase.py
@@ -0,0 +1,550 @@
+"""Test utilities. Don't use outside of the uvloop project."""
+
+
+import asyncio
+import asyncio.events
+import collections
+import contextlib
+import gc
+import logging
+import os
+import pprint
+import re
+import select
+import socket
+import ssl
+import sys
+import tempfile
+import threading
+import time
+import unittest
+import uvloop
+
+
+class MockPattern(str):
+ def __eq__(self, other):
+ return bool(re.search(str(self), other, re.S))
+
+
+class TestCaseDict(collections.UserDict):
+
+ def __init__(self, name):
+ super().__init__()
+ self.name = name
+
+ def __setitem__(self, key, value):
+ if key in self.data:
+ raise RuntimeError('duplicate test {}.{}'.format(
+ self.name, key))
+ super().__setitem__(key, value)
+
+
+class BaseTestCaseMeta(type):
+
+ @classmethod
+ def __prepare__(mcls, name, bases):
+ return TestCaseDict(name)
+
+ def __new__(mcls, name, bases, dct):
+ for test_name in dct:
+ if not test_name.startswith('test_'):
+ continue
+ for base in bases:
+ if hasattr(base, test_name):
+ raise RuntimeError(
+ 'duplicate test {}.{} (also defined in {} '
+ 'parent class)'.format(
+ name, test_name, base.__name__))
+
+ return super().__new__(mcls, name, bases, dict(dct))
+
+
+class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta):
+
+ def new_loop(self):
+ raise NotImplementedError
+
+ def new_policy(self):
+ raise NotImplementedError
+
+ def mock_pattern(self, str):
+ return MockPattern(str)
+
+ async def wait_closed(self, obj):
+ if not isinstance(obj, asyncio.StreamWriter):
+ return
+ try:
+ await obj.wait_closed()
+ except (BrokenPipeError, ConnectionError):
+ pass
+
+ def is_asyncio_loop(self):
+ return type(self.loop).__module__.startswith('asyncio.')
+
+ def run_loop_briefly(self, *, delay=0.01):
+ self.loop.run_until_complete(asyncio.sleep(delay))
+
+ def loop_exception_handler(self, loop, context):
+ self.__unhandled_exceptions.append(context)
+ self.loop.default_exception_handler(context)
+
+ def setUp(self):
+ self.loop = self.new_loop()
+ asyncio.set_event_loop_policy(self.new_policy())
+ asyncio.set_event_loop(self.loop)
+ self._check_unclosed_resources_in_debug = True
+
+ self.loop.set_exception_handler(self.loop_exception_handler)
+ self.__unhandled_exceptions = []
+
+ def tearDown(self):
+ self.loop.close()
+
+ if self.__unhandled_exceptions:
+ print('Unexpected calls to loop.call_exception_handler():')
+ pprint.pprint(self.__unhandled_exceptions)
+ self.fail('unexpected calls to loop.call_exception_handler()')
+ return
+
+ if not self._check_unclosed_resources_in_debug:
+ return
+
+ # GC to show any resource warnings as the test completes
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ if getattr(self.loop, '_debug_cc', False):
+ gc.collect()
+ gc.collect()
+ gc.collect()
+
+ self.assertEqual(
+ self.loop._debug_uv_handles_total,
+ self.loop._debug_uv_handles_freed,
+ 'not all uv_handle_t handles were freed')
+
+ self.assertEqual(
+ self.loop._debug_cb_handles_count, 0,
+ 'not all callbacks (call_soon) are GCed')
+
+ self.assertEqual(
+ self.loop._debug_cb_timer_handles_count, 0,
+ 'not all timer callbacks (call_later) are GCed')
+
+ self.assertEqual(
+ self.loop._debug_stream_write_ctx_cnt, 0,
+ 'not all stream write contexts are GCed')
+
+ for h_name, h_cnt in self.loop._debug_handles_current.items():
+ with self.subTest('Alive handle after test',
+ handle_name=h_name):
+ self.assertEqual(
+ h_cnt, 0,
+ 'alive {} after test'.format(h_name))
+
+ for h_name, h_cnt in self.loop._debug_handles_total.items():
+ with self.subTest('Total/closed handles',
+ handle_name=h_name):
+ self.assertEqual(
+ h_cnt, self.loop._debug_handles_closed[h_name],
+ 'total != closed for {}'.format(h_name))
+
+ asyncio.set_event_loop(None)
+ asyncio.set_event_loop_policy(None)
+ self.loop = None
+
+ def skip_unclosed_handles_check(self):
+ self._check_unclosed_resources_in_debug = False
+
+ def tcp_server(self, server_prog, *,
+ family=socket.AF_INET,
+ addr=None,
+ timeout=5,
+ backlog=1,
+ max_clients=10):
+
+ if addr is None:
+ if family == socket.AF_UNIX:
+ with tempfile.NamedTemporaryFile() as tmp:
+ addr = tmp.name
+ else:
+ addr = ('127.0.0.1', 0)
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ try:
+ sock.bind(addr)
+ sock.listen(backlog)
+ except OSError as ex:
+ sock.close()
+ raise ex
+
+ return TestThreadedServer(
+ self, sock, server_prog, timeout, max_clients)
+
+ def tcp_client(self, client_prog,
+ family=socket.AF_INET,
+ timeout=10):
+
+ sock = socket.socket(family, socket.SOCK_STREAM)
+
+ if timeout is None:
+ raise RuntimeError('timeout is required')
+ if timeout <= 0:
+ raise RuntimeError('only blocking sockets are supported')
+ sock.settimeout(timeout)
+
+ return TestThreadedClient(
+ self, sock, client_prog, timeout)
+
+ def unix_server(self, *args, **kwargs):
+ return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
+
+ def unix_client(self, *args, **kwargs):
+ return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
+
+ @contextlib.contextmanager
+ def unix_sock_name(self):
+ with tempfile.TemporaryDirectory() as td:
+ fn = os.path.join(td, 'sock')
+ try:
+ yield fn
+ finally:
+ try:
+ os.unlink(fn)
+ except OSError:
+ pass
+
+ def _abort_socket_test(self, ex):
+ try:
+ self.loop.stop()
+ finally:
+ self.fail(ex)
+
+
+def _cert_fullname(test_file_name, cert_file_name):
+ fullname = os.path.abspath(os.path.join(
+ os.path.dirname(test_file_name), 'certs', cert_file_name))
+ assert os.path.isfile(fullname)
+ return fullname
+
+
+@contextlib.contextmanager
+def silence_long_exec_warning():
+
+ class Filter(logging.Filter):
+ def filter(self, record):
+ return not (record.msg.startswith('Executing') and
+ record.msg.endswith('seconds'))
+
+ logger = logging.getLogger('asyncio')
+ filter = Filter()
+ logger.addFilter(filter)
+ try:
+ yield
+ finally:
+ logger.removeFilter(filter)
+
+
+def find_free_port(start_from=50000):
+ for port in range(start_from, start_from + 500):
+ sock = socket.socket()
+ with sock:
+ try:
+ sock.bind(('', port))
+ except socket.error:
+ continue
+ else:
+ return port
+ raise RuntimeError('could not find a free port')
+
+
+class SSLTestCase:
+
+ def _create_server_ssl_context(self, certfile, keyfile=None):
+ if hasattr(ssl, 'PROTOCOL_TLS'):
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ else:
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.load_cert_chain(certfile, keyfile)
+ return sslcontext
+
+ def _create_client_ssl_context(self, *, disable_verify=True):
+ sslcontext = ssl.create_default_context()
+ sslcontext.check_hostname = False
+ if disable_verify:
+ sslcontext.verify_mode = ssl.CERT_NONE
+ return sslcontext
+
+ @contextlib.contextmanager
+ def _silence_eof_received_warning(self):
+ # TODO This warning has to be fixed in asyncio.
+ logger = logging.getLogger('asyncio')
+ filter = logging.Filter('has no effect when using ssl')
+ logger.addFilter(filter)
+ try:
+ yield
+ finally:
+ logger.removeFilter(filter)
+
+
+class UVTestCase(BaseTestCase):
+
+ implementation = 'uvloop'
+
+ def new_loop(self):
+ return uvloop.new_event_loop()
+
+ def new_policy(self):
+ return uvloop.EventLoopPolicy()
+
+
+class AIOTestCase(BaseTestCase):
+
+ implementation = 'asyncio'
+
+ def setUp(self):
+ super().setUp()
+
+ if sys.version_info < (3, 12):
+ watcher = asyncio.SafeChildWatcher()
+ watcher.attach_loop(self.loop)
+ asyncio.set_child_watcher(watcher)
+
+ def tearDown(self):
+ if sys.version_info < (3, 12):
+ asyncio.set_child_watcher(None)
+ super().tearDown()
+
+ def new_loop(self):
+ return asyncio.new_event_loop()
+
+ def new_policy(self):
+ return asyncio.DefaultEventLoopPolicy()
+
+
+def has_IPv6():
+ server_sock = socket.socket(socket.AF_INET6)
+ with server_sock:
+ try:
+ server_sock.bind(('::1', 0))
+ except OSError:
+ return False
+ else:
+ return True
+
+
+has_IPv6 = has_IPv6()
+
+
+###############################################################################
+# Socket Testing Utilities
+###############################################################################
+
+
+class TestSocketWrapper:
+
+ def __init__(self, sock):
+ self.__sock = sock
+
+ def recv_all(self, n):
+ buf = b''
+ while len(buf) < n:
+ data = self.recv(n - len(buf))
+ if data == b'':
+ raise ConnectionAbortedError
+ buf += data
+ return buf
+
+ def starttls(self, ssl_context, *,
+ server_side=False,
+ server_hostname=None,
+ do_handshake_on_connect=True):
+
+ assert isinstance(ssl_context, ssl.SSLContext)
+
+ ssl_sock = ssl_context.wrap_socket(
+ self.__sock, server_side=server_side,
+ server_hostname=server_hostname,
+ do_handshake_on_connect=do_handshake_on_connect)
+
+ if server_side:
+ ssl_sock.do_handshake()
+
+ self.__sock.close()
+ self.__sock = ssl_sock
+
+ def __getattr__(self, name):
+ return getattr(self.__sock, name)
+
+ def __repr__(self):
+ return '<{} {!r}>'.format(type(self).__name__, self.__sock)
+
+
+class SocketThread(threading.Thread):
+
+ def stop(self):
+ self._active = False
+ self.join()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *exc):
+ self.stop()
+
+
+class TestThreadedClient(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout):
+ threading.Thread.__init__(self, None, None, 'test-client')
+ self.daemon = True
+
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+ self._prog = prog
+ self._test = test
+
+ def run(self):
+ try:
+ self._prog(TestSocketWrapper(self._sock))
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._test._abort_socket_test(ex)
+
+
+class TestThreadedServer(SocketThread):
+
+ def __init__(self, test, sock, prog, timeout, max_clients):
+ threading.Thread.__init__(self, None, None, 'test-server')
+ self.daemon = True
+
+ self._clients = 0
+ self._finished_clients = 0
+ self._max_clients = max_clients
+ self._timeout = timeout
+ self._sock = sock
+ self._active = True
+
+ self._prog = prog
+
+ self._s1, self._s2 = socket.socketpair()
+ self._s1.setblocking(False)
+
+ self._test = test
+
+ def stop(self):
+ try:
+ if self._s2 and self._s2.fileno() != -1:
+ try:
+ self._s2.send(b'stop')
+ except OSError:
+ pass
+ finally:
+ super().stop()
+
+ def run(self):
+ try:
+ with self._sock:
+ self._sock.setblocking(0)
+ self._run()
+ finally:
+ self._s1.close()
+ self._s2.close()
+
+ def _run(self):
+ while self._active:
+ if self._clients >= self._max_clients:
+ return
+
+ r, w, x = select.select(
+ [self._sock, self._s1], [], [], self._timeout)
+
+ if self._s1 in r:
+ return
+
+ if self._sock in r:
+ try:
+ conn, addr = self._sock.accept()
+ except BlockingIOError:
+ continue
+ except socket.timeout:
+ if not self._active:
+ return
+ else:
+ raise
+ else:
+ self._clients += 1
+ conn.settimeout(self._timeout)
+ try:
+ with conn:
+ self._handle_client(conn)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except BaseException as ex:
+ self._active = False
+ try:
+ raise
+ finally:
+ self._test._abort_socket_test(ex)
+
+ def _handle_client(self, sock):
+ self._prog(TestSocketWrapper(sock))
+
+ @property
+ def addr(self):
+ return self._sock.getsockname()
+
+
+###############################################################################
+# A few helpers from asyncio/tests/testutils.py
+###############################################################################
+
+
+def run_briefly(loop):
+ async def once():
+ pass
+ gen = once()
+ t = loop.create_task(gen)
+ # Don't log a warning if the task is not done after run_until_complete().
+ # It occurs if the loop is stopped or if a task raises a BaseException.
+ t._log_destroy_pending = False
+ try:
+ loop.run_until_complete(t)
+ finally:
+ gen.close()
+
+
+def run_until(loop, pred, timeout=30):
+ deadline = time.time() + timeout
+ while not pred():
+ if timeout is not None:
+ timeout = deadline - time.time()
+ if timeout <= 0:
+ raise asyncio.futures.TimeoutError()
+ loop.run_until_complete(asyncio.tasks.sleep(0.001))
+
+
+@contextlib.contextmanager
+def disable_logger():
+ """Context manager to disable asyncio logger.
+
+ For example, it can be used to ignore warnings in debug mode.
+ """
+ old_level = asyncio.log.logger.level
+ try:
+ asyncio.log.logger.setLevel(logging.CRITICAL + 1)
+ yield
+ finally:
+ asyncio.log.logger.setLevel(old_level)