summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/websockets/client.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/websockets/client.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/websockets/client.py')
-rw-r--r--venv/lib/python3.11/site-packages/websockets/client.py360
1 files changed, 360 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/websockets/client.py b/venv/lib/python3.11/site-packages/websockets/client.py
new file mode 100644
index 0000000..b2f6220
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/websockets/client.py
@@ -0,0 +1,360 @@
+from __future__ import annotations
+
+import warnings
+from typing import Any, Generator, List, Optional, Sequence
+
+from .datastructures import Headers, MultipleValuesError
+from .exceptions import (
+ InvalidHandshake,
+ InvalidHeader,
+ InvalidHeaderValue,
+ InvalidStatus,
+ InvalidUpgrade,
+ NegotiationError,
+)
+from .extensions import ClientExtensionFactory, Extension
+from .headers import (
+ build_authorization_basic,
+ build_extension,
+ build_host,
+ build_subprotocol,
+ parse_connection,
+ parse_extension,
+ parse_subprotocol,
+ parse_upgrade,
+)
+from .http11 import Request, Response
+from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
+from .typing import (
+ ConnectionOption,
+ ExtensionHeader,
+ LoggerLike,
+ Origin,
+ Subprotocol,
+ UpgradeProtocol,
+)
+from .uri import WebSocketURI
+from .utils import accept_key, generate_key
+
+
+# See #940 for why lazy_import isn't used here for backwards compatibility.
+# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
+from .legacy.client import * # isort:skip # noqa: I001
+from .legacy.client import __all__ as legacy__all__
+
+
+__all__ = ["ClientProtocol"] + legacy__all__
+
+
+class ClientProtocol(Protocol):
+ """
+ Sans-I/O implementation of a WebSocket client connection.
+
+ Args:
+ wsuri: URI of the WebSocket server, parsed
+ with :func:`~websockets.uri.parse_uri`.
+ origin: value of the ``Origin`` header. This is useful when connecting
+ to a server that validates the ``Origin`` header to defend against
+ Cross-Site WebSocket Hijacking attacks.
+ extensions: list of supported extensions, in order in which they
+ should be tried.
+ subprotocols: list of supported subprotocols, in order of decreasing
+ preference.
+ state: initial state of the WebSocket connection.
+ max_size: maximum size of incoming messages in bytes;
+ :obj:`None` disables the limit.
+ logger: logger for this connection;
+ defaults to ``logging.getLogger("websockets.client")``;
+ see the :doc:`logging guide <../../topics/logging>` for details.
+
+ """
+
+ def __init__(
+ self,
+ wsuri: WebSocketURI,
+ *,
+ origin: Optional[Origin] = None,
+ extensions: Optional[Sequence[ClientExtensionFactory]] = None,
+ subprotocols: Optional[Sequence[Subprotocol]] = None,
+ state: State = CONNECTING,
+ max_size: Optional[int] = 2**20,
+ logger: Optional[LoggerLike] = None,
+ ):
+ super().__init__(
+ side=CLIENT,
+ state=state,
+ max_size=max_size,
+ logger=logger,
+ )
+ self.wsuri = wsuri
+ self.origin = origin
+ self.available_extensions = extensions
+ self.available_subprotocols = subprotocols
+ self.key = generate_key()
+
+ def connect(self) -> Request:
+ """
+ Create a handshake request to open a connection.
+
+ You must send the handshake request with :meth:`send_request`.
+
+ You can modify it before sending it, for example to add HTTP headers.
+
+ Returns:
+ Request: WebSocket handshake request event to send to the server.
+
+ """
+ headers = Headers()
+
+ headers["Host"] = build_host(
+ self.wsuri.host, self.wsuri.port, self.wsuri.secure
+ )
+
+ if self.wsuri.user_info:
+ headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info)
+
+ if self.origin is not None:
+ headers["Origin"] = self.origin
+
+ headers["Upgrade"] = "websocket"
+ headers["Connection"] = "Upgrade"
+ headers["Sec-WebSocket-Key"] = self.key
+ headers["Sec-WebSocket-Version"] = "13"
+
+ if self.available_extensions is not None:
+ extensions_header = build_extension(
+ [
+ (extension_factory.name, extension_factory.get_request_params())
+ for extension_factory in self.available_extensions
+ ]
+ )
+ headers["Sec-WebSocket-Extensions"] = extensions_header
+
+ if self.available_subprotocols is not None:
+ protocol_header = build_subprotocol(self.available_subprotocols)
+ headers["Sec-WebSocket-Protocol"] = protocol_header
+
+ return Request(self.wsuri.resource_name, headers)
+
+ def process_response(self, response: Response) -> None:
+ """
+ Check a handshake response.
+
+ Args:
+ request: WebSocket handshake response received from the server.
+
+ Raises:
+ InvalidHandshake: if the handshake response is invalid.
+
+ """
+
+ if response.status_code != 101:
+ raise InvalidStatus(response)
+
+ headers = response.headers
+
+ connection: List[ConnectionOption] = sum(
+ [parse_connection(value) for value in headers.get_all("Connection")], []
+ )
+
+ if not any(value.lower() == "upgrade" for value in connection):
+ raise InvalidUpgrade(
+ "Connection", ", ".join(connection) if connection else None
+ )
+
+ upgrade: List[UpgradeProtocol] = sum(
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
+ )
+
+ # For compatibility with non-strict implementations, ignore case when
+ # checking the Upgrade header. It's supposed to be 'WebSocket'.
+ if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
+ raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)
+
+ try:
+ s_w_accept = headers["Sec-WebSocket-Accept"]
+ except KeyError as exc:
+ raise InvalidHeader("Sec-WebSocket-Accept") from exc
+ except MultipleValuesError as exc:
+ raise InvalidHeader(
+ "Sec-WebSocket-Accept",
+ "more than one Sec-WebSocket-Accept header found",
+ ) from exc
+
+ if s_w_accept != accept_key(self.key):
+ raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)
+
+ self.extensions = self.process_extensions(headers)
+
+ self.subprotocol = self.process_subprotocol(headers)
+
+ def process_extensions(self, headers: Headers) -> List[Extension]:
+ """
+ Handle the Sec-WebSocket-Extensions HTTP response header.
+
+ Check that each extension is supported, as well as its parameters.
+
+ :rfc:`6455` leaves the rules up to the specification of each
+ extension.
+
+ To provide this level of flexibility, for each extension accepted by
+ the server, we check for a match with each extension available in the
+ client configuration. If no match is found, an exception is raised.
+
+ If several variants of the same extension are accepted by the server,
+ it may be configured several times, which won't make sense in general.
+ Extensions must implement their own requirements. For this purpose,
+ the list of previously accepted extensions is provided.
+
+ Other requirements, for example related to mandatory extensions or the
+ order of extensions, may be implemented by overriding this method.
+
+ Args:
+ headers: WebSocket handshake response headers.
+
+ Returns:
+ List[Extension]: List of accepted extensions.
+
+ Raises:
+ InvalidHandshake: to abort the handshake.
+
+ """
+ accepted_extensions: List[Extension] = []
+
+ extensions = headers.get_all("Sec-WebSocket-Extensions")
+
+ if extensions:
+ if self.available_extensions is None:
+ raise InvalidHandshake("no extensions supported")
+
+ parsed_extensions: List[ExtensionHeader] = sum(
+ [parse_extension(header_value) for header_value in extensions], []
+ )
+
+ for name, response_params in parsed_extensions:
+ for extension_factory in self.available_extensions:
+ # Skip non-matching extensions based on their name.
+ if extension_factory.name != name:
+ continue
+
+ # Skip non-matching extensions based on their params.
+ try:
+ extension = extension_factory.process_response_params(
+ response_params, accepted_extensions
+ )
+ except NegotiationError:
+ continue
+
+ # Add matching extension to the final list.
+ accepted_extensions.append(extension)
+
+ # Break out of the loop once we have a match.
+ break
+
+ # If we didn't break from the loop, no extension in our list
+ # matched what the server sent. Fail the connection.
+ else:
+ raise NegotiationError(
+ f"Unsupported extension: "
+ f"name = {name}, params = {response_params}"
+ )
+
+ return accepted_extensions
+
+ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]:
+ """
+ Handle the Sec-WebSocket-Protocol HTTP response header.
+
+ If provided, check that it contains exactly one supported subprotocol.
+
+ Args:
+ headers: WebSocket handshake response headers.
+
+ Returns:
+ Optional[Subprotocol]: Subprotocol, if one was selected.
+
+ """
+ subprotocol: Optional[Subprotocol] = None
+
+ subprotocols = headers.get_all("Sec-WebSocket-Protocol")
+
+ if subprotocols:
+ if self.available_subprotocols is None:
+ raise InvalidHandshake("no subprotocols supported")
+
+ parsed_subprotocols: Sequence[Subprotocol] = sum(
+ [parse_subprotocol(header_value) for header_value in subprotocols], []
+ )
+
+ if len(parsed_subprotocols) > 1:
+ subprotocols_display = ", ".join(parsed_subprotocols)
+ raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")
+
+ subprotocol = parsed_subprotocols[0]
+
+ if subprotocol not in self.available_subprotocols:
+ raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
+
+ return subprotocol
+
+ def send_request(self, request: Request) -> None:
+ """
+ Send a handshake request to the server.
+
+ Args:
+ request: WebSocket handshake request event.
+
+ """
+ if self.debug:
+ self.logger.debug("> GET %s HTTP/1.1", request.path)
+ for key, value in request.headers.raw_items():
+ self.logger.debug("> %s: %s", key, value)
+
+ self.writes.append(request.serialize())
+
+ def parse(self) -> Generator[None, None, None]:
+ if self.state is CONNECTING:
+ try:
+ response = yield from Response.parse(
+ self.reader.read_line,
+ self.reader.read_exact,
+ self.reader.read_to_eof,
+ )
+ except Exception as exc:
+ self.handshake_exc = exc
+ self.parser = self.discard()
+ next(self.parser) # start coroutine
+ yield
+
+ if self.debug:
+ code, phrase = response.status_code, response.reason_phrase
+ self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
+ for key, value in response.headers.raw_items():
+ self.logger.debug("< %s: %s", key, value)
+ if response.body is not None:
+ self.logger.debug("< [body] (%d bytes)", len(response.body))
+
+ try:
+ self.process_response(response)
+ except InvalidHandshake as exc:
+ response._exception = exc
+ self.events.append(response)
+ self.handshake_exc = exc
+ self.parser = self.discard()
+ next(self.parser) # start coroutine
+ yield
+
+ assert self.state is CONNECTING
+ self.state = OPEN
+ self.events.append(response)
+
+ yield from super().parse()
+
+
+class ClientConnection(ClientProtocol):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ warnings.warn(
+ "ClientConnection was renamed to ClientProtocol",
+ DeprecationWarning,
+ )
+ super().__init__(*args, **kwargs)