from __future__ import annotations import io import itertools import sys import typing from .._models import Request, Response from .._types import SyncByteStream from .base import BaseTransport if typing.TYPE_CHECKING: from _typeshed import OptExcInfo # pragma: no cover from _typeshed.wsgi import WSGIApplication # pragma: no cover _T = typing.TypeVar("_T") def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: body = iter(body) for chunk in body: if chunk: return itertools.chain([chunk], body) return [] class WSGIByteStream(SyncByteStream): def __init__(self, result: typing.Iterable[bytes]) -> None: self._close = getattr(result, "close", None) self._result = _skip_leading_empty_chunks(result) def __iter__(self) -> typing.Iterator[bytes]: for part in self._result: yield part def close(self) -> None: if self._close is not None: self._close() class WSGITransport(BaseTransport): """ A custom transport that handles sending requests directly to an WSGI app. The simplest way to use this functionality is to use the `app` argument. ``` client = httpx.Client(app=app) ``` Alternatively, you can setup the transport instance explicitly. This allows you to include any additional configuration arguments specific to the WSGITransport class: ``` transport = httpx.WSGITransport( app=app, script_name="/submount", remote_addr="1.2.3.4" ) client = httpx.Client(transport=transport) ``` Arguments: * `app` - The WSGI application. * `raise_app_exceptions` - Boolean indicating if exceptions in the application should be raised. Default to `True`. Can be set to `False` for use cases such as testing the content of a client 500 response. * `script_name` - The root path on which the WSGI application should be mounted. * `remote_addr` - A string indicating the client IP of incoming requests. ``` """ def __init__( self, app: WSGIApplication, raise_app_exceptions: bool = True, script_name: str = "", remote_addr: str = "127.0.0.1", wsgi_errors: typing.TextIO | None = None, ) -> None: self.app = app self.raise_app_exceptions = raise_app_exceptions self.script_name = script_name self.remote_addr = remote_addr self.wsgi_errors = wsgi_errors def handle_request(self, request: Request) -> Response: request.read() wsgi_input = io.BytesIO(request.content) port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] environ = { "wsgi.version": (1, 0), "wsgi.url_scheme": request.url.scheme, "wsgi.input": wsgi_input, "wsgi.errors": self.wsgi_errors or sys.stderr, "wsgi.multithread": True, "wsgi.multiprocess": False, "wsgi.run_once": False, "REQUEST_METHOD": request.method, "SCRIPT_NAME": self.script_name, "PATH_INFO": request.url.path, "QUERY_STRING": request.url.query.decode("ascii"), "SERVER_NAME": request.url.host, "SERVER_PORT": str(port), "SERVER_PROTOCOL": "HTTP/1.1", "REMOTE_ADDR": self.remote_addr, } for header_key, header_value in request.headers.raw: key = header_key.decode("ascii").upper().replace("-", "_") if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): key = "HTTP_" + key environ[key] = header_value.decode("ascii") seen_status = None seen_response_headers = None seen_exc_info = None def start_response( status: str, response_headers: list[tuple[str, str]], exc_info: OptExcInfo | None = None, ) -> typing.Callable[[bytes], typing.Any]: nonlocal seen_status, seen_response_headers, seen_exc_info seen_status = status seen_response_headers = response_headers seen_exc_info = exc_info return lambda _: None result = self.app(environ, start_response) stream = WSGIByteStream(result) assert seen_status is not None assert seen_response_headers is not None if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: raise seen_exc_info[1] status_code = int(seen_status.split()[0]) headers = [ (key.encode("ascii"), value.encode("ascii")) for key, value in seen_response_headers ] return Response(status_code, headers=headers, stream=stream)