summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/anyio/streams/tls.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/anyio/streams/tls.py')
-rw-r--r--venv/lib/python3.11/site-packages/anyio/streams/tls.py338
1 files changed, 0 insertions, 338 deletions
diff --git a/venv/lib/python3.11/site-packages/anyio/streams/tls.py b/venv/lib/python3.11/site-packages/anyio/streams/tls.py
deleted file mode 100644
index e913eed..0000000
--- a/venv/lib/python3.11/site-packages/anyio/streams/tls.py
+++ /dev/null
@@ -1,338 +0,0 @@
-from __future__ import annotations
-
-import logging
-import re
-import ssl
-import sys
-from collections.abc import Callable, Mapping
-from dataclasses import dataclass
-from functools import wraps
-from typing import Any, Tuple, TypeVar
-
-from .. import (
- BrokenResourceError,
- EndOfStream,
- aclose_forcefully,
- get_cancelled_exc_class,
-)
-from .._core._typedattr import TypedAttributeSet, typed_attribute
-from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup
-
-if sys.version_info >= (3, 11):
- from typing import TypeVarTuple, Unpack
-else:
- from typing_extensions import TypeVarTuple, Unpack
-
-T_Retval = TypeVar("T_Retval")
-PosArgsT = TypeVarTuple("PosArgsT")
-_PCTRTT = Tuple[Tuple[str, str], ...]
-_PCTRTTT = Tuple[_PCTRTT, ...]
-
-
-class TLSAttribute(TypedAttributeSet):
- """Contains Transport Layer Security related attributes."""
-
- #: the selected ALPN protocol
- alpn_protocol: str | None = typed_attribute()
- #: the channel binding for type ``tls-unique``
- channel_binding_tls_unique: bytes = typed_attribute()
- #: the selected cipher
- cipher: tuple[str, str, int] = typed_attribute()
- #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
- # for more information)
- peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
- #: the peer certificate in binary form
- peer_certificate_binary: bytes | None = typed_attribute()
- #: ``True`` if this is the server side of the connection
- server_side: bool = typed_attribute()
- #: ciphers shared by the client during the TLS handshake (``None`` if this is the
- #: client side)
- shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
- #: the :class:`~ssl.SSLObject` used for encryption
- ssl_object: ssl.SSLObject = typed_attribute()
- #: ``True`` if this stream does (and expects) a closing TLS handshake when the
- #: stream is being closed
- standard_compatible: bool = typed_attribute()
- #: the TLS protocol version (e.g. ``TLSv1.2``)
- tls_version: str = typed_attribute()
-
-
-@dataclass(eq=False)
-class TLSStream(ByteStream):
- """
- A stream wrapper that encrypts all sent data and decrypts received data.
-
- This class has no public initializer; use :meth:`wrap` instead.
- All extra attributes from :class:`~TLSAttribute` are supported.
-
- :var AnyByteStream transport_stream: the wrapped stream
-
- """
-
- transport_stream: AnyByteStream
- standard_compatible: bool
- _ssl_object: ssl.SSLObject
- _read_bio: ssl.MemoryBIO
- _write_bio: ssl.MemoryBIO
-
- @classmethod
- async def wrap(
- cls,
- transport_stream: AnyByteStream,
- *,
- server_side: bool | None = None,
- hostname: str | None = None,
- ssl_context: ssl.SSLContext | None = None,
- standard_compatible: bool = True,
- ) -> TLSStream:
- """
- Wrap an existing stream with Transport Layer Security.
-
- This performs a TLS handshake with the peer.
-
- :param transport_stream: a bytes-transporting stream to wrap
- :param server_side: ``True`` if this is the server side of the connection,
- ``False`` if this is the client side (if omitted, will be set to ``False``
- if ``hostname`` has been provided, ``False`` otherwise). Used only to create
- a default context when an explicit context has not been provided.
- :param hostname: host name of the peer (if host name checking is desired)
- :param ssl_context: the SSLContext object to use (if not provided, a secure
- default will be created)
- :param standard_compatible: if ``False``, skip the closing handshake when
- closing the connection, and don't raise an exception if the peer does the
- same
- :raises ~ssl.SSLError: if the TLS handshake fails
-
- """
- if server_side is None:
- server_side = not hostname
-
- if not ssl_context:
- purpose = (
- ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
- )
- ssl_context = ssl.create_default_context(purpose)
-
- # Re-enable detection of unexpected EOFs if it was disabled by Python
- if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
- ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
-
- bio_in = ssl.MemoryBIO()
- bio_out = ssl.MemoryBIO()
- ssl_object = ssl_context.wrap_bio(
- bio_in, bio_out, server_side=server_side, server_hostname=hostname
- )
- wrapper = cls(
- transport_stream=transport_stream,
- standard_compatible=standard_compatible,
- _ssl_object=ssl_object,
- _read_bio=bio_in,
- _write_bio=bio_out,
- )
- await wrapper._call_sslobject_method(ssl_object.do_handshake)
- return wrapper
-
- async def _call_sslobject_method(
- self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
- ) -> T_Retval:
- while True:
- try:
- result = func(*args)
- except ssl.SSLWantReadError:
- try:
- # Flush any pending writes first
- if self._write_bio.pending:
- await self.transport_stream.send(self._write_bio.read())
-
- data = await self.transport_stream.receive()
- except EndOfStream:
- self._read_bio.write_eof()
- except OSError as exc:
- self._read_bio.write_eof()
- self._write_bio.write_eof()
- raise BrokenResourceError from exc
- else:
- self._read_bio.write(data)
- except ssl.SSLWantWriteError:
- await self.transport_stream.send(self._write_bio.read())
- except ssl.SSLSyscallError as exc:
- self._read_bio.write_eof()
- self._write_bio.write_eof()
- raise BrokenResourceError from exc
- except ssl.SSLError as exc:
- self._read_bio.write_eof()
- self._write_bio.write_eof()
- if (
- isinstance(exc, ssl.SSLEOFError)
- or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
- ):
- if self.standard_compatible:
- raise BrokenResourceError from exc
- else:
- raise EndOfStream from None
-
- raise
- else:
- # Flush any pending writes first
- if self._write_bio.pending:
- await self.transport_stream.send(self._write_bio.read())
-
- return result
-
- async def unwrap(self) -> tuple[AnyByteStream, bytes]:
- """
- Does the TLS closing handshake.
-
- :return: a tuple of (wrapped byte stream, bytes left in the read buffer)
-
- """
- await self._call_sslobject_method(self._ssl_object.unwrap)
- self._read_bio.write_eof()
- self._write_bio.write_eof()
- return self.transport_stream, self._read_bio.read()
-
- async def aclose(self) -> None:
- if self.standard_compatible:
- try:
- await self.unwrap()
- except BaseException:
- await aclose_forcefully(self.transport_stream)
- raise
-
- await self.transport_stream.aclose()
-
- async def receive(self, max_bytes: int = 65536) -> bytes:
- data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
- if not data:
- raise EndOfStream
-
- return data
-
- async def send(self, item: bytes) -> None:
- await self._call_sslobject_method(self._ssl_object.write, item)
-
- async def send_eof(self) -> None:
- tls_version = self.extra(TLSAttribute.tls_version)
- match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
- if match:
- major, minor = int(match.group(1)), int(match.group(2) or 0)
- if (major, minor) < (1, 3):
- raise NotImplementedError(
- f"send_eof() requires at least TLSv1.3; current "
- f"session uses {tls_version}"
- )
-
- raise NotImplementedError(
- "send_eof() has not yet been implemented for TLS streams"
- )
-
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- return {
- **self.transport_stream.extra_attributes,
- TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
- TLSAttribute.channel_binding_tls_unique: (
- self._ssl_object.get_channel_binding
- ),
- TLSAttribute.cipher: self._ssl_object.cipher,
- TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
- TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
- True
- ),
- TLSAttribute.server_side: lambda: self._ssl_object.server_side,
- TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
- if self._ssl_object.server_side
- else None,
- TLSAttribute.standard_compatible: lambda: self.standard_compatible,
- TLSAttribute.ssl_object: lambda: self._ssl_object,
- TLSAttribute.tls_version: self._ssl_object.version,
- }
-
-
-@dataclass(eq=False)
-class TLSListener(Listener[TLSStream]):
- """
- A convenience listener that wraps another listener and auto-negotiates a TLS session
- on every accepted connection.
-
- If the TLS handshake times out or raises an exception,
- :meth:`handle_handshake_error` is called to do whatever post-mortem processing is
- deemed necessary.
-
- Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
-
- :param Listener listener: the listener to wrap
- :param ssl_context: the SSL context object
- :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
- :param handshake_timeout: time limit for the TLS handshake
- (passed to :func:`~anyio.fail_after`)
- """
-
- listener: Listener[Any]
- ssl_context: ssl.SSLContext
- standard_compatible: bool = True
- handshake_timeout: float = 30
-
- @staticmethod
- async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
- """
- Handle an exception raised during the TLS handshake.
-
- This method does 3 things:
-
- #. Forcefully closes the original stream
- #. Logs the exception (unless it was a cancellation exception) using the
- ``anyio.streams.tls`` logger
- #. Reraises the exception if it was a base exception or a cancellation exception
-
- :param exc: the exception
- :param stream: the original stream
-
- """
- await aclose_forcefully(stream)
-
- # Log all except cancellation exceptions
- if not isinstance(exc, get_cancelled_exc_class()):
- # CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
- # any asyncio implementation, so we explicitly pass the exception to log
- # (https://github.com/python/cpython/issues/108668). Trio does not have this
- # issue because it works around the CPython bug.
- logging.getLogger(__name__).exception(
- "Error during TLS handshake", exc_info=exc
- )
-
- # Only reraise base exceptions and cancellation exceptions
- if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
- raise
-
- async def serve(
- self,
- handler: Callable[[TLSStream], Any],
- task_group: TaskGroup | None = None,
- ) -> None:
- @wraps(handler)
- async def handler_wrapper(stream: AnyByteStream) -> None:
- from .. import fail_after
-
- try:
- with fail_after(self.handshake_timeout):
- wrapped_stream = await TLSStream.wrap(
- stream,
- ssl_context=self.ssl_context,
- standard_compatible=self.standard_compatible,
- )
- except BaseException as exc:
- await self.handle_handshake_error(exc, stream)
- else:
- await handler(wrapped_stream)
-
- await self.listener.serve(handler_wrapper, task_group)
-
- async def aclose(self) -> None:
- await self.listener.aclose()
-
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- return {
- TLSAttribute.standard_compatible: lambda: self.standard_compatible,
- }