diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/websockets/client.py')
| -rw-r--r-- | venv/lib/python3.11/site-packages/websockets/client.py | 360 | 
1 files changed, 360 insertions, 0 deletions
| diff --git a/venv/lib/python3.11/site-packages/websockets/client.py b/venv/lib/python3.11/site-packages/websockets/client.py new file mode 100644 index 0000000..b2f6220 --- /dev/null +++ b/venv/lib/python3.11/site-packages/websockets/client.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +import warnings +from typing import Any, Generator, List, Optional, Sequence + +from .datastructures import Headers, MultipleValuesError +from .exceptions import ( +    InvalidHandshake, +    InvalidHeader, +    InvalidHeaderValue, +    InvalidStatus, +    InvalidUpgrade, +    NegotiationError, +) +from .extensions import ClientExtensionFactory, Extension +from .headers import ( +    build_authorization_basic, +    build_extension, +    build_host, +    build_subprotocol, +    parse_connection, +    parse_extension, +    parse_subprotocol, +    parse_upgrade, +) +from .http11 import Request, Response +from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State +from .typing import ( +    ConnectionOption, +    ExtensionHeader, +    LoggerLike, +    Origin, +    Subprotocol, +    UpgradeProtocol, +) +from .uri import WebSocketURI +from .utils import accept_key, generate_key + + +# See #940 for why lazy_import isn't used here for backwards compatibility. +# See #1400 for why listing compatibility imports in __all__ helps PyCharm. +from .legacy.client import *  # isort:skip  # noqa: I001 +from .legacy.client import __all__ as legacy__all__ + + +__all__ = ["ClientProtocol"] + legacy__all__ + + +class ClientProtocol(Protocol): +    """ +    Sans-I/O implementation of a WebSocket client connection. + +    Args: +        wsuri: URI of the WebSocket server, parsed +            with :func:`~websockets.uri.parse_uri`. +        origin: value of the ``Origin`` header. This is useful when connecting +            to a server that validates the ``Origin`` header to defend against +            Cross-Site WebSocket Hijacking attacks. +        extensions: list of supported extensions, in order in which they +            should be tried. +        subprotocols: list of supported subprotocols, in order of decreasing +            preference. +        state: initial state of the WebSocket connection. +        max_size: maximum size of incoming messages in bytes; +            :obj:`None` disables the limit. +        logger: logger for this connection; +            defaults to ``logging.getLogger("websockets.client")``; +            see the :doc:`logging guide <../../topics/logging>` for details. + +    """ + +    def __init__( +        self, +        wsuri: WebSocketURI, +        *, +        origin: Optional[Origin] = None, +        extensions: Optional[Sequence[ClientExtensionFactory]] = None, +        subprotocols: Optional[Sequence[Subprotocol]] = None, +        state: State = CONNECTING, +        max_size: Optional[int] = 2**20, +        logger: Optional[LoggerLike] = None, +    ): +        super().__init__( +            side=CLIENT, +            state=state, +            max_size=max_size, +            logger=logger, +        ) +        self.wsuri = wsuri +        self.origin = origin +        self.available_extensions = extensions +        self.available_subprotocols = subprotocols +        self.key = generate_key() + +    def connect(self) -> Request: +        """ +        Create a handshake request to open a connection. + +        You must send the handshake request with :meth:`send_request`. + +        You can modify it before sending it, for example to add HTTP headers. + +        Returns: +            Request: WebSocket handshake request event to send to the server. + +        """ +        headers = Headers() + +        headers["Host"] = build_host( +            self.wsuri.host, self.wsuri.port, self.wsuri.secure +        ) + +        if self.wsuri.user_info: +            headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info) + +        if self.origin is not None: +            headers["Origin"] = self.origin + +        headers["Upgrade"] = "websocket" +        headers["Connection"] = "Upgrade" +        headers["Sec-WebSocket-Key"] = self.key +        headers["Sec-WebSocket-Version"] = "13" + +        if self.available_extensions is not None: +            extensions_header = build_extension( +                [ +                    (extension_factory.name, extension_factory.get_request_params()) +                    for extension_factory in self.available_extensions +                ] +            ) +            headers["Sec-WebSocket-Extensions"] = extensions_header + +        if self.available_subprotocols is not None: +            protocol_header = build_subprotocol(self.available_subprotocols) +            headers["Sec-WebSocket-Protocol"] = protocol_header + +        return Request(self.wsuri.resource_name, headers) + +    def process_response(self, response: Response) -> None: +        """ +        Check a handshake response. + +        Args: +            request: WebSocket handshake response received from the server. + +        Raises: +            InvalidHandshake: if the handshake response is invalid. + +        """ + +        if response.status_code != 101: +            raise InvalidStatus(response) + +        headers = response.headers + +        connection: List[ConnectionOption] = sum( +            [parse_connection(value) for value in headers.get_all("Connection")], [] +        ) + +        if not any(value.lower() == "upgrade" for value in connection): +            raise InvalidUpgrade( +                "Connection", ", ".join(connection) if connection else None +            ) + +        upgrade: List[UpgradeProtocol] = sum( +            [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] +        ) + +        # For compatibility with non-strict implementations, ignore case when +        # checking the Upgrade header. It's supposed to be 'WebSocket'. +        if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"): +            raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None) + +        try: +            s_w_accept = headers["Sec-WebSocket-Accept"] +        except KeyError as exc: +            raise InvalidHeader("Sec-WebSocket-Accept") from exc +        except MultipleValuesError as exc: +            raise InvalidHeader( +                "Sec-WebSocket-Accept", +                "more than one Sec-WebSocket-Accept header found", +            ) from exc + +        if s_w_accept != accept_key(self.key): +            raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept) + +        self.extensions = self.process_extensions(headers) + +        self.subprotocol = self.process_subprotocol(headers) + +    def process_extensions(self, headers: Headers) -> List[Extension]: +        """ +        Handle the Sec-WebSocket-Extensions HTTP response header. + +        Check that each extension is supported, as well as its parameters. + +        :rfc:`6455` leaves the rules up to the specification of each +        extension. + +        To provide this level of flexibility, for each extension accepted by +        the server, we check for a match with each extension available in the +        client configuration. If no match is found, an exception is raised. + +        If several variants of the same extension are accepted by the server, +        it may be configured several times, which won't make sense in general. +        Extensions must implement their own requirements. For this purpose, +        the list of previously accepted extensions is provided. + +        Other requirements, for example related to mandatory extensions or the +        order of extensions, may be implemented by overriding this method. + +        Args: +            headers: WebSocket handshake response headers. + +        Returns: +            List[Extension]: List of accepted extensions. + +        Raises: +            InvalidHandshake: to abort the handshake. + +        """ +        accepted_extensions: List[Extension] = [] + +        extensions = headers.get_all("Sec-WebSocket-Extensions") + +        if extensions: +            if self.available_extensions is None: +                raise InvalidHandshake("no extensions supported") + +            parsed_extensions: List[ExtensionHeader] = sum( +                [parse_extension(header_value) for header_value in extensions], [] +            ) + +            for name, response_params in parsed_extensions: +                for extension_factory in self.available_extensions: +                    # Skip non-matching extensions based on their name. +                    if extension_factory.name != name: +                        continue + +                    # Skip non-matching extensions based on their params. +                    try: +                        extension = extension_factory.process_response_params( +                            response_params, accepted_extensions +                        ) +                    except NegotiationError: +                        continue + +                    # Add matching extension to the final list. +                    accepted_extensions.append(extension) + +                    # Break out of the loop once we have a match. +                    break + +                # If we didn't break from the loop, no extension in our list +                # matched what the server sent. Fail the connection. +                else: +                    raise NegotiationError( +                        f"Unsupported extension: " +                        f"name = {name}, params = {response_params}" +                    ) + +        return accepted_extensions + +    def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]: +        """ +        Handle the Sec-WebSocket-Protocol HTTP response header. + +        If provided, check that it contains exactly one supported subprotocol. + +        Args: +            headers: WebSocket handshake response headers. + +        Returns: +           Optional[Subprotocol]: Subprotocol, if one was selected. + +        """ +        subprotocol: Optional[Subprotocol] = None + +        subprotocols = headers.get_all("Sec-WebSocket-Protocol") + +        if subprotocols: +            if self.available_subprotocols is None: +                raise InvalidHandshake("no subprotocols supported") + +            parsed_subprotocols: Sequence[Subprotocol] = sum( +                [parse_subprotocol(header_value) for header_value in subprotocols], [] +            ) + +            if len(parsed_subprotocols) > 1: +                subprotocols_display = ", ".join(parsed_subprotocols) +                raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}") + +            subprotocol = parsed_subprotocols[0] + +            if subprotocol not in self.available_subprotocols: +                raise NegotiationError(f"unsupported subprotocol: {subprotocol}") + +        return subprotocol + +    def send_request(self, request: Request) -> None: +        """ +        Send a handshake request to the server. + +        Args: +            request: WebSocket handshake request event. + +        """ +        if self.debug: +            self.logger.debug("> GET %s HTTP/1.1", request.path) +            for key, value in request.headers.raw_items(): +                self.logger.debug("> %s: %s", key, value) + +        self.writes.append(request.serialize()) + +    def parse(self) -> Generator[None, None, None]: +        if self.state is CONNECTING: +            try: +                response = yield from Response.parse( +                    self.reader.read_line, +                    self.reader.read_exact, +                    self.reader.read_to_eof, +                ) +            except Exception as exc: +                self.handshake_exc = exc +                self.parser = self.discard() +                next(self.parser)  # start coroutine +                yield + +            if self.debug: +                code, phrase = response.status_code, response.reason_phrase +                self.logger.debug("< HTTP/1.1 %d %s", code, phrase) +                for key, value in response.headers.raw_items(): +                    self.logger.debug("< %s: %s", key, value) +                if response.body is not None: +                    self.logger.debug("< [body] (%d bytes)", len(response.body)) + +            try: +                self.process_response(response) +            except InvalidHandshake as exc: +                response._exception = exc +                self.events.append(response) +                self.handshake_exc = exc +                self.parser = self.discard() +                next(self.parser)  # start coroutine +                yield + +            assert self.state is CONNECTING +            self.state = OPEN +            self.events.append(response) + +        yield from super().parse() + + +class ClientConnection(ClientProtocol): +    def __init__(self, *args: Any, **kwargs: Any) -> None: +        warnings.warn( +            "ClientConnection was renamed to ClientProtocol", +            DeprecationWarning, +        ) +        super().__init__(*args, **kwargs) | 
