From 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:10:44 -0400 Subject: venv --- .../python3.11/site-packages/uvloop/_testbase.py | 550 +++++++++++++++++++++ 1 file changed, 550 insertions(+) create mode 100644 venv/lib/python3.11/site-packages/uvloop/_testbase.py (limited to 'venv/lib/python3.11/site-packages/uvloop/_testbase.py') diff --git a/venv/lib/python3.11/site-packages/uvloop/_testbase.py b/venv/lib/python3.11/site-packages/uvloop/_testbase.py new file mode 100644 index 0000000..c4a7595 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvloop/_testbase.py @@ -0,0 +1,550 @@ +"""Test utilities. Don't use outside of the uvloop project.""" + + +import asyncio +import asyncio.events +import collections +import contextlib +import gc +import logging +import os +import pprint +import re +import select +import socket +import ssl +import sys +import tempfile +import threading +import time +import unittest +import uvloop + + +class MockPattern(str): + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) + + +class TestCaseDict(collections.UserDict): + + def __init__(self, name): + super().__init__() + self.name = name + + def __setitem__(self, key, value): + if key in self.data: + raise RuntimeError('duplicate test {}.{}'.format( + self.name, key)) + super().__setitem__(key, value) + + +class BaseTestCaseMeta(type): + + @classmethod + def __prepare__(mcls, name, bases): + return TestCaseDict(name) + + def __new__(mcls, name, bases, dct): + for test_name in dct: + if not test_name.startswith('test_'): + continue + for base in bases: + if hasattr(base, test_name): + raise RuntimeError( + 'duplicate test {}.{} (also defined in {} ' + 'parent class)'.format( + name, test_name, base.__name__)) + + return super().__new__(mcls, name, bases, dict(dct)) + + +class BaseTestCase(unittest.TestCase, metaclass=BaseTestCaseMeta): + + def new_loop(self): + raise NotImplementedError + + def new_policy(self): + raise NotImplementedError + + def mock_pattern(self, str): + return MockPattern(str) + + async def wait_closed(self, obj): + if not isinstance(obj, asyncio.StreamWriter): + return + try: + await obj.wait_closed() + except (BrokenPipeError, ConnectionError): + pass + + def is_asyncio_loop(self): + return type(self.loop).__module__.startswith('asyncio.') + + def run_loop_briefly(self, *, delay=0.01): + self.loop.run_until_complete(asyncio.sleep(delay)) + + def loop_exception_handler(self, loop, context): + self.__unhandled_exceptions.append(context) + self.loop.default_exception_handler(context) + + def setUp(self): + self.loop = self.new_loop() + asyncio.set_event_loop_policy(self.new_policy()) + asyncio.set_event_loop(self.loop) + self._check_unclosed_resources_in_debug = True + + self.loop.set_exception_handler(self.loop_exception_handler) + self.__unhandled_exceptions = [] + + def tearDown(self): + self.loop.close() + + if self.__unhandled_exceptions: + print('Unexpected calls to loop.call_exception_handler():') + pprint.pprint(self.__unhandled_exceptions) + self.fail('unexpected calls to loop.call_exception_handler()') + return + + if not self._check_unclosed_resources_in_debug: + return + + # GC to show any resource warnings as the test completes + gc.collect() + gc.collect() + gc.collect() + + if getattr(self.loop, '_debug_cc', False): + gc.collect() + gc.collect() + gc.collect() + + self.assertEqual( + self.loop._debug_uv_handles_total, + self.loop._debug_uv_handles_freed, + 'not all uv_handle_t handles were freed') + + self.assertEqual( + self.loop._debug_cb_handles_count, 0, + 'not all callbacks (call_soon) are GCed') + + self.assertEqual( + self.loop._debug_cb_timer_handles_count, 0, + 'not all timer callbacks (call_later) are GCed') + + self.assertEqual( + self.loop._debug_stream_write_ctx_cnt, 0, + 'not all stream write contexts are GCed') + + for h_name, h_cnt in self.loop._debug_handles_current.items(): + with self.subTest('Alive handle after test', + handle_name=h_name): + self.assertEqual( + h_cnt, 0, + 'alive {} after test'.format(h_name)) + + for h_name, h_cnt in self.loop._debug_handles_total.items(): + with self.subTest('Total/closed handles', + handle_name=h_name): + self.assertEqual( + h_cnt, self.loop._debug_handles_closed[h_name], + 'total != closed for {}'.format(h_name)) + + asyncio.set_event_loop(None) + asyncio.set_event_loop_policy(None) + self.loop = None + + def skip_unclosed_handles_check(self): + self._check_unclosed_resources_in_debug = False + + def tcp_server(self, server_prog, *, + family=socket.AF_INET, + addr=None, + timeout=5, + backlog=1, + max_clients=10): + + if addr is None: + if family == socket.AF_UNIX: + with tempfile.NamedTemporaryFile() as tmp: + addr = tmp.name + else: + addr = ('127.0.0.1', 0) + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + try: + sock.bind(addr) + sock.listen(backlog) + except OSError as ex: + sock.close() + raise ex + + return TestThreadedServer( + self, sock, server_prog, timeout, max_clients) + + def tcp_client(self, client_prog, + family=socket.AF_INET, + timeout=10): + + sock = socket.socket(family, socket.SOCK_STREAM) + + if timeout is None: + raise RuntimeError('timeout is required') + if timeout <= 0: + raise RuntimeError('only blocking sockets are supported') + sock.settimeout(timeout) + + return TestThreadedClient( + self, sock, client_prog, timeout) + + def unix_server(self, *args, **kwargs): + return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) + + def unix_client(self, *args, **kwargs): + return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) + + @contextlib.contextmanager + def unix_sock_name(self): + with tempfile.TemporaryDirectory() as td: + fn = os.path.join(td, 'sock') + try: + yield fn + finally: + try: + os.unlink(fn) + except OSError: + pass + + def _abort_socket_test(self, ex): + try: + self.loop.stop() + finally: + self.fail(ex) + + +def _cert_fullname(test_file_name, cert_file_name): + fullname = os.path.abspath(os.path.join( + os.path.dirname(test_file_name), 'certs', cert_file_name)) + assert os.path.isfile(fullname) + return fullname + + +@contextlib.contextmanager +def silence_long_exec_warning(): + + class Filter(logging.Filter): + def filter(self, record): + return not (record.msg.startswith('Executing') and + record.msg.endswith('seconds')) + + logger = logging.getLogger('asyncio') + filter = Filter() + logger.addFilter(filter) + try: + yield + finally: + logger.removeFilter(filter) + + +def find_free_port(start_from=50000): + for port in range(start_from, start_from + 500): + sock = socket.socket() + with sock: + try: + sock.bind(('', port)) + except socket.error: + continue + else: + return port + raise RuntimeError('could not find a free port') + + +class SSLTestCase: + + def _create_server_ssl_context(self, certfile, keyfile=None): + if hasattr(ssl, 'PROTOCOL_TLS'): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS) + else: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _create_client_ssl_context(self, *, disable_verify=True): + sslcontext = ssl.create_default_context() + sslcontext.check_hostname = False + if disable_verify: + sslcontext.verify_mode = ssl.CERT_NONE + return sslcontext + + @contextlib.contextmanager + def _silence_eof_received_warning(self): + # TODO This warning has to be fixed in asyncio. + logger = logging.getLogger('asyncio') + filter = logging.Filter('has no effect when using ssl') + logger.addFilter(filter) + try: + yield + finally: + logger.removeFilter(filter) + + +class UVTestCase(BaseTestCase): + + implementation = 'uvloop' + + def new_loop(self): + return uvloop.new_event_loop() + + def new_policy(self): + return uvloop.EventLoopPolicy() + + +class AIOTestCase(BaseTestCase): + + implementation = 'asyncio' + + def setUp(self): + super().setUp() + + if sys.version_info < (3, 12): + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + if sys.version_info < (3, 12): + asyncio.set_child_watcher(None) + super().tearDown() + + def new_loop(self): + return asyncio.new_event_loop() + + def new_policy(self): + return asyncio.DefaultEventLoopPolicy() + + +def has_IPv6(): + server_sock = socket.socket(socket.AF_INET6) + with server_sock: + try: + server_sock.bind(('::1', 0)) + except OSError: + return False + else: + return True + + +has_IPv6 = has_IPv6() + + +############################################################################### +# Socket Testing Utilities +############################################################################### + + +class TestSocketWrapper: + + def __init__(self, sock): + self.__sock = sock + + def recv_all(self, n): + buf = b'' + while len(buf) < n: + data = self.recv(n - len(buf)) + if data == b'': + raise ConnectionAbortedError + buf += data + return buf + + def starttls(self, ssl_context, *, + server_side=False, + server_hostname=None, + do_handshake_on_connect=True): + + assert isinstance(ssl_context, ssl.SSLContext) + + ssl_sock = ssl_context.wrap_socket( + self.__sock, server_side=server_side, + server_hostname=server_hostname, + do_handshake_on_connect=do_handshake_on_connect) + + if server_side: + ssl_sock.do_handshake() + + self.__sock.close() + self.__sock = ssl_sock + + def __getattr__(self, name): + return getattr(self.__sock, name) + + def __repr__(self): + return '<{} {!r}>'.format(type(self).__name__, self.__sock) + + +class SocketThread(threading.Thread): + + def stop(self): + self._active = False + self.join() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *exc): + self.stop() + + +class TestThreadedClient(SocketThread): + + def __init__(self, test, sock, prog, timeout): + threading.Thread.__init__(self, None, None, 'test-client') + self.daemon = True + + self._timeout = timeout + self._sock = sock + self._active = True + self._prog = prog + self._test = test + + def run(self): + try: + self._prog(TestSocketWrapper(self._sock)) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._test._abort_socket_test(ex) + + +class TestThreadedServer(SocketThread): + + def __init__(self, test, sock, prog, timeout, max_clients): + threading.Thread.__init__(self, None, None, 'test-server') + self.daemon = True + + self._clients = 0 + self._finished_clients = 0 + self._max_clients = max_clients + self._timeout = timeout + self._sock = sock + self._active = True + + self._prog = prog + + self._s1, self._s2 = socket.socketpair() + self._s1.setblocking(False) + + self._test = test + + def stop(self): + try: + if self._s2 and self._s2.fileno() != -1: + try: + self._s2.send(b'stop') + except OSError: + pass + finally: + super().stop() + + def run(self): + try: + with self._sock: + self._sock.setblocking(0) + self._run() + finally: + self._s1.close() + self._s2.close() + + def _run(self): + while self._active: + if self._clients >= self._max_clients: + return + + r, w, x = select.select( + [self._sock, self._s1], [], [], self._timeout) + + if self._s1 in r: + return + + if self._sock in r: + try: + conn, addr = self._sock.accept() + except BlockingIOError: + continue + except socket.timeout: + if not self._active: + return + else: + raise + else: + self._clients += 1 + conn.settimeout(self._timeout) + try: + with conn: + self._handle_client(conn) + except (KeyboardInterrupt, SystemExit): + raise + except BaseException as ex: + self._active = False + try: + raise + finally: + self._test._abort_socket_test(ex) + + def _handle_client(self, sock): + self._prog(TestSocketWrapper(sock)) + + @property + def addr(self): + return self._sock.getsockname() + + +############################################################################### +# A few helpers from asyncio/tests/testutils.py +############################################################################### + + +def run_briefly(loop): + async def once(): + pass + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_until(loop, pred, timeout=30): + deadline = time.time() + timeout + while not pred(): + if timeout is not None: + timeout = deadline - time.time() + if timeout <= 0: + raise asyncio.futures.TimeoutError() + loop.run_until_complete(asyncio.tasks.sleep(0.001)) + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = asyncio.log.logger.level + try: + asyncio.log.logger.setLevel(logging.CRITICAL + 1) + yield + finally: + asyncio.log.logger.setLevel(old_level) -- cgit v1.2.3