Skip to content

Commit

Permalink
Attempt to prohibit mutating a Context after its in use
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Feb 2, 2025
1 parent 9b8c497 commit aa5f618
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@ Changelog
Versions are year-based with a strict backward-compatibility policy.
The third digit is only for regressions.

UNRELEASED
----------

Backward-incompatible changes:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

- Attempting using any methods that mutate an ``OpenSSL.SSL.Context`` after it
has been used to create an ``OpenSSL.SSL.Connection`` will emit a warning. In
a future release, this will raise an exception.

Changes:
^^^^^^^^


25.0.0 (2025-01-12)
-------------------

Expand Down
59 changes: 59 additions & 0 deletions src/OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,26 @@ class Session:
_session: Any


F = TypeVar("F", bound=Callable[..., Any])


def _require_not_used(f: F) -> F:
@wraps(f)
def inner(self: Context, *args: Any, **kwargs: Any) -> Any:
if self._used:
warnings.warn(
(
"Attempting to mutate a Context after a Connection was "
"created. In the future, this will raise an exception"
),
DeprecationWarning,
stacklevel=2,
)
return f(self, *args, **kwargs)

return typing.cast(F, inner)


class Context:
"""
:class:`OpenSSL.SSL.Context` instances define the parameters for setting
Expand Down Expand Up @@ -870,6 +890,7 @@ def __init__(self, method: int) -> None:
context = _ffi.gc(context, _lib.SSL_CTX_free)

self._context = context
self._used = False
self._passphrase_helper: _PassphraseHelper | None = None
self._passphrase_callback: _PassphraseCallback[Any] | None = None
self._passphrase_userdata: Any | None = None
Expand Down Expand Up @@ -898,6 +919,7 @@ def __init__(self, method: int) -> None:
self.set_min_proto_version(version)
self.set_max_proto_version(version)

@_require_not_used
def set_min_proto_version(self, version: int) -> None:
"""
Set the minimum supported protocol version. Setting the minimum
Expand All @@ -911,6 +933,7 @@ def set_min_proto_version(self, version: int) -> None:
_lib.SSL_CTX_set_min_proto_version(self._context, version) == 1
)

@_require_not_used
def set_max_proto_version(self, version: int) -> None:
"""
Set the maximum supported protocol version. Setting the maximum
Expand All @@ -924,6 +947,7 @@ def set_max_proto_version(self, version: int) -> None:
_lib.SSL_CTX_set_max_proto_version(self._context, version) == 1
)

@_require_not_used
def load_verify_locations(
self,
cafile: _StrOrBytesPath | None,
Expand Down Expand Up @@ -971,6 +995,7 @@ def wrapper(size: int, verify: bool, userdata: Any) -> bytes:
FILETYPE_PEM, wrapper, more_args=True, truncate=True
)

@_require_not_used
def set_passwd_cb(
self,
callback: _PassphraseCallback[_T],
Expand Down Expand Up @@ -1004,6 +1029,7 @@ def set_passwd_cb(
)
self._passphrase_userdata = userdata

@_require_not_used
def set_default_verify_paths(self) -> None:
"""
Specify that the platform provided CA certificates are to be used for
Expand Down Expand Up @@ -1079,6 +1105,7 @@ def _fallback_default_verify_paths(
self.load_verify_locations(None, capath)
break

@_require_not_used
def use_certificate_chain_file(self, certfile: _StrOrBytesPath) -> None:
"""
Load a certificate chain from a file.
Expand All @@ -1096,6 +1123,7 @@ def use_certificate_chain_file(self, certfile: _StrOrBytesPath) -> None:
if not result:
_raise_current_error()

@_require_not_used
def use_certificate_file(
self, certfile: _StrOrBytesPath, filetype: int = FILETYPE_PEM
) -> None:
Expand All @@ -1120,6 +1148,7 @@ def use_certificate_file(
if not use_result:
_raise_current_error()

@_require_not_used
def use_certificate(self, cert: X509 | x509.Certificate) -> None:
"""
Load a certificate from a X509 object
Expand All @@ -1144,6 +1173,7 @@ def use_certificate(self, cert: X509 | x509.Certificate) -> None:
if not use_result:
_raise_current_error()

@_require_not_used
def add_extra_chain_cert(self, certobj: X509 | x509.Certificate) -> None:
"""
Add certificate to chain
Expand Down Expand Up @@ -1176,6 +1206,7 @@ def _raise_passphrase_exception(self) -> None:

_raise_current_error()

@_require_not_used
def use_privatekey_file(
self, keyfile: _StrOrBytesPath, filetype: int = FILETYPE_PEM
) -> None:
Expand All @@ -1200,6 +1231,7 @@ def use_privatekey_file(
if not use_result:
self._raise_passphrase_exception()

@_require_not_used
def use_privatekey(self, pkey: _PrivateKey | PKey) -> None:
"""
Load a private key from a PKey object
Expand Down Expand Up @@ -1234,6 +1266,7 @@ def check_privatekey(self) -> None:
if not _lib.SSL_CTX_check_private_key(self._context):
_raise_current_error()

@_require_not_used
def load_client_ca(self, cafile: bytes) -> None:
"""
Load the trusted certificates that will be sent to the client. Does
Expand All @@ -1249,6 +1282,7 @@ def load_client_ca(self, cafile: bytes) -> None:
_openssl_assert(ca_list != _ffi.NULL)
_lib.SSL_CTX_set_client_CA_list(self._context, ca_list)

@_require_not_used
def set_session_id(self, buf: bytes) -> None:
"""
Set the session id to *buf* within which a session can be reused for
Expand All @@ -1266,6 +1300,7 @@ def set_session_id(self, buf: bytes) -> None:
== 1
)

@_require_not_used
def set_session_cache_mode(self, mode: int) -> int:
"""
Set the behavior of the session cache used by all connections using
Expand Down Expand Up @@ -1293,6 +1328,7 @@ def get_session_cache_mode(self) -> int:
"""
return _lib.SSL_CTX_get_session_cache_mode(self._context)

@_require_not_used
def set_verify(
self, mode: int, callback: _VerifyCallback | None = None
) -> None:
Expand Down Expand Up @@ -1330,6 +1366,7 @@ def set_verify(
self._verify_callback = self._verify_helper.callback
_lib.SSL_CTX_set_verify(self._context, mode, self._verify_callback)

@_require_not_used
def set_verify_depth(self, depth: int) -> None:
"""
Set the maximum depth for the certificate chain verification that shall
Expand Down Expand Up @@ -1361,6 +1398,7 @@ def get_verify_depth(self) -> int:
"""
return _lib.SSL_CTX_get_verify_depth(self._context)

@_require_not_used
def load_tmp_dh(self, dhfile: _StrOrBytesPath) -> None:
"""
Load parameters for Ephemeral Diffie-Hellman
Expand All @@ -1382,6 +1420,7 @@ def load_tmp_dh(self, dhfile: _StrOrBytesPath) -> None:
res = _lib.SSL_CTX_set_tmp_dh(self._context, dh)
_openssl_assert(res == 1)

@_require_not_used
def set_tmp_ecdh(self, curve: _EllipticCurve | ec.EllipticCurve) -> None:
"""
Select a curve to use for ECDHE key exchange.
Expand Down Expand Up @@ -1421,6 +1460,7 @@ def set_tmp_ecdh(self, curve: _EllipticCurve | ec.EllipticCurve) -> None:
ec = _ffi.gc(ec, _lib.EC_KEY_free)
_lib.SSL_CTX_set_tmp_ecdh(self._context, ec)

@_require_not_used
def set_cipher_list(self, cipher_list: bytes) -> None:
"""
Set the list of ciphers to be used in this context.
Expand Down Expand Up @@ -1460,6 +1500,7 @@ def set_cipher_list(self, cipher_list: bytes) -> None:
],
)

@_require_not_used
def set_client_ca_list(
self, certificate_authorities: Sequence[X509Name]
) -> None:
Expand Down Expand Up @@ -1497,6 +1538,7 @@ def set_client_ca_list(

_lib.SSL_CTX_set_client_CA_list(self._context, name_stack)

@_require_not_used
def add_client_ca(
self, certificate_authority: X509 | x509.Certificate
) -> None:
Expand Down Expand Up @@ -1531,6 +1573,7 @@ def add_client_ca(
)
_openssl_assert(add_result == 1)

@_require_not_used
def set_timeout(self, timeout: int) -> None:
"""
Set the timeout for newly created sessions for this Context object to
Expand All @@ -1554,6 +1597,7 @@ def get_timeout(self) -> int:
"""
return _lib.SSL_CTX_get_timeout(self._context)

@_require_not_used
def set_info_callback(
self, callback: Callable[[Connection, int, int], None]
) -> None:
Expand All @@ -1579,6 +1623,7 @@ def wrapper(ssl, where, return_code): # type: ignore[no-untyped-def]
_lib.SSL_CTX_set_info_callback(self._context, self._info_callback)

@_requires_keylog
@_require_not_used
def set_keylog_callback(
self, callback: Callable[[Connection, bytes], None]
) -> None:
Expand Down Expand Up @@ -1613,6 +1658,7 @@ def get_app_data(self) -> Any:
"""
return self._app_data

@_require_not_used
def set_app_data(self, data: Any) -> None:
"""
Set the application data (will be returned from get_app_data())
Expand All @@ -1639,6 +1685,7 @@ def get_cert_store(self) -> X509Store | None:
pystore._store = store
return pystore

@_require_not_used
def set_options(self, options: int) -> int:
"""
Add options. Options set before are not cleared!
Expand All @@ -1652,6 +1699,7 @@ def set_options(self, options: int) -> int:

return _lib.SSL_CTX_set_options(self._context, options)

@_require_not_used
def set_mode(self, mode: int) -> int:
"""
Add modes via bitmask. Modes set before are not cleared! This method
Expand All @@ -1665,6 +1713,7 @@ def set_mode(self, mode: int) -> int:

return _lib.SSL_CTX_set_mode(self._context, mode)

@_require_not_used
def set_tlsext_servername_callback(
self, callback: Callable[[Connection], None]
) -> None:
Expand All @@ -1690,6 +1739,7 @@ def wrapper(ssl, alert, arg): # type: ignore[no-untyped-def]
self._context, self._tlsext_servername_callback
)

@_require_not_used
def set_tlsext_use_srtp(self, profiles: bytes) -> None:
"""
Enable support for negotiating SRTP keying material.
Expand All @@ -1705,6 +1755,7 @@ def set_tlsext_use_srtp(self, profiles: bytes) -> None:
_lib.SSL_CTX_set_tlsext_use_srtp(self._context, profiles) == 0
)

@_require_not_used
def set_alpn_protos(self, protos: list[bytes]) -> None:
"""
Specify the protocols that the client is prepared to speak after the
Expand Down Expand Up @@ -1742,6 +1793,7 @@ def set_alpn_protos(self, protos: list[bytes]) -> None:
== 0
)

@_require_not_used
def set_alpn_select_callback(self, callback: _ALPNSelectCallback) -> None:
"""
Specify a callback function that will be called on the server when a
Expand Down Expand Up @@ -1786,6 +1838,7 @@ def _set_ocsp_callback(
rc = _lib.SSL_CTX_set_tlsext_status_arg(self._context, self._ocsp_data)
_openssl_assert(rc == 1)

@_require_not_used
def set_ocsp_server_callback(
self,
callback: _OCSPServerCallback[_T],
Expand All @@ -1808,6 +1861,7 @@ def set_ocsp_server_callback(
helper = _OCSPServerCallbackHelper(callback)
self._set_ocsp_callback(helper, data)

@_require_not_used
def set_ocsp_client_callback(
self,
callback: _OCSPClientCallback[_T],
Expand All @@ -1832,6 +1886,7 @@ def set_ocsp_client_callback(
helper = _OCSPClientCallbackHelper(callback)
self._set_ocsp_callback(helper, data)

@_require_not_used
def set_cookie_generate_callback(
self, callback: _CookieGenerateCallback
) -> None:
Expand All @@ -1841,6 +1896,7 @@ def set_cookie_generate_callback(
self._cookie_generate_helper.callback,
)

@_require_not_used
def set_cookie_verify_callback(
self, callback: _CookieVerifyCallback
) -> None:
Expand Down Expand Up @@ -1869,6 +1925,8 @@ def __init__(
if not isinstance(context, Context):
raise TypeError("context must be a Context instance")

context._used = True

ssl = _lib.SSL_new(context._context)
self._ssl = _ffi.gc(ssl, _lib.SSL_free)
# We set SSL_MODE_AUTO_RETRY to handle situations where OpenSSL returns
Expand Down Expand Up @@ -2000,6 +2058,7 @@ def set_context(self, context: Context) -> None:

_lib.SSL_set_SSL_CTX(self._ssl, context._context)
self._context = context
self._context._used = True

def get_servername(self) -> bytes | None:
"""
Expand Down

0 comments on commit aa5f618

Please sign in to comment.