import os
import platform
import socket
import ssl
import typing
import _ssl
from ._ssl_constants import (
_original_SSLContext,
_original_super_SSLContext,
_truststore_SSLContext_dunder_class,
_truststore_SSLContext_super_class,
)
if platform.system() == "Windows":
from ._windows import _configure_context, _verify_peercerts_impl
elif platform.system() == "Darwin":
from ._macos import _configure_context, _verify_peercerts_impl
else:
from ._openssl import _configure_context, _verify_peercerts_impl
if typing.TYPE_CHECKING:
from pip._vendor.typing_extensions import Buffer
_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]
def inject_into_ssl() -> None:
setattr(ssl, "SSLContext", SSLContext)
try:
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
setattr(urllib3_ssl, "SSLContext", SSLContext)
except ImportError:
pass
def extract_from_ssl() -> None:
setattr(ssl, "SSLContext", _original_SSLContext)
try:
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
urllib3_ssl.SSLContext = _original_SSLContext
except ImportError:
pass
class SSLContext(_truststore_SSLContext_super_class):
@property def __class__(self) -> type:
return _truststore_SSLContext_dunder_class or SSLContext
def __init__(self, protocol: int = None) -> None: self._ctx = _original_SSLContext(protocol)
class TruststoreSSLObject(ssl.SSLObject):
def do_handshake(self) -> None:
ret = super().do_handshake()
_verify_peercerts(self, server_hostname=self.server_hostname)
return ret
self._ctx.sslobject_class = TruststoreSSLObject
def wrap_socket(
self,
sock: socket.socket,
server_side: bool = False,
do_handshake_on_connect: bool = True,
suppress_ragged_eofs: bool = True,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
sock,
server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
session=session,
)
try:
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except Exception:
ssl_sock.close()
raise
return ssl_sock
def wrap_bio(
self,
incoming: ssl.MemoryBIO,
outgoing: ssl.MemoryBIO,
server_side: bool = False,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
incoming,
outgoing,
server_hostname=server_hostname,
server_side=server_side,
session=session,
)
return ssl_obj
def load_verify_locations(
self,
cafile: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
capath: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
cadata: typing.Union[str, "Buffer", None] = None,
) -> None:
return self._ctx.load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
)
def load_cert_chain(
self,
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
)
def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)
def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)
def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)
def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)
def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()
def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()
def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()
@typing.overload
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]:
...
@typing.overload
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]:
...
@typing.overload
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any:
...
def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()
@property
def check_hostname(self) -> bool:
return self._ctx.check_hostname
@check_hostname.setter
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value
@property
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name
@hostname_checks_common_name.setter
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value
@property
def keylog_filename(self) -> str:
return self._ctx.keylog_filename
@keylog_filename.setter
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value
@property
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version
@maximum_version.setter
def maximum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.maximum_version.__set__( self._ctx, value
)
@property
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version
@minimum_version.setter
def minimum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.minimum_version.__set__( self._ctx, value
)
@property
def options(self) -> ssl.Options:
return self._ctx.options
@options.setter
def options(self, value: ssl.Options) -> None:
_original_super_SSLContext.options.__set__( self._ctx, value
)
@property
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth
@post_handshake_auth.setter
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value
@property
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol
@property
def security_level(self) -> int:
return self._ctx.security_level
@property
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags
@verify_flags.setter
def verify_flags(self, value: ssl.VerifyFlags) -> None:
_original_super_SSLContext.verify_flags.__set__( self._ctx, value
)
@property
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode
@verify_mode.setter
def verify_mode(self, value: ssl.VerifyMode) -> None:
_original_super_SSLContext.verify_mode.__set__( self._ctx, value
)
def _verify_peercerts(
sock_or_sslobj: ssl.SSLSocket | ssl.SSLObject, server_hostname: str | None
) -> None:
sslobj: ssl.SSLObject = sock_or_sslobj try:
while not hasattr(sslobj, "get_unverified_chain"):
sslobj = sslobj._sslobj except AttributeError:
pass
unverified_chain: typing.Sequence[_ssl.Certificate] = (
sslobj.get_unverified_chain() or () )
cert_bytes = [cert.public_bytes(_ssl.ENCODING_DER) for cert in unverified_chain]
_verify_peercerts_impl(
sock_or_sslobj.context, cert_bytes, server_hostname=server_hostname
)