summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/websockets/client.py
blob: b2f622042df8d4b560488cbbe217b42c7f135e44 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
from __future__ import annotations

import warnings
from typing import Any, Generator, List, Optional, Sequence

from .datastructures import Headers, MultipleValuesError
from .exceptions import (
    InvalidHandshake,
    InvalidHeader,
    InvalidHeaderValue,
    InvalidStatus,
    InvalidUpgrade,
    NegotiationError,
)
from .extensions import ClientExtensionFactory, Extension
from .headers import (
    build_authorization_basic,
    build_extension,
    build_host,
    build_subprotocol,
    parse_connection,
    parse_extension,
    parse_subprotocol,
    parse_upgrade,
)
from .http11 import Request, Response
from .protocol import CLIENT, CONNECTING, OPEN, Protocol, State
from .typing import (
    ConnectionOption,
    ExtensionHeader,
    LoggerLike,
    Origin,
    Subprotocol,
    UpgradeProtocol,
)
from .uri import WebSocketURI
from .utils import accept_key, generate_key


# See #940 for why lazy_import isn't used here for backwards compatibility.
# See #1400 for why listing compatibility imports in __all__ helps PyCharm.
from .legacy.client import *  # isort:skip  # noqa: I001
from .legacy.client import __all__ as legacy__all__


__all__ = ["ClientProtocol"] + legacy__all__


class ClientProtocol(Protocol):
    """
    Sans-I/O implementation of a WebSocket client connection.

    Args:
        wsuri: URI of the WebSocket server, parsed
            with :func:`~websockets.uri.parse_uri`.
        origin: value of the ``Origin`` header. This is useful when connecting
            to a server that validates the ``Origin`` header to defend against
            Cross-Site WebSocket Hijacking attacks.
        extensions: list of supported extensions, in order in which they
            should be tried.
        subprotocols: list of supported subprotocols, in order of decreasing
            preference.
        state: initial state of the WebSocket connection.
        max_size: maximum size of incoming messages in bytes;
            :obj:`None` disables the limit.
        logger: logger for this connection;
            defaults to ``logging.getLogger("websockets.client")``;
            see the :doc:`logging guide <../../topics/logging>` for details.

    """

    def __init__(
        self,
        wsuri: WebSocketURI,
        *,
        origin: Optional[Origin] = None,
        extensions: Optional[Sequence[ClientExtensionFactory]] = None,
        subprotocols: Optional[Sequence[Subprotocol]] = None,
        state: State = CONNECTING,
        max_size: Optional[int] = 2**20,
        logger: Optional[LoggerLike] = None,
    ):
        super().__init__(
            side=CLIENT,
            state=state,
            max_size=max_size,
            logger=logger,
        )
        self.wsuri = wsuri
        self.origin = origin
        self.available_extensions = extensions
        self.available_subprotocols = subprotocols
        self.key = generate_key()

    def connect(self) -> Request:
        """
        Create a handshake request to open a connection.

        You must send the handshake request with :meth:`send_request`.

        You can modify it before sending it, for example to add HTTP headers.

        Returns:
            Request: WebSocket handshake request event to send to the server.

        """
        headers = Headers()

        headers["Host"] = build_host(
            self.wsuri.host, self.wsuri.port, self.wsuri.secure
        )

        if self.wsuri.user_info:
            headers["Authorization"] = build_authorization_basic(*self.wsuri.user_info)

        if self.origin is not None:
            headers["Origin"] = self.origin

        headers["Upgrade"] = "websocket"
        headers["Connection"] = "Upgrade"
        headers["Sec-WebSocket-Key"] = self.key
        headers["Sec-WebSocket-Version"] = "13"

        if self.available_extensions is not None:
            extensions_header = build_extension(
                [
                    (extension_factory.name, extension_factory.get_request_params())
                    for extension_factory in self.available_extensions
                ]
            )
            headers["Sec-WebSocket-Extensions"] = extensions_header

        if self.available_subprotocols is not None:
            protocol_header = build_subprotocol(self.available_subprotocols)
            headers["Sec-WebSocket-Protocol"] = protocol_header

        return Request(self.wsuri.resource_name, headers)

    def process_response(self, response: Response) -> None:
        """
        Check a handshake response.

        Args:
            request: WebSocket handshake response received from the server.

        Raises:
            InvalidHandshake: if the handshake response is invalid.

        """

        if response.status_code != 101:
            raise InvalidStatus(response)

        headers = response.headers

        connection: List[ConnectionOption] = sum(
            [parse_connection(value) for value in headers.get_all("Connection")], []
        )

        if not any(value.lower() == "upgrade" for value in connection):
            raise InvalidUpgrade(
                "Connection", ", ".join(connection) if connection else None
            )

        upgrade: List[UpgradeProtocol] = sum(
            [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
        )

        # For compatibility with non-strict implementations, ignore case when
        # checking the Upgrade header. It's supposed to be 'WebSocket'.
        if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
            raise InvalidUpgrade("Upgrade", ", ".join(upgrade) if upgrade else None)

        try:
            s_w_accept = headers["Sec-WebSocket-Accept"]
        except KeyError as exc:
            raise InvalidHeader("Sec-WebSocket-Accept") from exc
        except MultipleValuesError as exc:
            raise InvalidHeader(
                "Sec-WebSocket-Accept",
                "more than one Sec-WebSocket-Accept header found",
            ) from exc

        if s_w_accept != accept_key(self.key):
            raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)

        self.extensions = self.process_extensions(headers)

        self.subprotocol = self.process_subprotocol(headers)

    def process_extensions(self, headers: Headers) -> List[Extension]:
        """
        Handle the Sec-WebSocket-Extensions HTTP response header.

        Check that each extension is supported, as well as its parameters.

        :rfc:`6455` leaves the rules up to the specification of each
        extension.

        To provide this level of flexibility, for each extension accepted by
        the server, we check for a match with each extension available in the
        client configuration. If no match is found, an exception is raised.

        If several variants of the same extension are accepted by the server,
        it may be configured several times, which won't make sense in general.
        Extensions must implement their own requirements. For this purpose,
        the list of previously accepted extensions is provided.

        Other requirements, for example related to mandatory extensions or the
        order of extensions, may be implemented by overriding this method.

        Args:
            headers: WebSocket handshake response headers.

        Returns:
            List[Extension]: List of accepted extensions.

        Raises:
            InvalidHandshake: to abort the handshake.

        """
        accepted_extensions: List[Extension] = []

        extensions = headers.get_all("Sec-WebSocket-Extensions")

        if extensions:
            if self.available_extensions is None:
                raise InvalidHandshake("no extensions supported")

            parsed_extensions: List[ExtensionHeader] = sum(
                [parse_extension(header_value) for header_value in extensions], []
            )

            for name, response_params in parsed_extensions:
                for extension_factory in self.available_extensions:
                    # Skip non-matching extensions based on their name.
                    if extension_factory.name != name:
                        continue

                    # Skip non-matching extensions based on their params.
                    try:
                        extension = extension_factory.process_response_params(
                            response_params, accepted_extensions
                        )
                    except NegotiationError:
                        continue

                    # Add matching extension to the final list.
                    accepted_extensions.append(extension)

                    # Break out of the loop once we have a match.
                    break

                # If we didn't break from the loop, no extension in our list
                # matched what the server sent. Fail the connection.
                else:
                    raise NegotiationError(
                        f"Unsupported extension: "
                        f"name = {name}, params = {response_params}"
                    )

        return accepted_extensions

    def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]:
        """
        Handle the Sec-WebSocket-Protocol HTTP response header.

        If provided, check that it contains exactly one supported subprotocol.

        Args:
            headers: WebSocket handshake response headers.

        Returns:
           Optional[Subprotocol]: Subprotocol, if one was selected.

        """
        subprotocol: Optional[Subprotocol] = None

        subprotocols = headers.get_all("Sec-WebSocket-Protocol")

        if subprotocols:
            if self.available_subprotocols is None:
                raise InvalidHandshake("no subprotocols supported")

            parsed_subprotocols: Sequence[Subprotocol] = sum(
                [parse_subprotocol(header_value) for header_value in subprotocols], []
            )

            if len(parsed_subprotocols) > 1:
                subprotocols_display = ", ".join(parsed_subprotocols)
                raise InvalidHandshake(f"multiple subprotocols: {subprotocols_display}")

            subprotocol = parsed_subprotocols[0]

            if subprotocol not in self.available_subprotocols:
                raise NegotiationError(f"unsupported subprotocol: {subprotocol}")

        return subprotocol

    def send_request(self, request: Request) -> None:
        """
        Send a handshake request to the server.

        Args:
            request: WebSocket handshake request event.

        """
        if self.debug:
            self.logger.debug("> GET %s HTTP/1.1", request.path)
            for key, value in request.headers.raw_items():
                self.logger.debug("> %s: %s", key, value)

        self.writes.append(request.serialize())

    def parse(self) -> Generator[None, None, None]:
        if self.state is CONNECTING:
            try:
                response = yield from Response.parse(
                    self.reader.read_line,
                    self.reader.read_exact,
                    self.reader.read_to_eof,
                )
            except Exception as exc:
                self.handshake_exc = exc
                self.parser = self.discard()
                next(self.parser)  # start coroutine
                yield

            if self.debug:
                code, phrase = response.status_code, response.reason_phrase
                self.logger.debug("< HTTP/1.1 %d %s", code, phrase)
                for key, value in response.headers.raw_items():
                    self.logger.debug("< %s: %s", key, value)
                if response.body is not None:
                    self.logger.debug("< [body] (%d bytes)", len(response.body))

            try:
                self.process_response(response)
            except InvalidHandshake as exc:
                response._exception = exc
                self.events.append(response)
                self.handshake_exc = exc
                self.parser = self.discard()
                next(self.parser)  # start coroutine
                yield

            assert self.state is CONNECTING
            self.state = OPEN
            self.events.append(response)

        yield from super().parse()


class ClientConnection(ClientProtocol):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        warnings.warn(
            "ClientConnection was renamed to ClientProtocol",
            DeprecationWarning,
        )
        super().__init__(*args, **kwargs)