summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/uvicorn/protocols
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/uvicorn/protocols
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/uvicorn/protocols')
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py0
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pycbin0 -> 201 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pycbin0 -> 3585 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py0
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pycbin0 -> 206 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pycbin0 -> 705 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pycbin0 -> 3410 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pycbin0 -> 27291 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pycbin0 -> 30083 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py15
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py64
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py547
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py575
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py57
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py0
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pycbin0 -> 212 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pycbin0 -> 924 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pycbin0 -> 22107 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pycbin0 -> 22222 bytes
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py21
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py388
-rw-r--r--venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py377
22 files changed, 2044 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__init__.py
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..5ff273e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000..da51e43
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/__pycache__/utils.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__init__.py
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..c56186b
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc
new file mode 100644
index 0000000..a598e8a
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/auto.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc
new file mode 100644
index 0000000..a925cb4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/flow_control.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc
new file mode 100644
index 0000000..2379bb6
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/h11_impl.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc
new file mode 100644
index 0000000..ba76184
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/__pycache__/httptools_impl.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py
new file mode 100644
index 0000000..a14bec1
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/auto.py
@@ -0,0 +1,15 @@
+from __future__ import annotations
+
+import asyncio
+
+AutoHTTPProtocol: type[asyncio.Protocol]
+try:
+ import httptools # noqa
+except ImportError: # pragma: no cover
+ from uvicorn.protocols.http.h11_impl import H11Protocol
+
+ AutoHTTPProtocol = H11Protocol
+else: # pragma: no cover
+ from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
+
+ AutoHTTPProtocol = HttpToolsProtocol
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py
new file mode 100644
index 0000000..893a26c
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/flow_control.py
@@ -0,0 +1,64 @@
+import asyncio
+
+from uvicorn._types import (
+ ASGIReceiveCallable,
+ ASGISendCallable,
+ HTTPResponseBodyEvent,
+ HTTPResponseStartEvent,
+ Scope,
+)
+
+CLOSE_HEADER = (b"connection", b"close")
+
+HIGH_WATER_LIMIT = 65536
+
+
+class FlowControl:
+ def __init__(self, transport: asyncio.Transport) -> None:
+ self._transport = transport
+ self.read_paused = False
+ self.write_paused = False
+ self._is_writable_event = asyncio.Event()
+ self._is_writable_event.set()
+
+ async def drain(self) -> None:
+ await self._is_writable_event.wait()
+
+ def pause_reading(self) -> None:
+ if not self.read_paused:
+ self.read_paused = True
+ self._transport.pause_reading()
+
+ def resume_reading(self) -> None:
+ if self.read_paused:
+ self.read_paused = False
+ self._transport.resume_reading()
+
+ def pause_writing(self) -> None:
+ if not self.write_paused:
+ self.write_paused = True
+ self._is_writable_event.clear()
+
+ def resume_writing(self) -> None:
+ if self.write_paused:
+ self.write_paused = False
+ self._is_writable_event.set()
+
+
+async def service_unavailable(scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable") -> None:
+ response_start: "HTTPResponseStartEvent" = {
+ "type": "http.response.start",
+ "status": 503,
+ "headers": [
+ (b"content-type", b"text/plain; charset=utf-8"),
+ (b"connection", b"close"),
+ ],
+ }
+ await send(response_start)
+
+ response_body: "HTTPResponseBodyEvent" = {
+ "type": "http.response.body",
+ "body": b"Service Unavailable",
+ "more_body": False,
+ }
+ await send(response_body)
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py
new file mode 100644
index 0000000..d0f2b2a
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/h11_impl.py
@@ -0,0 +1,547 @@
+from __future__ import annotations
+
+import asyncio
+import http
+import logging
+from typing import Any, Callable, Literal, cast
+from urllib.parse import unquote
+
+import h11
+from h11._connection import DEFAULT_MAX_INCOMPLETE_EVENT_SIZE
+
+from uvicorn._types import (
+ ASGI3Application,
+ ASGIReceiveEvent,
+ ASGISendEvent,
+ HTTPRequestEvent,
+ HTTPResponseBodyEvent,
+ HTTPResponseStartEvent,
+ HTTPScope,
+)
+from uvicorn.config import Config
+from uvicorn.logging import TRACE_LOG_LEVEL
+from uvicorn.protocols.http.flow_control import (
+ CLOSE_HEADER,
+ HIGH_WATER_LIMIT,
+ FlowControl,
+ service_unavailable,
+)
+from uvicorn.protocols.utils import (
+ get_client_addr,
+ get_local_addr,
+ get_path_with_query_string,
+ get_remote_addr,
+ is_ssl,
+)
+from uvicorn.server import ServerState
+
+
+def _get_status_phrase(status_code: int) -> bytes:
+ try:
+ return http.HTTPStatus(status_code).phrase.encode()
+ except ValueError:
+ return b""
+
+
+STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)}
+
+
+class H11Protocol(asyncio.Protocol):
+ def __init__(
+ self,
+ config: Config,
+ server_state: ServerState,
+ app_state: dict[str, Any],
+ _loop: asyncio.AbstractEventLoop | None = None,
+ ) -> None:
+ if not config.loaded:
+ config.load()
+
+ self.config = config
+ self.app = config.loaded_app
+ self.loop = _loop or asyncio.get_event_loop()
+ self.logger = logging.getLogger("uvicorn.error")
+ self.access_logger = logging.getLogger("uvicorn.access")
+ self.access_log = self.access_logger.hasHandlers()
+ self.conn = h11.Connection(
+ h11.SERVER,
+ config.h11_max_incomplete_event_size
+ if config.h11_max_incomplete_event_size is not None
+ else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE,
+ )
+ self.ws_protocol_class = config.ws_protocol_class
+ self.root_path = config.root_path
+ self.limit_concurrency = config.limit_concurrency
+ self.app_state = app_state
+
+ # Timeouts
+ self.timeout_keep_alive_task: asyncio.TimerHandle | None = None
+ self.timeout_keep_alive = config.timeout_keep_alive
+
+ # Shared server state
+ self.server_state = server_state
+ self.connections = server_state.connections
+ self.tasks = server_state.tasks
+
+ # Per-connection state
+ self.transport: asyncio.Transport = None # type: ignore[assignment]
+ self.flow: FlowControl = None # type: ignore[assignment]
+ self.server: tuple[str, int] | None = None
+ self.client: tuple[str, int] | None = None
+ self.scheme: Literal["http", "https"] | None = None
+
+ # Per-request state
+ self.scope: HTTPScope = None # type: ignore[assignment]
+ self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
+ self.cycle: RequestResponseCycle = None # type: ignore[assignment]
+
+ # Protocol interface
+ def connection_made( # type: ignore[override]
+ self, transport: asyncio.Transport
+ ) -> None:
+ self.connections.add(self)
+
+ self.transport = transport
+ self.flow = FlowControl(transport)
+ self.server = get_local_addr(transport)
+ self.client = get_remote_addr(transport)
+ self.scheme = "https" if is_ssl(transport) else "http"
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
+
+ def connection_lost(self, exc: Exception | None) -> None:
+ self.connections.discard(self)
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
+
+ if self.cycle and not self.cycle.response_complete:
+ self.cycle.disconnected = True
+ if self.conn.our_state != h11.ERROR:
+ event = h11.ConnectionClosed()
+ try:
+ self.conn.send(event)
+ except h11.LocalProtocolError:
+ # Premature client disconnect
+ pass
+
+ if self.cycle is not None:
+ self.cycle.message_event.set()
+ if self.flow is not None:
+ self.flow.resume_writing()
+ if exc is None:
+ self.transport.close()
+ self._unset_keepalive_if_required()
+
+ def eof_received(self) -> None:
+ pass
+
+ def _unset_keepalive_if_required(self) -> None:
+ if self.timeout_keep_alive_task is not None:
+ self.timeout_keep_alive_task.cancel()
+ self.timeout_keep_alive_task = None
+
+ def _get_upgrade(self) -> bytes | None:
+ connection = []
+ upgrade = None
+ for name, value in self.headers:
+ if name == b"connection":
+ connection = [token.lower().strip() for token in value.split(b",")]
+ if name == b"upgrade":
+ upgrade = value.lower()
+ if b"upgrade" in connection:
+ return upgrade
+ return None
+
+ def _should_upgrade_to_ws(self) -> bool:
+ if self.ws_protocol_class is None:
+ if self.config.ws == "auto":
+ msg = "Unsupported upgrade request."
+ self.logger.warning(msg)
+ msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
+ self.logger.warning(msg)
+ return False
+ return True
+
+ def data_received(self, data: bytes) -> None:
+ self._unset_keepalive_if_required()
+
+ self.conn.receive_data(data)
+ self.handle_events()
+
+ def handle_events(self) -> None:
+ while True:
+ try:
+ event = self.conn.next_event()
+ except h11.RemoteProtocolError:
+ msg = "Invalid HTTP request received."
+ self.logger.warning(msg)
+ self.send_400_response(msg)
+ return
+
+ if event is h11.NEED_DATA:
+ break
+
+ elif event is h11.PAUSED:
+ # This case can occur in HTTP pipelining, so we need to
+ # stop reading any more data, and ensure that at the end
+ # of the active request/response cycle we handle any
+ # events that have been buffered up.
+ self.flow.pause_reading()
+ break
+
+ elif isinstance(event, h11.Request):
+ self.headers = [(key.lower(), value) for key, value in event.headers]
+ raw_path, _, query_string = event.target.partition(b"?")
+ path = unquote(raw_path.decode("ascii"))
+ full_path = self.root_path + path
+ full_raw_path = self.root_path.encode("ascii") + raw_path
+ self.scope = {
+ "type": "http",
+ "asgi": {
+ "version": self.config.asgi_version,
+ "spec_version": "2.4",
+ },
+ "http_version": event.http_version.decode("ascii"),
+ "server": self.server,
+ "client": self.client,
+ "scheme": self.scheme, # type: ignore[typeddict-item]
+ "method": event.method.decode("ascii"),
+ "root_path": self.root_path,
+ "path": full_path,
+ "raw_path": full_raw_path,
+ "query_string": query_string,
+ "headers": self.headers,
+ "state": self.app_state.copy(),
+ }
+
+ upgrade = self._get_upgrade()
+ if upgrade == b"websocket" and self._should_upgrade_to_ws():
+ self.handle_websocket_upgrade(event)
+ return
+
+ # Handle 503 responses when 'limit_concurrency' is exceeded.
+ if self.limit_concurrency is not None and (
+ len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
+ ):
+ app = service_unavailable
+ message = "Exceeded concurrency limit."
+ self.logger.warning(message)
+ else:
+ app = self.app
+
+ # When starting to process a request, disable the keep-alive
+ # timeout. Normally we disable this when receiving data from
+ # client and set back when finishing processing its request.
+ # However, for pipelined requests processing finishes after
+ # already receiving the next request and thus the timer may
+ # be set here, which we don't want.
+ self._unset_keepalive_if_required()
+
+ self.cycle = RequestResponseCycle(
+ scope=self.scope,
+ conn=self.conn,
+ transport=self.transport,
+ flow=self.flow,
+ logger=self.logger,
+ access_logger=self.access_logger,
+ access_log=self.access_log,
+ default_headers=self.server_state.default_headers,
+ message_event=asyncio.Event(),
+ on_response=self.on_response_complete,
+ )
+ task = self.loop.create_task(self.cycle.run_asgi(app))
+ task.add_done_callback(self.tasks.discard)
+ self.tasks.add(task)
+
+ elif isinstance(event, h11.Data):
+ if self.conn.our_state is h11.DONE:
+ continue
+ self.cycle.body += event.data
+ if len(self.cycle.body) > HIGH_WATER_LIMIT:
+ self.flow.pause_reading()
+ self.cycle.message_event.set()
+
+ elif isinstance(event, h11.EndOfMessage):
+ if self.conn.our_state is h11.DONE:
+ self.transport.resume_reading()
+ self.conn.start_next_cycle()
+ continue
+ self.cycle.more_body = False
+ self.cycle.message_event.set()
+
+ def handle_websocket_upgrade(self, event: h11.Request) -> None:
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
+
+ self.connections.discard(self)
+ output = [event.method, b" ", event.target, b" HTTP/1.1\r\n"]
+ for name, value in self.headers:
+ output += [name, b": ", value, b"\r\n"]
+ output.append(b"\r\n")
+ protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
+ config=self.config,
+ server_state=self.server_state,
+ app_state=self.app_state,
+ )
+ protocol.connection_made(self.transport)
+ protocol.data_received(b"".join(output))
+ self.transport.set_protocol(protocol)
+
+ def send_400_response(self, msg: str) -> None:
+ reason = STATUS_PHRASES[400]
+ headers: list[tuple[bytes, bytes]] = [
+ (b"content-type", b"text/plain; charset=utf-8"),
+ (b"connection", b"close"),
+ ]
+ event = h11.Response(status_code=400, headers=headers, reason=reason)
+ output = self.conn.send(event)
+ self.transport.write(output)
+
+ output = self.conn.send(event=h11.Data(data=msg.encode("ascii")))
+ self.transport.write(output)
+
+ output = self.conn.send(event=h11.EndOfMessage())
+ self.transport.write(output)
+
+ self.transport.close()
+
+ def on_response_complete(self) -> None:
+ self.server_state.total_requests += 1
+
+ if self.transport.is_closing():
+ return
+
+ # Set a short Keep-Alive timeout.
+ self._unset_keepalive_if_required()
+
+ self.timeout_keep_alive_task = self.loop.call_later(self.timeout_keep_alive, self.timeout_keep_alive_handler)
+
+ # Unpause data reads if needed.
+ self.flow.resume_reading()
+
+ # Unblock any pipelined events.
+ if self.conn.our_state is h11.DONE and self.conn.their_state is h11.DONE:
+ self.conn.start_next_cycle()
+ self.handle_events()
+
+ def shutdown(self) -> None:
+ """
+ Called by the server to commence a graceful shutdown.
+ """
+ if self.cycle is None or self.cycle.response_complete:
+ event = h11.ConnectionClosed()
+ self.conn.send(event)
+ self.transport.close()
+ else:
+ self.cycle.keep_alive = False
+
+ def pause_writing(self) -> None:
+ """
+ Called by the transport when the write buffer exceeds the high water mark.
+ """
+ self.flow.pause_writing()
+
+ def resume_writing(self) -> None:
+ """
+ Called by the transport when the write buffer drops below the low water mark.
+ """
+ self.flow.resume_writing()
+
+ def timeout_keep_alive_handler(self) -> None:
+ """
+ Called on a keep-alive connection if no new data is received after a short
+ delay.
+ """
+ if not self.transport.is_closing():
+ event = h11.ConnectionClosed()
+ self.conn.send(event)
+ self.transport.close()
+
+
+class RequestResponseCycle:
+ def __init__(
+ self,
+ scope: HTTPScope,
+ conn: h11.Connection,
+ transport: asyncio.Transport,
+ flow: FlowControl,
+ logger: logging.Logger,
+ access_logger: logging.Logger,
+ access_log: bool,
+ default_headers: list[tuple[bytes, bytes]],
+ message_event: asyncio.Event,
+ on_response: Callable[..., None],
+ ) -> None:
+ self.scope = scope
+ self.conn = conn
+ self.transport = transport
+ self.flow = flow
+ self.logger = logger
+ self.access_logger = access_logger
+ self.access_log = access_log
+ self.default_headers = default_headers
+ self.message_event = message_event
+ self.on_response = on_response
+
+ # Connection state
+ self.disconnected = False
+ self.keep_alive = True
+ self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
+
+ # Request state
+ self.body = b""
+ self.more_body = True
+
+ # Response state
+ self.response_started = False
+ self.response_complete = False
+
+ # ASGI exception wrapper
+ async def run_asgi(self, app: ASGI3Application) -> None:
+ try:
+ result = await app( # type: ignore[func-returns-value]
+ self.scope, self.receive, self.send
+ )
+ except BaseException as exc:
+ msg = "Exception in ASGI application\n"
+ self.logger.error(msg, exc_info=exc)
+ if not self.response_started:
+ await self.send_500_response()
+ else:
+ self.transport.close()
+ else:
+ if result is not None:
+ msg = "ASGI callable should return None, but returned '%s'."
+ self.logger.error(msg, result)
+ self.transport.close()
+ elif not self.response_started and not self.disconnected:
+ msg = "ASGI callable returned without starting response."
+ self.logger.error(msg)
+ await self.send_500_response()
+ elif not self.response_complete and not self.disconnected:
+ msg = "ASGI callable returned without completing response."
+ self.logger.error(msg)
+ self.transport.close()
+ finally:
+ self.on_response = lambda: None
+
+ async def send_500_response(self) -> None:
+ response_start_event: HTTPResponseStartEvent = {
+ "type": "http.response.start",
+ "status": 500,
+ "headers": [
+ (b"content-type", b"text/plain; charset=utf-8"),
+ (b"connection", b"close"),
+ ],
+ }
+ await self.send(response_start_event)
+ response_body_event: HTTPResponseBodyEvent = {
+ "type": "http.response.body",
+ "body": b"Internal Server Error",
+ "more_body": False,
+ }
+ await self.send(response_body_event)
+
+ # ASGI interface
+ async def send(self, message: ASGISendEvent) -> None:
+ message_type = message["type"]
+
+ if self.flow.write_paused and not self.disconnected:
+ await self.flow.drain()
+
+ if self.disconnected:
+ return
+
+ if not self.response_started:
+ # Sending response status line and headers
+ if message_type != "http.response.start":
+ msg = "Expected ASGI message 'http.response.start', but got '%s'."
+ raise RuntimeError(msg % message_type)
+ message = cast("HTTPResponseStartEvent", message)
+
+ self.response_started = True
+ self.waiting_for_100_continue = False
+
+ status = message["status"]
+ headers = self.default_headers + list(message.get("headers", []))
+
+ if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
+ headers = headers + [CLOSE_HEADER]
+
+ if self.access_log:
+ self.access_logger.info(
+ '%s - "%s %s HTTP/%s" %d',
+ get_client_addr(self.scope),
+ self.scope["method"],
+ get_path_with_query_string(self.scope),
+ self.scope["http_version"],
+ status,
+ )
+
+ # Write response status line and headers
+ reason = STATUS_PHRASES[status]
+ response = h11.Response(status_code=status, headers=headers, reason=reason)
+ output = self.conn.send(event=response)
+ self.transport.write(output)
+
+ elif not self.response_complete:
+ # Sending response body
+ if message_type != "http.response.body":
+ msg = "Expected ASGI message 'http.response.body', but got '%s'."
+ raise RuntimeError(msg % message_type)
+ message = cast("HTTPResponseBodyEvent", message)
+
+ body = message.get("body", b"")
+ more_body = message.get("more_body", False)
+
+ # Write response body
+ data = b"" if self.scope["method"] == "HEAD" else body
+ output = self.conn.send(event=h11.Data(data=data))
+ self.transport.write(output)
+
+ # Handle response completion
+ if not more_body:
+ self.response_complete = True
+ self.message_event.set()
+ output = self.conn.send(event=h11.EndOfMessage())
+ self.transport.write(output)
+
+ else:
+ # Response already sent
+ msg = "Unexpected ASGI message '%s' sent, after response already completed."
+ raise RuntimeError(msg % message_type)
+
+ if self.response_complete:
+ if self.conn.our_state is h11.MUST_CLOSE or not self.keep_alive:
+ self.conn.send(event=h11.ConnectionClosed())
+ self.transport.close()
+ self.on_response()
+
+ async def receive(self) -> ASGIReceiveEvent:
+ if self.waiting_for_100_continue and not self.transport.is_closing():
+ headers: list[tuple[str, str]] = []
+ event = h11.InformationalResponse(status_code=100, headers=headers, reason="Continue")
+ output = self.conn.send(event=event)
+ self.transport.write(output)
+ self.waiting_for_100_continue = False
+
+ if not self.disconnected and not self.response_complete:
+ self.flow.resume_reading()
+ await self.message_event.wait()
+ self.message_event.clear()
+
+ if self.disconnected or self.response_complete:
+ return {"type": "http.disconnect"}
+
+ message: HTTPRequestEvent = {
+ "type": "http.request",
+ "body": self.body,
+ "more_body": self.more_body,
+ }
+ self.body = b""
+ return message
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py
new file mode 100644
index 0000000..997c6bb
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py
@@ -0,0 +1,575 @@
+from __future__ import annotations
+
+import asyncio
+import http
+import logging
+import re
+import urllib
+from asyncio.events import TimerHandle
+from collections import deque
+from typing import Any, Callable, Literal, cast
+
+import httptools
+
+from uvicorn._types import (
+ ASGI3Application,
+ ASGIReceiveEvent,
+ ASGISendEvent,
+ HTTPRequestEvent,
+ HTTPResponseBodyEvent,
+ HTTPResponseStartEvent,
+ HTTPScope,
+)
+from uvicorn.config import Config
+from uvicorn.logging import TRACE_LOG_LEVEL
+from uvicorn.protocols.http.flow_control import (
+ CLOSE_HEADER,
+ HIGH_WATER_LIMIT,
+ FlowControl,
+ service_unavailable,
+)
+from uvicorn.protocols.utils import (
+ get_client_addr,
+ get_local_addr,
+ get_path_with_query_string,
+ get_remote_addr,
+ is_ssl,
+)
+from uvicorn.server import ServerState
+
+HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
+HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")
+
+
+def _get_status_line(status_code: int) -> bytes:
+ try:
+ phrase = http.HTTPStatus(status_code).phrase.encode()
+ except ValueError:
+ phrase = b""
+ return b"".join([b"HTTP/1.1 ", str(status_code).encode(), b" ", phrase, b"\r\n"])
+
+
+STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)}
+
+
+class HttpToolsProtocol(asyncio.Protocol):
+ def __init__(
+ self,
+ config: Config,
+ server_state: ServerState,
+ app_state: dict[str, Any],
+ _loop: asyncio.AbstractEventLoop | None = None,
+ ) -> None:
+ if not config.loaded:
+ config.load()
+
+ self.config = config
+ self.app = config.loaded_app
+ self.loop = _loop or asyncio.get_event_loop()
+ self.logger = logging.getLogger("uvicorn.error")
+ self.access_logger = logging.getLogger("uvicorn.access")
+ self.access_log = self.access_logger.hasHandlers()
+ self.parser = httptools.HttpRequestParser(self)
+ self.ws_protocol_class = config.ws_protocol_class
+ self.root_path = config.root_path
+ self.limit_concurrency = config.limit_concurrency
+ self.app_state = app_state
+
+ # Timeouts
+ self.timeout_keep_alive_task: TimerHandle | None = None
+ self.timeout_keep_alive = config.timeout_keep_alive
+
+ # Global state
+ self.server_state = server_state
+ self.connections = server_state.connections
+ self.tasks = server_state.tasks
+
+ # Per-connection state
+ self.transport: asyncio.Transport = None # type: ignore[assignment]
+ self.flow: FlowControl = None # type: ignore[assignment]
+ self.server: tuple[str, int] | None = None
+ self.client: tuple[str, int] | None = None
+ self.scheme: Literal["http", "https"] | None = None
+ self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque()
+
+ # Per-request state
+ self.scope: HTTPScope = None # type: ignore[assignment]
+ self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment]
+ self.expect_100_continue = False
+ self.cycle: RequestResponseCycle = None # type: ignore[assignment]
+
+ # Protocol interface
+ def connection_made( # type: ignore[override]
+ self, transport: asyncio.Transport
+ ) -> None:
+ self.connections.add(self)
+
+ self.transport = transport
+ self.flow = FlowControl(transport)
+ self.server = get_local_addr(transport)
+ self.client = get_remote_addr(transport)
+ self.scheme = "https" if is_ssl(transport) else "http"
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix)
+
+ def connection_lost(self, exc: Exception | None) -> None:
+ self.connections.discard(self)
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection lost", prefix)
+
+ if self.cycle and not self.cycle.response_complete:
+ self.cycle.disconnected = True
+ if self.cycle is not None:
+ self.cycle.message_event.set()
+ if self.flow is not None:
+ self.flow.resume_writing()
+ if exc is None:
+ self.transport.close()
+ self._unset_keepalive_if_required()
+
+ self.parser = None
+
+ def eof_received(self) -> None:
+ pass
+
+ def _unset_keepalive_if_required(self) -> None:
+ if self.timeout_keep_alive_task is not None:
+ self.timeout_keep_alive_task.cancel()
+ self.timeout_keep_alive_task = None
+
+ def _get_upgrade(self) -> bytes | None:
+ connection = []
+ upgrade = None
+ for name, value in self.headers:
+ if name == b"connection":
+ connection = [token.lower().strip() for token in value.split(b",")]
+ if name == b"upgrade":
+ upgrade = value.lower()
+ if b"upgrade" in connection:
+ return upgrade
+ return None
+
+ def _should_upgrade_to_ws(self, upgrade: bytes | None) -> bool:
+ if upgrade == b"websocket" and self.ws_protocol_class is not None:
+ return True
+ if self.config.ws == "auto":
+ msg = "Unsupported upgrade request."
+ self.logger.warning(msg)
+ msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501
+ self.logger.warning(msg)
+ return False
+
+ def _should_upgrade(self) -> bool:
+ upgrade = self._get_upgrade()
+ return self._should_upgrade_to_ws(upgrade)
+
+ def data_received(self, data: bytes) -> None:
+ self._unset_keepalive_if_required()
+
+ try:
+ self.parser.feed_data(data)
+ except httptools.HttpParserError:
+ msg = "Invalid HTTP request received."
+ self.logger.warning(msg)
+ self.send_400_response(msg)
+ return
+ except httptools.HttpParserUpgrade:
+ upgrade = self._get_upgrade()
+ if self._should_upgrade_to_ws(upgrade):
+ self.handle_websocket_upgrade()
+
+ def handle_websocket_upgrade(self) -> None:
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sUpgrading to WebSocket", prefix)
+
+ self.connections.discard(self)
+ method = self.scope["method"].encode()
+ output = [method, b" ", self.url, b" HTTP/1.1\r\n"]
+ for name, value in self.scope["headers"]:
+ output += [name, b": ", value, b"\r\n"]
+ output.append(b"\r\n")
+ protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
+ config=self.config,
+ server_state=self.server_state,
+ app_state=self.app_state,
+ )
+ protocol.connection_made(self.transport)
+ protocol.data_received(b"".join(output))
+ self.transport.set_protocol(protocol)
+
+ def send_400_response(self, msg: str) -> None:
+ content = [STATUS_LINE[400]]
+ for name, value in self.server_state.default_headers:
+ content.extend([name, b": ", value, b"\r\n"])
+ content.extend(
+ [
+ b"content-type: text/plain; charset=utf-8\r\n",
+ b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
+ b"connection: close\r\n",
+ b"\r\n",
+ msg.encode("ascii"),
+ ]
+ )
+ self.transport.write(b"".join(content))
+ self.transport.close()
+
+ def on_message_begin(self) -> None:
+ self.url = b""
+ self.expect_100_continue = False
+ self.headers = []
+ self.scope = { # type: ignore[typeddict-item]
+ "type": "http",
+ "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
+ "http_version": "1.1",
+ "server": self.server,
+ "client": self.client,
+ "scheme": self.scheme, # type: ignore[typeddict-item]
+ "root_path": self.root_path,
+ "headers": self.headers,
+ "state": self.app_state.copy(),
+ }
+
+ # Parser callbacks
+ def on_url(self, url: bytes) -> None:
+ self.url += url
+
+ def on_header(self, name: bytes, value: bytes) -> None:
+ name = name.lower()
+ if name == b"expect" and value.lower() == b"100-continue":
+ self.expect_100_continue = True
+ self.headers.append((name, value))
+
+ def on_headers_complete(self) -> None:
+ http_version = self.parser.get_http_version()
+ method = self.parser.get_method()
+ self.scope["method"] = method.decode("ascii")
+ if http_version != "1.1":
+ self.scope["http_version"] = http_version
+ if self.parser.should_upgrade() and self._should_upgrade():
+ return
+ parsed_url = httptools.parse_url(self.url)
+ raw_path = parsed_url.path
+ path = raw_path.decode("ascii")
+ if "%" in path:
+ path = urllib.parse.unquote(path)
+ full_path = self.root_path + path
+ full_raw_path = self.root_path.encode("ascii") + raw_path
+ self.scope["path"] = full_path
+ self.scope["raw_path"] = full_raw_path
+ self.scope["query_string"] = parsed_url.query or b""
+
+ # Handle 503 responses when 'limit_concurrency' is exceeded.
+ if self.limit_concurrency is not None and (
+ len(self.connections) >= self.limit_concurrency or len(self.tasks) >= self.limit_concurrency
+ ):
+ app = service_unavailable
+ message = "Exceeded concurrency limit."
+ self.logger.warning(message)
+ else:
+ app = self.app
+
+ existing_cycle = self.cycle
+ self.cycle = RequestResponseCycle(
+ scope=self.scope,
+ transport=self.transport,
+ flow=self.flow,
+ logger=self.logger,
+ access_logger=self.access_logger,
+ access_log=self.access_log,
+ default_headers=self.server_state.default_headers,
+ message_event=asyncio.Event(),
+ expect_100_continue=self.expect_100_continue,
+ keep_alive=http_version != "1.0",
+ on_response=self.on_response_complete,
+ )
+ if existing_cycle is None or existing_cycle.response_complete:
+ # Standard case - start processing the request.
+ task = self.loop.create_task(self.cycle.run_asgi(app))
+ task.add_done_callback(self.tasks.discard)
+ self.tasks.add(task)
+ else:
+ # Pipelined HTTP requests need to be queued up.
+ self.flow.pause_reading()
+ self.pipeline.appendleft((self.cycle, app))
+
+ def on_body(self, body: bytes) -> None:
+ if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
+ return
+ self.cycle.body += body
+ if len(self.cycle.body) > HIGH_WATER_LIMIT:
+ self.flow.pause_reading()
+ self.cycle.message_event.set()
+
+ def on_message_complete(self) -> None:
+ if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete:
+ return
+ self.cycle.more_body = False
+ self.cycle.message_event.set()
+
+ def on_response_complete(self) -> None:
+ # Callback for pipelined HTTP requests to be started.
+ self.server_state.total_requests += 1
+
+ if self.transport.is_closing():
+ return
+
+ self._unset_keepalive_if_required()
+
+ # Unpause data reads if needed.
+ self.flow.resume_reading()
+
+ # Unblock any pipelined events. If there are none, arm the
+ # Keep-Alive timeout instead.
+ if self.pipeline:
+ cycle, app = self.pipeline.pop()
+ task = self.loop.create_task(cycle.run_asgi(app))
+ task.add_done_callback(self.tasks.discard)
+ self.tasks.add(task)
+ else:
+ self.timeout_keep_alive_task = self.loop.call_later(
+ self.timeout_keep_alive, self.timeout_keep_alive_handler
+ )
+
+ def shutdown(self) -> None:
+ """
+ Called by the server to commence a graceful shutdown.
+ """
+ if self.cycle is None or self.cycle.response_complete:
+ self.transport.close()
+ else:
+ self.cycle.keep_alive = False
+
+ def pause_writing(self) -> None:
+ """
+ Called by the transport when the write buffer exceeds the high water mark.
+ """
+ self.flow.pause_writing()
+
+ def resume_writing(self) -> None:
+ """
+ Called by the transport when the write buffer drops below the low water mark.
+ """
+ self.flow.resume_writing()
+
+ def timeout_keep_alive_handler(self) -> None:
+ """
+ Called on a keep-alive connection if no new data is received after a short
+ delay.
+ """
+ if not self.transport.is_closing():
+ self.transport.close()
+
+
+class RequestResponseCycle:
+ def __init__(
+ self,
+ scope: HTTPScope,
+ transport: asyncio.Transport,
+ flow: FlowControl,
+ logger: logging.Logger,
+ access_logger: logging.Logger,
+ access_log: bool,
+ default_headers: list[tuple[bytes, bytes]],
+ message_event: asyncio.Event,
+ expect_100_continue: bool,
+ keep_alive: bool,
+ on_response: Callable[..., None],
+ ):
+ self.scope = scope
+ self.transport = transport
+ self.flow = flow
+ self.logger = logger
+ self.access_logger = access_logger
+ self.access_log = access_log
+ self.default_headers = default_headers
+ self.message_event = message_event
+ self.on_response = on_response
+
+ # Connection state
+ self.disconnected = False
+ self.keep_alive = keep_alive
+ self.waiting_for_100_continue = expect_100_continue
+
+ # Request state
+ self.body = b""
+ self.more_body = True
+
+ # Response state
+ self.response_started = False
+ self.response_complete = False
+ self.chunked_encoding: bool | None = None
+ self.expected_content_length = 0
+
+ # ASGI exception wrapper
+ async def run_asgi(self, app: ASGI3Application) -> None:
+ try:
+ result = await app( # type: ignore[func-returns-value]
+ self.scope, self.receive, self.send
+ )
+ except BaseException as exc:
+ msg = "Exception in ASGI application\n"
+ self.logger.error(msg, exc_info=exc)
+ if not self.response_started:
+ await self.send_500_response()
+ else:
+ self.transport.close()
+ else:
+ if result is not None:
+ msg = "ASGI callable should return None, but returned '%s'."
+ self.logger.error(msg, result)
+ self.transport.close()
+ elif not self.response_started and not self.disconnected:
+ msg = "ASGI callable returned without starting response."
+ self.logger.error(msg)
+ await self.send_500_response()
+ elif not self.response_complete and not self.disconnected:
+ msg = "ASGI callable returned without completing response."
+ self.logger.error(msg)
+ self.transport.close()
+ finally:
+ self.on_response = lambda: None
+
+ async def send_500_response(self) -> None:
+ response_start_event: HTTPResponseStartEvent = {
+ "type": "http.response.start",
+ "status": 500,
+ "headers": [
+ (b"content-type", b"text/plain; charset=utf-8"),
+ (b"connection", b"close"),
+ ],
+ }
+ await self.send(response_start_event)
+ response_body_event: HTTPResponseBodyEvent = {
+ "type": "http.response.body",
+ "body": b"Internal Server Error",
+ "more_body": False,
+ }
+ await self.send(response_body_event)
+
+ # ASGI interface
+ async def send(self, message: ASGISendEvent) -> None:
+ message_type = message["type"]
+
+ if self.flow.write_paused and not self.disconnected:
+ await self.flow.drain()
+
+ if self.disconnected:
+ return
+
+ if not self.response_started:
+ # Sending response status line and headers
+ if message_type != "http.response.start":
+ msg = "Expected ASGI message 'http.response.start', but got '%s'."
+ raise RuntimeError(msg % message_type)
+ message = cast("HTTPResponseStartEvent", message)
+
+ self.response_started = True
+ self.waiting_for_100_continue = False
+
+ status_code = message["status"]
+ headers = self.default_headers + list(message.get("headers", []))
+
+ if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers:
+ headers = headers + [CLOSE_HEADER]
+
+ if self.access_log:
+ self.access_logger.info(
+ '%s - "%s %s HTTP/%s" %d',
+ get_client_addr(self.scope),
+ self.scope["method"],
+ get_path_with_query_string(self.scope),
+ self.scope["http_version"],
+ status_code,
+ )
+
+ # Write response status line and headers
+ content = [STATUS_LINE[status_code]]
+
+ for name, value in headers:
+ if HEADER_RE.search(name):
+ raise RuntimeError("Invalid HTTP header name.")
+ if HEADER_VALUE_RE.search(value):
+ raise RuntimeError("Invalid HTTP header value.")
+
+ name = name.lower()
+ if name == b"content-length" and self.chunked_encoding is None:
+ self.expected_content_length = int(value.decode())
+ self.chunked_encoding = False
+ elif name == b"transfer-encoding" and value.lower() == b"chunked":
+ self.expected_content_length = 0
+ self.chunked_encoding = True
+ elif name == b"connection" and value.lower() == b"close":
+ self.keep_alive = False
+ content.extend([name, b": ", value, b"\r\n"])
+
+ if self.chunked_encoding is None and self.scope["method"] != "HEAD" and status_code not in (204, 304):
+ # Neither content-length nor transfer-encoding specified
+ self.chunked_encoding = True
+ content.append(b"transfer-encoding: chunked\r\n")
+
+ content.append(b"\r\n")
+ self.transport.write(b"".join(content))
+
+ elif not self.response_complete:
+ # Sending response body
+ if message_type != "http.response.body":
+ msg = "Expected ASGI message 'http.response.body', but got '%s'."
+ raise RuntimeError(msg % message_type)
+
+ body = cast(bytes, message.get("body", b""))
+ more_body = message.get("more_body", False)
+
+ # Write response body
+ if self.scope["method"] == "HEAD":
+ self.expected_content_length = 0
+ elif self.chunked_encoding:
+ if body:
+ content = [b"%x\r\n" % len(body), body, b"\r\n"]
+ else:
+ content = []
+ if not more_body:
+ content.append(b"0\r\n\r\n")
+ self.transport.write(b"".join(content))
+ else:
+ num_bytes = len(body)
+ if num_bytes > self.expected_content_length:
+ raise RuntimeError("Response content longer than Content-Length")
+ else:
+ self.expected_content_length -= num_bytes
+ self.transport.write(body)
+
+ # Handle response completion
+ if not more_body:
+ if self.expected_content_length != 0:
+ raise RuntimeError("Response content shorter than Content-Length")
+ self.response_complete = True
+ self.message_event.set()
+ if not self.keep_alive:
+ self.transport.close()
+ self.on_response()
+
+ else:
+ # Response already sent
+ msg = "Unexpected ASGI message '%s' sent, after response already completed."
+ raise RuntimeError(msg % message_type)
+
+ async def receive(self) -> ASGIReceiveEvent:
+ if self.waiting_for_100_continue and not self.transport.is_closing():
+ self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
+ self.waiting_for_100_continue = False
+
+ if not self.disconnected and not self.response_complete:
+ self.flow.resume_reading()
+ await self.message_event.wait()
+ self.message_event.clear()
+
+ if self.disconnected or self.response_complete:
+ return {"type": "http.disconnect"}
+ message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body}
+ self.body = b""
+ return message
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
new file mode 100644
index 0000000..4e65806
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/utils.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import asyncio
+import urllib.parse
+
+from uvicorn._types import WWWScope
+
+
+class ClientDisconnected(IOError):
+ ...
+
+
+def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
+ socket_info = transport.get_extra_info("socket")
+ if socket_info is not None:
+ try:
+ info = socket_info.getpeername()
+ return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
+ except OSError: # pragma: no cover
+ # This case appears to inconsistently occur with uvloop
+ # bound to a unix domain socket.
+ return None
+
+ info = transport.get_extra_info("peername")
+ if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
+ return (str(info[0]), int(info[1]))
+ return None
+
+
+def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
+ socket_info = transport.get_extra_info("socket")
+ if socket_info is not None:
+ info = socket_info.getsockname()
+
+ return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None
+ info = transport.get_extra_info("sockname")
+ if info is not None and isinstance(info, (list, tuple)) and len(info) == 2:
+ return (str(info[0]), int(info[1]))
+ return None
+
+
+def is_ssl(transport: asyncio.Transport) -> bool:
+ return bool(transport.get_extra_info("sslcontext"))
+
+
+def get_client_addr(scope: WWWScope) -> str:
+ client = scope.get("client")
+ if not client:
+ return ""
+ return "%s:%d" % client
+
+
+def get_path_with_query_string(scope: WWWScope) -> str:
+ path_with_query_string = urllib.parse.quote(scope["path"])
+ if scope["query_string"]:
+ path_with_query_string = "{}?{}".format(path_with_query_string, scope["query_string"].decode("ascii"))
+ return path_with_query_string
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__init__.py
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..d216ab9
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc
new file mode 100644
index 0000000..e8c06fa
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/auto.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc
new file mode 100644
index 0000000..334a441
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/websockets_impl.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc
new file mode 100644
index 0000000..3b1911e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/__pycache__/wsproto_impl.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py
new file mode 100644
index 0000000..08fd136
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/auto.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import asyncio
+import typing
+
+AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
+try:
+ import websockets # noqa
+except ImportError: # pragma: no cover
+ try:
+ import wsproto # noqa
+ except ImportError:
+ AutoWebSocketsProtocol = None
+ else:
+ from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
+
+ AutoWebSocketsProtocol = WSProtocol
+else:
+ from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
+
+ AutoWebSocketsProtocol = WebSocketProtocol
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py
new file mode 100644
index 0000000..6d098d5
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/websockets_impl.py
@@ -0,0 +1,388 @@
+from __future__ import annotations
+
+import asyncio
+import http
+import logging
+from typing import Any, Literal, Optional, Sequence, cast
+from urllib.parse import unquote
+
+import websockets
+from websockets.datastructures import Headers
+from websockets.exceptions import ConnectionClosed
+from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
+from websockets.legacy.server import HTTPResponse
+from websockets.server import WebSocketServerProtocol
+from websockets.typing import Subprotocol
+
+from uvicorn._types import (
+ ASGISendEvent,
+ WebSocketAcceptEvent,
+ WebSocketCloseEvent,
+ WebSocketConnectEvent,
+ WebSocketDisconnectEvent,
+ WebSocketReceiveEvent,
+ WebSocketResponseBodyEvent,
+ WebSocketResponseStartEvent,
+ WebSocketScope,
+ WebSocketSendEvent,
+)
+from uvicorn.config import Config
+from uvicorn.logging import TRACE_LOG_LEVEL
+from uvicorn.protocols.utils import (
+ ClientDisconnected,
+ get_local_addr,
+ get_path_with_query_string,
+ get_remote_addr,
+ is_ssl,
+)
+from uvicorn.server import ServerState
+
+
+class Server:
+ closing = False
+
+ def register(self, ws: WebSocketServerProtocol) -> None:
+ pass
+
+ def unregister(self, ws: WebSocketServerProtocol) -> None:
+ pass
+
+ def is_serving(self) -> bool:
+ return not self.closing
+
+
+class WebSocketProtocol(WebSocketServerProtocol):
+ extra_headers: list[tuple[str, str]]
+
+ def __init__(
+ self,
+ config: Config,
+ server_state: ServerState,
+ app_state: dict[str, Any],
+ _loop: asyncio.AbstractEventLoop | None = None,
+ ):
+ if not config.loaded:
+ config.load()
+
+ self.config = config
+ self.app = config.loaded_app
+ self.loop = _loop or asyncio.get_event_loop()
+ self.root_path = config.root_path
+ self.app_state = app_state
+
+ # Shared server state
+ self.connections = server_state.connections
+ self.tasks = server_state.tasks
+
+ # Connection state
+ self.transport: asyncio.Transport = None # type: ignore[assignment]
+ self.server: tuple[str, int] | None = None
+ self.client: tuple[str, int] | None = None
+ self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
+
+ # Connection events
+ self.scope: WebSocketScope
+ self.handshake_started_event = asyncio.Event()
+ self.handshake_completed_event = asyncio.Event()
+ self.closed_event = asyncio.Event()
+ self.initial_response: HTTPResponse | None = None
+ self.connect_sent = False
+ self.lost_connection_before_handshake = False
+ self.accepted_subprotocol: Subprotocol | None = None
+
+ self.ws_server: Server = Server() # type: ignore[assignment]
+
+ extensions = []
+ if self.config.ws_per_message_deflate:
+ extensions.append(ServerPerMessageDeflateFactory())
+
+ super().__init__(
+ ws_handler=self.ws_handler,
+ ws_server=self.ws_server, # type: ignore[arg-type]
+ max_size=self.config.ws_max_size,
+ max_queue=self.config.ws_max_queue,
+ ping_interval=self.config.ws_ping_interval,
+ ping_timeout=self.config.ws_ping_timeout,
+ extensions=extensions,
+ logger=logging.getLogger("uvicorn.error"),
+ )
+ self.server_header = None
+ self.extra_headers = [
+ (name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
+ ]
+
+ def connection_made( # type: ignore[override]
+ self, transport: asyncio.Transport
+ ) -> None:
+ self.connections.add(self)
+ self.transport = transport
+ self.server = get_local_addr(transport)
+ self.client = get_remote_addr(transport)
+ self.scheme = "wss" if is_ssl(transport) else "ws"
+
+ if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
+
+ super().connection_made(transport)
+
+ def connection_lost(self, exc: Exception | None) -> None:
+ self.connections.remove(self)
+
+ if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
+
+ self.lost_connection_before_handshake = not self.handshake_completed_event.is_set()
+ self.handshake_completed_event.set()
+ super().connection_lost(exc)
+ if exc is None:
+ self.transport.close()
+
+ def shutdown(self) -> None:
+ self.ws_server.closing = True
+ if self.handshake_completed_event.is_set():
+ self.fail_connection(1012)
+ else:
+ self.send_500_response()
+ self.transport.close()
+
+ def on_task_complete(self, task: asyncio.Task) -> None:
+ self.tasks.discard(task)
+
+ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None:
+ """
+ This hook is called to determine if the websocket should return
+ an HTTP response and close.
+
+ Our behavior here is to start the ASGI application, and then wait
+ for either `accept` or `close` in order to determine if we should
+ close the connection.
+ """
+ path_portion, _, query_string = path.partition("?")
+
+ websockets.legacy.handshake.check_request(headers)
+
+ subprotocols = []
+ for header in headers.get_all("Sec-WebSocket-Protocol"):
+ subprotocols.extend([token.strip() for token in header.split(",")])
+
+ asgi_headers = [
+ (name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
+ for name, value in headers.raw_items()
+ ]
+ path = unquote(path_portion)
+ full_path = self.root_path + path
+ full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")
+
+ self.scope = {
+ "type": "websocket",
+ "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
+ "http_version": "1.1",
+ "scheme": self.scheme,
+ "server": self.server,
+ "client": self.client,
+ "root_path": self.root_path,
+ "path": full_path,
+ "raw_path": full_raw_path,
+ "query_string": query_string.encode("ascii"),
+ "headers": asgi_headers,
+ "subprotocols": subprotocols,
+ "state": self.app_state.copy(),
+ "extensions": {"websocket.http.response": {}},
+ }
+ task = self.loop.create_task(self.run_asgi())
+ task.add_done_callback(self.on_task_complete)
+ self.tasks.add(task)
+ await self.handshake_started_event.wait()
+ return self.initial_response
+
+ def process_subprotocol(
+ self, headers: Headers, available_subprotocols: Sequence[Subprotocol] | None
+ ) -> Subprotocol | None:
+ """
+ We override the standard 'process_subprotocol' behavior here so that
+ we return whatever subprotocol is sent in the 'accept' message.
+ """
+ return self.accepted_subprotocol
+
+ def send_500_response(self) -> None:
+ msg = b"Internal Server Error"
+ content = [
+ b"HTTP/1.1 500 Internal Server Error\r\n" b"content-type: text/plain; charset=utf-8\r\n",
+ b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
+ b"connection: close\r\n",
+ b"\r\n",
+ msg,
+ ]
+ self.transport.write(b"".join(content))
+ # Allow handler task to terminate cleanly, as websockets doesn't cancel it by
+ # itself (see https://github.com/encode/uvicorn/issues/920)
+ self.handshake_started_event.set()
+
+ async def ws_handler( # type: ignore[override]
+ self, protocol: WebSocketServerProtocol, path: str
+ ) -> Any:
+ """
+ This is the main handler function for the 'websockets' implementation
+ to call into. We just wait for close then return, and instead allow
+ 'send' and 'receive' events to drive the flow.
+ """
+ self.handshake_completed_event.set()
+ await self.wait_closed()
+
+ async def run_asgi(self) -> None:
+ """
+ Wrapper around the ASGI callable, handling exceptions and unexpected
+ termination states.
+ """
+ try:
+ result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
+ except ClientDisconnected:
+ self.closed_event.set()
+ self.transport.close()
+ except BaseException as exc:
+ self.closed_event.set()
+ msg = "Exception in ASGI application\n"
+ self.logger.error(msg, exc_info=exc)
+ if not self.handshake_started_event.is_set():
+ self.send_500_response()
+ else:
+ await self.handshake_completed_event.wait()
+ self.transport.close()
+ else:
+ self.closed_event.set()
+ if not self.handshake_started_event.is_set():
+ msg = "ASGI callable returned without sending handshake."
+ self.logger.error(msg)
+ self.send_500_response()
+ self.transport.close()
+ elif result is not None:
+ msg = "ASGI callable should return None, but returned '%s'."
+ self.logger.error(msg, result)
+ await self.handshake_completed_event.wait()
+ self.transport.close()
+
+ async def asgi_send(self, message: ASGISendEvent) -> None:
+ message_type = message["type"]
+
+ if not self.handshake_started_event.is_set():
+ if message_type == "websocket.accept":
+ message = cast("WebSocketAcceptEvent", message)
+ self.logger.info(
+ '%s - "WebSocket %s" [accepted]',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ )
+ self.initial_response = None
+ self.accepted_subprotocol = cast(Optional[Subprotocol], message.get("subprotocol"))
+ if "headers" in message:
+ self.extra_headers.extend(
+ # ASGI spec requires bytes
+ # But for compatibility we need to convert it to strings
+ (name.decode("latin-1"), value.decode("latin-1"))
+ for name, value in message["headers"]
+ )
+ self.handshake_started_event.set()
+
+ elif message_type == "websocket.close":
+ message = cast("WebSocketCloseEvent", message)
+ self.logger.info(
+ '%s - "WebSocket %s" 403',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ )
+ self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
+ self.handshake_started_event.set()
+ self.closed_event.set()
+
+ elif message_type == "websocket.http.response.start":
+ message = cast("WebSocketResponseStartEvent", message)
+ self.logger.info(
+ '%s - "WebSocket %s" %d',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ message["status"],
+ )
+ # websockets requires the status to be an enum. look it up.
+ status = http.HTTPStatus(message["status"])
+ headers = [
+ (name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", [])
+ ]
+ self.initial_response = (status, headers, b"")
+ self.handshake_started_event.set()
+
+ else:
+ msg = (
+ "Expected ASGI message 'websocket.accept', 'websocket.close', "
+ "or 'websocket.http.response.start' but got '%s'."
+ )
+ raise RuntimeError(msg % message_type)
+
+ elif not self.closed_event.is_set() and self.initial_response is None:
+ await self.handshake_completed_event.wait()
+
+ try:
+ if message_type == "websocket.send":
+ message = cast("WebSocketSendEvent", message)
+ bytes_data = message.get("bytes")
+ text_data = message.get("text")
+ data = text_data if bytes_data is None else bytes_data
+ await self.send(data) # type: ignore[arg-type]
+
+ elif message_type == "websocket.close":
+ message = cast("WebSocketCloseEvent", message)
+ code = message.get("code", 1000)
+ reason = message.get("reason", "") or ""
+ await self.close(code, reason)
+ self.closed_event.set()
+
+ else:
+ msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
+ raise RuntimeError(msg % message_type)
+ except ConnectionClosed as exc:
+ raise ClientDisconnected from exc
+
+ elif self.initial_response is not None:
+ if message_type == "websocket.http.response.body":
+ message = cast("WebSocketResponseBodyEvent", message)
+ body = self.initial_response[2] + message["body"]
+ self.initial_response = self.initial_response[:2] + (body,)
+ if not message.get("more_body", False):
+ self.closed_event.set()
+ else:
+ msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
+ raise RuntimeError(msg % message_type)
+
+ else:
+ msg = "Unexpected ASGI message '%s', after sending 'websocket.close' " "or response already completed."
+ raise RuntimeError(msg % message_type)
+
+ async def asgi_receive(
+ self,
+ ) -> WebSocketDisconnectEvent | WebSocketConnectEvent | WebSocketReceiveEvent:
+ if not self.connect_sent:
+ self.connect_sent = True
+ return {"type": "websocket.connect"}
+
+ await self.handshake_completed_event.wait()
+
+ if self.lost_connection_before_handshake:
+ # If the handshake failed or the app closed before handshake completion,
+ # use 1006 Abnormal Closure.
+ return {"type": "websocket.disconnect", "code": 1006}
+
+ if self.closed_event.is_set():
+ return {"type": "websocket.disconnect", "code": 1005}
+
+ try:
+ data = await self.recv()
+ except ConnectionClosed as exc:
+ self.closed_event.set()
+ if self.ws_server.closing:
+ return {"type": "websocket.disconnect", "code": 1012}
+ return {"type": "websocket.disconnect", "code": exc.code}
+
+ if isinstance(data, str):
+ return {"type": "websocket.receive", "text": data}
+ return {"type": "websocket.receive", "bytes": data}
diff --git a/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py
new file mode 100644
index 0000000..c926252
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/uvicorn/protocols/websockets/wsproto_impl.py
@@ -0,0 +1,377 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import typing
+from typing import Literal
+from urllib.parse import unquote
+
+import wsproto
+from wsproto import ConnectionType, events
+from wsproto.connection import ConnectionState
+from wsproto.extensions import Extension, PerMessageDeflate
+from wsproto.utilities import LocalProtocolError, RemoteProtocolError
+
+from uvicorn._types import (
+ ASGISendEvent,
+ WebSocketAcceptEvent,
+ WebSocketCloseEvent,
+ WebSocketEvent,
+ WebSocketResponseBodyEvent,
+ WebSocketResponseStartEvent,
+ WebSocketScope,
+ WebSocketSendEvent,
+)
+from uvicorn.config import Config
+from uvicorn.logging import TRACE_LOG_LEVEL
+from uvicorn.protocols.utils import (
+ ClientDisconnected,
+ get_local_addr,
+ get_path_with_query_string,
+ get_remote_addr,
+ is_ssl,
+)
+from uvicorn.server import ServerState
+
+
+class WSProtocol(asyncio.Protocol):
+ def __init__(
+ self,
+ config: Config,
+ server_state: ServerState,
+ app_state: dict[str, typing.Any],
+ _loop: asyncio.AbstractEventLoop | None = None,
+ ) -> None:
+ if not config.loaded:
+ config.load()
+
+ self.config = config
+ self.app = config.loaded_app
+ self.loop = _loop or asyncio.get_event_loop()
+ self.logger = logging.getLogger("uvicorn.error")
+ self.root_path = config.root_path
+ self.app_state = app_state
+
+ # Shared server state
+ self.connections = server_state.connections
+ self.tasks = server_state.tasks
+ self.default_headers = server_state.default_headers
+
+ # Connection state
+ self.transport: asyncio.Transport = None # type: ignore[assignment]
+ self.server: tuple[str, int] | None = None
+ self.client: tuple[str, int] | None = None
+ self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]
+
+ # WebSocket state
+ self.queue: asyncio.Queue[WebSocketEvent] = asyncio.Queue()
+ self.handshake_complete = False
+ self.close_sent = False
+
+ # Rejection state
+ self.response_started = False
+
+ self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)
+
+ self.read_paused = False
+ self.writable = asyncio.Event()
+ self.writable.set()
+
+ # Buffers
+ self.bytes = b""
+ self.text = ""
+
+ # Protocol interface
+
+ def connection_made( # type: ignore[override]
+ self, transport: asyncio.Transport
+ ) -> None:
+ self.connections.add(self)
+ self.transport = transport
+ self.server = get_local_addr(transport)
+ self.client = get_remote_addr(transport)
+ self.scheme = "wss" if is_ssl(transport) else "ws"
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
+
+ def connection_lost(self, exc: Exception | None) -> None:
+ code = 1005 if self.handshake_complete else 1006
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
+ self.connections.remove(self)
+
+ if self.logger.level <= TRACE_LOG_LEVEL:
+ prefix = "%s:%d - " % self.client if self.client else ""
+ self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
+
+ self.handshake_complete = True
+ if exc is None:
+ self.transport.close()
+
+ def eof_received(self) -> None:
+ pass
+
+ def data_received(self, data: bytes) -> None:
+ try:
+ self.conn.receive_data(data)
+ except RemoteProtocolError as err:
+ # TODO: Remove `type: ignore` when wsproto fixes the type annotation.
+ self.transport.write(self.conn.send(err.event_hint)) # type: ignore[arg-type] # noqa: E501
+ self.transport.close()
+ else:
+ self.handle_events()
+
+ def handle_events(self) -> None:
+ for event in self.conn.events():
+ if isinstance(event, events.Request):
+ self.handle_connect(event)
+ elif isinstance(event, events.TextMessage):
+ self.handle_text(event)
+ elif isinstance(event, events.BytesMessage):
+ self.handle_bytes(event)
+ elif isinstance(event, events.CloseConnection):
+ self.handle_close(event)
+ elif isinstance(event, events.Ping):
+ self.handle_ping(event)
+
+ def pause_writing(self) -> None:
+ """
+ Called by the transport when the write buffer exceeds the high water mark.
+ """
+ self.writable.clear()
+
+ def resume_writing(self) -> None:
+ """
+ Called by the transport when the write buffer drops below the low water mark.
+ """
+ self.writable.set()
+
+ def shutdown(self) -> None:
+ if self.handshake_complete:
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
+ output = self.conn.send(wsproto.events.CloseConnection(code=1012))
+ self.transport.write(output)
+ else:
+ self.send_500_response()
+ self.transport.close()
+
+ def on_task_complete(self, task: asyncio.Task) -> None:
+ self.tasks.discard(task)
+
+ # Event handlers
+
+ def handle_connect(self, event: events.Request) -> None:
+ headers = [(b"host", event.host.encode())]
+ headers += [(key.lower(), value) for key, value in event.extra_headers]
+ raw_path, _, query_string = event.target.partition("?")
+ path = unquote(raw_path)
+ full_path = self.root_path + path
+ full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii")
+ self.scope: WebSocketScope = {
+ "type": "websocket",
+ "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
+ "http_version": "1.1",
+ "scheme": self.scheme,
+ "server": self.server,
+ "client": self.client,
+ "root_path": self.root_path,
+ "path": full_path,
+ "raw_path": full_raw_path,
+ "query_string": query_string.encode("ascii"),
+ "headers": headers,
+ "subprotocols": event.subprotocols,
+ "state": self.app_state.copy(),
+ "extensions": {"websocket.http.response": {}},
+ }
+ self.queue.put_nowait({"type": "websocket.connect"})
+ task = self.loop.create_task(self.run_asgi())
+ task.add_done_callback(self.on_task_complete)
+ self.tasks.add(task)
+
+ def handle_text(self, event: events.TextMessage) -> None:
+ self.text += event.data
+ if event.message_finished:
+ self.queue.put_nowait({"type": "websocket.receive", "text": self.text})
+ self.text = ""
+ if not self.read_paused:
+ self.read_paused = True
+ self.transport.pause_reading()
+
+ def handle_bytes(self, event: events.BytesMessage) -> None:
+ self.bytes += event.data
+ # todo: we may want to guard the size of self.bytes and self.text
+ if event.message_finished:
+ self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes})
+ self.bytes = b""
+ if not self.read_paused:
+ self.read_paused = True
+ self.transport.pause_reading()
+
+ def handle_close(self, event: events.CloseConnection) -> None:
+ if self.conn.state == ConnectionState.REMOTE_CLOSING:
+ self.transport.write(self.conn.send(event.response()))
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})
+ self.transport.close()
+
+ def handle_ping(self, event: events.Ping) -> None:
+ self.transport.write(self.conn.send(event.response()))
+
+ def send_500_response(self) -> None:
+ if self.response_started or self.handshake_complete:
+ return # we cannot send responses anymore
+ headers = [
+ (b"content-type", b"text/plain; charset=utf-8"),
+ (b"connection", b"close"),
+ ]
+ output = self.conn.send(wsproto.events.RejectConnection(status_code=500, headers=headers, has_body=True))
+ output += self.conn.send(wsproto.events.RejectData(data=b"Internal Server Error"))
+ self.transport.write(output)
+
+ async def run_asgi(self) -> None:
+ try:
+ result = await self.app(self.scope, self.receive, self.send)
+ except ClientDisconnected:
+ self.transport.close()
+ except BaseException:
+ self.logger.exception("Exception in ASGI application\n")
+ self.send_500_response()
+ self.transport.close()
+ else:
+ if not self.handshake_complete:
+ msg = "ASGI callable returned without completing handshake."
+ self.logger.error(msg)
+ self.send_500_response()
+ self.transport.close()
+ elif result is not None:
+ msg = "ASGI callable should return None, but returned '%s'."
+ self.logger.error(msg, result)
+ self.transport.close()
+
+ async def send(self, message: ASGISendEvent) -> None:
+ await self.writable.wait()
+
+ message_type = message["type"]
+
+ if not self.handshake_complete:
+ if message_type == "websocket.accept":
+ message = typing.cast(WebSocketAcceptEvent, message)
+ self.logger.info(
+ '%s - "WebSocket %s" [accepted]',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ )
+ subprotocol = message.get("subprotocol")
+ extra_headers = self.default_headers + list(message.get("headers", []))
+ extensions: list[Extension] = []
+ if self.config.ws_per_message_deflate:
+ extensions.append(PerMessageDeflate())
+ if not self.transport.is_closing():
+ self.handshake_complete = True
+ output = self.conn.send(
+ wsproto.events.AcceptConnection(
+ subprotocol=subprotocol,
+ extensions=extensions,
+ extra_headers=extra_headers,
+ )
+ )
+ self.transport.write(output)
+
+ elif message_type == "websocket.close":
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
+ self.logger.info(
+ '%s - "WebSocket %s" 403',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ )
+ self.handshake_complete = True
+ self.close_sent = True
+ event = events.RejectConnection(status_code=403, headers=[])
+ output = self.conn.send(event)
+ self.transport.write(output)
+ self.transport.close()
+
+ elif message_type == "websocket.http.response.start":
+ message = typing.cast(WebSocketResponseStartEvent, message)
+ # ensure status code is in the valid range
+ if not (100 <= message["status"] < 600):
+ msg = "Invalid HTTP status code '%d' in response."
+ raise RuntimeError(msg % message["status"])
+ self.logger.info(
+ '%s - "WebSocket %s" %d',
+ self.scope["client"],
+ get_path_with_query_string(self.scope),
+ message["status"],
+ )
+ self.handshake_complete = True
+ event = events.RejectConnection(
+ status_code=message["status"],
+ headers=list(message["headers"]),
+ has_body=True,
+ )
+ output = self.conn.send(event)
+ self.transport.write(output)
+ self.response_started = True
+
+ else:
+ msg = (
+ "Expected ASGI message 'websocket.accept', 'websocket.close' "
+ "or 'websocket.http.response.start' "
+ "but got '%s'."
+ )
+ raise RuntimeError(msg % message_type)
+
+ elif not self.close_sent and not self.response_started:
+ try:
+ if message_type == "websocket.send":
+ message = typing.cast(WebSocketSendEvent, message)
+ bytes_data = message.get("bytes")
+ text_data = message.get("text")
+ data = text_data if bytes_data is None else bytes_data
+ output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
+ if not self.transport.is_closing():
+ self.transport.write(output)
+
+ elif message_type == "websocket.close":
+ message = typing.cast(WebSocketCloseEvent, message)
+ self.close_sent = True
+ code = message.get("code", 1000)
+ reason = message.get("reason", "") or ""
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
+ output = self.conn.send(wsproto.events.CloseConnection(code=code, reason=reason))
+ if not self.transport.is_closing():
+ self.transport.write(output)
+ self.transport.close()
+
+ else:
+ msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'."
+ raise RuntimeError(msg % message_type)
+ except LocalProtocolError as exc:
+ raise ClientDisconnected from exc
+ elif self.response_started:
+ if message_type == "websocket.http.response.body":
+ message = typing.cast("WebSocketResponseBodyEvent", message)
+ body_finished = not message.get("more_body", False)
+ reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
+ output = self.conn.send(reject_data)
+ self.transport.write(output)
+
+ if body_finished:
+ self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
+ self.close_sent = True
+ self.transport.close()
+
+ else:
+ msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'."
+ raise RuntimeError(msg % message_type)
+
+ else:
+ msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
+ raise RuntimeError(msg % message_type)
+
+ async def receive(self) -> WebSocketEvent:
+ message = await self.queue.get()
+ if self.read_paused and self.queue.empty():
+ self.read_paused = False
+ self.transport.resume_reading()
+ return message