diff --git a/Modules/urllib3/contrib/__init__.py b/Modules/urllib3/contrib/__init__.py index 8b13789..e69de29 100644 --- a/Modules/urllib3/contrib/__init__.py +++ b/Modules/urllib3/contrib/__init__.py @@ -1 +0,0 @@ - diff --git a/Modules/urllib3/contrib/pyopenssl.py b/Modules/urllib3/contrib/pyopenssl.py new file mode 100644 index 0000000..3987d63 --- /dev/null +++ b/Modules/urllib3/contrib/pyopenssl.py @@ -0,0 +1,548 @@ +""" +Module for using pyOpenSSL as a TLS backend. This module was relevant before +the standard library ``ssl`` module supported SNI, but now that we've dropped +support for Python 2.7 all relevant Python versions support SNI so +**this module is no longer recommended**. + +This needs the following packages installed: + +* `pyOpenSSL`_ (tested with 16.0.0) +* `cryptography`_ (minimum 1.3.4, from pyopenssl) +* `idna`_ (minimum 2.0) + +However, pyOpenSSL depends on cryptography, so while we use all three directly here we +end up having relatively few packages required. + +You can install them with the following command: + +.. code-block:: bash + + $ python -m pip install pyopenssl cryptography idna + +To activate certificate checking, call +:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code +before you begin making HTTP requests. This can be done in a ``sitecustomize`` +module, or at any other time before your application begins using ``urllib3``, +like this: + +.. code-block:: python + + try: + import urllib3.contrib.pyopenssl + urllib3.contrib.pyopenssl.inject_into_urllib3() + except ImportError: + pass + +.. _pyopenssl: https://www.pyopenssl.org +.. _cryptography: https://cryptography.io +.. _idna: https://github.com/kjd/idna +""" + +from __future__ import annotations + +import OpenSSL.SSL # type: ignore[import] +from cryptography import x509 + +try: + from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined] +except ImportError: + # UnsupportedExtension is gone in cryptography >= 2.1.0 + class UnsupportedExtension(Exception): # type: ignore[no-redef] + pass + + +import logging +import ssl +import typing +from io import BytesIO +from socket import socket as socket_cls +from socket import timeout + +from .. import util + +if typing.TYPE_CHECKING: + from OpenSSL.crypto import X509 # type: ignore[import] + + +__all__ = ["inject_into_urllib3", "extract_from_urllib3"] + +# Map from urllib3 to PyOpenSSL compatible parameter-values. +_openssl_versions = { + util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] + util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] + ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, +} + +if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): + _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD + +if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"): + _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD + + +_stdlib_to_openssl_verify = { + ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, + ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, + ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER + + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} +_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()} + +# The SSLvX values are the most likely to be missing in the future +# but we check them all just to be sure. +_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr( + OpenSSL.SSL, "OP_NO_SSLv3", 0 +) +_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0) +_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0) +_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0) +_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0) + +_openssl_to_ssl_minimum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1, + ssl.TLSVersion.TLSv1_3: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), + ssl.TLSVersion.MAXIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), +} +_openssl_to_ssl_maximum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 + | _OP_NO_TLSv1 + | _OP_NO_TLSv1_1 + | _OP_NO_TLSv1_2 + | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, +} + +# OpenSSL will only write 16K at a time +SSL_WRITE_BLOCKSIZE = 16384 + +orig_util_SSLContext = util.ssl_.SSLContext + + +log = logging.getLogger(__name__) + + +def inject_into_urllib3() -> None: + "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." + + _validate_dependencies_met() + + util.SSLContext = PyOpenSSLContext # type: ignore[assignment] + util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment] + util.IS_PYOPENSSL = True + util.ssl_.IS_PYOPENSSL = True + + +def extract_from_urllib3() -> None: + "Undo monkey-patching by :func:`inject_into_urllib3`." + + util.SSLContext = orig_util_SSLContext + util.ssl_.SSLContext = orig_util_SSLContext + util.IS_PYOPENSSL = False + util.ssl_.IS_PYOPENSSL = False + + +def _validate_dependencies_met() -> None: + """ + Verifies that PyOpenSSL's package-level dependencies have been met. + Throws `ImportError` if they are not met. + """ + # Method added in `cryptography==1.1`; not available in older versions + from cryptography.x509.extensions import Extensions + + if getattr(Extensions, "get_extension_for_class", None) is None: + raise ImportError( + "'cryptography' module missing required functionality. " + "Try upgrading to v1.3.4 or newer." + ) + + # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 + # attribute is only present on those versions. + from OpenSSL.crypto import X509 + + x509 = X509() + if getattr(x509, "_x509", None) is None: + raise ImportError( + "'pyOpenSSL' module missing required functionality. " + "Try upgrading to v0.14 or newer." + ) + + +def _dnsname_to_stdlib(name: str) -> str | None: + """ + Converts a dNSName SubjectAlternativeName field to the form used by the + standard library on the given Python version. + + Cryptography produces a dNSName as a unicode string that was idna-decoded + from ASCII bytes. We need to idna-encode that string to get it back, and + then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib + uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). + + If the name cannot be idna-encoded then we return None signalling that + the name given should be skipped. + """ + + def idna_encode(name: str) -> bytes | None: + """ + Borrowed wholesale from the Python Cryptography Project. It turns out + that we can't just safely call `idna.encode`: it can explode for + wildcard names. This avoids that problem. + """ + import idna + + try: + for prefix in ["*.", "."]: + if name.startswith(prefix): + name = name[len(prefix) :] + return prefix.encode("ascii") + idna.encode(name) + return idna.encode(name) + except idna.core.IDNAError: + return None + + # Don't send IPv6 addresses through the IDNA encoder. + if ":" in name: + return name + + encoded_name = idna_encode(name) + if encoded_name is None: + return None + return encoded_name.decode("utf-8") + + +def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]: + """ + Given an PyOpenSSL certificate, provides all the subject alternative names. + """ + cert = peer_cert.to_cryptography() + + # We want to find the SAN extension. Ask Cryptography to locate it (it's + # faster than looping in Python) + try: + ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + except x509.ExtensionNotFound: + # No such extension, return the empty list. + return [] + except ( + x509.DuplicateExtension, + UnsupportedExtension, + x509.UnsupportedGeneralNameType, + UnicodeError, + ) as e: + # A problem has been found with the quality of the certificate. Assume + # no SAN field is present. + log.warning( + "A problem was encountered with the certificate that prevented " + "urllib3 from finding the SubjectAlternativeName field. This can " + "affect certificate validation. The error was %s", + e, + ) + return [] + + # We want to return dNSName and iPAddress fields. We need to cast the IPs + # back to strings because the match_hostname function wants them as + # strings. + # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 + # decoded. This is pretty frustrating, but that's what the standard library + # does with certificates, and so we need to attempt to do the same. + # We also want to skip over names which cannot be idna encoded. + names = [ + ("DNS", name) + for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName)) + if name is not None + ] + names.extend( + ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress) + ) + + return names + + +class WrappedSocket: + """API-compatibility wrapper for Python OpenSSL's Connection-class.""" + + def __init__( + self, + connection: OpenSSL.SSL.Connection, + socket: socket_cls, + suppress_ragged_eofs: bool = True, + ) -> None: + self.connection = connection + self.socket = socket + self.suppress_ragged_eofs = suppress_ragged_eofs + self._io_refs = 0 + self._closed = False + + def fileno(self) -> int: + return self.socket.fileno() + + # Copy-pasted from Python 3.5 source code + def _decref_socketios(self) -> None: + if self._io_refs > 0: + self._io_refs -= 1 + if self._closed: + self.close() + + def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes: + try: + data = self.connection.recv(*args, **kwargs) + except OpenSSL.SSL.SysCallError as e: + if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): + return b"" + else: + raise OSError(e.args[0], str(e)) from e + except OpenSSL.SSL.ZeroReturnError: + if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: + return b"" + else: + raise + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(self.socket, self.socket.gettimeout()): + raise timeout("The read operation timed out") from e + else: + return self.recv(*args, **kwargs) + + # TLS 1.3 post-handshake authentication + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"read error: {e!r}") from e + else: + return data # type: ignore[no-any-return] + + def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int: + try: + return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return] + except OpenSSL.SSL.SysCallError as e: + if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): + return 0 + else: + raise OSError(e.args[0], str(e)) from e + except OpenSSL.SSL.ZeroReturnError: + if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: + return 0 + else: + raise + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(self.socket, self.socket.gettimeout()): + raise timeout("The read operation timed out") from e + else: + return self.recv_into(*args, **kwargs) + + # TLS 1.3 post-handshake authentication + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"read error: {e!r}") from e + + def settimeout(self, timeout: float) -> None: + return self.socket.settimeout(timeout) + + def _send_until_done(self, data: bytes) -> int: + while True: + try: + return self.connection.send(data) # type: ignore[no-any-return] + except OpenSSL.SSL.WantWriteError as e: + if not util.wait_for_write(self.socket, self.socket.gettimeout()): + raise timeout() from e + continue + except OpenSSL.SSL.SysCallError as e: + raise OSError(e.args[0], str(e)) from e + + def sendall(self, data: bytes) -> None: + total_sent = 0 + while total_sent < len(data): + sent = self._send_until_done( + data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE] + ) + total_sent += sent + + def shutdown(self) -> None: + # FIXME rethrow compatible exceptions should we ever use this + self.connection.shutdown() + + def close(self) -> None: + self._closed = True + if self._io_refs <= 0: + self._real_close() + + def _real_close(self) -> None: + try: + return self.connection.close() # type: ignore[no-any-return] + except OpenSSL.SSL.Error: + return + + def getpeercert( + self, binary_form: bool = False + ) -> dict[str, list[typing.Any]] | None: + x509 = self.connection.get_peer_certificate() + + if not x509: + return x509 # type: ignore[no-any-return] + + if binary_form: + return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return] + + return { + "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item] + "subjectAltName": get_subj_alt_name(x509), + } + + def version(self) -> str: + return self.connection.get_protocol_version_name() # type: ignore[no-any-return] + + +WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined] + + +class PyOpenSSLContext: + """ + I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible + for translating the interface of the standard library ``SSLContext`` object + to calls into PyOpenSSL. + """ + + def __init__(self, protocol: int) -> None: + self.protocol = _openssl_versions[protocol] + self._ctx = OpenSSL.SSL.Context(self.protocol) + self._options = 0 + self.check_hostname = False + self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED + self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED + + @property + def options(self) -> int: + return self._options + + @options.setter + def options(self, value: int) -> None: + self._options = value + self._set_ctx_options() + + @property + def verify_mode(self) -> int: + return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] + + @verify_mode.setter + def verify_mode(self, value: ssl.VerifyMode) -> None: + self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) + + def set_default_verify_paths(self) -> None: + self._ctx.set_default_verify_paths() + + def set_ciphers(self, ciphers: bytes | str) -> None: + if isinstance(ciphers, str): + ciphers = ciphers.encode("utf-8") + self._ctx.set_cipher_list(ciphers) + + def load_verify_locations( + self, + cafile: str | None = None, + capath: str | None = None, + cadata: bytes | None = None, + ) -> None: + if cafile is not None: + cafile = cafile.encode("utf-8") # type: ignore[assignment] + if capath is not None: + capath = capath.encode("utf-8") # type: ignore[assignment] + try: + self._ctx.load_verify_locations(cafile, capath) + if cadata is not None: + self._ctx.load_verify_locations(BytesIO(cadata)) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e + + def load_cert_chain( + self, + certfile: str, + keyfile: str | None = None, + password: str | None = None, + ) -> None: + try: + self._ctx.use_certificate_chain_file(certfile) + if password is not None: + if not isinstance(password, bytes): + password = password.encode("utf-8") # type: ignore[assignment] + self._ctx.set_passwd_cb(lambda *_: password) + self._ctx.use_privatekey_file(keyfile or certfile) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e + + def set_alpn_protocols(self, protocols: list[bytes | str]) -> None: + protocols = [util.util.to_bytes(p, "ascii") for p in protocols] + return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return] + + def wrap_socket( + self, + sock: socket_cls, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: bytes | str | None = None, + ) -> WrappedSocket: + cnx = OpenSSL.SSL.Connection(self._ctx, sock) + + # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3 + if server_hostname and not util.ssl_.is_ipaddress(server_hostname): + if isinstance(server_hostname, str): + server_hostname = server_hostname.encode("utf-8") + cnx.set_tlsext_host_name(server_hostname) + + cnx.set_connect_state() + + while True: + try: + cnx.do_handshake() + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(sock, sock.gettimeout()): + raise timeout("select timed out") from e + continue + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"bad handshake: {e!r}") from e + break + + return WrappedSocket(cnx, sock) + + def _set_ctx_options(self) -> None: + self._ctx.set_options( + self._options + | _openssl_to_ssl_minimum_version[self._minimum_version] + | _openssl_to_ssl_maximum_version[self._maximum_version] + ) + + @property + def minimum_version(self) -> int: + return self._minimum_version + + @minimum_version.setter + def minimum_version(self, minimum_version: int) -> None: + self._minimum_version = minimum_version + self._set_ctx_options() + + @property + def maximum_version(self) -> int: + return self._maximum_version + + @maximum_version.setter + def maximum_version(self, maximum_version: int) -> None: + self._maximum_version = maximum_version + self._set_ctx_options() + + +def _verify_callback( + cnx: OpenSSL.SSL.Connection, + x509: X509, + err_no: int, + err_depth: int, + return_code: int, +) -> bool: + return err_no == 0 diff --git a/Modules/urllib3/contrib/socks.py b/Modules/urllib3/contrib/socks.py new file mode 100644 index 0000000..6c3bb76 --- /dev/null +++ b/Modules/urllib3/contrib/socks.py @@ -0,0 +1,230 @@ +""" +This module contains provisional support for SOCKS proxies from within +urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and +SOCKS5. To enable its functionality, either install PySocks or install this +module with the ``socks`` extra. + +The SOCKS implementation supports the full range of urllib3 features. It also +supports the following SOCKS features: + +- SOCKS4A (``proxy_url='socks4a://...``) +- SOCKS4 (``proxy_url='socks4://...``) +- SOCKS5 with remote DNS (``proxy_url='socks5h://...``) +- SOCKS5 with local DNS (``proxy_url='socks5://...``) +- Usernames and passwords for the SOCKS proxy + +.. note:: + It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in + your ``proxy_url`` to ensure that DNS resolution is done from the remote + server instead of client-side when connecting to a domain name. + +SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5 +supports IPv4, IPv6, and domain names. + +When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url`` +will be sent as the ``userid`` section of the SOCKS request: + +.. code-block:: python + + proxy_url="socks4a://@proxy-host" + +When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion +of the ``proxy_url`` will be sent as the username/password to authenticate +with the proxy: + +.. code-block:: python + + proxy_url="socks5h://:@proxy-host" + +""" + +from __future__ import annotations + +try: + import socks # type: ignore[import] +except ImportError: + import warnings + + from ..exceptions import DependencyWarning + + warnings.warn( + ( + "SOCKS support in urllib3 requires the installation of optional " + "dependencies: specifically, PySocks. For more information, see " + "https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies" + ), + DependencyWarning, + ) + raise + +import typing +from socket import timeout as SocketTimeout + +from ..connection import HTTPConnection, HTTPSConnection +from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool +from ..exceptions import ConnectTimeoutError, NewConnectionError +from ..poolmanager import PoolManager +from ..util.url import parse_url + +try: + import ssl +except ImportError: + ssl = None # type: ignore[assignment] + +from typing import TypedDict + + +class _TYPE_SOCKS_OPTIONS(TypedDict): + socks_version: int + proxy_host: str | None + proxy_port: str | None + username: str | None + password: str | None + rdns: bool + + +class SOCKSConnection(HTTPConnection): + """ + A plain-text HTTP connection that connects via a SOCKS proxy. + """ + + def __init__( + self, + _socks_options: _TYPE_SOCKS_OPTIONS, + *args: typing.Any, + **kwargs: typing.Any, + ) -> None: + self._socks_options = _socks_options + super().__init__(*args, **kwargs) + + def _new_conn(self) -> socks.socksocket: + """ + Establish a new connection via the SOCKS proxy. + """ + extra_kw: dict[str, typing.Any] = {} + if self.source_address: + extra_kw["source_address"] = self.source_address + + if self.socket_options: + extra_kw["socket_options"] = self.socket_options + + try: + conn = socks.create_connection( + (self.host, self.port), + proxy_type=self._socks_options["socks_version"], + proxy_addr=self._socks_options["proxy_host"], + proxy_port=self._socks_options["proxy_port"], + proxy_username=self._socks_options["username"], + proxy_password=self._socks_options["password"], + proxy_rdns=self._socks_options["rdns"], + timeout=self.timeout, + **extra_kw, + ) + + except SocketTimeout as e: + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e + + except socks.ProxyError as e: + # This is fragile as hell, but it seems to be the only way to raise + # useful errors here. + if e.socket_err: + error = e.socket_err + if isinstance(error, SocketTimeout): + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e + else: + # Adding `from e` messes with coverage somehow, so it's omitted. + # See #2386. + raise NewConnectionError( + self, f"Failed to establish a new connection: {error}" + ) + else: + raise NewConnectionError( + self, f"Failed to establish a new connection: {e}" + ) from e + + except OSError as e: # Defensive: PySocks should catch all these. + raise NewConnectionError( + self, f"Failed to establish a new connection: {e}" + ) from e + + return conn + + +# We don't need to duplicate the Verified/Unverified distinction from +# urllib3/connection.py here because the HTTPSConnection will already have been +# correctly set to either the Verified or Unverified form by that module. This +# means the SOCKSHTTPSConnection will automatically be the correct type. +class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): + pass + + +class SOCKSHTTPConnectionPool(HTTPConnectionPool): + ConnectionCls = SOCKSConnection + + +class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): + ConnectionCls = SOCKSHTTPSConnection + + +class SOCKSProxyManager(PoolManager): + """ + A version of the urllib3 ProxyManager that routes connections via the + defined SOCKS proxy. + """ + + pool_classes_by_scheme = { + "http": SOCKSHTTPConnectionPool, + "https": SOCKSHTTPSConnectionPool, + } + + def __init__( + self, + proxy_url: str, + username: str | None = None, + password: str | None = None, + num_pools: int = 10, + headers: typing.Mapping[str, str] | None = None, + **connection_pool_kw: typing.Any, + ): + parsed = parse_url(proxy_url) + + if username is None and password is None and parsed.auth is not None: + split = parsed.auth.split(":") + if len(split) == 2: + username, password = split + if parsed.scheme == "socks5": + socks_version = socks.PROXY_TYPE_SOCKS5 + rdns = False + elif parsed.scheme == "socks5h": + socks_version = socks.PROXY_TYPE_SOCKS5 + rdns = True + elif parsed.scheme == "socks4": + socks_version = socks.PROXY_TYPE_SOCKS4 + rdns = False + elif parsed.scheme == "socks4a": + socks_version = socks.PROXY_TYPE_SOCKS4 + rdns = True + else: + raise ValueError(f"Unable to determine SOCKS version from {proxy_url}") + + self.proxy_url = proxy_url + + socks_options = { + "socks_version": socks_version, + "proxy_host": parsed.host, + "proxy_port": parsed.port, + "username": username, + "password": password, + "rdns": rdns, + } + connection_pool_kw["_socks_options"] = socks_options + + super().__init__(num_pools, headers, **connection_pool_kw) + + self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme