summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/anyio/streams/tls.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/anyio/streams/tls.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/anyio/streams/tls.py')
-rw-r--r--venv/lib/python3.11/site-packages/anyio/streams/tls.py338
1 files changed, 338 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/anyio/streams/tls.py b/venv/lib/python3.11/site-packages/anyio/streams/tls.py
new file mode 100644
index 0000000..e913eed
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/anyio/streams/tls.py
@@ -0,0 +1,338 @@
+from __future__ import annotations
+
+import logging
+import re
+import ssl
+import sys
+from collections.abc import Callable, Mapping
+from dataclasses import dataclass
+from functools import wraps
+from typing import Any, Tuple, TypeVar
+
+from .. import (
+ BrokenResourceError,
+ EndOfStream,
+ aclose_forcefully,
+ get_cancelled_exc_class,
+)
+from .._core._typedattr import TypedAttributeSet, typed_attribute
+from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
+
+if sys.version_info >= (3, 11):
+ from typing import TypeVarTuple, Unpack
+else:
+ from typing_extensions import TypeVarTuple, Unpack
+
+T_Retval = TypeVar("T_Retval")
+PosArgsT = TypeVarTuple("PosArgsT")
+_PCTRTT = Tuple[Tuple[str, str], ...]
+_PCTRTTT = Tuple[_PCTRTT, ...]
+
+
+class TLSAttribute(TypedAttributeSet):
+ """Contains Transport Layer Security related attributes."""
+
+ #: the selected ALPN protocol
+ alpn_protocol: str | None = typed_attribute()
+ #: the channel binding for type ``tls-unique``
+ channel_binding_tls_unique: bytes = typed_attribute()
+ #: the selected cipher
+ cipher: tuple[str, str, int] = typed_attribute()
+ #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
+ # for more information)
+ peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
+ #: the peer certificate in binary form
+ peer_certificate_binary: bytes | None = typed_attribute()
+ #: ``True`` if this is the server side of the connection
+ server_side: bool = typed_attribute()
+ #: ciphers shared by the client during the TLS handshake (``None`` if this is the
+ #: client side)
+ shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
+ #: the :class:`~ssl.SSLObject` used for encryption
+ ssl_object: ssl.SSLObject = typed_attribute()
+ #: ``True`` if this stream does (and expects) a closing TLS handshake when the
+ #: stream is being closed
+ standard_compatible: bool = typed_attribute()
+ #: the TLS protocol version (e.g. ``TLSv1.2``)
+ tls_version: str = typed_attribute()
+
+
+@dataclass(eq=False)
+class TLSStream(ByteStream):
+ """
+ A stream wrapper that encrypts all sent data and decrypts received data.
+
+ This class has no public initializer; use :meth:`wrap` instead.
+ All extra attributes from :class:`~TLSAttribute` are supported.
+
+ :var AnyByteStream transport_stream: the wrapped stream
+
+ """
+
+ transport_stream: AnyByteStream
+ standard_compatible: bool
+ _ssl_object: ssl.SSLObject
+ _read_bio: ssl.MemoryBIO
+ _write_bio: ssl.MemoryBIO
+
+ @classmethod
+ async def wrap(
+ cls,
+ transport_stream: AnyByteStream,
+ *,
+ server_side: bool | None = None,
+ hostname: str | None = None,
+ ssl_context: ssl.SSLContext | None = None,
+ standard_compatible: bool = True,
+ ) -> TLSStream:
+ """
+ Wrap an existing stream with Transport Layer Security.
+
+ This performs a TLS handshake with the peer.
+
+ :param transport_stream: a bytes-transporting stream to wrap
+ :param server_side: ``True`` if this is the server side of the connection,
+ ``False`` if this is the client side (if omitted, will be set to ``False``
+ if ``hostname`` has been provided, ``False`` otherwise). Used only to create
+ a default context when an explicit context has not been provided.
+ :param hostname: host name of the peer (if host name checking is desired)
+ :param ssl_context: the SSLContext object to use (if not provided, a secure
+ default will be created)
+ :param standard_compatible: if ``False``, skip the closing handshake when
+ closing the connection, and don't raise an exception if the peer does the
+ same
+ :raises ~ssl.SSLError: if the TLS handshake fails
+
+ """
+ if server_side is None:
+ server_side = not hostname
+
+ if not ssl_context:
+ purpose = (
+ ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
+ )
+ ssl_context = ssl.create_default_context(purpose)
+
+ # Re-enable detection of unexpected EOFs if it was disabled by Python
+ if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
+ ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
+
+ bio_in = ssl.MemoryBIO()
+ bio_out = ssl.MemoryBIO()
+ ssl_object = ssl_context.wrap_bio(
+ bio_in, bio_out, server_side=server_side, server_hostname=hostname
+ )
+ wrapper = cls(
+ transport_stream=transport_stream,
+ standard_compatible=standard_compatible,
+ _ssl_object=ssl_object,
+ _read_bio=bio_in,
+ _write_bio=bio_out,
+ )
+ await wrapper._call_sslobject_method(ssl_object.do_handshake)
+ return wrapper
+
+ async def _call_sslobject_method(
+ self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
+ ) -> T_Retval:
+ while True:
+ try:
+ result = func(*args)
+ except ssl.SSLWantReadError:
+ try:
+ # Flush any pending writes first
+ if self._write_bio.pending:
+ await self.transport_stream.send(self._write_bio.read())
+
+ data = await self.transport_stream.receive()
+ except EndOfStream:
+ self._read_bio.write_eof()
+ except OSError as exc:
+ self._read_bio.write_eof()
+ self._write_bio.write_eof()
+ raise BrokenResourceError from exc
+ else:
+ self._read_bio.write(data)
+ except ssl.SSLWantWriteError:
+ await self.transport_stream.send(self._write_bio.read())
+ except ssl.SSLSyscallError as exc:
+ self._read_bio.write_eof()
+ self._write_bio.write_eof()
+ raise BrokenResourceError from exc
+ except ssl.SSLError as exc:
+ self._read_bio.write_eof()
+ self._write_bio.write_eof()
+ if (
+ isinstance(exc, ssl.SSLEOFError)
+ or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
+ ):
+ if self.standard_compatible:
+ raise BrokenResourceError from exc
+ else:
+ raise EndOfStream from None
+
+ raise
+ else:
+ # Flush any pending writes first
+ if self._write_bio.pending:
+ await self.transport_stream.send(self._write_bio.read())
+
+ return result
+
+ async def unwrap(self) -> tuple[AnyByteStream, bytes]:
+ """
+ Does the TLS closing handshake.
+
+ :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
+
+ """
+ await self._call_sslobject_method(self._ssl_object.unwrap)
+ self._read_bio.write_eof()
+ self._write_bio.write_eof()
+ return self.transport_stream, self._read_bio.read()
+
+ async def aclose(self) -> None:
+ if self.standard_compatible:
+ try:
+ await self.unwrap()
+ except BaseException:
+ await aclose_forcefully(self.transport_stream)
+ raise
+
+ await self.transport_stream.aclose()
+
+ async def receive(self, max_bytes: int = 65536) -> bytes:
+ data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
+ if not data:
+ raise EndOfStream
+
+ return data
+
+ async def send(self, item: bytes) -> None:
+ await self._call_sslobject_method(self._ssl_object.write, item)
+
+ async def send_eof(self) -> None:
+ tls_version = self.extra(TLSAttribute.tls_version)
+ match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
+ if match:
+ major, minor = int(match.group(1)), int(match.group(2) or 0)
+ if (major, minor) < (1, 3):
+ raise NotImplementedError(
+ f"send_eof() requires at least TLSv1.3; current "
+ f"session uses {tls_version}"
+ )
+
+ raise NotImplementedError(
+ "send_eof() has not yet been implemented for TLS streams"
+ )
+
+ @property
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return {
+ **self.transport_stream.extra_attributes,
+ TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
+ TLSAttribute.channel_binding_tls_unique: (
+ self._ssl_object.get_channel_binding
+ ),
+ TLSAttribute.cipher: self._ssl_object.cipher,
+ TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
+ TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
+ True
+ ),
+ TLSAttribute.server_side: lambda: self._ssl_object.server_side,
+ TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
+ if self._ssl_object.server_side
+ else None,
+ TLSAttribute.standard_compatible: lambda: self.standard_compatible,
+ TLSAttribute.ssl_object: lambda: self._ssl_object,
+ TLSAttribute.tls_version: self._ssl_object.version,
+ }
+
+
+@dataclass(eq=False)
+class TLSListener(Listener[TLSStream]):
+ """
+ A convenience listener that wraps another listener and auto-negotiates a TLS session
+ on every accepted connection.
+
+ If the TLS handshake times out or raises an exception,
+ :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
+ deemed necessary.
+
+ Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
+
+ :param Listener listener: the listener to wrap
+ :param ssl_context: the SSL context object
+ :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
+ :param handshake_timeout: time limit for the TLS handshake
+ (passed to :func:`~anyio.fail_after`)
+ """
+
+ listener: Listener[Any]
+ ssl_context: ssl.SSLContext
+ standard_compatible: bool = True
+ handshake_timeout: float = 30
+
+ @staticmethod
+ async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
+ """
+ Handle an exception raised during the TLS handshake.
+
+ This method does 3 things:
+
+ #. Forcefully closes the original stream
+ #. Logs the exception (unless it was a cancellation exception) using the
+ ``anyio.streams.tls`` logger
+ #. Reraises the exception if it was a base exception or a cancellation exception
+
+ :param exc: the exception
+ :param stream: the original stream
+
+ """
+ await aclose_forcefully(stream)
+
+ # Log all except cancellation exceptions
+ if not isinstance(exc, get_cancelled_exc_class()):
+ # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
+ # any asyncio implementation, so we explicitly pass the exception to log
+ # (https://github.com/python/cpython/issues/108668). Trio does not have this
+ # issue because it works around the CPython bug.
+ logging.getLogger(__name__).exception(
+ "Error during TLS handshake", exc_info=exc
+ )
+
+ # Only reraise base exceptions and cancellation exceptions
+ if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
+ raise
+
+ async def serve(
+ self,
+ handler: Callable[[TLSStream], Any],
+ task_group: TaskGroup | None = None,
+ ) -> None:
+ @wraps(handler)
+ async def handler_wrapper(stream: AnyByteStream) -> None:
+ from .. import fail_after
+
+ try:
+ with fail_after(self.handshake_timeout):
+ wrapped_stream = await TLSStream.wrap(
+ stream,
+ ssl_context=self.ssl_context,
+ standard_compatible=self.standard_compatible,
+ )
+ except BaseException as exc:
+ await self.handle_handshake_error(exc, stream)
+ else:
+ await handler(wrapped_stream)
+
+ await self.listener.serve(handler_wrapper, task_group)
+
+ async def aclose(self) -> None:
+ await self.listener.aclose()
+
+ @property
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return {
+ TLSAttribute.standard_compatible: lambda: self.standard_compatible,
+ }