summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/websockets/legacy/handshake.py
blob: ad8faf04045e7184bbff40bf7c17cbc1cb788f0c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations

import base64
import binascii
from typing import List

from ..datastructures import Headers, MultipleValuesError
from ..exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
from ..headers import parse_connection, parse_upgrade
from ..typing import ConnectionOption, UpgradeProtocol
from ..utils import accept_key as accept, generate_key


__all__ = ["build_request", "check_request", "build_response", "check_response"]


def build_request(headers: Headers) -> str:
    """
    Build a handshake request to send to the server.

    Update request headers passed in argument.

    Args:
        headers: Handshake request headers.

    Returns:
        str: ``key`` that must be passed to :func:`check_response`.

    """
    key = generate_key()
    headers["Upgrade"] = "websocket"
    headers["Connection"] = "Upgrade"
    headers["Sec-WebSocket-Key"] = key
    headers["Sec-WebSocket-Version"] = "13"
    return key


def check_request(headers: Headers) -> str:
    """
    Check a handshake request received from the client.

    This function doesn't verify that the request is an HTTP/1.1 or higher GET
    request and doesn't perform ``Host`` and ``Origin`` checks. These controls
    are usually performed earlier in the HTTP request handling code. They're
    the responsibility of the caller.

    Args:
        headers: Handshake request headers.

    Returns:
        str: ``key`` that must be passed to :func:`build_response`.

    Raises:
        InvalidHandshake: If the handshake request is invalid.
            Then, the server must return a 400 Bad Request error.

    """
    connection: List[ConnectionOption] = sum(
        [parse_connection(value) for value in headers.get_all("Connection")], []
    )

    if not any(value.lower() == "upgrade" for value in connection):
        raise InvalidUpgrade("Connection", ", ".join(connection))

    upgrade: List[UpgradeProtocol] = sum(
        [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
    )

    # For compatibility with non-strict implementations, ignore case when
    # checking the Upgrade header. The RFC always uses "websocket", except
    # in section 11.2. (IANA registration) where it uses "WebSocket".
    if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
        raise InvalidUpgrade("Upgrade", ", ".join(upgrade))

    try:
        s_w_key = headers["Sec-WebSocket-Key"]
    except KeyError as exc:
        raise InvalidHeader("Sec-WebSocket-Key") from exc
    except MultipleValuesError as exc:
        raise InvalidHeader(
            "Sec-WebSocket-Key", "more than one Sec-WebSocket-Key header found"
        ) from exc

    try:
        raw_key = base64.b64decode(s_w_key.encode(), validate=True)
    except binascii.Error as exc:
        raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key) from exc
    if len(raw_key) != 16:
        raise InvalidHeaderValue("Sec-WebSocket-Key", s_w_key)

    try:
        s_w_version = headers["Sec-WebSocket-Version"]
    except KeyError as exc:
        raise InvalidHeader("Sec-WebSocket-Version") from exc
    except MultipleValuesError as exc:
        raise InvalidHeader(
            "Sec-WebSocket-Version", "more than one Sec-WebSocket-Version header found"
        ) from exc

    if s_w_version != "13":
        raise InvalidHeaderValue("Sec-WebSocket-Version", s_w_version)

    return s_w_key


def build_response(headers: Headers, key: str) -> None:
    """
    Build a handshake response to send to the client.

    Update response headers passed in argument.

    Args:
        headers: Handshake response headers.
        key: Returned by :func:`check_request`.

    """
    headers["Upgrade"] = "websocket"
    headers["Connection"] = "Upgrade"
    headers["Sec-WebSocket-Accept"] = accept(key)


def check_response(headers: Headers, key: str) -> None:
    """
    Check a handshake response received from the server.

    This function doesn't verify that the response is an HTTP/1.1 or higher
    response with a 101 status code. These controls are the responsibility of
    the caller.

    Args:
        headers: Handshake response headers.
        key: Returned by :func:`build_request`.

    Raises:
        InvalidHandshake: If the handshake response is invalid.

    """
    connection: List[ConnectionOption] = sum(
        [parse_connection(value) for value in headers.get_all("Connection")], []
    )

    if not any(value.lower() == "upgrade" for value in connection):
        raise InvalidUpgrade("Connection", " ".join(connection))

    upgrade: List[UpgradeProtocol] = sum(
        [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
    )

    # For compatibility with non-strict implementations, ignore case when
    # checking the Upgrade header. The RFC always uses "websocket", except
    # in section 11.2. (IANA registration) where it uses "WebSocket".
    if not (len(upgrade) == 1 and upgrade[0].lower() == "websocket"):
        raise InvalidUpgrade("Upgrade", ", ".join(upgrade))

    try:
        s_w_accept = headers["Sec-WebSocket-Accept"]
    except KeyError as exc:
        raise InvalidHeader("Sec-WebSocket-Accept") from exc
    except MultipleValuesError as exc:
        raise InvalidHeader(
            "Sec-WebSocket-Accept", "more than one Sec-WebSocket-Accept header found"
        ) from exc

    if s_w_accept != accept(key):
        raise InvalidHeaderValue("Sec-WebSocket-Accept", s_w_accept)