diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/uvicorn/protocols | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/uvicorn/protocols')
22 files changed, 2044 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..5ff273e --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..da51e43 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c56186b --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a598e8a --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a925cb4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..2379bb6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ba76184 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py new file mode 100644 index 0000000..a14bec1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import asyncio + +AutoHTTPProtocol: type[asyncio.Protocol] +try: + import httptools # noqa +except ImportError: # pragma: no cover + from uvicorn.protocols.http.h11_impl import H11Protocol + + AutoHTTPProtocol = H11Protocol +else: # pragma: no cover + from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol + + AutoHTTPProtocol = HttpToolsProtocol diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py new file mode 100644 index 0000000..893a26c --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py @@ -0,0 +1,64 @@ +import asyncio + +from uvicorn._types import ( + ASGIReceiveCallable, + ASGISendCallable, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + Scope, +) + +CLOSE_HEADER = (b"connection", b"close") + +HIGH_WATER_LIMIT = 65536 + + +class FlowControl: + def __init__(self, transport: asyncio.Transport) -> None: + self._transport = transport + self.read_paused = False + self.write_paused = False + self._is_writable_event = asyncio.Event() + self._is_writable_event.set() + + async def drain(self) -> None: + await self._is_writable_event.wait() + + def pause_reading(self) -> None: + if not self.read_paused: + self.read_paused = True + self._transport.pause_reading() + + def resume_reading(self) -> None: + if self.read_paused: + self.read_paused = False + self._transport.resume_reading() + + def pause_writing(self) -> None: + if not self.write_paused: + self.write_paused = True + self._is_writable_event.clear() + + def resume_writing(self) -> None: + if self.write_paused: + self.write_paused = False + self._is_writable_event.set() + + +async def service_unavailable(scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None: + response_start: "HTTPResponseStartEvent" = { + "type": "http.response.start", + "status": 503, + "headers": [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ], + } + await send(response_start) + + response_body: "HTTPResponseBodyEvent" = { + "type": "http.response.body", + "body": b"Service Unavailable", + "more_body": False, + } + await send(response_body) diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py new file mode 100644 index 0000000..d0f2b2a --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import asyncio +import http +import logging +from typing import Any, Callable, Literal, cast +from urllib.parse import unquote + +import h11 +from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE + +from uvicorn._types import ( + ASGI3Application, + ASGIReceiveEvent, + ASGISendEvent, + HTTPRequestEvent, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, +) +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.http.flow_control import ( + CLOSE_HEADER, + HIGH_WATER_LIMIT, + FlowControl, + service_unavailable, +) +from uvicorn.protocols.utils import ( + get_client_addr, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState + + +def _get_status_phrase(status_code: int) -> bytes: + try: + return http.HTTPStatus(status_code).phrase.encode() + except ValueError: + return b"" + + +STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)} + + +class H11Protocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.access_logger = logging.getLogger("uvicorn.access") + self.access_log = self.access_logger.hasHandlers() + self.conn = h11.Connection( + h11.SERVER, + config.h11_max_incomplete_event_size + if config.h11_max_incomplete_event_size is not None + else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) + self.ws_protocol_class = config.ws_protocol_class + self.root_path = config.root_path + self.limit_concurrency = config.limit_concurrency + self.app_state = app_state + + # Timeouts + self.timeout_keep_alive_task: asyncio.TimerHandle | None = None + self.timeout_keep_alive = config.timeout_keep_alive + + # Shared server state + self.server_state = server_state + self.connections = server_state.connections + self.tasks = server_state.tasks + + # Per-connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.flow: FlowControl = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + self.scheme: Literal["http", "https"] | None = None + + # Per-request state + self.scope: HTTPScope = None # type: ignore[assignment] + self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] + self.cycle: RequestResponseCycle = None # type: ignore[assignment] + + # Protocol interface + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: + self.connections.add(self) + + self.transport = transport + self.flow = FlowControl(transport) + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "https" if is_ssl(transport) else "http" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) + + def connection_lost(self, exc: Exception | None) -> None: + self.connections.discard(self) + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) + + if self.cycle and not self.cycle.response_complete: + self.cycle.disconnected = True + if self.conn.our_state != h11.ERROR: + event = h11.ConnectionClosed() + try: + self.conn.send(event) + except h11.LocalProtocolError: + # Premature client disconnect + pass + + if self.cycle is not None: + self.cycle.message_event.set() + if self.flow is not None: + self.flow.resume_writing() + if exc is None: + self.transport.close() + self._unset_keepalive_if_required() + + def eof_received(self) -> None: + pass + + def _unset_keepalive_if_required(self) -> None: + if self.timeout_keep_alive_task is not None: + self.timeout_keep_alive_task.cancel() + self.timeout_keep_alive_task = None + + def _get_upgrade(self) -> bytes | None: + connection = [] + upgrade = None + for name, value in self.headers: + if name == b"connection": + connection = [token.lower().strip() for token in value.split(b",")] + if name == b"upgrade": + upgrade = value.lower() + if b"upgrade" in connection: + return upgrade + return None + + def _should_upgrade_to_ws(self) -> bool: + if self.ws_protocol_class is None: + if self.config.ws == "auto": + msg = "Unsupported upgrade request." + self.logger.warning(msg) + msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + return False + return True + + def data_received(self, data: bytes) -> None: + self._unset_keepalive_if_required() + + self.conn.receive_data(data) + self.handle_events() + + def handle_events(self) -> None: + while True: + try: + event = self.conn.next_event() + except h11.RemoteProtocolError: + msg = "Invalid HTTP request received." + self.logger.warning(msg) + self.send_400_response(msg) + return + + if event is h11.NEED_DATA: + break + + elif event is h11.PAUSED: + # This case can occur in HTTP pipelining, so we need to + # stop reading any more data, and ensure that at the end + # of the active request/response cycle we handle any + # events that have been buffered up. + self.flow.pause_reading() + break + + elif isinstance(event, h11.Request): + self.headers = [(key.lower(), value) for key, value in event.headers] + raw_path, _, query_string = event.target.partition(b"?") + path = unquote(raw_path.decode("ascii")) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path + self.scope = { + "type": "http", + "asgi": { + "version": self.config.asgi_version, + "spec_version": "2.4", + }, + "http_version": event.http_version.decode("ascii"), + "server": self.server, + "client": self.client, + "scheme": self.scheme, # type: ignore[typeddict-item] + "method": event.method.decode("ascii"), + "root_path": self.root_path, + "path": full_path, + "raw_path": full_raw_path, + "query_string": query_string, + "headers": self.headers, + "state": self.app_state.copy(), + } + + upgrade = self._get_upgrade() + if upgrade == b"websocket" and self._should_upgrade_to_ws(): + self.handle_websocket_upgrade(event) + return + + # Handle 503 responses when 'limit_concurrency' is exceeded. + if self.limit_concurrency is not None and ( + len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency + ): + app = service_unavailable + message = "Exceeded concurrency limit." + self.logger.warning(message) + else: + app = self.app + + # When starting to process a request, disable the keep-alive + # timeout. Normally we disable this when receiving data from + # client and set back when finishing processing its request. + # However, for pipelined requests processing finishes after + # already receiving the next request and thus the timer may + # be set here, which we don't want. + self._unset_keepalive_if_required() + + self.cycle = RequestResponseCycle( + scope=self.scope, + conn=self.conn, + transport=self.transport, + flow=self.flow, + logger=self.logger, + access_logger=self.access_logger, + access_log=self.access_log, + default_headers=self.server_state.default_headers, + message_event=asyncio.Event(), + on_response=self.on_response_complete, + ) + task = self.loop.create_task(self.cycle.run_asgi(app)) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + + elif isinstance(event, h11.Data): + if self.conn.our_state is h11.DONE: + continue + self.cycle.body += event.data + if len(self.cycle.body) > HIGH_WATER_LIMIT: + self.flow.pause_reading() + self.cycle.message_event.set() + + elif isinstance(event, h11.EndOfMessage): + if self.conn.our_state is h11.DONE: + self.transport.resume_reading() + self.conn.start_next_cycle() + continue + self.cycle.more_body = False + self.cycle.message_event.set() + + def handle_websocket_upgrade(self, event: h11.Request) -> None: + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) + + self.connections.discard(self) + output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"] + for name, value in self.headers: + output += [name, b": ", value, b"\r\n"] + output.append(b"\r\n") + protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] + config=self.config, + server_state=self.server_state, + app_state=self.app_state, + ) + protocol.connection_made(self.transport) + protocol.data_received(b"".join(output)) + self.transport.set_protocol(protocol) + + def send_400_response(self, msg: str) -> None: + reason = STATUS_PHRASES[400] + headers: list[tuple[bytes, bytes]] = [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ] + event = h11.Response(status_code=400, headers=headers, reason=reason) + output = self.conn.send(event) + self.transport.write(output) + + output = self.conn.send(event=h11.Data(data=msg.encode("ascii"))) + self.transport.write(output) + + output = self.conn.send(event=h11.EndOfMessage()) + self.transport.write(output) + + self.transport.close() + + def on_response_complete(self) -> None: + self.server_state.total_requests += 1 + + if self.transport.is_closing(): + return + + # Set a short Keep-Alive timeout. + self._unset_keepalive_if_required() + + self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler) + + # Unpause data reads if needed. + self.flow.resume_reading() + + # Unblock any pipelined events. + if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE: + self.conn.start_next_cycle() + self.handle_events() + + def shutdown(self) -> None: + """ + Called by the server to commence a graceful shutdown. + """ + if self.cycle is None or self.cycle.response_complete: + event = h11.ConnectionClosed() + self.conn.send(event) + self.transport.close() + else: + self.cycle.keep_alive = False + + def pause_writing(self) -> None: + """ + Called by the transport when the write buffer exceeds the high water mark. + """ + self.flow.pause_writing() + + def resume_writing(self) -> None: + """ + Called by the transport when the write buffer drops below the low water mark. + """ + self.flow.resume_writing() + + def timeout_keep_alive_handler(self) -> None: + """ + Called on a keep-alive connection if no new data is received after a short + delay. + """ + if not self.transport.is_closing(): + event = h11.ConnectionClosed() + self.conn.send(event) + self.transport.close() + + +class RequestResponseCycle: + def __init__( + self, + scope: HTTPScope, + conn: h11.Connection, + transport: asyncio.Transport, + flow: FlowControl, + logger: logging.Logger, + access_logger: logging.Logger, + access_log: bool, + default_headers: list[tuple[bytes, bytes]], + message_event: asyncio.Event, + on_response: Callable[..., None], + ) -> None: + self.scope = scope + self.conn = conn + self.transport = transport + self.flow = flow + self.logger = logger + self.access_logger = access_logger + self.access_log = access_log + self.default_headers = default_headers + self.message_event = message_event + self.on_response = on_response + + # Connection state + self.disconnected = False + self.keep_alive = True + self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue + + # Request state + self.body = b"" + self.more_body = True + + # Response state + self.response_started = False + self.response_complete = False + + # ASGI exception wrapper + async def run_asgi(self, app: ASGI3Application) -> None: + try: + result = await app( # type: ignore[func-returns-value] + self.scope, self.receive, self.send + ) + except BaseException as exc: + msg = "Exception in ASGI application\n" + self.logger.error(msg, exc_info=exc) + if not self.response_started: + await self.send_500_response() + else: + self.transport.close() + else: + if result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + elif not self.response_started and not self.disconnected: + msg = "ASGI callable returned without starting response." + self.logger.error(msg) + await self.send_500_response() + elif not self.response_complete and not self.disconnected: + msg = "ASGI callable returned without completing response." + self.logger.error(msg) + self.transport.close() + finally: + self.on_response = lambda: None + + async def send_500_response(self) -> None: + response_start_event: HTTPResponseStartEvent = { + "type": "http.response.start", + "status": 500, + "headers": [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ], + } + await self.send(response_start_event) + response_body_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": b"Internal Server Error", + "more_body": False, + } + await self.send(response_body_event) + + # ASGI interface + async def send(self, message: ASGISendEvent) -> None: + message_type = message["type"] + + if self.flow.write_paused and not self.disconnected: + await self.flow.drain() + + if self.disconnected: + return + + if not self.response_started: + # Sending response status line and headers + if message_type != "http.response.start": + msg = "Expected ASGI message 'http.response.start', but got '%s'." + raise RuntimeError(msg % message_type) + message = cast("HTTPResponseStartEvent", message) + + self.response_started = True + self.waiting_for_100_continue = False + + status = message["status"] + headers = self.default_headers + list(message.get("headers", [])) + + if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers: + headers = headers + [CLOSE_HEADER] + + if self.access_log: + self.access_logger.info( + '%s - "%s %s HTTP/%s" %d', + get_client_addr(self.scope), + self.scope["method"], + get_path_with_query_string(self.scope), + self.scope["http_version"], + status, + ) + + # Write response status line and headers + reason = STATUS_PHRASES[status] + response = h11.Response(status_code=status, headers=headers, reason=reason) + output = self.conn.send(event=response) + self.transport.write(output) + + elif not self.response_complete: + # Sending response body + if message_type != "http.response.body": + msg = "Expected ASGI message 'http.response.body', but got '%s'." + raise RuntimeError(msg % message_type) + message = cast("HTTPResponseBodyEvent", message) + + body = message.get("body", b"") + more_body = message.get("more_body", False) + + # Write response body + data = b"" if self.scope["method"] == "HEAD" else body + output = self.conn.send(event=h11.Data(data=data)) + self.transport.write(output) + + # Handle response completion + if not more_body: + self.response_complete = True + self.message_event.set() + output = self.conn.send(event=h11.EndOfMessage()) + self.transport.write(output) + + else: + # Response already sent + msg = "Unexpected ASGI message '%s' sent, after response already completed." + raise RuntimeError(msg % message_type) + + if self.response_complete: + if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive: + self.conn.send(event=h11.ConnectionClosed()) + self.transport.close() + self.on_response() + + async def receive(self) -> ASGIReceiveEvent: + if self.waiting_for_100_continue and not self.transport.is_closing(): + headers: list[tuple[str, str]] = [] + event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue") + output = self.conn.send(event=event) + self.transport.write(output) + self.waiting_for_100_continue = False + + if not self.disconnected and not self.response_complete: + self.flow.resume_reading() + await self.message_event.wait() + self.message_event.clear() + + if self.disconnected or self.response_complete: + return {"type": "http.disconnect"} + + message: HTTPRequestEvent = { + "type": "http.request", + "body": self.body, + "more_body": self.more_body, + } + self.body = b"" + return message diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py new file mode 100644 index 0000000..997c6bb --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py @@ -0,0 +1,575 @@ +from __future__ import annotations + +import asyncio +import http +import logging +import re +import urllib +from asyncio.events import TimerHandle +from collections import deque +from typing import Any, Callable, Literal, cast + +import httptools + +from uvicorn._types import ( + ASGI3Application, + ASGIReceiveEvent, + ASGISendEvent, + HTTPRequestEvent, + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + HTTPScope, +) +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.http.flow_control import ( + CLOSE_HEADER, + HIGH_WATER_LIMIT, + FlowControl, + service_unavailable, +) +from uvicorn.protocols.utils import ( + get_client_addr, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState + +HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]') +HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]") + + +def _get_status_line(status_code: int) -> bytes: + try: + phrase = http.HTTPStatus(status_code).phrase.encode() + except ValueError: + phrase = b"" + return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"]) + + +STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)} + + +class HttpToolsProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.access_logger = logging.getLogger("uvicorn.access") + self.access_log = self.access_logger.hasHandlers() + self.parser = httptools.HttpRequestParser(self) + self.ws_protocol_class = config.ws_protocol_class + self.root_path = config.root_path + self.limit_concurrency = config.limit_concurrency + self.app_state = app_state + + # Timeouts + self.timeout_keep_alive_task: TimerHandle | None = None + self.timeout_keep_alive = config.timeout_keep_alive + + # Global state + self.server_state = server_state + self.connections = server_state.connections + self.tasks = server_state.tasks + + # Per-connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.flow: FlowControl = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + self.scheme: Literal["http", "https"] | None = None + self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque() + + # Per-request state + self.scope: HTTPScope = None # type: ignore[assignment] + self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] + self.expect_100_continue = False + self.cycle: RequestResponseCycle = None # type: ignore[assignment] + + # Protocol interface + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: + self.connections.add(self) + + self.transport = transport + self.flow = FlowControl(transport) + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "https" if is_ssl(transport) else "http" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) + + def connection_lost(self, exc: Exception | None) -> None: + self.connections.discard(self) + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix) + + if self.cycle and not self.cycle.response_complete: + self.cycle.disconnected = True + if self.cycle is not None: + self.cycle.message_event.set() + if self.flow is not None: + self.flow.resume_writing() + if exc is None: + self.transport.close() + self._unset_keepalive_if_required() + + self.parser = None + + def eof_received(self) -> None: + pass + + def _unset_keepalive_if_required(self) -> None: + if self.timeout_keep_alive_task is not None: + self.timeout_keep_alive_task.cancel() + self.timeout_keep_alive_task = None + + def _get_upgrade(self) -> bytes | None: + connection = [] + upgrade = None + for name, value in self.headers: + if name == b"connection": + connection = [token.lower().strip() for token in value.split(b",")] + if name == b"upgrade": + upgrade = value.lower() + if b"upgrade" in connection: + return upgrade + return None + + def _should_upgrade_to_ws(self, upgrade: bytes | None) -> bool: + if upgrade == b"websocket" and self.ws_protocol_class is not None: + return True + if self.config.ws == "auto": + msg = "Unsupported upgrade request." + self.logger.warning(msg) + msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501 + self.logger.warning(msg) + return False + + def _should_upgrade(self) -> bool: + upgrade = self._get_upgrade() + return self._should_upgrade_to_ws(upgrade) + + def data_received(self, data: bytes) -> None: + self._unset_keepalive_if_required() + + try: + self.parser.feed_data(data) + except httptools.HttpParserError: + msg = "Invalid HTTP request received." + self.logger.warning(msg) + self.send_400_response(msg) + return + except httptools.HttpParserUpgrade: + upgrade = self._get_upgrade() + if self._should_upgrade_to_ws(upgrade): + self.handle_websocket_upgrade() + + def handle_websocket_upgrade(self) -> None: + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix) + + self.connections.discard(self) + method = self.scope["method"].encode() + output = [method, b" ", self.url, b" HTTP/1.1\r\n"] + for name, value in self.scope["headers"]: + output += [name, b": ", value, b"\r\n"] + output.append(b"\r\n") + protocol = self.ws_protocol_class( # type: ignore[call-arg, misc] + config=self.config, + server_state=self.server_state, + app_state=self.app_state, + ) + protocol.connection_made(self.transport) + protocol.data_received(b"".join(output)) + self.transport.set_protocol(protocol) + + def send_400_response(self, msg: str) -> None: + content = [STATUS_LINE[400]] + for name, value in self.server_state.default_headers: + content.extend([name, b": ", value, b"\r\n"]) + content.extend( + [ + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg.encode("ascii"), + ] + ) + self.transport.write(b"".join(content)) + self.transport.close() + + def on_message_begin(self) -> None: + self.url = b"" + self.expect_100_continue = False + self.headers = [] + self.scope = { # type: ignore[typeddict-item] + "type": "http", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"}, + "http_version": "1.1", + "server": self.server, + "client": self.client, + "scheme": self.scheme, # type: ignore[typeddict-item] + "root_path": self.root_path, + "headers": self.headers, + "state": self.app_state.copy(), + } + + # Parser callbacks + def on_url(self, url: bytes) -> None: + self.url += url + + def on_header(self, name: bytes, value: bytes) -> None: + name = name.lower() + if name == b"expect" and value.lower() == b"100-continue": + self.expect_100_continue = True + self.headers.append((name, value)) + + def on_headers_complete(self) -> None: + http_version = self.parser.get_http_version() + method = self.parser.get_method() + self.scope["method"] = method.decode("ascii") + if http_version != "1.1": + self.scope["http_version"] = http_version + if self.parser.should_upgrade() and self._should_upgrade(): + return + parsed_url = httptools.parse_url(self.url) + raw_path = parsed_url.path + path = raw_path.decode("ascii") + if "%" in path: + path = urllib.parse.unquote(path) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path + self.scope["path"] = full_path + self.scope["raw_path"] = full_raw_path + self.scope["query_string"] = parsed_url.query or b"" + + # Handle 503 responses when 'limit_concurrency' is exceeded. + if self.limit_concurrency is not None and ( + len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency + ): + app = service_unavailable + message = "Exceeded concurrency limit." + self.logger.warning(message) + else: + app = self.app + + existing_cycle = self.cycle + self.cycle = RequestResponseCycle( + scope=self.scope, + transport=self.transport, + flow=self.flow, + logger=self.logger, + access_logger=self.access_logger, + access_log=self.access_log, + default_headers=self.server_state.default_headers, + message_event=asyncio.Event(), + expect_100_continue=self.expect_100_continue, + keep_alive=http_version != "1.0", + on_response=self.on_response_complete, + ) + if existing_cycle is None or existing_cycle.response_complete: + # Standard case - start processing the request. + task = self.loop.create_task(self.cycle.run_asgi(app)) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + else: + # Pipelined HTTP requests need to be queued up. + self.flow.pause_reading() + self.pipeline.appendleft((self.cycle, app)) + + def on_body(self, body: bytes) -> None: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: + return + self.cycle.body += body + if len(self.cycle.body) > HIGH_WATER_LIMIT: + self.flow.pause_reading() + self.cycle.message_event.set() + + def on_message_complete(self) -> None: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: + return + self.cycle.more_body = False + self.cycle.message_event.set() + + def on_response_complete(self) -> None: + # Callback for pipelined HTTP requests to be started. + self.server_state.total_requests += 1 + + if self.transport.is_closing(): + return + + self._unset_keepalive_if_required() + + # Unpause data reads if needed. + self.flow.resume_reading() + + # Unblock any pipelined events. If there are none, arm the + # Keep-Alive timeout instead. + if self.pipeline: + cycle, app = self.pipeline.pop() + task = self.loop.create_task(cycle.run_asgi(app)) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + else: + self.timeout_keep_alive_task = self.loop.call_later( + self.timeout_keep_alive, self.timeout_keep_alive_handler + ) + + def shutdown(self) -> None: + """ + Called by the server to commence a graceful shutdown. + """ + if self.cycle is None or self.cycle.response_complete: + self.transport.close() + else: + self.cycle.keep_alive = False + + def pause_writing(self) -> None: + """ + Called by the transport when the write buffer exceeds the high water mark. + """ + self.flow.pause_writing() + + def resume_writing(self) -> None: + """ + Called by the transport when the write buffer drops below the low water mark. + """ + self.flow.resume_writing() + + def timeout_keep_alive_handler(self) -> None: + """ + Called on a keep-alive connection if no new data is received after a short + delay. + """ + if not self.transport.is_closing(): + self.transport.close() + + +class RequestResponseCycle: + def __init__( + self, + scope: HTTPScope, + transport: asyncio.Transport, + flow: FlowControl, + logger: logging.Logger, + access_logger: logging.Logger, + access_log: bool, + default_headers: list[tuple[bytes, bytes]], + message_event: asyncio.Event, + expect_100_continue: bool, + keep_alive: bool, + on_response: Callable[..., None], + ): + self.scope = scope + self.transport = transport + self.flow = flow + self.logger = logger + self.access_logger = access_logger + self.access_log = access_log + self.default_headers = default_headers + self.message_event = message_event + self.on_response = on_response + + # Connection state + self.disconnected = False + self.keep_alive = keep_alive + self.waiting_for_100_continue = expect_100_continue + + # Request state + self.body = b"" + self.more_body = True + + # Response state + self.response_started = False + self.response_complete = False + self.chunked_encoding: bool | None = None + self.expected_content_length = 0 + + # ASGI exception wrapper + async def run_asgi(self, app: ASGI3Application) -> None: + try: + result = await app( # type: ignore[func-returns-value] + self.scope, self.receive, self.send + ) + except BaseException as exc: + msg = "Exception in ASGI application\n" + self.logger.error(msg, exc_info=exc) + if not self.response_started: + await self.send_500_response() + else: + self.transport.close() + else: + if result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + elif not self.response_started and not self.disconnected: + msg = "ASGI callable returned without starting response." + self.logger.error(msg) + await self.send_500_response() + elif not self.response_complete and not self.disconnected: + msg = "ASGI callable returned without completing response." + self.logger.error(msg) + self.transport.close() + finally: + self.on_response = lambda: None + + async def send_500_response(self) -> None: + response_start_event: HTTPResponseStartEvent = { + "type": "http.response.start", + "status": 500, + "headers": [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ], + } + await self.send(response_start_event) + response_body_event: HTTPResponseBodyEvent = { + "type": "http.response.body", + "body": b"Internal Server Error", + "more_body": False, + } + await self.send(response_body_event) + + # ASGI interface + async def send(self, message: ASGISendEvent) -> None: + message_type = message["type"] + + if self.flow.write_paused and not self.disconnected: + await self.flow.drain() + + if self.disconnected: + return + + if not self.response_started: + # Sending response status line and headers + if message_type != "http.response.start": + msg = "Expected ASGI message 'http.response.start', but got '%s'." + raise RuntimeError(msg % message_type) + message = cast("HTTPResponseStartEvent", message) + + self.response_started = True + self.waiting_for_100_continue = False + + status_code = message["status"] + headers = self.default_headers + list(message.get("headers", [])) + + if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers: + headers = headers + [CLOSE_HEADER] + + if self.access_log: + self.access_logger.info( + '%s - "%s %s HTTP/%s" %d', + get_client_addr(self.scope), + self.scope["method"], + get_path_with_query_string(self.scope), + self.scope["http_version"], + status_code, + ) + + # Write response status line and headers + content = [STATUS_LINE[status_code]] + + for name, value in headers: + if HEADER_RE.search(name): + raise RuntimeError("Invalid HTTP header name.") + if HEADER_VALUE_RE.search(value): + raise RuntimeError("Invalid HTTP header value.") + + name = name.lower() + if name == b"content-length" and self.chunked_encoding is None: + self.expected_content_length = int(value.decode()) + self.chunked_encoding = False + elif name == b"transfer-encoding" and value.lower() == b"chunked": + self.expected_content_length = 0 + self.chunked_encoding = True + elif name == b"connection" and value.lower() == b"close": + self.keep_alive = False + content.extend([name, b": ", value, b"\r\n"]) + + if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304): + # Neither content-length nor transfer-encoding specified + self.chunked_encoding = True + content.append(b"transfer-encoding: chunked\r\n") + + content.append(b"\r\n") + self.transport.write(b"".join(content)) + + elif not self.response_complete: + # Sending response body + if message_type != "http.response.body": + msg = "Expected ASGI message 'http.response.body', but got '%s'." + raise RuntimeError(msg % message_type) + + body = cast(bytes, message.get("body", b"")) + more_body = message.get("more_body", False) + + # Write response body + if self.scope["method"] == "HEAD": + self.expected_content_length = 0 + elif self.chunked_encoding: + if body: + content = [b"%x\r\n" % len(body), body, b"\r\n"] + else: + content = [] + if not more_body: + content.append(b"0\r\n\r\n") + self.transport.write(b"".join(content)) + else: + num_bytes = len(body) + if num_bytes > self.expected_content_length: + raise RuntimeError("Response content longer than Content-Length") + else: + self.expected_content_length -= num_bytes + self.transport.write(body) + + # Handle response completion + if not more_body: + if self.expected_content_length != 0: + raise RuntimeError("Response content shorter than Content-Length") + self.response_complete = True + self.message_event.set() + if not self.keep_alive: + self.transport.close() + self.on_response() + + else: + # Response already sent + msg = "Unexpected ASGI message '%s' sent, after response already completed." + raise RuntimeError(msg % message_type) + + async def receive(self) -> ASGIReceiveEvent: + if self.waiting_for_100_continue and not self.transport.is_closing(): + self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + self.waiting_for_100_continue = False + + if not self.disconnected and not self.response_complete: + self.flow.resume_reading() + await self.message_event.wait() + self.message_event.clear() + + if self.disconnected or self.response_complete: + return {"type": "http.disconnect"} + message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body} + self.body = b"" + return message diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py new file mode 100644 index 0000000..4e65806 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import asyncio +import urllib.parse + +from uvicorn._types import WWWScope + + +class ClientDisconnected(IOError): + ... + + +def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None: + socket_info = transport.get_extra_info("socket") + if socket_info is not None: + try: + info = socket_info.getpeername() + return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None + except OSError: # pragma: no cover + # This case appears to inconsistently occur with uvloop + # bound to a unix domain socket. + return None + + info = transport.get_extra_info("peername") + if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: + return (str(info[0]), int(info[1])) + return None + + +def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None: + socket_info = transport.get_extra_info("socket") + if socket_info is not None: + info = socket_info.getsockname() + + return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None + info = transport.get_extra_info("sockname") + if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: + return (str(info[0]), int(info[1])) + return None + + +def is_ssl(transport: asyncio.Transport) -> bool: + return bool(transport.get_extra_info("sslcontext")) + + +def get_client_addr(scope: WWWScope) -> str: + client = scope.get("client") + if not client: + return "" + return "%s:%d" % client + + +def get_path_with_query_string(scope: WWWScope) -> str: + path_with_query_string = urllib.parse.quote(scope["path"]) + if scope["query_string"]: + path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii")) + return path_with_query_string diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..d216ab9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..e8c06fa --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..334a441 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..3b1911e --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py new file mode 100644 index 0000000..08fd136 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import asyncio +import typing + +AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None +try: + import websockets # noqa +except ImportError: # pragma: no cover + try: + import wsproto # noqa + except ImportError: + AutoWebSocketsProtocol = None + else: + from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + + AutoWebSocketsProtocol = WSProtocol +else: + from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + + AutoWebSocketsProtocol = WebSocketProtocol diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py new file mode 100644 index 0000000..6d098d5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import asyncio +import http +import logging +from typing import Any, Literal, Optional, Sequence, cast +from urllib.parse import unquote + +import websockets +from websockets.datastructures import Headers +from websockets.exceptions import ConnectionClosed +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory +from websockets.legacy.server import HTTPResponse +from websockets.server import WebSocketServerProtocol +from websockets.typing import Subprotocol + +from uvicorn._types import ( + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, + WebSocketScope, + WebSocketSendEvent, +) +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.utils import ( + ClientDisconnected, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState + + +class Server: + closing = False + + def register(self, ws: WebSocketServerProtocol) -> None: + pass + + def unregister(self, ws: WebSocketServerProtocol) -> None: + pass + + def is_serving(self) -> bool: + return not self.closing + + +class WebSocketProtocol(WebSocketServerProtocol): + extra_headers: list[tuple[str, str]] + + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, + ): + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.root_path = config.root_path + self.app_state = app_state + + # Shared server state + self.connections = server_state.connections + self.tasks = server_state.tasks + + # Connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + + # Connection events + self.scope: WebSocketScope + self.handshake_started_event = asyncio.Event() + self.handshake_completed_event = asyncio.Event() + self.closed_event = asyncio.Event() + self.initial_response: HTTPResponse | None = None + self.connect_sent = False + self.lost_connection_before_handshake = False + self.accepted_subprotocol: Subprotocol | None = None + + self.ws_server: Server = Server() # type: ignore[assignment] + + extensions = [] + if self.config.ws_per_message_deflate: + extensions.append(ServerPerMessageDeflateFactory()) + + super().__init__( + ws_handler=self.ws_handler, + ws_server=self.ws_server, # type: ignore[arg-type] + max_size=self.config.ws_max_size, + max_queue=self.config.ws_max_queue, + ping_interval=self.config.ws_ping_interval, + ping_timeout=self.config.ws_ping_timeout, + extensions=extensions, + logger=logging.getLogger("uvicorn.error"), + ) + self.server_header = None + self.extra_headers = [ + (name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers + ] + + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: + self.connections.add(self) + self.transport = transport + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "wss" if is_ssl(transport) else "ws" + + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) + + super().connection_made(transport) + + def connection_lost(self, exc: Exception | None) -> None: + self.connections.remove(self) + + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + + self.lost_connection_before_handshake = not self.handshake_completed_event.is_set() + self.handshake_completed_event.set() + super().connection_lost(exc) + if exc is None: + self.transport.close() + + def shutdown(self) -> None: + self.ws_server.closing = True + if self.handshake_completed_event.is_set(): + self.fail_connection(1012) + else: + self.send_500_response() + self.transport.close() + + def on_task_complete(self, task: asyncio.Task) -> None: + self.tasks.discard(task) + + async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None: + """ + This hook is called to determine if the websocket should return + an HTTP response and close. + + Our behavior here is to start the ASGI application, and then wait + for either `accept` or `close` in order to determine if we should + close the connection. + """ + path_portion, _, query_string = path.partition("?") + + websockets.legacy.handshake.check_request(headers) + + subprotocols = [] + for header in headers.get_all("Sec-WebSocket-Protocol"): + subprotocols.extend([token.strip() for token in header.split(",")]) + + asgi_headers = [ + (name.encode("ascii"), value.encode("ascii", errors="surrogateescape")) + for name, value in headers.raw_items() + ] + path = unquote(path_portion) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii") + + self.scope = { + "type": "websocket", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"}, + "http_version": "1.1", + "scheme": self.scheme, + "server": self.server, + "client": self.client, + "root_path": self.root_path, + "path": full_path, + "raw_path": full_raw_path, + "query_string": query_string.encode("ascii"), + "headers": asgi_headers, + "subprotocols": subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + task = self.loop.create_task(self.run_asgi()) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) + await self.handshake_started_event.wait() + return self.initial_response + + def process_subprotocol( + self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None + ) -> Subprotocol | None: + """ + We override the standard 'process_subprotocol' behavior here so that + we return whatever subprotocol is sent in the 'accept' message. + """ + return self.accepted_subprotocol + + def send_500_response(self) -> None: + msg = b"Internal Server Error" + content = [ + b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg, + ] + self.transport.write(b"".join(content)) + # Allow handler task to terminate cleanly, as websockets doesn't cancel it by + # itself (see https://github.com/encode/uvicorn/issues/920) + self.handshake_started_event.set() + + async def ws_handler( # type: ignore[override] + self, protocol: WebSocketServerProtocol, path: str + ) -> Any: + """ + This is the main handler function for the 'websockets' implementation + to call into. We just wait for close then return, and instead allow + 'send' and 'receive' events to drive the flow. + """ + self.handshake_completed_event.set() + await self.wait_closed() + + async def run_asgi(self) -> None: + """ + Wrapper around the ASGI callable, handling exceptions and unexpected + termination states. + """ + try: + result = await self.app(self.scope, self.asgi_receive, self.asgi_send) + except ClientDisconnected: + self.closed_event.set() + self.transport.close() + except BaseException as exc: + self.closed_event.set() + msg = "Exception in ASGI application\n" + self.logger.error(msg, exc_info=exc) + if not self.handshake_started_event.is_set(): + self.send_500_response() + else: + await self.handshake_completed_event.wait() + self.transport.close() + else: + self.closed_event.set() + if not self.handshake_started_event.is_set(): + msg = "ASGI callable returned without sending handshake." + self.logger.error(msg) + self.send_500_response() + self.transport.close() + elif result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + await self.handshake_completed_event.wait() + self.transport.close() + + async def asgi_send(self, message: ASGISendEvent) -> None: + message_type = message["type"] + + if not self.handshake_started_event.is_set(): + if message_type == "websocket.accept": + message = cast("WebSocketAcceptEvent", message) + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.initial_response = None + self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol")) + if "headers" in message: + self.extra_headers.extend( + # ASGI spec requires bytes + # But for compatibility we need to convert it to strings + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in message["headers"] + ) + self.handshake_started_event.set() + + elif message_type == "websocket.close": + message = cast("WebSocketCloseEvent", message) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"") + self.handshake_started_event.set() + self.closed_event.set() + + elif message_type == "websocket.http.response.start": + message = cast("WebSocketResponseStartEvent", message) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + # websockets requires the status to be an enum. look it up. + status = http.HTTPStatus(message["status"]) + headers = [ + (name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", []) + ] + self.initial_response = (status, headers, b"") + self.handshake_started_event.set() + + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close', " + "or 'websocket.http.response.start' but got '%s'." + ) + raise RuntimeError(msg % message_type) + + elif not self.closed_event.is_set() and self.initial_response is None: + await self.handshake_completed_event.wait() + + try: + if message_type == "websocket.send": + message = cast("WebSocketSendEvent", message) + bytes_data = message.get("bytes") + text_data = message.get("text") + data = text_data if bytes_data is None else bytes_data + await self.send(data) # type: ignore[arg-type] + + elif message_type == "websocket.close": + message = cast("WebSocketCloseEvent", message) + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + await self.close(code, reason) + self.closed_event.set() + + else: + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." + raise RuntimeError(msg % message_type) + except ConnectionClosed as exc: + raise ClientDisconnected from exc + + elif self.initial_response is not None: + if message_type == "websocket.http.response.body": + message = cast("WebSocketResponseBodyEvent", message) + body = self.initial_response[2] + message["body"] + self.initial_response = self.initial_response[:2] + (body,) + if not message.get("more_body", False): + self.closed_event.set() + else: + msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." + raise RuntimeError(msg % message_type) + + else: + msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed." + raise RuntimeError(msg % message_type) + + async def asgi_receive( + self, + ) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent: + if not self.connect_sent: + self.connect_sent = True + return {"type": "websocket.connect"} + + await self.handshake_completed_event.wait() + + if self.lost_connection_before_handshake: + # If the handshake failed or the app closed before handshake completion, + # use 1006 Abnormal Closure. + return {"type": "websocket.disconnect", "code": 1006} + + if self.closed_event.is_set(): + return {"type": "websocket.disconnect", "code": 1005} + + try: + data = await self.recv() + except ConnectionClosed as exc: + self.closed_event.set() + if self.ws_server.closing: + return {"type": "websocket.disconnect", "code": 1012} + return {"type": "websocket.disconnect", "code": exc.code} + + if isinstance(data, str): + return {"type": "websocket.receive", "text": data} + return {"type": "websocket.receive", "bytes": data} diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py new file mode 100644 index 0000000..c926252 --- /dev/null +++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py @@ -0,0 +1,377 @@ +from __future__ import annotations + +import asyncio +import logging +import typing +from typing import Literal +from urllib.parse import unquote + +import wsproto +from wsproto import ConnectionType, events +from wsproto.connection import ConnectionState +from wsproto.extensions import Extension, PerMessageDeflate +from wsproto.utilities import LocalProtocolError, RemoteProtocolError + +from uvicorn._types import ( + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, + WebSocketScope, + WebSocketSendEvent, +) +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.utils import ( + ClientDisconnected, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState + + +class WSProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, typing.Any], + _loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.root_path = config.root_path + self.app_state = app_state + + # Shared server state + self.connections = server_state.connections + self.tasks = server_state.tasks + self.default_headers = server_state.default_headers + + # Connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + + # WebSocket state + self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue() + self.handshake_complete = False + self.close_sent = False + + # Rejection state + self.response_started = False + + self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER) + + self.read_paused = False + self.writable = asyncio.Event() + self.writable.set() + + # Buffers + self.bytes = b"" + self.text = "" + + # Protocol interface + + def connection_made( # type: ignore[override] + self, transport: asyncio.Transport + ) -> None: + self.connections.add(self) + self.transport = transport + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "wss" if is_ssl(transport) else "ws" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) + + def connection_lost(self, exc: Exception | None) -> None: + code = 1005 if self.handshake_complete else 1006 + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + self.connections.remove(self) + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + + self.handshake_complete = True + if exc is None: + self.transport.close() + + def eof_received(self) -> None: + pass + + def data_received(self, data: bytes) -> None: + try: + self.conn.receive_data(data) + except RemoteProtocolError as err: + # TODO: Remove `type: ignore` when wsproto fixes the type annotation. + self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501 + self.transport.close() + else: + self.handle_events() + + def handle_events(self) -> None: + for event in self.conn.events(): + if isinstance(event, events.Request): + self.handle_connect(event) + elif isinstance(event, events.TextMessage): + self.handle_text(event) + elif isinstance(event, events.BytesMessage): + self.handle_bytes(event) + elif isinstance(event, events.CloseConnection): + self.handle_close(event) + elif isinstance(event, events.Ping): + self.handle_ping(event) + + def pause_writing(self) -> None: + """ + Called by the transport when the write buffer exceeds the high water mark. + """ + self.writable.clear() + + def resume_writing(self) -> None: + """ + Called by the transport when the write buffer drops below the low water mark. + """ + self.writable.set() + + def shutdown(self) -> None: + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + output = self.conn.send(wsproto.events.CloseConnection(code=1012)) + self.transport.write(output) + else: + self.send_500_response() + self.transport.close() + + def on_task_complete(self, task: asyncio.Task) -> None: + self.tasks.discard(task) + + # Event handlers + + def handle_connect(self, event: events.Request) -> None: + headers = [(b"host", event.host.encode())] + headers += [(key.lower(), value) for key, value in event.extra_headers] + raw_path, _, query_string = event.target.partition("?") + path = unquote(raw_path) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii") + self.scope: WebSocketScope = { + "type": "websocket", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"}, + "http_version": "1.1", + "scheme": self.scheme, + "server": self.server, + "client": self.client, + "root_path": self.root_path, + "path": full_path, + "raw_path": full_raw_path, + "query_string": query_string.encode("ascii"), + "headers": headers, + "subprotocols": event.subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + self.queue.put_nowait({"type": "websocket.connect"}) + task = self.loop.create_task(self.run_asgi()) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) + + def handle_text(self, event: events.TextMessage) -> None: + self.text += event.data + if event.message_finished: + self.queue.put_nowait({"type": "websocket.receive", "text": self.text}) + self.text = "" + if not self.read_paused: + self.read_paused = True + self.transport.pause_reading() + + def handle_bytes(self, event: events.BytesMessage) -> None: + self.bytes += event.data + # todo: we may want to guard the size of self.bytes and self.text + if event.message_finished: + self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes}) + self.bytes = b"" + if not self.read_paused: + self.read_paused = True + self.transport.pause_reading() + + def handle_close(self, event: events.CloseConnection) -> None: + if self.conn.state == ConnectionState.REMOTE_CLOSING: + self.transport.write(self.conn.send(event.response())) + self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code}) + self.transport.close() + + def handle_ping(self, event: events.Ping) -> None: + self.transport.write(self.conn.send(event.response())) + + def send_500_response(self) -> None: + if self.response_started or self.handshake_complete: + return # we cannot send responses anymore + headers = [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"connection", b"close"), + ] + output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True)) + output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error")) + self.transport.write(output) + + async def run_asgi(self) -> None: + try: + result = await self.app(self.scope, self.receive, self.send) + except ClientDisconnected: + self.transport.close() + except BaseException: + self.logger.exception("Exception in ASGI application\n") + self.send_500_response() + self.transport.close() + else: + if not self.handshake_complete: + msg = "ASGI callable returned without completing handshake." + self.logger.error(msg) + self.send_500_response() + self.transport.close() + elif result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + + async def send(self, message: ASGISendEvent) -> None: + await self.writable.wait() + + message_type = message["type"] + + if not self.handshake_complete: + if message_type == "websocket.accept": + message = typing.cast(WebSocketAcceptEvent, message) + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + subprotocol = message.get("subprotocol") + extra_headers = self.default_headers + list(message.get("headers", [])) + extensions: list[Extension] = [] + if self.config.ws_per_message_deflate: + extensions.append(PerMessageDeflate()) + if not self.transport.is_closing(): + self.handshake_complete = True + output = self.conn.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extensions=extensions, + extra_headers=extra_headers, + ) + ) + self.transport.write(output) + + elif message_type == "websocket.close": + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.handshake_complete = True + self.close_sent = True + event = events.RejectConnection(status_code=403, headers=[]) + output = self.conn.send(event) + self.transport.write(output) + self.transport.close() + + elif message_type == "websocket.http.response.start": + message = typing.cast(WebSocketResponseStartEvent, message) + # ensure status code is in the valid range + if not (100 <= message["status"] < 600): + msg = "Invalid HTTP status code '%d' in response." + raise RuntimeError(msg % message["status"]) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + self.handshake_complete = True + event = events.RejectConnection( + status_code=message["status"], + headers=list(message["headers"]), + has_body=True, + ) + output = self.conn.send(event) + self.transport.write(output) + self.response_started = True + + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) + + elif not self.close_sent and not self.response_started: + try: + if message_type == "websocket.send": + message = typing.cast(WebSocketSendEvent, message) + bytes_data = message.get("bytes") + text_data = message.get("text") + data = text_data if bytes_data is None else bytes_data + output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore + if not self.transport.is_closing(): + self.transport.write(output) + + elif message_type == "websocket.close": + message = typing.cast(WebSocketCloseEvent, message) + self.close_sent = True + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason)) + if not self.transport.is_closing(): + self.transport.write(output) + self.transport.close() + + else: + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." + raise RuntimeError(msg % message_type) + except LocalProtocolError as exc: + raise ClientDisconnected from exc + elif self.response_started: + if message_type == "websocket.http.response.body": + message = typing.cast("WebSocketResponseBodyEvent", message) + body_finished = not message.get("more_body", False) + reject_data = events.RejectData(data=message["body"], body_finished=body_finished) + output = self.conn.send(reject_data) + self.transport.write(output) + + if body_finished: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.close_sent = True + self.transport.close() + + else: + msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." + raise RuntimeError(msg % message_type) + + else: + msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." + raise RuntimeError(msg % message_type) + + async def receive(self) -> WebSocketEvent: + message = await self.queue.get() + if self.read_paused and self.queue.empty(): + self.read_paused = False + self.transport.resume_reading() + return message |