diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/websockets/legacy | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/websockets/legacy')
20 files changed, 4538 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__init__.py b/venv/lib/python3.11/site-packages/websockets/legacy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__init__.py diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5384f2b --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/async_timeout.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/async_timeout.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f3e18ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/async_timeout.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/auth.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/auth.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e572023 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/auth.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/client.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/client.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8ab650d --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/client.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/compatibility.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/compatibility.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e65d9ec --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/compatibility.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/framing.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/framing.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..47c4426 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/framing.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/handshake.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/handshake.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0db3ae7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/handshake.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/http.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/http.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2199c73 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/http.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/protocol.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/protocol.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..002cb1f --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/protocol.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/server.cpython-311.pyc b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/server.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..6319b6e --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/__pycache__/server.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/async_timeout.py b/venv/lib/python3.11/site-packages/websockets/legacy/async_timeout.py new file mode 100644 index 0000000..8264094 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/async_timeout.py @@ -0,0 +1,265 @@ +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Optional, Type + + +# From https://github.com/python/typing_extensions/blob/main/src/typing_extensions.py +# Licensed under the Python Software Foundation License (PSF-2.0) + +if sys.version_info >= (3, 11): + from typing import final +else: + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py + +__version__ = "4.0.2" + + +__all__ = ("timeout", "timeout_at", "Timeout") + + +def timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + if delay is not None: + deadline = loop.time() + delay # type: Optional[float] + else: + deadline = None + return Timeout(deadline, loop) + + +def timeout_at(deadline: Optional[float]) -> "Timeout": + """Schedule the timeout at absolute time. + + deadline argument points on the time in the same clock system + as loop.time(). + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + + >>> async with timeout_at(loop.time() + 10): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + + + """ + loop = asyncio.get_running_loop() + return Timeout(deadline, loop) + + +class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + +@final +class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__( + self, deadline: Optional[float], loop: asyncio.AbstractEventLoop + ) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout() instead", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + + The delay can be negative. + + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError("cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + + deadline argument points on the time in the same clock system + as loop.time(). + + If new deadline is in the past the timeout is raised immediately. + + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError("cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon(self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at(deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "asyncio.Task[None]") -> None: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None + + +# End https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/auth.py b/venv/lib/python3.11/site-packages/websockets/legacy/auth.py new file mode 100644 index 0000000..d342583 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/auth.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import functools +import hmac +import http +from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast + +from ..datastructures import Headers +from ..exceptions import InvalidHeader +from ..headers import build_www_authenticate_basic, parse_authorization_basic +from .server import HTTPResponse, WebSocketServerProtocol + + +__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] + +Credentials = Tuple[str, str] + + +def is_credentials(value: Any) -> bool: + try: + username, password = value + except (TypeError, ValueError): + return False + else: + return isinstance(username, str) and isinstance(password, str) + + +class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): + """ + WebSocket server protocol that enforces HTTP Basic Auth. + + """ + + realm: str = "" + """ + Scope of protection. + + If provided, it should contain only ASCII characters because the + encoding of non-ASCII characters is undefined. + """ + + username: Optional[str] = None + """Username of the authenticated user.""" + + def __init__( + self, + *args: Any, + realm: Optional[str] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + **kwargs: Any, + ) -> None: + if realm is not None: + self.realm = realm # shadow class attribute + self._check_credentials = check_credentials + super().__init__(*args, **kwargs) + + async def check_credentials(self, username: str, password: str) -> bool: + """ + Check whether credentials are authorized. + + This coroutine may be overridden in a subclass, for example to + authenticate against a database or an external service. + + Args: + username: HTTP Basic Auth username. + password: HTTP Basic Auth password. + + Returns: + bool: :obj:`True` if the handshake should continue; + :obj:`False` if it should fail with an HTTP 401 error. + + """ + if self._check_credentials is not None: + return await self._check_credentials(username, password) + + return False + + async def process_request( + self, + path: str, + request_headers: Headers, + ) -> Optional[HTTPResponse]: + """ + Check HTTP Basic Auth and return an HTTP 401 response if needed. + + """ + try: + authorization = request_headers["Authorization"] + except KeyError: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Missing credentials\n", + ) + + try: + username, password = parse_authorization_basic(authorization) + except InvalidHeader: + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Unsupported credentials\n", + ) + + if not await self.check_credentials(username, password): + return ( + http.HTTPStatus.UNAUTHORIZED, + [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], + b"Invalid credentials\n", + ) + + self.username = username + + return await super().process_request(path, request_headers) + + +def basic_auth_protocol_factory( + realm: Optional[str] = None, + credentials: Optional[Union[Credentials, Iterable[Credentials]]] = None, + check_credentials: Optional[Callable[[str, str], Awaitable[bool]]] = None, + create_protocol: Optional[Callable[..., BasicAuthWebSocketServerProtocol]] = None, +) -> Callable[..., BasicAuthWebSocketServerProtocol]: + """ + Protocol factory that enforces HTTP Basic Auth. + + :func:`basic_auth_protocol_factory` is designed to integrate with + :func:`~websockets.server.serve` like this:: + + websockets.serve( + ..., + create_protocol=websockets.basic_auth_protocol_factory( + realm="my dev server", + credentials=("hello", "iloveyou"), + ) + ) + + Args: + realm: Scope of protection. It should contain only ASCII characters + because the encoding of non-ASCII characters is undefined. + Refer to section 2.2 of :rfc:`7235` for details. + credentials: Hard coded authorized credentials. It can be a + ``(username, password)`` pair or a list of such pairs. + check_credentials: Coroutine that verifies credentials. + It receives ``username`` and ``password`` arguments + and returns a :class:`bool`. One of ``credentials`` or + ``check_credentials`` must be provided but not both. + create_protocol: Factory that creates the protocol. By default, this + is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced + by a subclass. + Raises: + TypeError: If the ``credentials`` or ``check_credentials`` argument is + wrong. + + """ + if (credentials is None) == (check_credentials is None): + raise TypeError("provide either credentials or check_credentials") + + if credentials is not None: + if is_credentials(credentials): + credentials_list = [cast(Credentials, credentials)] + elif isinstance(credentials, Iterable): + credentials_list = list(credentials) + if not all(is_credentials(item) for item in credentials_list): + raise TypeError(f"invalid credentials argument: {credentials}") + else: + raise TypeError(f"invalid credentials argument: {credentials}") + + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + + if create_protocol is None: + create_protocol = BasicAuthWebSocketServerProtocol + + return functools.partial( + create_protocol, + realm=realm, + check_credentials=check_credentials, + ) diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/client.py b/venv/lib/python3.11/site-packages/websockets/legacy/client.py new file mode 100644 index 0000000..4862252 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/client.py @@ -0,0 +1,705 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import random +import urllib.parse +import warnings +from types import TracebackType +from typing import ( + Any, + AsyncIterator, + Callable, + Generator, + List, + Optional, + Sequence, + Tuple, + Type, + cast, +) + +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidStatusCode, + NegotiationError, + RedirectHandshake, + SecurityError, +) +from ..extensions import ClientExtensionFactory, Extension +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import ( + build_authorization_basic, + build_extension, + build_host, + build_subprotocol, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) +from ..http import USER_AGENT +from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .compatibility import asyncio_timeout +from .handshake import build_request, check_response +from .http import read_response +from .protocol import WebSocketCommonProtocol + + +__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] + + +class WebSocketClientProtocol(WebSocketCommonProtocol): + """ + WebSocket client connection. + + :class:`WebSocketClientProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. + + It supports asynchronous iteration to receive incoming messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + See :func:`connect` for the documentation of ``logger``, ``origin``, + ``extensions``, ``subprotocols``, ``extra_headers``, and + ``user_agent_header``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + """ + + is_client = True + side = "client" + + def __init__( + self, + *, + logger: Optional[LoggerLike] = None, + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + **kwargs: Any, + ) -> None: + if logger is None: + logger = logging.getLogger("websockets.client") + super().__init__(logger=logger, **kwargs) + self.origin = origin + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self.user_agent_header = user_agent_header + + def write_http_request(self, path: str, headers: Headers) -> None: + """ + Write request line and headers to the HTTP request. + + """ + self.path = path + self.request_headers = headers + + if self.debug: + self.logger.debug("> GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + + # Since the path and headers only contain ASCII characters, + # we can keep this simple. + request = f"GET {path} HTTP/1.1\r\n" + request += str(headers) + + self.transport.write(request.encode()) + + async def read_http_response(self) -> Tuple[int, Headers]: + """ + Read status line and headers from the HTTP response. + + If the response contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + Raises: + InvalidMessage: If the HTTP message is malformed or isn't an + HTTP/1.1 GET response. + + """ + try: + status_code, reason, headers = await read_response(self.reader) + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP response") from exc + + if self.debug: + self.logger.debug("< HTTP/1.1 %d %s", status_code, reason) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.response_headers = headers + + return status_code, self.response_headers + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Optional[Sequence[ClientExtensionFactory]], + ) -> List[Extension]: + """ + Handle the Sec-WebSocket-Extensions HTTP response header. + + Check that each extension is supported, as well as its parameters. + + Return the list of accepted extensions. + + Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the + connection. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension accepted by + the server, we check for a match with each extension available in the + client configuration. If no match is found, an exception is raised. + + If several variants of the same extension are accepted by the server, + it may be configured several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + """ + accepted_extensions: List[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values: + if available_extensions is None: + raise InvalidHandshake("no extensions supported") + + parsed_header_values: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, response_params in parsed_header_values: + for extension_factory in available_extensions: + # Skip non-matching extensions based on their name. + if extension_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + extension = extension_factory.process_response_params( + response_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the server sent. Fail the connection. + else: + raise NegotiationError( + f"Unsupported extension: " + f"name = {name}, params = {response_params}" + ) + + return accepted_extensions + + @staticmethod + def process_subprotocol( + headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP response header. + + Check that it contains exactly one supported subprotocol. + + Return the selected subprotocol. + + """ + subprotocol: Optional[Subprotocol] = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values: + if available_subprotocols is None: + raise InvalidHandshake("no subprotocols supported") + + parsed_header_values: Sequence[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + if len(parsed_header_values) > 1: + subprotocols = ", ".join(parsed_header_values) + raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") + + subprotocol = parsed_header_values[0] + + if subprotocol not in available_subprotocols: + raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + + return subprotocol + + async def handshake( + self, + wsuri: WebSocketURI, + origin: Optional[Origin] = None, + available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + ) -> None: + """ + Perform the client side of the opening handshake. + + Args: + wsuri: URI of the WebSocket server. + origin: Value of the ``Origin`` header. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. + + Raises: + InvalidHandshake: If the handshake fails. + + """ + request_headers = Headers() + + request_headers["Host"] = build_host(wsuri.host, wsuri.port, wsuri.secure) + + if wsuri.user_info: + request_headers["Authorization"] = build_authorization_basic( + *wsuri.user_info + ) + + if origin is not None: + request_headers["Origin"] = origin + + key = build_request(request_headers) + + if available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in available_extensions + ] + ) + request_headers["Sec-WebSocket-Extensions"] = extensions_header + + if available_subprotocols is not None: + protocol_header = build_subprotocol(available_subprotocols) + request_headers["Sec-WebSocket-Protocol"] = protocol_header + + if self.extra_headers is not None: + request_headers.update(self.extra_headers) + + if self.user_agent_header is not None: + request_headers.setdefault("User-Agent", self.user_agent_header) + + self.write_http_request(wsuri.resource_name, request_headers) + + status_code, response_headers = await self.read_http_response() + if status_code in (301, 302, 303, 307, 308): + if "Location" not in response_headers: + raise InvalidHeader("Location") + raise RedirectHandshake(response_headers["Location"]) + elif status_code != 101: + raise InvalidStatusCode(status_code, response_headers) + + check_response(response_headers, key) + + self.extensions = self.process_extensions( + response_headers, available_extensions + ) + + self.subprotocol = self.process_subprotocol( + response_headers, available_subprotocols + ) + + self.connection_open() + + +class Connect: + """ + Connect to the WebSocket server at ``uri``. + + Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which + can then be used to send and receive messages. + + :func:`connect` can be used as a asynchronous context manager:: + + async with websockets.connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in websockets.connect(...): + try: + ... + except websockets.ConnectionClosed: + continue + + The connection is closed automatically after each iteration of the loop. + + If an error occurs while establishing the connection, :func:`connect` + retries with exponential backoff. The backoff delay starts at three + seconds and increases up to one minute. + + If an error occurs in the body of the loop, you can handle the exception + and :func:`connect` will reconnect with the next iteration; or you can + let the exception bubble up and break out of the loop. This lets you + decide which errors trigger a reconnection and which errors are fatal. + + Args: + uri: URI of the WebSocket server. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketClientProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers: Arbitrary HTTP headers to add to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_connection` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS + settings. When connecting to a ``wss://`` URI, if ``ssl`` isn't + provided, a TLS context is created + with :func:`~ssl.create_default_context`. + + * You can set ``host`` and ``port`` to connect to a different host and + port from those found in ``uri``. This only changes the destination of + the TCP connection. The host name from ``uri`` is still used in the TLS + handshake for secure connections and in the ``Host`` header. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + ~asyncio.TimeoutError: If the opening handshake times out. + + """ + + MAX_REDIRECTS_ALLOWED = 10 + + def __init__( + self, + uri: str, + *, + create_protocol: Optional[Callable[..., WebSocketClientProtocol]] = None, + logger: Optional[LoggerLike] = None, + compression: Optional[str] = "deflate", + origin: Optional[Origin] = None, + extensions: Optional[Sequence[ClientExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLike] = None, + user_agent_header: Optional[str] = USER_AGENT, + open_timeout: Optional[float] = 10, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + timeout: Optional[float] = kwargs.pop("timeout", None) + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + klass: Optional[Type[WebSocketClientProtocol]] = kwargs.pop("klass", None) + if klass is None: + klass = WebSocketClientProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + + # Backwards compatibility: the loop parameter used to be supported. + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: + loop = asyncio.get_event_loop() + else: + loop = _loop + warnings.warn("remove loop argument", DeprecationWarning) + + wsuri = parse_uri(uri) + if wsuri.secure: + kwargs.setdefault("ssl", True) + elif kwargs.get("ssl") is not None: + raise ValueError( + "connect() received a ssl argument for a ws:// URI, " + "use a wss:// URI to enable TLS" + ) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + factory = functools.partial( + create_protocol, + logger=logger, + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + user_agent_header=user_agent_header, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + host=wsuri.host, + port=wsuri.port, + secure=wsuri.secure, + legacy_recv=legacy_recv, + loop=_loop, + ) + + if kwargs.pop("unix", False): + path: Optional[str] = kwargs.pop("path", None) + create_connection = functools.partial( + loop.create_unix_connection, factory, path, **kwargs + ) + else: + host: Optional[str] + port: Optional[int] + if kwargs.get("sock") is None: + host, port = wsuri.host, wsuri.port + else: + # If sock is given, host and port shouldn't be specified. + host, port = None, None + if kwargs.get("ssl"): + kwargs.setdefault("server_hostname", wsuri.host) + # If host and port are given, override values from the URI. + host = kwargs.pop("host", host) + port = kwargs.pop("port", port) + create_connection = functools.partial( + loop.create_connection, factory, host, port, **kwargs + ) + + self.open_timeout = open_timeout + if logger is None: + logger = logging.getLogger("websockets.client") + self.logger = logger + + # This is a coroutine function. + self._create_connection = create_connection + self._uri = uri + self._wsuri = wsuri + + def handle_redirect(self, uri: str) -> None: + # Update the state of this instance to connect to a new URI. + old_uri = self._uri + old_wsuri = self._wsuri + new_uri = urllib.parse.urljoin(old_uri, uri) + new_wsuri = parse_uri(new_uri) + + # Forbid TLS downgrade. + if old_wsuri.secure and not new_wsuri.secure: + raise SecurityError("redirect from WSS to WS") + + same_origin = ( + old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port + ) + + # Rewrite the host and port arguments for cross-origin redirects. + # This preserves connection overrides with the host and port + # arguments if the redirect points to the same host and port. + if not same_origin: + # Replace the host and port argument passed to the protocol factory. + factory = self._create_connection.args[0] + factory = functools.partial( + factory.func, + *factory.args, + **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), + ) + # Replace the host and port argument passed to create_connection. + self._create_connection = functools.partial( + self._create_connection.func, + *(factory, new_wsuri.host, new_wsuri.port), + **self._create_connection.keywords, + ) + + # Set the new WebSocket URI. This suffices for same-origin redirects. + self._uri = new_uri + self._wsuri = new_wsuri + + # async for ... in connect(...): + + BACKOFF_MIN = 1.92 + BACKOFF_MAX = 60.0 + BACKOFF_FACTOR = 1.618 + BACKOFF_INITIAL = 5 + + async def __aiter__(self) -> AsyncIterator[WebSocketClientProtocol]: + backoff_delay = self.BACKOFF_MIN + while True: + try: + async with self as protocol: + yield protocol + except Exception: + # Add a random initial delay between 0 and 5 seconds. + # See 7.2.3. Recovering from Abnormal Closure in RFC 6544. + if backoff_delay == self.BACKOFF_MIN: + initial_delay = random.random() * self.BACKOFF_INITIAL + self.logger.info( + "! connect failed; reconnecting in %.1f seconds", + initial_delay, + exc_info=True, + ) + await asyncio.sleep(initial_delay) + else: + self.logger.info( + "! connect failed again; retrying in %d seconds", + int(backoff_delay), + exc_info=True, + ) + await asyncio.sleep(int(backoff_delay)) + # Increase delay with truncated exponential backoff. + backoff_delay = backoff_delay * self.BACKOFF_FACTOR + backoff_delay = min(backoff_delay, self.BACKOFF_MAX) + continue + else: + # Connection succeeded - reset backoff delay + backoff_delay = self.BACKOFF_MIN + + # async with connect(...) as ...: + + async def __aenter__(self) -> WebSocketClientProtocol: + return await self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.protocol.close() + + # ... = await connect(...) + + def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl_timeout__().__await__() + + async def __await_impl_timeout__(self) -> WebSocketClientProtocol: + async with asyncio_timeout(self.open_timeout): + return await self.__await_impl__() + + async def __await_impl__(self) -> WebSocketClientProtocol: + for redirects in range(self.MAX_REDIRECTS_ALLOWED): + _transport, _protocol = await self._create_connection() + protocol = cast(WebSocketClientProtocol, _protocol) + try: + await protocol.handshake( + self._wsuri, + origin=protocol.origin, + available_extensions=protocol.available_extensions, + available_subprotocols=protocol.available_subprotocols, + extra_headers=protocol.extra_headers, + ) + except RedirectHandshake as exc: + protocol.fail_connection() + await protocol.wait_closed() + self.handle_redirect(exc.uri) + # Avoid leaking a connected socket when the handshake fails. + except (Exception, asyncio.CancelledError): + protocol.fail_connection() + await protocol.wait_closed() + raise + else: + self.protocol = protocol + return protocol + else: + raise SecurityError("too many redirects") + + # ... = yield from connect(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +connect = Connect + + +def unix_connect( + path: Optional[str] = None, + uri: str = "ws://localhost/", + **kwargs: Any, +) -> Connect: + """ + Similar to :func:`connect`, but for connecting to a Unix socket. + + This function builds upon the event loop's + :meth:`~asyncio.loop.create_unix_connection` method. + + It is only available on Unix. + + It's mainly useful for debugging servers listening on Unix sockets. + + Args: + path: File system path to the Unix socket. + uri: URI of the WebSocket server; the host is used in the TLS + handshake for secure connections and in the ``Host`` header. + + """ + return connect(uri=uri, path=path, unix=True, **kwargs) diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/compatibility.py b/venv/lib/python3.11/site-packages/websockets/legacy/compatibility.py new file mode 100644 index 0000000..6bd01e7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/compatibility.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import sys + + +__all__ = ["asyncio_timeout"] + + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout # noqa: F401 +else: + from .async_timeout import timeout as asyncio_timeout # noqa: F401 diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/framing.py b/venv/lib/python3.11/site-packages/websockets/legacy/framing.py new file mode 100644 index 0000000..b77b869 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/framing.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import struct +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Sequence, Tuple + +from .. import extensions, frames +from ..exceptions import PayloadTooBig, ProtocolError + + +try: + from ..speedups import apply_mask +except ImportError: + from ..utils import apply_mask + + +class Frame(NamedTuple): + fin: bool + opcode: frames.Opcode + data: bytes + rsv1: bool = False + rsv2: bool = False + rsv3: bool = False + + @property + def new_frame(self) -> frames.Frame: + return frames.Frame( + self.opcode, + self.data, + self.fin, + self.rsv1, + self.rsv2, + self.rsv3, + ) + + def __str__(self) -> str: + return str(self.new_frame) + + def check(self) -> None: + return self.new_frame.check() + + @classmethod + async def read( + cls, + reader: Callable[[int], Awaitable[bytes]], + *, + mask: bool, + max_size: Optional[int] = None, + extensions: Optional[Sequence[extensions.Extension]] = None, + ) -> Frame: + """ + Read a WebSocket frame. + + Args: + reader: Coroutine that reads exactly the requested number of + bytes, unless the end of file is reached. + mask: Whether the frame should be masked i.e. whether the read + happens on the server side. + max_size: Maximum payload size in bytes. + extensions: List of extensions, applied in reverse order. + + Raises: + PayloadTooBig: If the frame exceeds ``max_size``. + ProtocolError: If the frame contains incorrect values. + + """ + + # Read the header. + data = await reader(2) + head1, head2 = struct.unpack("!BB", data) + + # While not Pythonic, this is marginally faster than calling bool(). + fin = True if head1 & 0b10000000 else False + rsv1 = True if head1 & 0b01000000 else False + rsv2 = True if head1 & 0b00100000 else False + rsv3 = True if head1 & 0b00010000 else False + + try: + opcode = frames.Opcode(head1 & 0b00001111) + except ValueError as exc: + raise ProtocolError("invalid opcode") from exc + + if (True if head2 & 0b10000000 else False) != mask: + raise ProtocolError("incorrect masking") + + length = head2 & 0b01111111 + if length == 126: + data = await reader(2) + (length,) = struct.unpack("!H", data) + elif length == 127: + data = await reader(8) + (length,) = struct.unpack("!Q", data) + if max_size is not None and length > max_size: + raise PayloadTooBig(f"over size limit ({length} > {max_size} bytes)") + if mask: + mask_bits = await reader(4) + + # Read the data. + data = await reader(length) + if mask: + data = apply_mask(data, mask_bits) + + new_frame = frames.Frame(opcode, data, fin, rsv1, rsv2, rsv3) + + if extensions is None: + extensions = [] + for extension in reversed(extensions): + new_frame = extension.decode(new_frame, max_size=max_size) + + new_frame.check() + + return cls( + new_frame.fin, + new_frame.opcode, + new_frame.data, + new_frame.rsv1, + new_frame.rsv2, + new_frame.rsv3, + ) + + def write( + self, + write: Callable[[bytes], Any], + *, + mask: bool, + extensions: Optional[Sequence[extensions.Extension]] = None, + ) -> None: + """ + Write a WebSocket frame. + + Args: + frame: Frame to write. + write: Function that writes bytes. + mask: Whether the frame should be masked i.e. whether the write + happens on the client side. + extensions: List of extensions, applied in order. + + Raises: + ProtocolError: If the frame contains incorrect values. + + """ + # The frame is written in a single call to write in order to prevent + # TCP fragmentation. See #68 for details. This also makes it safe to + # send frames concurrently from multiple coroutines. + write(self.new_frame.serialize(mask=mask, extensions=extensions)) + + +# Backwards compatibility with previously documented public APIs +from ..frames import ( # noqa: E402, F401, I001 + Close, + prepare_ctrl as encode_data, + prepare_data, +) + + +def parse_close(data: bytes) -> Tuple[int, str]: + """ + Parse the payload from a close frame. + + Returns: + Close code and reason. + + Raises: + ProtocolError: If data is ill-formed. + UnicodeDecodeError: If the reason isn't valid UTF-8. + + """ + close = Close.parse(data) + return close.code, close.reason + + +def serialize_close(code: int, reason: str) -> bytes: + """ + Serialize the payload for a close frame. + + """ + return Close(code, reason).serialize() diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/handshake.py b/venv/lib/python3.11/site-packages/websockets/legacy/handshake.py new file mode 100644 index 0000000..ad8faf0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/handshake.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import base64 +import binascii +from typing import List + +from ..datastructures import Headers, MultipleValuesError +from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade +from ..headers import parse_connection, parse_upgrade +from ..typing import ConnectionOption, UpgradeProtocol +from ..utils import accept_key as accept, generate_key + + +__all__ = ["build_request", "check_request", "build_response", "check_response"] + + +def build_request(headers: Headers) -> str: + """ + Build a handshake request to send to the server. + + Update request headers passed in argument. + + Args: + headers: Handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`check_response`. + + """ + key = generate_key() + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Key"] = key + headers["Sec-WebSocket-Version"] = "13" + return key + + +def check_request(headers: Headers) -> str: + """ + Check a handshake request received from the client. + + This function doesn't verify that the request is an HTTP/1.1 or higher GET + request and doesn't perform ``Host`` and ``Origin`` checks. These controls + are usually performed earlier in the HTTP request handling code. They're + the responsibility of the caller. + + Args: + headers: Handshake request headers. + + Returns: + str: ``key`` that must be passed to :func:`build_response`. + + Raises: + InvalidHandshake: If the handshake request is invalid. + Then, the server must return a 400 Bad Request error. + + """ + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", ", ".join(connection)) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_key = headers["Sec-WebSocket-Key"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Key") from exc + except MultipleValuesError as exc: + raise InvalidHeader( + "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found" + ) from exc + + try: + raw_key = base64.b64decode(s_w_key.encode(), validate=True) + except binascii.Error as exc: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc + if len(raw_key) != 16: + raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) + + try: + s_w_version = headers["Sec-WebSocket-Version"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Version") from exc + except MultipleValuesError as exc: + raise InvalidHeader( + "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found" + ) from exc + + if s_w_version != "13": + raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version) + + return s_w_key + + +def build_response(headers: Headers, key: str) -> None: + """ + Build a handshake response to send to the client. + + Update response headers passed in argument. + + Args: + headers: Handshake response headers. + key: Returned by :func:`check_request`. + + """ + headers["Upgrade"] = "websocket" + headers["Connection"] = "Upgrade" + headers["Sec-WebSocket-Accept"] = accept(key) + + +def check_response(headers: Headers, key: str) -> None: + """ + Check a handshake response received from the server. + + This function doesn't verify that the response is an HTTP/1.1 or higher + response with a 101 status code. These controls are the responsibility of + the caller. + + Args: + headers: Handshake response headers. + key: Returned by :func:`build_request`. + + Raises: + InvalidHandshake: If the handshake response is invalid. + + """ + connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", " ".join(connection)) + + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. The RFC always uses "websocket", except + # in section 11.2. (IANA registration) where it uses "WebSocket". + if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): + raise InvalidUpgrade("Upgrade", ", ".join(upgrade)) + + try: + s_w_accept = headers["Sec-WebSocket-Accept"] + except KeyError as exc: + raise InvalidHeader("Sec-WebSocket-Accept") from exc + except MultipleValuesError as exc: + raise InvalidHeader( + "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found" + ) from exc + + if s_w_accept != accept(key): + raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/http.py b/venv/lib/python3.11/site-packages/websockets/legacy/http.py new file mode 100644 index 0000000..2ac7f70 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/http.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import asyncio +import re +from typing import Tuple + +from ..datastructures import Headers +from ..exceptions import SecurityError + + +__all__ = ["read_request", "read_response"] + +MAX_HEADERS = 128 +MAX_LINE = 8192 + + +def d(value: bytes) -> str: + """ + Decode a bytestring for interpolating into an error message. + + """ + return value.decode(errors="backslashreplace") + + +# See https://www.rfc-editor.org/rfc/rfc7230.html#appendix-B. + +# Regex for validating header names. + +_token_re = re.compile(rb"[-!#$%&\'*+.^_`|~0-9a-zA-Z]+") + +# Regex for validating header values. + +# We don't attempt to support obsolete line folding. + +# Include HTAB (\x09), SP (\x20), VCHAR (\x21-\x7e), obs-text (\x80-\xff). + +# The ABNF is complicated because it attempts to express that optional +# whitespace is ignored. We strip whitespace and don't revalidate that. + +# See also https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 + +_value_re = re.compile(rb"[\x09\x20-\x7e\x80-\xff]*") + + +async def read_request(stream: asyncio.StreamReader) -> Tuple[str, Headers]: + """ + Read an HTTP/1.1 GET request and return ``(path, headers)``. + + ``path`` isn't URL-decoded or validated in any way. + + ``path`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the request body because + WebSocket handshake requests don't have one. If the request contains a + body, it may be read from ``stream`` after this coroutine returns. + + Args: + stream: Input to read the request from. + + Raises: + EOFError: If the connection is closed without a full HTTP request. + SecurityError: If the request exceeds a security limit. + ValueError: If the request isn't well formatted. + + """ + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.1 + + # Parsing is simple because fixed values are expected for method and + # version and because path isn't checked. Since WebSocket software tends + # to implement HTTP/1.1 strictly, there's little need for lenient parsing. + + try: + request_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP request line") from exc + + try: + method, raw_path, version = request_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP request line: {d(request_line)}") from None + + if method != b"GET": + raise ValueError(f"unsupported HTTP method: {d(method)}") + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + path = raw_path.decode("ascii", "surrogateescape") + + headers = await read_headers(stream) + + return path, headers + + +async def read_response(stream: asyncio.StreamReader) -> Tuple[int, str, Headers]: + """ + Read an HTTP/1.1 response and return ``(status_code, reason, headers)``. + + ``reason`` and ``headers`` are expected to contain only ASCII characters. + Other characters are represented with surrogate escapes. + + :func:`read_request` doesn't attempt to read the response body because + WebSocket handshake responses don't have one. If the response contains a + body, it may be read from ``stream`` after this coroutine returns. + + Args: + stream: Input to read the response from. + + Raises: + EOFError: If the connection is closed without a full HTTP response. + SecurityError: If the response exceeds a security limit. + ValueError: If the response isn't well formatted. + + """ + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.1.2 + + # As in read_request, parsing is simple because a fixed value is expected + # for version, status_code is a 3-digit number, and reason can be ignored. + + try: + status_line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP status line") from exc + + try: + version, raw_status_code, raw_reason = status_line.split(b" ", 2) + except ValueError: # not enough values to unpack (expected 3, got 1-2) + raise ValueError(f"invalid HTTP status line: {d(status_line)}") from None + + if version != b"HTTP/1.1": + raise ValueError(f"unsupported HTTP version: {d(version)}") + try: + status_code = int(raw_status_code) + except ValueError: # invalid literal for int() with base 10 + raise ValueError(f"invalid HTTP status code: {d(raw_status_code)}") from None + if not 100 <= status_code < 1000: + raise ValueError(f"unsupported HTTP status code: {d(raw_status_code)}") + if not _value_re.fullmatch(raw_reason): + raise ValueError(f"invalid HTTP reason phrase: {d(raw_reason)}") + reason = raw_reason.decode() + + headers = await read_headers(stream) + + return status_code, reason, headers + + +async def read_headers(stream: asyncio.StreamReader) -> Headers: + """ + Read HTTP headers from ``stream``. + + Non-ASCII characters are represented with surrogate escapes. + + """ + # https://www.rfc-editor.org/rfc/rfc7230.html#section-3.2 + + # We don't attempt to support obsolete line folding. + + headers = Headers() + for _ in range(MAX_HEADERS + 1): + try: + line = await read_line(stream) + except EOFError as exc: + raise EOFError("connection closed while reading HTTP headers") from exc + if line == b"": + break + + try: + raw_name, raw_value = line.split(b":", 1) + except ValueError: # not enough values to unpack (expected 2, got 1) + raise ValueError(f"invalid HTTP header line: {d(line)}") from None + if not _token_re.fullmatch(raw_name): + raise ValueError(f"invalid HTTP header name: {d(raw_name)}") + raw_value = raw_value.strip(b" \t") + if not _value_re.fullmatch(raw_value): + raise ValueError(f"invalid HTTP header value: {d(raw_value)}") + + name = raw_name.decode("ascii") # guaranteed to be ASCII at this point + value = raw_value.decode("ascii", "surrogateescape") + headers[name] = value + + else: + raise SecurityError("too many HTTP headers") + + return headers + + +async def read_line(stream: asyncio.StreamReader) -> bytes: + """ + Read a single line from ``stream``. + + CRLF is stripped from the return value. + + """ + # Security: this is bounded by the StreamReader's limit (default = 32Â KiB). + line = await stream.readline() + # Security: this guarantees header values are small (hard-coded = 8Â KiB) + if len(line) > MAX_LINE: + raise SecurityError("line too long") + # Not mandatory but safe - https://www.rfc-editor.org/rfc/rfc7230.html#section-3.5 + if not line.endswith(b"\r\n"): + raise EOFError("line without CRLF") + return line[:-2] diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/protocol.py b/venv/lib/python3.11/site-packages/websockets/legacy/protocol.py new file mode 100644 index 0000000..19cee0e --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/protocol.py @@ -0,0 +1,1645 @@ +from __future__ import annotations + +import asyncio +import codecs +import collections +import logging +import random +import ssl +import struct +import sys +import time +import uuid +import warnings +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Deque, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, + cast, +) + +from ..datastructures import Headers +from ..exceptions import ( + ConnectionClosed, + ConnectionClosedError, + ConnectionClosedOK, + InvalidState, + PayloadTooBig, + ProtocolError, +) +from ..extensions import Extension +from ..frames import ( + OK_CLOSE_CODES, + OP_BINARY, + OP_CLOSE, + OP_CONT, + OP_PING, + OP_PONG, + OP_TEXT, + Close, + CloseCode, + Opcode, + prepare_ctrl, + prepare_data, +) +from ..protocol import State +from ..typing import Data, LoggerLike, Subprotocol +from .compatibility import asyncio_timeout +from .framing import Frame + + +__all__ = ["WebSocketCommonProtocol", "broadcast"] + + +# In order to ensure consistency, the code always checks the current value of +# WebSocketCommonProtocol.state before assigning a new value and never yields +# between the check and the assignment. + + +class WebSocketCommonProtocol(asyncio.Protocol): + """ + WebSocket connection. + + :class:`WebSocketCommonProtocol` provides APIs shared between WebSocket + servers and clients. You shouldn't use it directly. Instead, use + :class:`~websockets.client.WebSocketClientProtocol` or + :class:`~websockets.server.WebSocketServerProtocol`. + + This documentation focuses on low-level details that aren't covered in the + documentation of :class:`~websockets.client.WebSocketClientProtocol` and + :class:`~websockets.server.WebSocketServerProtocol` for the sake of + simplicity. + + Once the connection is open, a Ping_ frame is sent every ``ping_interval`` + seconds. This serves as a keepalive. It helps keeping the connection open, + especially in the presence of proxies with short timeouts on inactive + connections. Set ``ping_interval`` to :obj:`None` to disable this behavior. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + If the corresponding Pong_ frame isn't received within ``ping_timeout`` + seconds, the connection is considered unusable and is closed with code 1011. + This ensures that the remote endpoint remains responsive. Set + ``ping_timeout`` to :obj:`None` to disable this behavior. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + See the discussion of :doc:`timeouts <../../topics/timeouts>` for details. + + The ``close_timeout`` parameter defines a maximum wait time for completing + the closing handshake and terminating the TCP connection. For legacy + reasons, :meth:`close` completes in at most ``5 * close_timeout`` seconds + for clients and ``4 * close_timeout`` for servers. + + ``close_timeout`` is a parameter of the protocol because websockets usually + calls :meth:`close` implicitly upon exit: + + * on the client side, when using :func:`~websockets.client.connect` as a + context manager; + * on the server side, when the connection handler terminates. + + To apply a timeout to any other API, wrap it in :func:`~asyncio.timeout` or + :func:`~asyncio.wait_for`. + + The ``max_size`` parameter enforces the maximum size for incoming messages + in bytes. The default value is 1Â MiB. If a larger message is received, + :meth:`recv` will raise :exc:`~websockets.exceptions.ConnectionClosedError` + and the connection will be closed with code 1009. + + The ``max_queue`` parameter sets the maximum length of the queue that + holds incoming messages. The default value is ``32``. Messages are added + to an in-memory queue when they're received; then :meth:`recv` pops from + that queue. In order to prevent excessive memory consumption when + messages are received faster than they can be processed, the queue must + be bounded. If the queue fills up, the protocol stops processing incoming + data until :meth:`recv` is called. In this situation, various receive + buffers (at least in :mod:`asyncio` and in the OS) will fill up, then the + TCP receive window will shrink, slowing down transmission to avoid packet + loss. + + Since Python can use up to 4 bytes of memory to represent a single + character, each connection may use up to ``4 * max_size * max_queue`` + bytes of memory to store incoming messages. By default, this is 128Â MiB. + You may want to lower the limits, depending on your application's + requirements. + + The ``read_limit`` argument sets the high-water limit of the buffer for + incoming bytes. The low-water limit is half the high-water limit. The + default value is 64Â KiB, half of asyncio's default (based on the current + implementation of :class:`~asyncio.StreamReader`). + + The ``write_limit`` argument sets the high-water limit of the buffer for + outgoing bytes. The low-water limit is a quarter of the high-water limit. + The default value is 64Â KiB, equal to asyncio's default (based on the + current implementation of ``FlowControlMixin``). + + See the discussion of :doc:`memory usage <../../topics/memory>` for details. + + Args: + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.protocol")``. + See the :doc:`logging guide <../../topics/logging>` for details. + ping_interval: Delay between keepalive pings in seconds. + :obj:`None` disables keepalive pings. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + For legacy reasons, the actual timeout is 4 or 5 times larger. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. + max_queue: Maximum number of incoming messages in receive buffer. + :obj:`None` disables the limit. + read_limit: High-water mark of read buffer in bytes. + write_limit: High-water mark of write buffer in bytes. + + """ + + # There are only two differences between the client-side and server-side + # behavior: masking the payload and closing the underlying TCP connection. + # Set is_client = True/False and side = "client"/"server" to pick a side. + is_client: bool + side: str = "undefined" + + def __init__( + self, + *, + logger: Optional[LoggerLike] = None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + # The following arguments are kept only for backwards compatibility. + host: Optional[str] = None, + port: Optional[int] = None, + secure: Optional[bool] = None, + legacy_recv: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + timeout: Optional[float] = None, + ) -> None: + if legacy_recv: # pragma: no cover + warnings.warn("legacy_recv is deprecated", DeprecationWarning) + + # Backwards compatibility: close_timeout used to be called timeout. + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: the loop parameter used to be supported. + if loop is None: + loop = asyncio.get_event_loop() + else: + warnings.warn("remove loop argument", DeprecationWarning) + + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_size = max_size + self.max_queue = max_queue + self.read_limit = read_limit + self.write_limit = write_limit + + # Unique identifier. For logs. + self.id: uuid.UUID = uuid.uuid4() + """Unique identifier of the connection. Useful in logs.""" + + # Logger or LoggerAdapter for this connection. + if logger is None: + logger = logging.getLogger("websockets.protocol") + self.logger: LoggerLike = logging.LoggerAdapter(logger, {"websocket": self}) + """Logger for this connection.""" + + # Track if DEBUG is enabled. Shortcut logging calls if it isn't. + self.debug = logger.isEnabledFor(logging.DEBUG) + + self.loop = loop + + self._host = host + self._port = port + self._secure = secure + self.legacy_recv = legacy_recv + + # Configure read buffer limits. The high-water limit is defined by + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) + + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + + self._drain_lock = asyncio.Lock() + + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute + # :meth:`connection_open` to change the state to OPEN. + self.state = State.CONNECTING + if self.debug: + self.logger.debug("= connection is CONNECTING") + + # HTTP protocol parameters. + self.path: str + """Path of the opening handshake request.""" + self.request_headers: Headers + """Opening handshake request headers.""" + self.response_headers: Headers + """Opening handshake response headers.""" + + # WebSocket protocol parameters. + self.extensions: List[Extension] = [] + self.subprotocol: Optional[Subprotocol] = None + """Subprotocol, if one was negotiated.""" + + # Close code and reason, set when a close frame is sent or received. + self.close_rcvd: Optional[Close] = None + self.close_sent: Optional[Close] = None + self.close_rcvd_then_sent: Optional[bool] = None + + # Completed when the connection state becomes CLOSED. Translates the + # :meth:`connection_lost` callback to a :class:`~asyncio.Future` + # that can be awaited. (Other :class:`~asyncio.Protocol` callbacks are + # translated by ``self.stream_reader``). + self.connection_lost_waiter: asyncio.Future[None] = loop.create_future() + + # Queue of received messages. + self.messages: Deque[Data] = collections.deque() + self._pop_message_waiter: Optional[asyncio.Future[None]] = None + self._put_message_waiter: Optional[asyncio.Future[None]] = None + + # Protect sending fragmented messages. + self._fragmented_message_waiter: Optional[asyncio.Future[None]] = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pings: Dict[bytes, Tuple[asyncio.Future[float], float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + This value is updated after sending a ping frame and receiving a + matching pong frame. Before the first ping, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive <keepalive>` mechanism + that sends ping frames automatically at regular intervals. You can also + send ping frames and measure latency with :meth:`ping`. + """ + + # Task running the data transfer. + self.transfer_data_task: asyncio.Task[None] + + # Exception that occurred during data transfer, if any. + self.transfer_data_exc: Optional[BaseException] = None + + # Task sending keepalive pings. + self.keepalive_ping_task: asyncio.Task[None] + + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep(0) + await self._drain_helper() + + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. + + Enter the OPEN state and start the data transfer phase. + + """ + # 4.1. The WebSocket Connection is Established. + assert self.state is State.CONNECTING + self.state = State.OPEN + if self.debug: + self.logger.debug("= connection is OPEN") + # Start the task that receives incoming WebSocket messages. + self.transfer_data_task = self.loop.create_task(self.transfer_data()) + # Start the task that sends pings at regular intervals. + self.keepalive_ping_task = self.loop.create_task(self.keepalive_ping()) + # Start the task that eventually closes the TCP connection. + self.close_connection_task = self.loop.create_task(self.close_connection()) + + @property + def host(self) -> Optional[str]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[0] instead of host", DeprecationWarning) + return self._host + + @property + def port(self) -> Optional[int]: + alternative = "remote_address" if self.is_client else "local_address" + warnings.warn(f"use {alternative}[1] instead of port", DeprecationWarning) + return self._port + + @property + def secure(self) -> Optional[bool]: + warnings.warn("don't use secure", DeprecationWarning) + return self._secure + + # Public API + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getsockname`. + + :obj:`None` if the TCP connection isn't established yet. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family; + see :meth:`~socket.socket.getpeername`. + + :obj:`None` if the TCP connection isn't established yet. + + """ + try: + transport = self.transport + except AttributeError: + return None + else: + return transport.get_extra_info("peername") + + @property + def open(self) -> bool: + """ + :obj:`True` when the connection is open; :obj:`False` otherwise. + + This attribute may be used to detect disconnections. However, this + approach is discouraged per the EAFP_ principle. Instead, you should + handle :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + .. _EAFP: https://docs.python.org/3/glossary.html#term-eafp + + """ + return self.state is State.OPEN and not self.transfer_data_task.done() + + @property + def closed(self) -> bool: + """ + :obj:`True` when the connection is closed; :obj:`False` otherwise. + + Be aware that both :attr:`open` and :attr:`closed` are :obj:`False` + during the opening and closing sequences. + + """ + return self.state is State.CLOSED + + @property + def close_code(self) -> Optional[int]: + """ + WebSocket close code, defined in `section 7.1.5 of RFC 6455`_. + + .. _section 7.1.5 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.5 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return CloseCode.ABNORMAL_CLOSURE + else: + return self.close_rcvd.code + + @property + def close_reason(self) -> Optional[str]: + """ + WebSocket close reason, defined in `section 7.1.6 of RFC 6455`_. + + .. _section 7.1.6 of RFC 6455: + https://www.rfc-editor.org/rfc/rfc6455.html#section-7.1.6 + + :obj:`None` if the connection isn't closed yet. + + """ + if self.state is not State.CLOSED: + return None + elif self.close_rcvd is None: + return "" + else: + return self.close_rcvd.reason + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator exits normally when the connection is closed with the close + code 1000 (OK) or 1001 (going away) or without a close code. + + It raises a :exc:`~websockets.exceptions.ConnectionClosedError` + exception when the connection is closed with any other code. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + async def recv(self) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing the next + message. The next invocation of :meth:`recv` will return it. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~asyncio.timeout` or :func:`~asyncio.wait_for`. + + Returns: + Data: A string (:class:`str`) for a Text_ frame. A bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If two coroutines call :meth:`recv` concurrently. + + """ + if self._pop_message_waiter is not None: + raise RuntimeError( + "cannot call recv while another coroutine " + "is already waiting for the next message" + ) + + # Don't await self.ensure_open() here: + # - messages could be available in the queue even if the connection + # is closed; + # - messages could be received before the closing frame even if the + # connection is closing. + + # Wait until there's a message in the queue (if necessary) or the + # connection is closed. + while len(self.messages) <= 0: + pop_message_waiter: asyncio.Future[None] = self.loop.create_future() + self._pop_message_waiter = pop_message_waiter + try: + # If asyncio.wait() is canceled, it doesn't cancel + # pop_message_waiter and self.transfer_data_task. + await asyncio.wait( + [pop_message_waiter, self.transfer_data_task], + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + self._pop_message_waiter = None + + # If asyncio.wait(...) exited because self.transfer_data_task + # completed before receiving a new message, raise a suitable + # exception (or return None if legacy_recv is enabled). + if not pop_message_waiter.done(): + if self.legacy_recv: + return None # type: ignore + else: + # Wait until the connection is closed to raise + # ConnectionClosed with the correct code and reason. + await self.ensure_open() + + # Pop a message from the queue. + message = self.messages.popleft() + + # Notify transfer_data(). + if self._put_message_waiter is not None: + self._put_message_waiter.set_result(None) + self._put_message_waiter = None + + return message + + async def send( + self, + message: Union[Data, Iterable[Data], AsyncIterable[Data]], + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you want to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message (Union[Data, Iterable[Data], AsyncIterable[Data]): message + to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + await self.ensure_open() + + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self._fragmented_message_waiter is not None: + await asyncio.shield(self._fragmented_message_waiter) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, (str, bytes, bytearray, memoryview)): + opcode, data = prepare_data(message) + await self.write_frame(True, opcode, data) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + # Work around https://github.com/python/mypy/issues/6227 + message = cast(Iterable[Data], message) + + iter_message = iter(message) + try: + fragment = next(iter_message) + except StopIteration: + return + opcode, data = prepare_data(fragment) + + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + for fragment in iter_message: + confirm_opcode, data = prepare_data(fragment) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except (Exception, asyncio.CancelledError): + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(CloseCode.INTERNAL_ERROR) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + # Fragmented message -- asynchronous iterator + + elif isinstance(message, AsyncIterable): + # Implement aiter_message = aiter(message) without aiter + # Work around https://github.com/python/mypy/issues/5738 + aiter_message = cast( + Callable[[AsyncIterable[Data]], AsyncIterator[Data]], + type(message).__aiter__, + )(message) + try: + # Implement fragment = anext(aiter_message) without anext + # Work around https://github.com/python/mypy/issues/5738 + fragment = await cast( + Callable[[AsyncIterator[Data]], Awaitable[Data]], + type(aiter_message).__anext__, + )(aiter_message) + except StopAsyncIteration: + return + opcode, data = prepare_data(fragment) + + self._fragmented_message_waiter = asyncio.Future() + try: + # First fragment. + await self.write_frame(False, opcode, data) + + # Other fragments. + async for fragment in aiter_message: + confirm_opcode, data = prepare_data(fragment) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") + await self.write_frame(False, OP_CONT, data) + + # Final fragment. + await self.write_frame(True, OP_CONT, b"") + + except (Exception, asyncio.CancelledError): + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + self.fail_connection(CloseCode.INTERNAL_ERROR) + raise + + finally: + self._fragmented_message_waiter.set_result(None) + self._fragmented_message_waiter = None + + else: + raise TypeError("data must be str, bytes-like, or iterable") + + async def close( + self, + code: int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. As a consequence, there's no need + to await :meth:`wait_closed` after :meth:`close`. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Wrapping :func:`close` in :func:`~asyncio.create_task` is safe, given + that errors during connection termination aren't particularly useful. + + Canceling :meth:`close` is discouraged. If it takes too long, you can + set a shorter ``close_timeout``. If you don't want to wait, let the + Python process exit, then the OS will take care of closing the TCP + connection. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + async with asyncio_timeout(self.close_timeout): + await self.write_close_frame(Close(code, reason)) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers + # are full, the closing handshake won't complete anyway. + # Fail the connection to shut down faster. + self.fail_connection() + + # If no close frame is received within the timeout, asyncio_timeout() + # cancels the data transfer task and raises TimeoutError. + + # If close() is called multiple times concurrently and one of these + # calls hits the timeout, the data transfer task will be canceled. + # Other calls will receive a CancelledError here. + + try: + # If close() is canceled during the wait, self.transfer_data_task + # is canceled before the timeout elapses. + async with asyncio_timeout(self.close_timeout): + await self.transfer_data_task + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + # Wait for the close connection task to close the TCP connection. + await asyncio.shield(self.close_connection_task) + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + This coroutine is identical to the :attr:`closed` attribute, except it + can be awaited. + + This can make it easier to detect connection termination, regardless + of its cause, in tasks that interact with the WebSocket connection. + + """ + await asyncio.shield(self.connection_lost_waiter) + + async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: + """ + Send a Ping_. + + .. _Ping: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.2 + + A ping may serve as a keepalive, as a check that the remote endpoint + received all messages up to this point, or to measure :attr:`latency`. + + Canceling :meth:`ping` is discouraged. If :meth:`ping` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Canceling the :class:`~asyncio.Future` returned by :meth:`ping` has no + effect. + + Args: + data (Optional[Data]): payload of the ping; a string will be + encoded to UTF-8; or :obj:`None` to generate a payload + containing four random bytes. + + Returns: + ~asyncio.Future[float]: A future that will be completed when the + corresponding pong is received. You can ignore it if you don't + intend to wait. The result of the future is the latency of the + connection in seconds. + + :: + + pong_waiter = await ws.ping() + # only if you want to wait for the corresponding pong + latency = await pong_waiter + + Raises: + ConnectionClosed: When the connection is closed. + RuntimeError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + await self.ensure_open() + + if data is not None: + data = prepare_ctrl(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise RuntimeError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_waiter = self.loop.create_future() + # Resolution of time.monotonic() may be too low on Windows. + ping_timestamp = time.perf_counter() + self.pings[data] = (pong_waiter, ping_timestamp) + + await self.write_frame(True, OP_PING, data) + + return asyncio.shield(pong_waiter) + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Canceling :meth:`pong` is discouraged. If :meth:`pong` doesn't return + immediately, it means the write buffer is full. If you don't want to + wait, you should close the connection. + + Args: + data (Data): Payload of the pong. A string will be encoded to + UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + await self.ensure_open() + + data = prepare_ctrl(data) + + await self.write_frame(True, OP_PONG, data) + + # Private methods - no guarantees. + + def connection_closed_exc(self) -> ConnectionClosed: + exc: ConnectionClosed + if ( + self.close_rcvd is not None + and self.close_rcvd.code in OK_CLOSE_CODES + and self.close_sent is not None + and self.close_sent.code in OK_CLOSE_CODES + ): + exc = ConnectionClosedOK( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + else: + exc = ConnectionClosedError( + self.close_rcvd, + self.close_sent, + self.close_rcvd_then_sent, + ) + # Chain to the exception that terminated data transfer, if any. + exc.__cause__ = self.transfer_data_exc + return exc + + async def ensure_open(self) -> None: + """ + Check that the WebSocket connection is open. + + Raise :exc:`~websockets.exceptions.ConnectionClosed` if it isn't. + + """ + # Handle cases from most common to least common for performance. + if self.state is State.OPEN: + # If self.transfer_data_task exited without a closing handshake, + # self.close_connection_task may be closing the connection, going + # straight from OPEN to CLOSED. + if self.transfer_data_task.done(): + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + else: + return + + if self.state is State.CLOSED: + raise self.connection_closed_exc() + + if self.state is State.CLOSING: + # If we started the closing handshake, wait for its completion to + # get the proper close code and reason. self.close_connection_task + # will complete within 4 or 5 * close_timeout after close(). The + # CLOSING state also occurs when failing the connection. In that + # case self.close_connection_task will complete even faster. + await asyncio.shield(self.close_connection_task) + raise self.connection_closed_exc() + + # Control may only reach this point in buggy third-party subclasses. + assert self.state is State.CONNECTING + raise InvalidState("WebSocket connection isn't established yet") + + async def transfer_data(self) -> None: + """ + Read incoming messages and put them in a queue. + + This coroutine runs in a task until the closing handshake is started. + + """ + try: + while True: + message = await self.read_message() + + # Exit the loop when receiving a close frame. + if message is None: + break + + # Wait until there's room in the queue (if necessary). + if self.max_queue is not None: + while len(self.messages) >= self.max_queue: + self._put_message_waiter = self.loop.create_future() + try: + await asyncio.shield(self._put_message_waiter) + finally: + self._put_message_waiter = None + + # Put the message in the queue. + self.messages.append(message) + + # Notify recv(). + if self._pop_message_waiter is not None: + self._pop_message_waiter.set_result(None) + self._pop_message_waiter = None + + except asyncio.CancelledError as exc: + self.transfer_data_exc = exc + # If fail_connection() cancels this task, avoid logging the error + # twice and failing the connection again. + raise + + except ProtocolError as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.PROTOCOL_ERROR) + + except (ConnectionError, TimeoutError, EOFError, ssl.SSLError) as exc: + # Reading data with self.reader.readexactly may raise: + # - most subclasses of ConnectionError if the TCP connection + # breaks, is reset, or is aborted; + # - TimeoutError if the TCP connection times out; + # - IncompleteReadError, a subclass of EOFError, if fewer + # bytes are available than requested; + # - ssl.SSLError if the other side infringes the TLS protocol. + self.transfer_data_exc = exc + self.fail_connection(CloseCode.ABNORMAL_CLOSURE) + + except UnicodeDecodeError as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.INVALID_DATA) + + except PayloadTooBig as exc: + self.transfer_data_exc = exc + self.fail_connection(CloseCode.MESSAGE_TOO_BIG) + + except Exception as exc: + # This shouldn't happen often because exceptions expected under + # regular circumstances are handled above. If it does, consider + # catching and handling more exceptions. + self.logger.error("data transfer failed", exc_info=True) + + self.transfer_data_exc = exc + self.fail_connection(CloseCode.INTERNAL_ERROR) + + async def read_message(self) -> Optional[Data]: + """ + Read a single message from the connection. + + Re-assemble data frames if the message is fragmented. + + Return :obj:`None` when the closing handshake is started. + + """ + frame = await self.read_data_frame(max_size=self.max_size) + + # A close frame was received. + if frame is None: + return None + + if frame.opcode == OP_TEXT: + text = True + elif frame.opcode == OP_BINARY: + text = False + else: # frame.opcode == OP_CONT + raise ProtocolError("unexpected opcode") + + # Shortcut for the common case - no fragmentation + if frame.fin: + return frame.data.decode("utf-8") if text else frame.data + + # 5.4. Fragmentation + fragments: List[Data] = [] + max_size = self.max_size + if text: + decoder_factory = codecs.getincrementaldecoder("utf-8") + decoder = decoder_factory(errors="strict") + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal fragments + fragments.append(decoder.decode(frame.data, frame.fin)) + + else: + + def append(frame: Frame) -> None: + nonlocal fragments, max_size + fragments.append(decoder.decode(frame.data, frame.fin)) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + else: + if max_size is None: + + def append(frame: Frame) -> None: + nonlocal fragments + fragments.append(frame.data) + + else: + + def append(frame: Frame) -> None: + nonlocal fragments, max_size + fragments.append(frame.data) + assert isinstance(max_size, int) + max_size -= len(frame.data) + + append(frame) + + while not frame.fin: + frame = await self.read_data_frame(max_size=max_size) + if frame is None: + raise ProtocolError("incomplete fragmented message") + if frame.opcode != OP_CONT: + raise ProtocolError("unexpected opcode") + append(frame) + + return ("" if text else b"").join(fragments) + + async def read_data_frame(self, max_size: Optional[int]) -> Optional[Frame]: + """ + Read a single data frame from the connection. + + Process control frames received before the next data frame. + + Return :obj:`None` if a close frame is encountered before any data frame. + + """ + # 6.2. Receiving Data + while True: + frame = await self.read_frame(max_size) + + # 5.5. Control Frames + if frame.opcode == OP_CLOSE: + # 7.1.5. The WebSocket Connection Close Code + # 7.1.6. The WebSocket Connection Close Reason + self.close_rcvd = Close.parse(frame.data) + if self.close_sent is not None: + self.close_rcvd_then_sent = False + try: + # Echo the original data instead of re-serializing it with + # Close.serialize() because that fails when the close frame + # is empty and Close.parse() synthesizes a 1005 close code. + await self.write_close_frame(self.close_rcvd, frame.data) + except ConnectionClosed: + # Connection closed before we could echo the close frame. + pass + return None + + elif frame.opcode == OP_PING: + # Answer pings, unless connection is CLOSING. + if self.state is State.OPEN: + try: + await self.pong(frame.data) + except ConnectionClosed: + # Connection closed while draining write buffer. + pass + + elif frame.opcode == OP_PONG: + if frame.data in self.pings: + pong_timestamp = time.perf_counter() + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, (pong_waiter, ping_timestamp) in self.pings.items(): + ping_ids.append(ping_id) + if not pong_waiter.done(): + pong_waiter.set_result(pong_timestamp - ping_timestamp) + if ping_id == frame.data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + # 5.6. Data Frames + else: + return frame + + async def read_frame(self, max_size: Optional[int]) -> Frame: + """ + Read a single frame from the connection. + + """ + frame = await Frame.read( + self.reader.readexactly, + mask=not self.is_client, + max_size=max_size, + extensions=self.extensions, + ) + if self.debug: + self.logger.debug("< %s", frame) + return frame + + def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: + frame = Frame(fin, Opcode(opcode), data) + if self.debug: + self.logger.debug("> %s", frame) + frame.write( + self.transport.write, + mask=self.is_client, + extensions=self.extensions, + ) + + async def drain(self) -> None: + try: + # drain() cannot be called concurrently by multiple coroutines: + # http://bugs.python.org/issue29930. Remove this lock when no + # version of Python where this bugs exists is supported anymore. + async with self._drain_lock: + # Handle flow control automatically. + await self._drain() + except ConnectionError: + # Terminate the connection if the socket died. + self.fail_connection() + # Wait until the connection is closed to raise ConnectionClosed + # with the correct code and reason. + await self.ensure_open() + + async def write_frame( + self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN + ) -> None: + # Defensive assertion for protocol compliance. + if self.state is not _state: # pragma: no cover + raise InvalidState( + f"Cannot write to a WebSocket in the {self.state.name} state" + ) + self.write_frame_sync(fin, opcode, data) + await self.drain() + + async def write_close_frame( + self, close: Close, data: Optional[bytes] = None + ) -> None: + """ + Write a close frame if and only if the connection state is OPEN. + + This dedicated coroutine must be used for writing close frames to + ensure that at most one close frame is sent on a given connection. + + """ + # Test and set the connection state before sending the close frame to + # avoid sending two frames in case of concurrent calls. + if self.state is State.OPEN: + # 7.1.3. The WebSocket Closing Handshake is Started + self.state = State.CLOSING + if self.debug: + self.logger.debug("= connection is CLOSING") + + self.close_sent = close + if self.close_rcvd is not None: + self.close_rcvd_then_sent = True + if data is None: + data = close.serialize() + + # 7.1.2. Start the WebSocket Closing Handshake + await self.write_frame(True, OP_CLOSE, data, _state=State.CLOSING) + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + This coroutine exits when the connection terminates and one of the + following happens: + + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`close_connection` cancels :attr:`keepalive_ping_task`. + + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when close_connection() cancels self.keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + self.logger.debug("% sending keepalive ping") + pong_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + async with asyncio_timeout(self.ping_timeout): + await pong_waiter + self.logger.debug("% received keepalive pong") + except asyncio.TimeoutError: + if self.debug: + self.logger.debug("! timed out waiting for keepalive pong") + self.fail_connection( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + + except ConnectionClosed: + pass + + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + async def close_connection(self) -> None: + """ + 7.1.1. Close the WebSocket Connection + + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + + """ + try: + # Wait for the data transfer phase to complete. + if hasattr(self, "transfer_data_task"): + try: + await self.transfer_data_task + except asyncio.CancelledError: + pass + + # Cancel the keepalive ping task. + if hasattr(self, "keepalive_ping_task"): + self.keepalive_ping_task.cancel() + + # A client should wait for a TCP close from the server. + if self.is_client and hasattr(self, "transfer_data_task"): + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("! timed out waiting for TCP close") + + # Half-close the TCP connection if possible (when there's no TLS). + if self.transport.can_write_eof(): + if self.debug: + self.logger.debug("x half-closing TCP connection") + # write_eof() doesn't document which exceptions it raises. + # "[Errno 107] Transport endpoint is not connected" happens + # but it isn't completely clear under which circumstances. + # uvloop can raise RuntimeError here. + try: + self.transport.write_eof() + except (OSError, RuntimeError): # pragma: no cover + pass + + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("! timed out waiting for TCP close") + + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is canceled (for example). + await self.close_transport() + + async def close_transport(self) -> None: + """ + Close the TCP connection. + + """ + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. + if self.connection_lost_waiter.done() and self.transport.is_closing(): + return + + # Close the TCP connection. Buffers are flushed asynchronously. + if self.debug: + self.logger.debug("x closing TCP connection") + self.transport.close() + + if await self.wait_for_connection_lost(): + return + if self.debug: + self.logger.debug("! timed out waiting for TCP close") + + # Abort the TCP connection. Buffers are discarded. + if self.debug: + self.logger.debug("x aborting TCP connection") + # Due to a bug in coverage, this is erroneously reported as not covered. + self.transport.abort() # pragma: no cover + + # connection_lost() is called quickly after aborting. + await self.wait_for_connection_lost() + + async def wait_for_connection_lost(self) -> bool: + """ + Wait until the TCP connection is closed or ``self.close_timeout`` elapses. + + Return :obj:`True` if the connection is closed and :obj:`False` + otherwise. + + """ + if not self.connection_lost_waiter.done(): + try: + async with asyncio_timeout(self.close_timeout): + await asyncio.shield(self.connection_lost_waiter) + except asyncio.TimeoutError: + pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() + + def fail_connection( + self, + code: int = CloseCode.ABNORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + 7.1.7. Fail the WebSocket Connection + + This requires: + + 1. Stopping all processing of incoming data, which means cancelling + :attr:`transfer_data_task`. The close code will be 1006 unless a + close frame was received earlier. + + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + + 3. Closing the connection. :meth:`close_connection` takes care of + this once :attr:`transfer_data_task` exits after being canceled. + + (The specification describes these steps in the opposite order.) + + """ + if self.debug: + self.logger.debug("! failing connection with code %d", code) + + # Cancel transfer_data_task if the opening handshake succeeded. + # cancel() is idempotent and ignored if the task is done already. + if hasattr(self, "transfer_data_task"): + self.transfer_data_task.cancel() + + # Send a close frame when the state is OPEN (a close frame was already + # sent if it's CLOSING), except when failing the connection because of + # an error reading from or writing to the network. + # Don't send a close frame if the connection is broken. + if code != CloseCode.ABNORMAL_CLOSURE and self.state is State.OPEN: + close = Close(code, reason) + + # Write the close frame without draining the write buffer. + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not drainig the write buffer is acceptable in this context. + + # This duplicates a few lines of code from write_close_frame(). + + self.state = State.CLOSING + if self.debug: + self.logger.debug("= connection is CLOSING") + + # If self.close_rcvd was set, the connection state would be + # CLOSING. Therefore self.close_rcvd isn't set and we don't + # have to set self.close_rcvd_then_sent. + assert self.close_rcvd is None + self.close_sent = close + + self.write_frame_sync(True, OP_CLOSE, close.serialize()) + + # Start close_connection_task if the opening handshake didn't succeed. + if not hasattr(self, "close_connection_task"): + self.close_connection_task = self.loop.create_task(self.close_connection()) + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + + They'll never receive a pong once the connection is closed. + + """ + assert self.state is State.CLOSED + exc = self.connection_closed_exc() + + for pong_waiter, _ping_timestamp in self.pings.values(): + pong_waiter.set_exception(exc) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + pong_waiter.cancel() + + # asyncio.Protocol methods + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Configure write buffer limits. + + The high-water limit is defined by ``self.write_limit``. + + The low-water limit currently defaults to ``self.write_limit // 4`` in + :meth:`~asyncio.WriteTransport.set_write_buffer_limits`, which should + be all right for reasonable use cases of this library. + + This is the earliest point where we can get hold of the transport, + which means it's the best point for configuring it. + + """ + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport + + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """ + 7.1.4. The WebSocket Connection is Closed. + + """ + self.state = State.CLOSED + self.logger.debug("= connection is CLOSED") + + self.abort_pings() + + # If self.connection_lost_waiter isn't pending, that's a bug, because: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + + if True: # pragma: no cover + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def pause_writing(self) -> None: # pragma: no cover + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: # pragma: no cover + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning :obj:`True`. + + Besides, that doesn't work on TLS connections. + + """ + self.reader.feed_eof() + + +def broadcast( + websockets: Iterable[WebSocketCommonProtocol], + message: Data, + raise_exceptions: bool = False, +) -> None: + """ + Broadcast a message to several WebSocket connections. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like + object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent + as a Binary_ frame. + + .. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + .. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + + :func:`broadcast` pushes the message synchronously to all connections even + if their write buffers are overflowing. There's no backpressure. + + If you broadcast messages faster than a connection can handle them, messages + will pile up in its write buffer until the connection times out. Keep + ``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage + from slow connections. + + Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`, + :func:`broadcast` doesn't support sending fragmented messages. Indeed, + fragmentation is useful for sending large messages without buffering them in + memory, while :func:`broadcast` buffers one copy per connection as fast as + possible. + + :func:`broadcast` skips connections that aren't open in order to avoid + errors on connections where the closing handshake is in progress. + + :func:`broadcast` ignores failures to write the message on some connections. + It continues writing to other connections. On Python 3.11 and above, you + may set ``raise_exceptions`` to :obj:`True` to record failures and raise all + exceptions in a :pep:`654` :exc:`ExceptionGroup`. + + Args: + websockets: WebSocket connections to which the message will be sent. + message: Message to send. + raise_exceptions: Whether to raise an exception in case of failures. + + Raises: + TypeError: If ``message`` doesn't have a supported type. + + """ + if not isinstance(message, (str, bytes, bytearray, memoryview)): + raise TypeError("data must be str or bytes-like") + + if raise_exceptions: + if sys.version_info[:2] < (3, 11): # pragma: no cover + raise ValueError("raise_exceptions requires at least Python 3.11") + exceptions = [] + + opcode, data = prepare_data(message) + + for websocket in websockets: + if websocket.state is not State.OPEN: + continue + + if websocket._fragmented_message_waiter is not None: + if raise_exceptions: + exception = RuntimeError("sending a fragmented message") + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: sending a fragmented message", + ) + + try: + websocket.write_frame_sync(True, opcode, data) + except Exception as write_exception: + if raise_exceptions: + exception = RuntimeError("failed to write message") + exception.__cause__ = write_exception + exceptions.append(exception) + else: + websocket.logger.warning( + "skipped broadcast: failed to write message", + exc_info=True, + ) + + if raise_exceptions: + raise ExceptionGroup("skipped broadcast", exceptions) diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/server.py b/venv/lib/python3.11/site-packages/websockets/legacy/server.py new file mode 100644 index 0000000..7c24dd7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/legacy/server.py @@ -0,0 +1,1185 @@ +from __future__ import annotations + +import asyncio +import email.utils +import functools +import http +import inspect +import logging +import socket +import warnings +from types import TracebackType +from typing import ( + Any, + Awaitable, + Callable, + Generator, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) + +from ..datastructures import Headers, HeadersLike, MultipleValuesError +from ..exceptions import ( + AbortHandshake, + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidOrigin, + InvalidUpgrade, + NegotiationError, +) +from ..extensions import Extension, ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..headers import ( + build_extension, + parse_extension, + parse_subprotocol, + validate_subprotocols, +) +from ..http import USER_AGENT +from ..protocol import State +from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol +from .compatibility import asyncio_timeout +from .handshake import build_response, check_request +from .http import read_request +from .protocol import WebSocketCommonProtocol + + +__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] + + +HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]] + +HTTPResponse = Tuple[StatusLike, HeadersLike, bytes] + + +class WebSocketServerProtocol(WebSocketCommonProtocol): + """ + WebSocket server connection. + + :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send` + coroutines for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises + a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection + is closed with any other code. + + You may customize the opening handshake in a subclass by + overriding :meth:`process_request` or :meth:`select_subprotocol`. + + Args: + ws_server: WebSocket server that created this connection. + + See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``, + ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + """ + + is_client = False + side = "server" + + def __init__( + self, + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], + ws_server: WebSocketServer, + *, + logger: Optional[LoggerLike] = None, + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, + process_request: Optional[ + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] + ] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, + open_timeout: Optional[float] = 10, + **kwargs: Any, + ) -> None: + if logger is None: + logger = logging.getLogger("websockets.server") + super().__init__(logger=logger, **kwargs) + # For backwards compatibility with 6.0 or earlier. + if origins is not None and "" in origins: + warnings.warn("use None instead of '' in origins", DeprecationWarning) + origins = [None if origin == "" else origin for origin in origins] + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to serve to trigger the deprecation warning on direct + # use of WebSocketServerProtocol. + self.ws_handler = remove_path_argument(ws_handler) + self.ws_server = ws_server + self.origins = origins + self.available_extensions = extensions + self.available_subprotocols = subprotocols + self.extra_headers = extra_headers + self.server_header = server_header + self._process_request = process_request + self._select_subprotocol = select_subprotocol + self.open_timeout = open_timeout + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + Register connection and initialize a task to handle it. + + """ + super().connection_made(transport) + # Register the connection with the server before creating the handler + # task. Registering at the beginning of the handler coroutine would + # create a race condition between the creation of the task, which + # schedules its execution, and the moment the handler starts running. + self.ws_server.register(self) + self.handler_task = self.loop.create_task(self.handler()) + + async def handler(self) -> None: + """ + Handle the lifecycle of a WebSocket connection. + + Since this method doesn't have a caller able to handle exceptions, it + attempts to log relevant ones and guarantees that the TCP connection is + closed before exiting. + + """ + try: + try: + async with asyncio_timeout(self.open_timeout): + await self.handshake( + origins=self.origins, + available_extensions=self.available_extensions, + available_subprotocols=self.available_subprotocols, + extra_headers=self.extra_headers, + ) + except asyncio.TimeoutError: # pragma: no cover + raise + except ConnectionError: + raise + except Exception as exc: + if isinstance(exc, AbortHandshake): + status, headers, body = exc.status, exc.headers, exc.body + elif isinstance(exc, InvalidOrigin): + if self.debug: + self.logger.debug("! invalid origin", exc_info=True) + status, headers, body = ( + http.HTTPStatus.FORBIDDEN, + Headers(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), + ) + elif isinstance(exc, InvalidUpgrade): + if self.debug: + self.logger.debug("! invalid upgrade", exc_info=True) + status, headers, body = ( + http.HTTPStatus.UPGRADE_REQUIRED, + Headers([("Upgrade", "websocket")]), + ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ).encode(), + ) + elif isinstance(exc, InvalidHandshake): + if self.debug: + self.logger.debug("! invalid handshake", exc_info=True) + status, headers, body = ( + http.HTTPStatus.BAD_REQUEST, + Headers(), + f"Failed to open a WebSocket connection: {exc}.\n".encode(), + ) + else: + self.logger.error("opening handshake failed", exc_info=True) + status, headers, body = ( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + Headers(), + ( + b"Failed to open a WebSocket connection.\n" + b"See server log for more information.\n" + ), + ) + + headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + if self.server_header is not None: + headers.setdefault("Server", self.server_header) + + headers.setdefault("Content-Length", str(len(body))) + headers.setdefault("Content-Type", "text/plain") + headers.setdefault("Connection", "close") + + self.write_http_response(status, headers, body) + self.logger.info( + "connection rejected (%d %s)", status.value, status.phrase + ) + await self.close_transport() + return + + try: + await self.ws_handler(self) + except Exception: + self.logger.error("connection handler failed", exc_info=True) + if not self.closed: + self.fail_connection(1011) + raise + + try: + await self.close() + except ConnectionError: + raise + except Exception: + self.logger.error("closing handshake failed", exc_info=True) + raise + + except Exception: + # Last-ditch attempt to avoid leaking connections on errors. + try: + self.transport.close() + except Exception: # pragma: no cover + pass + + finally: + # Unregister the connection with the server when the handler task + # terminates. Registration is tied to the lifecycle of the handler + # task because the server waits for tasks attached to registered + # connections before terminating. + self.ws_server.unregister(self) + self.logger.info("connection closed") + + async def read_http_request(self) -> Tuple[str, Headers]: + """ + Read request line and headers from the HTTP request. + + If the request contains a body, it may be read from ``self.reader`` + after this coroutine returns. + + Raises: + InvalidMessage: if the HTTP message is malformed or isn't an + HTTP/1.1 GET request. + + """ + try: + path, headers = await read_request(self.reader) + except asyncio.CancelledError: # pragma: no cover + raise + except Exception as exc: + raise InvalidMessage("did not receive a valid HTTP request") from exc + + if self.debug: + self.logger.debug("< GET %s HTTP/1.1", path) + for key, value in headers.raw_items(): + self.logger.debug("< %s: %s", key, value) + + self.path = path + self.request_headers = headers + + return path, headers + + def write_http_response( + self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None + ) -> None: + """ + Write status line and headers to the HTTP response. + + This coroutine is also able to write a response body. + + """ + self.response_headers = headers + + if self.debug: + self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase) + for key, value in headers.raw_items(): + self.logger.debug("> %s: %s", key, value) + if body is not None: + self.logger.debug("> [body] (%d bytes)", len(body)) + + # Since the status line and headers only contain ASCII characters, + # we can keep this simple. + response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" + response += str(headers) + + self.transport.write(response.encode()) + + if body is not None: + self.transport.write(body) + + async def process_request( + self, path: str, request_headers: Headers + ) -> Optional[HTTPResponse]: + """ + Intercept the HTTP request and return an HTTP response if appropriate. + + You may override this method in a :class:`WebSocketServerProtocol` + subclass, for example: + + * to return an HTTP 200 OK response on a given path; then a load + balancer can use this path for a health check; + * to authenticate the request and return an HTTP 401 Unauthorized or an + HTTP 403 Forbidden when authentication fails. + + You may also override this method with the ``process_request`` + argument of :func:`serve` and :class:`WebSocketServerProtocol`. This + is equivalent, except ``process_request`` won't have access to the + protocol instance, so it can't store information for later use. + + :meth:`process_request` is expected to complete quickly. If it may run + for a long time, then it should await :meth:`wait_closed` and exit if + :meth:`wait_closed` completes, or else it could prevent the server + from shutting down. + + Args: + path: request path, including optional query string. + request_headers: request headers. + + Returns: + Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None` + to continue the WebSocket handshake normally. + + An HTTP response, represented by a 3-uple of the response status, + headers, and body, to abort the WebSocket handshake and return + that HTTP response instead. + + """ + if self._process_request is not None: + response = self._process_request(path, request_headers) + if isinstance(response, Awaitable): + return await response + else: + # For backwards compatibility with 7.0. + warnings.warn( + "declare process_request as a coroutine", DeprecationWarning + ) + return response + return None + + @staticmethod + def process_origin( + headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None + ) -> Optional[Origin]: + """ + Handle the Origin HTTP request header. + + Args: + headers: request headers. + origins: optional list of acceptable origins. + + Raises: + InvalidOrigin: if the origin isn't acceptable. + + """ + # "The user agent MUST NOT include more than one Origin header field" + # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3. + try: + origin = cast(Optional[Origin], headers.get("Origin")) + except MultipleValuesError as exc: + raise InvalidHeader("Origin", "more than one Origin header found") from exc + if origins is not None: + if origin not in origins: + raise InvalidOrigin(origin) + return origin + + @staticmethod + def process_extensions( + headers: Headers, + available_extensions: Optional[Sequence[ServerExtensionFactory]], + ) -> Tuple[Optional[str], List[Extension]]: + """ + Handle the Sec-WebSocket-Extensions HTTP request header. + + Accept or reject each extension proposed in the client request. + Negotiate parameters for accepted extensions. + + Return the Sec-WebSocket-Extensions HTTP response header and the list + of accepted extensions. + + :rfc:`6455` leaves the rules up to the specification of each + :extension. + + To provide this level of flexibility, for each extension proposed by + the client, we check for a match with each extension available in the + server configuration. If no match is found, the extension is ignored. + + If several variants of the same extension are proposed by the client, + it may be accepted several times, which won't make sense in general. + Extensions must implement their own requirements. For this purpose, + the list of previously accepted extensions is provided. + + This process doesn't allow the server to reorder extensions. It can + only select a subset of the extensions proposed by the client. + + Other requirements, for example related to mandatory extensions or the + order of extensions, may be implemented by overriding this method. + + Args: + headers: request headers. + extensions: optional list of supported extensions. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. + + """ + response_header_value: Optional[str] = None + + extension_headers: List[ExtensionHeader] = [] + accepted_extensions: List[Extension] = [] + + header_values = headers.get_all("Sec-WebSocket-Extensions") + + if header_values and available_extensions: + parsed_header_values: List[ExtensionHeader] = sum( + [parse_extension(header_value) for header_value in header_values], [] + ) + + for name, request_params in parsed_header_values: + for ext_factory in available_extensions: + # Skip non-matching extensions based on their name. + if ext_factory.name != name: + continue + + # Skip non-matching extensions based on their params. + try: + response_params, extension = ext_factory.process_request_params( + request_params, accepted_extensions + ) + except NegotiationError: + continue + + # Add matching extension to the final list. + extension_headers.append((name, response_params)) + accepted_extensions.append(extension) + + # Break out of the loop once we have a match. + break + + # If we didn't break from the loop, no extension in our list + # matched what the client sent. The extension is declined. + + # Serialize extension header. + if extension_headers: + response_header_value = build_extension(extension_headers) + + return response_header_value, accepted_extensions + + # Not @staticmethod because it calls self.select_subprotocol() + def process_subprotocol( + self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] + ) -> Optional[Subprotocol]: + """ + Handle the Sec-WebSocket-Protocol HTTP request header. + + Return Sec-WebSocket-Protocol HTTP response header, which is the same + as the selected subprotocol. + + Args: + headers: request headers. + available_subprotocols: optional list of supported subprotocols. + + Raises: + InvalidHandshake: to abort the handshake with an HTTP 400 error. + + """ + subprotocol: Optional[Subprotocol] = None + + header_values = headers.get_all("Sec-WebSocket-Protocol") + + if header_values and available_subprotocols: + parsed_header_values: List[Subprotocol] = sum( + [parse_subprotocol(header_value) for header_value in header_values], [] + ) + + subprotocol = self.select_subprotocol( + parsed_header_values, available_subprotocols + ) + + return subprotocol + + def select_subprotocol( + self, + client_subprotocols: Sequence[Subprotocol], + server_subprotocols: Sequence[Subprotocol], + ) -> Optional[Subprotocol]: + """ + Pick a subprotocol among those supported by the client and the server. + + If several subprotocols are available, select the preferred subprotocol + by giving equal weight to the preferences of the client and the server. + + If no subprotocol is available, proceed without a subprotocol. + + You may provide a ``select_subprotocol`` argument to :func:`serve` or + :class:`WebSocketServerProtocol` to override this logic. For example, + you could reject the handshake if the client doesn't support a + particular subprotocol, rather than accept the handshake without that + subprotocol. + + Args: + client_subprotocols: list of subprotocols offered by the client. + server_subprotocols: list of subprotocols available on the server. + + Returns: + Optional[Subprotocol]: Selected subprotocol, if a common subprotocol + was found. + + :obj:`None` to continue without a subprotocol. + + """ + if self._select_subprotocol is not None: + return self._select_subprotocol(client_subprotocols, server_subprotocols) + + subprotocols = set(client_subprotocols) & set(server_subprotocols) + if not subprotocols: + return None + return sorted( + subprotocols, + key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p), + )[0] + + async def handshake( + self, + origins: Optional[Sequence[Optional[Origin]]] = None, + available_extensions: Optional[Sequence[ServerExtensionFactory]] = None, + available_subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + ) -> str: + """ + Perform the server side of the opening handshake. + + Args: + origins: list of acceptable values of the Origin HTTP header; + include :obj:`None` if the lack of an origin is acceptable. + extensions: list of supported extensions, in order in which they + should be tried. + subprotocols: list of supported subprotocols, in order of + decreasing preference. + extra_headers: arbitrary HTTP headers to add to the response when + the handshake succeeds. + + Returns: + str: path of the URI of the request. + + Raises: + InvalidHandshake: if the handshake fails. + + """ + path, request_headers = await self.read_http_request() + + # Hook for customizing request handling, for example checking + # authentication or treating some paths as plain HTTP endpoints. + early_response_awaitable = self.process_request(path, request_headers) + if isinstance(early_response_awaitable, Awaitable): + early_response = await early_response_awaitable + else: + # For backwards compatibility with 7.0. + warnings.warn("declare process_request as a coroutine", DeprecationWarning) + early_response = early_response_awaitable + + # The connection may drop while process_request is running. + if self.state is State.CLOSED: + # This subclass of ConnectionError is silently ignored in handler(). + raise BrokenPipeError("connection closed during opening handshake") + + # Change the response to a 503 error if the server is shutting down. + if not self.ws_server.is_serving(): + early_response = ( + http.HTTPStatus.SERVICE_UNAVAILABLE, + [], + b"Server is shutting down.\n", + ) + + if early_response is not None: + raise AbortHandshake(*early_response) + + key = check_request(request_headers) + + self.origin = self.process_origin(request_headers, origins) + + extensions_header, self.extensions = self.process_extensions( + request_headers, available_extensions + ) + + protocol_header = self.subprotocol = self.process_subprotocol( + request_headers, available_subprotocols + ) + + response_headers = Headers() + + build_response(response_headers, key) + + if extensions_header is not None: + response_headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + response_headers["Sec-WebSocket-Protocol"] = protocol_header + + if callable(extra_headers): + extra_headers = extra_headers(path, self.request_headers) + if extra_headers is not None: + response_headers.update(extra_headers) + + response_headers.setdefault("Date", email.utils.formatdate(usegmt=True)) + if self.server_header is not None: + response_headers.setdefault("Server", self.server_header) + + self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers) + + self.logger.info("connection open") + + self.connection_open() + + return path + + +class WebSocketServer: + """ + WebSocket server returned by :func:`serve`. + + This class provides the same interface as :class:`~asyncio.Server`, + notably the :meth:`~asyncio.Server.close` + and :meth:`~asyncio.Server.wait_closed` methods. + + It keeps track of WebSocket connections in order to close them properly + when shutting down. + + Args: + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__(self, logger: Optional[LoggerLike] = None): + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + # Keep track of active connections. + self.websockets: Set[WebSocketServerProtocol] = set() + + # Task responsible for closing the server and terminating connections. + self.close_task: Optional[asyncio.Task[None]] = None + + # Completed when the server is closed and connections are terminated. + self.closed_waiter: asyncio.Future[None] + + def wrap(self, server: asyncio.base_events.Server) -> None: + """ + Attach to a given :class:`~asyncio.Server`. + + Since :meth:`~asyncio.loop.create_server` doesn't support injecting a + custom ``Server`` class, the easiest solution that doesn't rely on + private :mod:`asyncio` APIs is to: + + - instantiate a :class:`WebSocketServer` + - give the protocol factory a reference to that instance + - call :meth:`~asyncio.loop.create_server` with the factory + - attach the resulting :class:`~asyncio.Server` with this method + + """ + self.server = server + for sock in server.sockets: + if sock.family == socket.AF_INET: + name = "%s:%d" % sock.getsockname() + elif sock.family == socket.AF_INET6: + name = "[%s]:%d" % sock.getsockname()[:2] + elif sock.family == socket.AF_UNIX: + name = sock.getsockname() + # In the unlikely event that someone runs websockets over a + # protocol other than IP or Unix sockets, avoid crashing. + else: # pragma: no cover + name = str(sock.getsockname()) + self.logger.info("server listening on %s", name) + + # Initialized here because we need a reference to the event loop. + # This should be moved back to __init__ when dropping Python < 3.10. + self.closed_waiter = server.get_loop().create_future() + + def register(self, protocol: WebSocketServerProtocol) -> None: + """ + Register a connection with this server. + + """ + self.websockets.add(protocol) + + def unregister(self, protocol: WebSocketServerProtocol) -> None: + """ + Unregister a connection with this server. + + """ + self.websockets.remove(protocol) + + def close(self, close_connections: bool = True) -> None: + """ + Close the server. + + * Close the underlying :class:`~asyncio.Server`. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + + * Wait until all connection handlers terminate. + + :meth:`close` is idempotent. + + """ + if self.close_task is None: + self.close_task = self.get_loop().create_task( + self._close(close_connections) + ) + + async def _close(self, close_connections: bool) -> None: + """ + Implementation of :meth:`close`. + + This calls :meth:`~asyncio.Server.close` on the underlying + :class:`~asyncio.Server` object to stop accepting new connections and + then closes open connections with close code 1001. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.server.close() + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://bugs.python.org/issue34852 for details. + await asyncio.sleep(0) + + if close_connections: + # Close OPEN connections with close code 1001. After server.close(), + # handshake() closes OPENING connections with an HTTP 503 error. + close_tasks = [ + asyncio.create_task(websocket.close(1001)) + for websocket in self.websockets + if websocket.state is not State.CONNECTING + ] + # asyncio.wait doesn't accept an empty first argument. + if close_tasks: + await asyncio.wait(close_tasks) + + # Wait until all TCP connections are closed. + await self.server.wait_closed() + + # Wait until all connection handlers terminate. + # asyncio.wait doesn't accept an empty first argument. + if self.websockets: + await asyncio.wait( + [websocket.handler_task for websocket in self.websockets] + ) + + # Tell wait_closed() to return. + self.closed_waiter.set_result(None) + + self.logger.info("server closed") + + async def wait_closed(self) -> None: + """ + Wait until the server is closed. + + When :meth:`wait_closed` returns, all TCP connections are closed and + all connection handlers have returned. + + To ensure a fast shutdown, a connection handler should always be + awaiting at least one of: + + * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK`; + * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is + closed, it returns. + + Then the connection handler is immediately notified of the shutdown; + it can clean up and exit. + + """ + await asyncio.shield(self.closed_waiter) + + def get_loop(self) -> asyncio.AbstractEventLoop: + """ + See :meth:`asyncio.Server.get_loop`. + + """ + return self.server.get_loop() + + def is_serving(self) -> bool: + """ + See :meth:`asyncio.Server.is_serving`. + + """ + return self.server.is_serving() + + async def start_serving(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.start_serving`. + + Typical use:: + + server = await serve(..., start_serving=False) + # perform additional setup here... + # ... then start the server + await server.start_serving() + + """ + await self.server.start_serving() + + async def serve_forever(self) -> None: # pragma: no cover + """ + See :meth:`asyncio.Server.serve_forever`. + + Typical use:: + + server = await serve(...) + # this coroutine doesn't return + # canceling it stops the server + await server.serve_forever() + + This is an alternative to using :func:`serve` as an asynchronous context + manager. Shutdown is triggered by canceling :meth:`serve_forever` + instead of exiting a :func:`serve` context. + + """ + await self.server.serve_forever() + + @property + def sockets(self) -> Iterable[socket.socket]: + """ + See :attr:`asyncio.Server.sockets`. + + """ + return self.server.sockets + + async def __aenter__(self) -> WebSocketServer: # pragma: no cover + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: # pragma: no cover + self.close() + await self.wait_closed() + + +class Serve: + """ + Start a WebSocket server listening on ``host`` and ``port``. + + Whenever a client connects, the server creates a + :class:`WebSocketServerProtocol`, performs the opening handshake, and + delegates to the connection handler, ``ws_handler``. + + The handler receives the :class:`WebSocketServerProtocol` and uses it to + send and receive messages. + + Once the handler completes, either normally or with an exception, the + server performs the closing handshake and closes the connection. + + Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object + provides a :meth:`~WebSocketServer.close` method to shut down the server:: + + stop = asyncio.Future() # set this future to exit the server + + server = await serve(...) + await stop + await server.close() + + :func:`serve` can be used as an asynchronous context manager. Then, the + server is shut down automatically when exiting the context:: + + stop = asyncio.Future() # set this future to exit the server + + async with serve(...): + await stop + + Args: + ws_handler: Connection handler. It receives the WebSocket connection, + which is a :class:`WebSocketServerProtocol`, in argument. + host: Network interfaces the server binds to. + See :meth:`~asyncio.loop.create_server` for details. + port: TCP port the server listens on. + See :meth:`~asyncio.loop.create_server` for details. + create_protocol: Factory for the :class:`asyncio.Protocol` managing + the connection. It defaults to :class:`WebSocketServerProtocol`. + Set it to a wrapper or a subclass to customize connection handling. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` + in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]): + Arbitrary HTTP headers to add to the response. This can be + a :data:`~websockets.datastructures.HeadersLike` or a callable + taking the request path and headers in arguments and returning + a :data:`~websockets.datastructures.HeadersLike`. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + process_request (Optional[Callable[[str, Headers], \ + Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]): + Intercept HTTP request before the opening handshake. + See :meth:`~WebSocketServerProtocol.process_request` for details. + select_subprotocol: Select a subprotocol supported by the client. + See :meth:`~WebSocketServerProtocol.select_subprotocol` for details. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + + See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the + documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``, + ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``. + + Any other keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_server` method. + + For example: + + * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS. + + * You can set ``sock`` to a :obj:`~socket.socket` that you created + outside of websockets. + + Returns: + WebSocketServer: WebSocket server. + + """ + + def __init__( + self, + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], + host: Optional[Union[str, Sequence[str]]] = None, + port: Optional[int] = None, + *, + create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None, + logger: Optional[LoggerLike] = None, + compression: Optional[str] = "deflate", + origins: Optional[Sequence[Optional[Origin]]] = None, + extensions: Optional[Sequence[ServerExtensionFactory]] = None, + subprotocols: Optional[Sequence[Subprotocol]] = None, + extra_headers: Optional[HeadersLikeOrCallable] = None, + server_header: Optional[str] = USER_AGENT, + process_request: Optional[ + Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]] + ] = None, + select_subprotocol: Optional[ + Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol] + ] = None, + open_timeout: Optional[float] = 10, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = None, + max_size: Optional[int] = 2**20, + max_queue: Optional[int] = 2**5, + read_limit: int = 2**16, + write_limit: int = 2**16, + **kwargs: Any, + ) -> None: + # Backwards compatibility: close_timeout used to be called timeout. + timeout: Optional[float] = kwargs.pop("timeout", None) + if timeout is None: + timeout = 10 + else: + warnings.warn("rename timeout to close_timeout", DeprecationWarning) + # If both are specified, timeout is ignored. + if close_timeout is None: + close_timeout = timeout + + # Backwards compatibility: create_protocol used to be called klass. + klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None) + if klass is None: + klass = WebSocketServerProtocol + else: + warnings.warn("rename klass to create_protocol", DeprecationWarning) + # If both are specified, klass is ignored. + if create_protocol is None: + create_protocol = klass + + # Backwards compatibility: recv() used to return None on closed connections + legacy_recv: bool = kwargs.pop("legacy_recv", False) + + # Backwards compatibility: the loop parameter used to be supported. + _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None) + if _loop is None: + loop = asyncio.get_event_loop() + else: + loop = _loop + warnings.warn("remove loop argument", DeprecationWarning) + + ws_server = WebSocketServer(logger=logger) + + secure = kwargs.get("ssl") is not None + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + factory = functools.partial( + create_protocol, + # For backwards compatibility with 10.0 or earlier. Done here in + # addition to WebSocketServerProtocol to trigger the deprecation + # warning once per serve() call rather than once per connection. + remove_path_argument(ws_handler), + ws_server, + host=host, + port=port, + secure=secure, + open_timeout=open_timeout, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_size=max_size, + max_queue=max_queue, + read_limit=read_limit, + write_limit=write_limit, + loop=_loop, + legacy_recv=legacy_recv, + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + extra_headers=extra_headers, + server_header=server_header, + process_request=process_request, + select_subprotocol=select_subprotocol, + logger=logger, + ) + + if kwargs.pop("unix", False): + path: Optional[str] = kwargs.pop("path", None) + # unix_serve(path) must not specify host and port parameters. + assert host is None and port is None + create_server = functools.partial( + loop.create_unix_server, factory, path, **kwargs + ) + else: + create_server = functools.partial( + loop.create_server, factory, host, port, **kwargs + ) + + # This is a coroutine function. + self._create_server = create_server + self.ws_server = ws_server + + # async with serve(...) + + async def __aenter__(self) -> WebSocketServer: + return await self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.ws_server.close() + await self.ws_server.wait_closed() + + # await serve(...) + + def __await__(self) -> Generator[Any, None, WebSocketServer]: + # Create a suitable iterator by calling __await__ on a coroutine. + return self.__await_impl__().__await__() + + async def __await_impl__(self) -> WebSocketServer: + server = await self._create_server() + self.ws_server.wrap(server) + return self.ws_server + + # yield from serve(...) - remove when dropping Python < 3.10 + + __iter__ = __await__ + + +serve = Serve + + +def unix_serve( + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated + ], + path: Optional[str] = None, + **kwargs: Any, +) -> Serve: + """ + Start a WebSocket server listening on a Unix socket. + + This function is identical to :func:`serve`, except the ``host`` and + ``port`` arguments are replaced by ``path``. It is only available on Unix. + + Unrecognized keyword arguments are passed the event loop's + :meth:`~asyncio.loop.create_unix_server` method. + + It's useful for deploying a server behind a reverse proxy such as nginx. + + Args: + path: File system path to the Unix socket. + + """ + return serve(ws_handler, path=path, unix=True, **kwargs) + + +def remove_path_argument( + ws_handler: Union[ + Callable[[WebSocketServerProtocol], Awaitable[Any]], + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ] +) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]: + try: + inspect.signature(ws_handler).bind(None) + except TypeError: + try: + inspect.signature(ws_handler).bind(None, "") + except TypeError: # pragma: no cover + # ws_handler accepts neither one nor two arguments; leave it alone. + pass + else: + # ws_handler accepts two arguments; activate backwards compatibility. + + # Enable deprecation warning and announce deprecation in 11.0. + # warnings.warn("remove second argument of ws_handler", DeprecationWarning) + + async def _ws_handler(websocket: WebSocketServerProtocol) -> Any: + return await cast( + Callable[[WebSocketServerProtocol, str], Awaitable[Any]], + ws_handler, + )(websocket, websocket.path) + + return _ws_handler + + return cast( + Callable[[WebSocketServerProtocol], Awaitable[Any]], + ws_handler, + ) |