summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/websockets/legacy/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/websockets/legacy/server.py')
-rw-r--r--venv/lib/python3.11/site-packages/websockets/legacy/server.py1185
1 files changed, 1185 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/websockets/legacy/server.py b/venv/lib/python3.11/site-packages/websockets/legacy/server.py
new file mode 100644
index 0000000..7c24dd7
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/websockets/legacy/server.py
@@ -0,0 +1,1185 @@
+from __future__ import annotations
+
+import asyncio
+import email.utils
+import functools
+import http
+import inspect
+import logging
+import socket
+import warnings
+from types import TracebackType
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
+
+from ..datastructures import Headers, HeadersLike, MultipleValuesError
+from ..exceptions import (
+ AbortHandshake,
+ InvalidHandshake,
+ InvalidHeader,
+ InvalidMessage,
+ InvalidOrigin,
+ InvalidUpgrade,
+ NegotiationError,
+)
+from ..extensions import Extension, ServerExtensionFactory
+from ..extensions.permessage_deflate import enable_server_permessage_deflate
+from ..headers import (
+ build_extension,
+ parse_extension,
+ parse_subprotocol,
+ validate_subprotocols,
+)
+from ..http import USER_AGENT
+from ..protocol import State
+from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol
+from .compatibility import asyncio_timeout
+from .handshake import build_response, check_request
+from .http import read_request
+from .protocol import WebSocketCommonProtocol
+
+
+__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"]
+
+
+HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
+
+HTTPResponse = Tuple[StatusLike, HeadersLike, bytes]
+
+
+class WebSocketServerProtocol(WebSocketCommonProtocol):
+ """
+ WebSocket server connection.
+
+ :class:`WebSocketServerProtocol` provides :meth:`recv` and :meth:`send`
+ coroutines for receiving and sending messages.
+
+ It supports asynchronous iteration to receive messages::
+
+ async for message in websocket:
+ await process(message)
+
+ The iterator exits normally when the connection is closed with close code
+ 1000 (OK) or 1001 (going away) or without a close code. It raises
+ a :exc:`~websockets.exceptions.ConnectionClosedError` when the connection
+ is closed with any other code.
+
+ You may customize the opening handshake in a subclass by
+ overriding :meth:`process_request` or :meth:`select_subprotocol`.
+
+ Args:
+ ws_server: WebSocket server that created this connection.
+
+ See :func:`serve` for the documentation of ``ws_handler``, ``logger``, ``origins``,
+ ``extensions``, ``subprotocols``, ``extra_headers``, and ``server_header``.
+
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
+
+ """
+
+ is_client = False
+ side = "server"
+
+ def __init__(
+ self,
+ ws_handler: Union[
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
+ ],
+ ws_server: WebSocketServer,
+ *,
+ logger: Optional[LoggerLike] = None,
+ origins: Optional[Sequence[Optional[Origin]]] = None,
+ extensions: Optional[Sequence[ServerExtensionFactory]] = None,
+ subprotocols: Optional[Sequence[Subprotocol]] = None,
+ extra_headers: Optional[HeadersLikeOrCallable] = None,
+ server_header: Optional[str] = USER_AGENT,
+ process_request: Optional[
+ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
+ ] = None,
+ select_subprotocol: Optional[
+ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
+ ] = None,
+ open_timeout: Optional[float] = 10,
+ **kwargs: Any,
+ ) -> None:
+ if logger is None:
+ logger = logging.getLogger("websockets.server")
+ super().__init__(logger=logger, **kwargs)
+ # For backwards compatibility with 6.0 or earlier.
+ if origins is not None and "" in origins:
+ warnings.warn("use None instead of '' in origins", DeprecationWarning)
+ origins = [None if origin == "" else origin for origin in origins]
+ # For backwards compatibility with 10.0 or earlier. Done here in
+ # addition to serve to trigger the deprecation warning on direct
+ # use of WebSocketServerProtocol.
+ self.ws_handler = remove_path_argument(ws_handler)
+ self.ws_server = ws_server
+ self.origins = origins
+ self.available_extensions = extensions
+ self.available_subprotocols = subprotocols
+ self.extra_headers = extra_headers
+ self.server_header = server_header
+ self._process_request = process_request
+ self._select_subprotocol = select_subprotocol
+ self.open_timeout = open_timeout
+
+ def connection_made(self, transport: asyncio.BaseTransport) -> None:
+ """
+ Register connection and initialize a task to handle it.
+
+ """
+ super().connection_made(transport)
+ # Register the connection with the server before creating the handler
+ # task. Registering at the beginning of the handler coroutine would
+ # create a race condition between the creation of the task, which
+ # schedules its execution, and the moment the handler starts running.
+ self.ws_server.register(self)
+ self.handler_task = self.loop.create_task(self.handler())
+
+ async def handler(self) -> None:
+ """
+ Handle the lifecycle of a WebSocket connection.
+
+ Since this method doesn't have a caller able to handle exceptions, it
+ attempts to log relevant ones and guarantees that the TCP connection is
+ closed before exiting.
+
+ """
+ try:
+ try:
+ async with asyncio_timeout(self.open_timeout):
+ await self.handshake(
+ origins=self.origins,
+ available_extensions=self.available_extensions,
+ available_subprotocols=self.available_subprotocols,
+ extra_headers=self.extra_headers,
+ )
+ except asyncio.TimeoutError: # pragma: no cover
+ raise
+ except ConnectionError:
+ raise
+ except Exception as exc:
+ if isinstance(exc, AbortHandshake):
+ status, headers, body = exc.status, exc.headers, exc.body
+ elif isinstance(exc, InvalidOrigin):
+ if self.debug:
+ self.logger.debug("! invalid origin", exc_info=True)
+ status, headers, body = (
+ http.HTTPStatus.FORBIDDEN,
+ Headers(),
+ f"Failed to open a WebSocket connection: {exc}.\n".encode(),
+ )
+ elif isinstance(exc, InvalidUpgrade):
+ if self.debug:
+ self.logger.debug("! invalid upgrade", exc_info=True)
+ status, headers, body = (
+ http.HTTPStatus.UPGRADE_REQUIRED,
+ Headers([("Upgrade", "websocket")]),
+ (
+ f"Failed to open a WebSocket connection: {exc}.\n"
+ f"\n"
+ f"You cannot access a WebSocket server directly "
+ f"with a browser. You need a WebSocket client.\n"
+ ).encode(),
+ )
+ elif isinstance(exc, InvalidHandshake):
+ if self.debug:
+ self.logger.debug("! invalid handshake", exc_info=True)
+ status, headers, body = (
+ http.HTTPStatus.BAD_REQUEST,
+ Headers(),
+ f"Failed to open a WebSocket connection: {exc}.\n".encode(),
+ )
+ else:
+ self.logger.error("opening handshake failed", exc_info=True)
+ status, headers, body = (
+ http.HTTPStatus.INTERNAL_SERVER_ERROR,
+ Headers(),
+ (
+ b"Failed to open a WebSocket connection.\n"
+ b"See server log for more information.\n"
+ ),
+ )
+
+ headers.setdefault("Date", email.utils.formatdate(usegmt=True))
+ if self.server_header is not None:
+ headers.setdefault("Server", self.server_header)
+
+ headers.setdefault("Content-Length", str(len(body)))
+ headers.setdefault("Content-Type", "text/plain")
+ headers.setdefault("Connection", "close")
+
+ self.write_http_response(status, headers, body)
+ self.logger.info(
+ "connection rejected (%d %s)", status.value, status.phrase
+ )
+ await self.close_transport()
+ return
+
+ try:
+ await self.ws_handler(self)
+ except Exception:
+ self.logger.error("connection handler failed", exc_info=True)
+ if not self.closed:
+ self.fail_connection(1011)
+ raise
+
+ try:
+ await self.close()
+ except ConnectionError:
+ raise
+ except Exception:
+ self.logger.error("closing handshake failed", exc_info=True)
+ raise
+
+ except Exception:
+ # Last-ditch attempt to avoid leaking connections on errors.
+ try:
+ self.transport.close()
+ except Exception: # pragma: no cover
+ pass
+
+ finally:
+ # Unregister the connection with the server when the handler task
+ # terminates. Registration is tied to the lifecycle of the handler
+ # task because the server waits for tasks attached to registered
+ # connections before terminating.
+ self.ws_server.unregister(self)
+ self.logger.info("connection closed")
+
+ async def read_http_request(self) -> Tuple[str, Headers]:
+ """
+ Read request line and headers from the HTTP request.
+
+ If the request contains a body, it may be read from ``self.reader``
+ after this coroutine returns.
+
+ Raises:
+ InvalidMessage: if the HTTP message is malformed or isn't an
+ HTTP/1.1 GET request.
+
+ """
+ try:
+ path, headers = await read_request(self.reader)
+ except asyncio.CancelledError: # pragma: no cover
+ raise
+ except Exception as exc:
+ raise InvalidMessage("did not receive a valid HTTP request") from exc
+
+ if self.debug:
+ self.logger.debug("< GET %s HTTP/1.1", path)
+ for key, value in headers.raw_items():
+ self.logger.debug("< %s: %s", key, value)
+
+ self.path = path
+ self.request_headers = headers
+
+ return path, headers
+
+ def write_http_response(
+ self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None
+ ) -> None:
+ """
+ Write status line and headers to the HTTP response.
+
+ This coroutine is also able to write a response body.
+
+ """
+ self.response_headers = headers
+
+ if self.debug:
+ self.logger.debug("> HTTP/1.1 %d %s", status.value, status.phrase)
+ for key, value in headers.raw_items():
+ self.logger.debug("> %s: %s", key, value)
+ if body is not None:
+ self.logger.debug("> [body] (%d bytes)", len(body))
+
+ # Since the status line and headers only contain ASCII characters,
+ # we can keep this simple.
+ response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
+ response += str(headers)
+
+ self.transport.write(response.encode())
+
+ if body is not None:
+ self.transport.write(body)
+
+ async def process_request(
+ self, path: str, request_headers: Headers
+ ) -> Optional[HTTPResponse]:
+ """
+ Intercept the HTTP request and return an HTTP response if appropriate.
+
+ You may override this method in a :class:`WebSocketServerProtocol`
+ subclass, for example:
+
+ * to return an HTTP 200 OK response on a given path; then a load
+ balancer can use this path for a health check;
+ * to authenticate the request and return an HTTP 401 Unauthorized or an
+ HTTP 403 Forbidden when authentication fails.
+
+ You may also override this method with the ``process_request``
+ argument of :func:`serve` and :class:`WebSocketServerProtocol`. This
+ is equivalent, except ``process_request`` won't have access to the
+ protocol instance, so it can't store information for later use.
+
+ :meth:`process_request` is expected to complete quickly. If it may run
+ for a long time, then it should await :meth:`wait_closed` and exit if
+ :meth:`wait_closed` completes, or else it could prevent the server
+ from shutting down.
+
+ Args:
+ path: request path, including optional query string.
+ request_headers: request headers.
+
+ Returns:
+ Optional[Tuple[StatusLike, HeadersLike, bytes]]: :obj:`None`
+ to continue the WebSocket handshake normally.
+
+ An HTTP response, represented by a 3-uple of the response status,
+ headers, and body, to abort the WebSocket handshake and return
+ that HTTP response instead.
+
+ """
+ if self._process_request is not None:
+ response = self._process_request(path, request_headers)
+ if isinstance(response, Awaitable):
+ return await response
+ else:
+ # For backwards compatibility with 7.0.
+ warnings.warn(
+ "declare process_request as a coroutine", DeprecationWarning
+ )
+ return response
+ return None
+
+ @staticmethod
+ def process_origin(
+ headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None
+ ) -> Optional[Origin]:
+ """
+ Handle the Origin HTTP request header.
+
+ Args:
+ headers: request headers.
+ origins: optional list of acceptable origins.
+
+ Raises:
+ InvalidOrigin: if the origin isn't acceptable.
+
+ """
+ # "The user agent MUST NOT include more than one Origin header field"
+ # per https://www.rfc-editor.org/rfc/rfc6454.html#section-7.3.
+ try:
+ origin = cast(Optional[Origin], headers.get("Origin"))
+ except MultipleValuesError as exc:
+ raise InvalidHeader("Origin", "more than one Origin header found") from exc
+ if origins is not None:
+ if origin not in origins:
+ raise InvalidOrigin(origin)
+ return origin
+
+ @staticmethod
+ def process_extensions(
+ headers: Headers,
+ available_extensions: Optional[Sequence[ServerExtensionFactory]],
+ ) -> Tuple[Optional[str], List[Extension]]:
+ """
+ Handle the Sec-WebSocket-Extensions HTTP request header.
+
+ Accept or reject each extension proposed in the client request.
+ Negotiate parameters for accepted extensions.
+
+ Return the Sec-WebSocket-Extensions HTTP response header and the list
+ of accepted extensions.
+
+ :rfc:`6455` leaves the rules up to the specification of each
+ :extension.
+
+ To provide this level of flexibility, for each extension proposed by
+ the client, we check for a match with each extension available in the
+ server configuration. If no match is found, the extension is ignored.
+
+ If several variants of the same extension are proposed by the client,
+ it may be accepted several times, which won't make sense in general.
+ Extensions must implement their own requirements. For this purpose,
+ the list of previously accepted extensions is provided.
+
+ This process doesn't allow the server to reorder extensions. It can
+ only select a subset of the extensions proposed by the client.
+
+ Other requirements, for example related to mandatory extensions or the
+ order of extensions, may be implemented by overriding this method.
+
+ Args:
+ headers: request headers.
+ extensions: optional list of supported extensions.
+
+ Raises:
+ InvalidHandshake: to abort the handshake with an HTTP 400 error.
+
+ """
+ response_header_value: Optional[str] = None
+
+ extension_headers: List[ExtensionHeader] = []
+ accepted_extensions: List[Extension] = []
+
+ header_values = headers.get_all("Sec-WebSocket-Extensions")
+
+ if header_values and available_extensions:
+ parsed_header_values: List[ExtensionHeader] = sum(
+ [parse_extension(header_value) for header_value in header_values], []
+ )
+
+ for name, request_params in parsed_header_values:
+ for ext_factory in available_extensions:
+ # Skip non-matching extensions based on their name.
+ if ext_factory.name != name:
+ continue
+
+ # Skip non-matching extensions based on their params.
+ try:
+ response_params, extension = ext_factory.process_request_params(
+ request_params, accepted_extensions
+ )
+ except NegotiationError:
+ continue
+
+ # Add matching extension to the final list.
+ extension_headers.append((name, response_params))
+ accepted_extensions.append(extension)
+
+ # Break out of the loop once we have a match.
+ break
+
+ # If we didn't break from the loop, no extension in our list
+ # matched what the client sent. The extension is declined.
+
+ # Serialize extension header.
+ if extension_headers:
+ response_header_value = build_extension(extension_headers)
+
+ return response_header_value, accepted_extensions
+
+ # Not @staticmethod because it calls self.select_subprotocol()
+ def process_subprotocol(
+ self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
+ ) -> Optional[Subprotocol]:
+ """
+ Handle the Sec-WebSocket-Protocol HTTP request header.
+
+ Return Sec-WebSocket-Protocol HTTP response header, which is the same
+ as the selected subprotocol.
+
+ Args:
+ headers: request headers.
+ available_subprotocols: optional list of supported subprotocols.
+
+ Raises:
+ InvalidHandshake: to abort the handshake with an HTTP 400 error.
+
+ """
+ subprotocol: Optional[Subprotocol] = None
+
+ header_values = headers.get_all("Sec-WebSocket-Protocol")
+
+ if header_values and available_subprotocols:
+ parsed_header_values: List[Subprotocol] = sum(
+ [parse_subprotocol(header_value) for header_value in header_values], []
+ )
+
+ subprotocol = self.select_subprotocol(
+ parsed_header_values, available_subprotocols
+ )
+
+ return subprotocol
+
+ def select_subprotocol(
+ self,
+ client_subprotocols: Sequence[Subprotocol],
+ server_subprotocols: Sequence[Subprotocol],
+ ) -> Optional[Subprotocol]:
+ """
+ Pick a subprotocol among those supported by the client and the server.
+
+ If several subprotocols are available, select the preferred subprotocol
+ by giving equal weight to the preferences of the client and the server.
+
+ If no subprotocol is available, proceed without a subprotocol.
+
+ You may provide a ``select_subprotocol`` argument to :func:`serve` or
+ :class:`WebSocketServerProtocol` to override this logic. For example,
+ you could reject the handshake if the client doesn't support a
+ particular subprotocol, rather than accept the handshake without that
+ subprotocol.
+
+ Args:
+ client_subprotocols: list of subprotocols offered by the client.
+ server_subprotocols: list of subprotocols available on the server.
+
+ Returns:
+ Optional[Subprotocol]: Selected subprotocol, if a common subprotocol
+ was found.
+
+ :obj:`None` to continue without a subprotocol.
+
+ """
+ if self._select_subprotocol is not None:
+ return self._select_subprotocol(client_subprotocols, server_subprotocols)
+
+ subprotocols = set(client_subprotocols) & set(server_subprotocols)
+ if not subprotocols:
+ return None
+ return sorted(
+ subprotocols,
+ key=lambda p: client_subprotocols.index(p) + server_subprotocols.index(p),
+ )[0]
+
+ async def handshake(
+ self,
+ origins: Optional[Sequence[Optional[Origin]]] = None,
+ available_extensions: Optional[Sequence[ServerExtensionFactory]] = None,
+ available_subprotocols: Optional[Sequence[Subprotocol]] = None,
+ extra_headers: Optional[HeadersLikeOrCallable] = None,
+ ) -> str:
+ """
+ Perform the server side of the opening handshake.
+
+ Args:
+ origins: list of acceptable values of the Origin HTTP header;
+ include :obj:`None` if the lack of an origin is acceptable.
+ extensions: list of supported extensions, in order in which they
+ should be tried.
+ subprotocols: list of supported subprotocols, in order of
+ decreasing preference.
+ extra_headers: arbitrary HTTP headers to add to the response when
+ the handshake succeeds.
+
+ Returns:
+ str: path of the URI of the request.
+
+ Raises:
+ InvalidHandshake: if the handshake fails.
+
+ """
+ path, request_headers = await self.read_http_request()
+
+ # Hook for customizing request handling, for example checking
+ # authentication or treating some paths as plain HTTP endpoints.
+ early_response_awaitable = self.process_request(path, request_headers)
+ if isinstance(early_response_awaitable, Awaitable):
+ early_response = await early_response_awaitable
+ else:
+ # For backwards compatibility with 7.0.
+ warnings.warn("declare process_request as a coroutine", DeprecationWarning)
+ early_response = early_response_awaitable
+
+ # The connection may drop while process_request is running.
+ if self.state is State.CLOSED:
+ # This subclass of ConnectionError is silently ignored in handler().
+ raise BrokenPipeError("connection closed during opening handshake")
+
+ # Change the response to a 503 error if the server is shutting down.
+ if not self.ws_server.is_serving():
+ early_response = (
+ http.HTTPStatus.SERVICE_UNAVAILABLE,
+ [],
+ b"Server is shutting down.\n",
+ )
+
+ if early_response is not None:
+ raise AbortHandshake(*early_response)
+
+ key = check_request(request_headers)
+
+ self.origin = self.process_origin(request_headers, origins)
+
+ extensions_header, self.extensions = self.process_extensions(
+ request_headers, available_extensions
+ )
+
+ protocol_header = self.subprotocol = self.process_subprotocol(
+ request_headers, available_subprotocols
+ )
+
+ response_headers = Headers()
+
+ build_response(response_headers, key)
+
+ if extensions_header is not None:
+ response_headers["Sec-WebSocket-Extensions"] = extensions_header
+
+ if protocol_header is not None:
+ response_headers["Sec-WebSocket-Protocol"] = protocol_header
+
+ if callable(extra_headers):
+ extra_headers = extra_headers(path, self.request_headers)
+ if extra_headers is not None:
+ response_headers.update(extra_headers)
+
+ response_headers.setdefault("Date", email.utils.formatdate(usegmt=True))
+ if self.server_header is not None:
+ response_headers.setdefault("Server", self.server_header)
+
+ self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers)
+
+ self.logger.info("connection open")
+
+ self.connection_open()
+
+ return path
+
+
+class WebSocketServer:
+ """
+ WebSocket server returned by :func:`serve`.
+
+ This class provides the same interface as :class:`~asyncio.Server`,
+ notably the :meth:`~asyncio.Server.close`
+ and :meth:`~asyncio.Server.wait_closed` methods.
+
+ It keeps track of WebSocket connections in order to close them properly
+ when shutting down.
+
+ Args:
+ logger: Logger for this server.
+ It defaults to ``logging.getLogger("websockets.server")``.
+ See the :doc:`logging guide <../../topics/logging>` for details.
+
+ """
+
+ def __init__(self, logger: Optional[LoggerLike] = None):
+ if logger is None:
+ logger = logging.getLogger("websockets.server")
+ self.logger = logger
+
+ # Keep track of active connections.
+ self.websockets: Set[WebSocketServerProtocol] = set()
+
+ # Task responsible for closing the server and terminating connections.
+ self.close_task: Optional[asyncio.Task[None]] = None
+
+ # Completed when the server is closed and connections are terminated.
+ self.closed_waiter: asyncio.Future[None]
+
+ def wrap(self, server: asyncio.base_events.Server) -> None:
+ """
+ Attach to a given :class:`~asyncio.Server`.
+
+ Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
+ custom ``Server`` class, the easiest solution that doesn't rely on
+ private :mod:`asyncio` APIs is to:
+
+ - instantiate a :class:`WebSocketServer`
+ - give the protocol factory a reference to that instance
+ - call :meth:`~asyncio.loop.create_server` with the factory
+ - attach the resulting :class:`~asyncio.Server` with this method
+
+ """
+ self.server = server
+ for sock in server.sockets:
+ if sock.family == socket.AF_INET:
+ name = "%s:%d" % sock.getsockname()
+ elif sock.family == socket.AF_INET6:
+ name = "[%s]:%d" % sock.getsockname()[:2]
+ elif sock.family == socket.AF_UNIX:
+ name = sock.getsockname()
+ # In the unlikely event that someone runs websockets over a
+ # protocol other than IP or Unix sockets, avoid crashing.
+ else: # pragma: no cover
+ name = str(sock.getsockname())
+ self.logger.info("server listening on %s", name)
+
+ # Initialized here because we need a reference to the event loop.
+ # This should be moved back to __init__ when dropping Python < 3.10.
+ self.closed_waiter = server.get_loop().create_future()
+
+ def register(self, protocol: WebSocketServerProtocol) -> None:
+ """
+ Register a connection with this server.
+
+ """
+ self.websockets.add(protocol)
+
+ def unregister(self, protocol: WebSocketServerProtocol) -> None:
+ """
+ Unregister a connection with this server.
+
+ """
+ self.websockets.remove(protocol)
+
+ def close(self, close_connections: bool = True) -> None:
+ """
+ Close the server.
+
+ * Close the underlying :class:`~asyncio.Server`.
+ * When ``close_connections`` is :obj:`True`, which is the default,
+ close existing connections. Specifically:
+
+ * Reject opening WebSocket connections with an HTTP 503 (service
+ unavailable) error. This happens when the server accepted the TCP
+ connection but didn't complete the opening handshake before closing.
+ * Close open WebSocket connections with close code 1001 (going away).
+
+ * Wait until all connection handlers terminate.
+
+ :meth:`close` is idempotent.
+
+ """
+ if self.close_task is None:
+ self.close_task = self.get_loop().create_task(
+ self._close(close_connections)
+ )
+
+ async def _close(self, close_connections: bool) -> None:
+ """
+ Implementation of :meth:`close`.
+
+ This calls :meth:`~asyncio.Server.close` on the underlying
+ :class:`~asyncio.Server` object to stop accepting new connections and
+ then closes open connections with close code 1001.
+
+ """
+ self.logger.info("server closing")
+
+ # Stop accepting new connections.
+ self.server.close()
+
+ # Wait until all accepted connections reach connection_made() and call
+ # register(). See https://bugs.python.org/issue34852 for details.
+ await asyncio.sleep(0)
+
+ if close_connections:
+ # Close OPEN connections with close code 1001. After server.close(),
+ # handshake() closes OPENING connections with an HTTP 503 error.
+ close_tasks = [
+ asyncio.create_task(websocket.close(1001))
+ for websocket in self.websockets
+ if websocket.state is not State.CONNECTING
+ ]
+ # asyncio.wait doesn't accept an empty first argument.
+ if close_tasks:
+ await asyncio.wait(close_tasks)
+
+ # Wait until all TCP connections are closed.
+ await self.server.wait_closed()
+
+ # Wait until all connection handlers terminate.
+ # asyncio.wait doesn't accept an empty first argument.
+ if self.websockets:
+ await asyncio.wait(
+ [websocket.handler_task for websocket in self.websockets]
+ )
+
+ # Tell wait_closed() to return.
+ self.closed_waiter.set_result(None)
+
+ self.logger.info("server closed")
+
+ async def wait_closed(self) -> None:
+ """
+ Wait until the server is closed.
+
+ When :meth:`wait_closed` returns, all TCP connections are closed and
+ all connection handlers have returned.
+
+ To ensure a fast shutdown, a connection handler should always be
+ awaiting at least one of:
+
+ * :meth:`~WebSocketServerProtocol.recv`: when the connection is closed,
+ it raises :exc:`~websockets.exceptions.ConnectionClosedOK`;
+ * :meth:`~WebSocketServerProtocol.wait_closed`: when the connection is
+ closed, it returns.
+
+ Then the connection handler is immediately notified of the shutdown;
+ it can clean up and exit.
+
+ """
+ await asyncio.shield(self.closed_waiter)
+
+ def get_loop(self) -> asyncio.AbstractEventLoop:
+ """
+ See :meth:`asyncio.Server.get_loop`.
+
+ """
+ return self.server.get_loop()
+
+ def is_serving(self) -> bool:
+ """
+ See :meth:`asyncio.Server.is_serving`.
+
+ """
+ return self.server.is_serving()
+
+ async def start_serving(self) -> None: # pragma: no cover
+ """
+ See :meth:`asyncio.Server.start_serving`.
+
+ Typical use::
+
+ server = await serve(..., start_serving=False)
+ # perform additional setup here...
+ # ... then start the server
+ await server.start_serving()
+
+ """
+ await self.server.start_serving()
+
+ async def serve_forever(self) -> None: # pragma: no cover
+ """
+ See :meth:`asyncio.Server.serve_forever`.
+
+ Typical use::
+
+ server = await serve(...)
+ # this coroutine doesn't return
+ # canceling it stops the server
+ await server.serve_forever()
+
+ This is an alternative to using :func:`serve` as an asynchronous context
+ manager. Shutdown is triggered by canceling :meth:`serve_forever`
+ instead of exiting a :func:`serve` context.
+
+ """
+ await self.server.serve_forever()
+
+ @property
+ def sockets(self) -> Iterable[socket.socket]:
+ """
+ See :attr:`asyncio.Server.sockets`.
+
+ """
+ return self.server.sockets
+
+ async def __aenter__(self) -> WebSocketServer: # pragma: no cover
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None: # pragma: no cover
+ self.close()
+ await self.wait_closed()
+
+
+class Serve:
+ """
+ Start a WebSocket server listening on ``host`` and ``port``.
+
+ Whenever a client connects, the server creates a
+ :class:`WebSocketServerProtocol`, performs the opening handshake, and
+ delegates to the connection handler, ``ws_handler``.
+
+ The handler receives the :class:`WebSocketServerProtocol` and uses it to
+ send and receive messages.
+
+ Once the handler completes, either normally or with an exception, the
+ server performs the closing handshake and closes the connection.
+
+ Awaiting :func:`serve` yields a :class:`WebSocketServer`. This object
+ provides a :meth:`~WebSocketServer.close` method to shut down the server::
+
+ stop = asyncio.Future() # set this future to exit the server
+
+ server = await serve(...)
+ await stop
+ await server.close()
+
+ :func:`serve` can be used as an asynchronous context manager. Then, the
+ server is shut down automatically when exiting the context::
+
+ stop = asyncio.Future() # set this future to exit the server
+
+ async with serve(...):
+ await stop
+
+ Args:
+ ws_handler: Connection handler. It receives the WebSocket connection,
+ which is a :class:`WebSocketServerProtocol`, in argument.
+ host: Network interfaces the server binds to.
+ See :meth:`~asyncio.loop.create_server` for details.
+ port: TCP port the server listens on.
+ See :meth:`~asyncio.loop.create_server` for details.
+ create_protocol: Factory for the :class:`asyncio.Protocol` managing
+ the connection. It defaults to :class:`WebSocketServerProtocol`.
+ Set it to a wrapper or a subclass to customize connection handling.
+ logger: Logger for this server.
+ It defaults to ``logging.getLogger("websockets.server")``.
+ See the :doc:`logging guide <../../topics/logging>` for details.
+ compression: The "permessage-deflate" extension is enabled by default.
+ Set ``compression`` to :obj:`None` to disable it. See the
+ :doc:`compression guide <../../topics/compression>` for details.
+ origins: Acceptable values of the ``Origin`` header, for defending
+ against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
+ in the list if the lack of an origin is acceptable.
+ extensions: List of supported extensions, in order in which they
+ should be negotiated and run.
+ subprotocols: List of supported subprotocols, in order of decreasing
+ preference.
+ extra_headers (Union[HeadersLike, Callable[[str, Headers], HeadersLike]]):
+ Arbitrary HTTP headers to add to the response. This can be
+ a :data:`~websockets.datastructures.HeadersLike` or a callable
+ taking the request path and headers in arguments and returning
+ a :data:`~websockets.datastructures.HeadersLike`.
+ server_header: Value of the ``Server`` response header.
+ It defaults to ``"Python/x.y.z websockets/X.Y"``.
+ Setting it to :obj:`None` removes the header.
+ process_request (Optional[Callable[[str, Headers], \
+ Awaitable[Optional[Tuple[StatusLike, HeadersLike, bytes]]]]]):
+ Intercept HTTP request before the opening handshake.
+ See :meth:`~WebSocketServerProtocol.process_request` for details.
+ select_subprotocol: Select a subprotocol supported by the client.
+ See :meth:`~WebSocketServerProtocol.select_subprotocol` for details.
+ open_timeout: Timeout for opening connections in seconds.
+ :obj:`None` disables the timeout.
+
+ See :class:`~websockets.legacy.protocol.WebSocketCommonProtocol` for the
+ documentation of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
+ ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit``.
+
+ Any other keyword arguments are passed the event loop's
+ :meth:`~asyncio.loop.create_server` method.
+
+ For example:
+
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enable TLS.
+
+ * You can set ``sock`` to a :obj:`~socket.socket` that you created
+ outside of websockets.
+
+ Returns:
+ WebSocketServer: WebSocket server.
+
+ """
+
+ def __init__(
+ self,
+ ws_handler: Union[
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
+ ],
+ host: Optional[Union[str, Sequence[str]]] = None,
+ port: Optional[int] = None,
+ *,
+ create_protocol: Optional[Callable[..., WebSocketServerProtocol]] = None,
+ logger: Optional[LoggerLike] = None,
+ compression: Optional[str] = "deflate",
+ origins: Optional[Sequence[Optional[Origin]]] = None,
+ extensions: Optional[Sequence[ServerExtensionFactory]] = None,
+ subprotocols: Optional[Sequence[Subprotocol]] = None,
+ extra_headers: Optional[HeadersLikeOrCallable] = None,
+ server_header: Optional[str] = USER_AGENT,
+ process_request: Optional[
+ Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
+ ] = None,
+ select_subprotocol: Optional[
+ Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
+ ] = None,
+ open_timeout: Optional[float] = 10,
+ ping_interval: Optional[float] = 20,
+ ping_timeout: Optional[float] = 20,
+ close_timeout: Optional[float] = None,
+ max_size: Optional[int] = 2**20,
+ max_queue: Optional[int] = 2**5,
+ read_limit: int = 2**16,
+ write_limit: int = 2**16,
+ **kwargs: Any,
+ ) -> None:
+ # Backwards compatibility: close_timeout used to be called timeout.
+ timeout: Optional[float] = kwargs.pop("timeout", None)
+ if timeout is None:
+ timeout = 10
+ else:
+ warnings.warn("rename timeout to close_timeout", DeprecationWarning)
+ # If both are specified, timeout is ignored.
+ if close_timeout is None:
+ close_timeout = timeout
+
+ # Backwards compatibility: create_protocol used to be called klass.
+ klass: Optional[Type[WebSocketServerProtocol]] = kwargs.pop("klass", None)
+ if klass is None:
+ klass = WebSocketServerProtocol
+ else:
+ warnings.warn("rename klass to create_protocol", DeprecationWarning)
+ # If both are specified, klass is ignored.
+ if create_protocol is None:
+ create_protocol = klass
+
+ # Backwards compatibility: recv() used to return None on closed connections
+ legacy_recv: bool = kwargs.pop("legacy_recv", False)
+
+ # Backwards compatibility: the loop parameter used to be supported.
+ _loop: Optional[asyncio.AbstractEventLoop] = kwargs.pop("loop", None)
+ if _loop is None:
+ loop = asyncio.get_event_loop()
+ else:
+ loop = _loop
+ warnings.warn("remove loop argument", DeprecationWarning)
+
+ ws_server = WebSocketServer(logger=logger)
+
+ secure = kwargs.get("ssl") is not None
+
+ if compression == "deflate":
+ extensions = enable_server_permessage_deflate(extensions)
+ elif compression is not None:
+ raise ValueError(f"unsupported compression: {compression}")
+
+ if subprotocols is not None:
+ validate_subprotocols(subprotocols)
+
+ factory = functools.partial(
+ create_protocol,
+ # For backwards compatibility with 10.0 or earlier. Done here in
+ # addition to WebSocketServerProtocol to trigger the deprecation
+ # warning once per serve() call rather than once per connection.
+ remove_path_argument(ws_handler),
+ ws_server,
+ host=host,
+ port=port,
+ secure=secure,
+ open_timeout=open_timeout,
+ ping_interval=ping_interval,
+ ping_timeout=ping_timeout,
+ close_timeout=close_timeout,
+ max_size=max_size,
+ max_queue=max_queue,
+ read_limit=read_limit,
+ write_limit=write_limit,
+ loop=_loop,
+ legacy_recv=legacy_recv,
+ origins=origins,
+ extensions=extensions,
+ subprotocols=subprotocols,
+ extra_headers=extra_headers,
+ server_header=server_header,
+ process_request=process_request,
+ select_subprotocol=select_subprotocol,
+ logger=logger,
+ )
+
+ if kwargs.pop("unix", False):
+ path: Optional[str] = kwargs.pop("path", None)
+ # unix_serve(path) must not specify host and port parameters.
+ assert host is None and port is None
+ create_server = functools.partial(
+ loop.create_unix_server, factory, path, **kwargs
+ )
+ else:
+ create_server = functools.partial(
+ loop.create_server, factory, host, port, **kwargs
+ )
+
+ # This is a coroutine function.
+ self._create_server = create_server
+ self.ws_server = ws_server
+
+ # async with serve(...)
+
+ async def __aenter__(self) -> WebSocketServer:
+ return await self
+
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> None:
+ self.ws_server.close()
+ await self.ws_server.wait_closed()
+
+ # await serve(...)
+
+ def __await__(self) -> Generator[Any, None, WebSocketServer]:
+ # Create a suitable iterator by calling __await__ on a coroutine.
+ return self.__await_impl__().__await__()
+
+ async def __await_impl__(self) -> WebSocketServer:
+ server = await self._create_server()
+ self.ws_server.wrap(server)
+ return self.ws_server
+
+ # yield from serve(...) - remove when dropping Python < 3.10
+
+ __iter__ = __await__
+
+
+serve = Serve
+
+
+def unix_serve(
+ ws_handler: Union[
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]], # deprecated
+ ],
+ path: Optional[str] = None,
+ **kwargs: Any,
+) -> Serve:
+ """
+ Start a WebSocket server listening on a Unix socket.
+
+ This function is identical to :func:`serve`, except the ``host`` and
+ ``port`` arguments are replaced by ``path``. It is only available on Unix.
+
+ Unrecognized keyword arguments are passed the event loop's
+ :meth:`~asyncio.loop.create_unix_server` method.
+
+ It's useful for deploying a server behind a reverse proxy such as nginx.
+
+ Args:
+ path: File system path to the Unix socket.
+
+ """
+ return serve(ws_handler, path=path, unix=True, **kwargs)
+
+
+def remove_path_argument(
+ ws_handler: Union[
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
+ ]
+) -> Callable[[WebSocketServerProtocol], Awaitable[Any]]:
+ try:
+ inspect.signature(ws_handler).bind(None)
+ except TypeError:
+ try:
+ inspect.signature(ws_handler).bind(None, "")
+ except TypeError: # pragma: no cover
+ # ws_handler accepts neither one nor two arguments; leave it alone.
+ pass
+ else:
+ # ws_handler accepts two arguments; activate backwards compatibility.
+
+ # Enable deprecation warning and announce deprecation in 11.0.
+ # warnings.warn("remove second argument of ws_handler", DeprecationWarning)
+
+ async def _ws_handler(websocket: WebSocketServerProtocol) -> Any:
+ return await cast(
+ Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
+ ws_handler,
+ )(websocket, websocket.path)
+
+ return _ws_handler
+
+ return cast(
+ Callable[[WebSocketServerProtocol], Awaitable[Any]],
+ ws_handler,
+ )