Skip to content

Commit

Permalink
Add support for passphrase to be an awaitable
Browse files Browse the repository at this point in the history
This commit expands the passphrase argument in SSHClientConnectionOptions
and SSHServerConnectionOptions to accept an awaitable, in addition to the
callable which was previously added. Options evaluation is still done in
a separate executor thread to avoid blocking the asyncio event loop, but
it's now possible to pass in an awaitable which will run as a task in the
event loop to determine the passphrase to use if you need to do "blocking"
async operations to determine this value.
  • Loading branch information
ronf committed Mar 28, 2024
1 parent 5277c51 commit a93224f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 30 deletions.
38 changes: 23 additions & 15 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6977,7 +6977,7 @@ async def construct(cls, options: Optional['_OptionsSelf'] = None,
loop = asyncio.get_event_loop()

return cast(_OptionsSelf, await loop.run_in_executor(
None, functools.partial(cls, options, **kwargs)))
None, functools.partial(cls, options, loop=loop, **kwargs)))

# pylint: disable=arguments-differ
def prepare(self, config: SSHConfig, # type: ignore
Expand Down Expand Up @@ -7255,10 +7255,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
A list of optional certificates which can be paired with the
provided client keys.
:param passphrase: (optional)
The passphrase to use to decrypt client keys when loading them,
if they are encrypted. If this is not specified, only unencrypted
client keys can be loaded. If the keys passed into client_keys
are already loaded, this argument is ignored.
The passphrase to use to decrypt client keys if they are
encrypted, or a `callable` or coroutine which takes a filename
as a parameter and returns the passphrase to use to decrypt
that file. If not specified, only unencrypted client keys can
be loaded. If the keys passed into client_keys are already
loaded, this argument is ignored.
:param ignore_encrypted: (optional)
Whether or not to ignore encrypted keys when no passphrase is
specified. This defaults to `True` when keys are specified via
Expand Down Expand Up @@ -7605,7 +7607,9 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
max_pktsize: int

# pylint: disable=arguments-differ
def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore
def prepare(self, # type: ignore
loop: Optional[asyncio.AbstractEventLoop] = None,
last_config: Optional[SSHConfig] = None,
config: DefTuple[ConfigPaths] = None, reload: bool = False,
client_factory: Optional[_ClientFactory] = None,
client_version: _VersionArg = (), host: str = '',
Expand Down Expand Up @@ -7761,7 +7765,7 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore

self.client_host_keypairs = \
load_keypairs(cast(KeyPairListArg, client_host_keys),
passphrase, client_host_certs)
passphrase, client_host_certs, loop=loop)

self.client_host_keysign = client_host_keysign
self.client_host = client_host
Expand Down Expand Up @@ -7839,7 +7843,8 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore
if client_keys:
self.client_keys = \
load_keypairs(cast(KeyPairListArg, client_keys), passphrase,
client_certs, identities_only, ignore_encrypted)
client_certs, identities_only, ignore_encrypted,
loop=loop)
elif client_keys is not None:
self.client_keys = load_default_keypairs(passphrase, client_certs)
else:
Expand Down Expand Up @@ -7914,11 +7919,12 @@ class SSHServerConnectionOptions(SSHConnectionOptions):
A list of optional certificates which can be paired with the
provided server host keys.
:param passphrase: (optional)
The passphrase to use to decrypt server host keys when loading
them, if they are encrypted. If this is not specified, only
unencrypted server host keys can be loaded. If the keys passed
into server_host_keys are already loaded, this argument is
ignored.
The passphrase to use to decrypt server host keys if they are
encrypted, or a `callable` or coroutine which takes a filename
as a parameter and returns the passphrase to use to decrypt
that file. If not specified, only unencrypted server host keys
can be loaded. If the keys passed into server_host_keys are
already loaded, this argument is ignored.
:param known_client_hosts: (optional)
A list of client hosts which should be trusted to perform
host-based client authentication. If this is not specified,
Expand Down Expand Up @@ -8224,7 +8230,9 @@ class SSHServerConnectionOptions(SSHConnectionOptions):
max_pktsize: int

# pylint: disable=arguments-differ
def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore
def prepare(self, # type: ignore
loop: Optional[asyncio.AbstractEventLoop] = None,
last_config: Optional[SSHConfig] = None,
config: DefTuple[ConfigPaths] = None, reload: bool = False,
accept_addr: str = '', accept_port: int = 0,
username: str = '', client_host: str = '',
Expand Down Expand Up @@ -8320,7 +8328,7 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore
config.get('HostCertificate', ()))

server_keys = load_keypairs(server_host_keys, passphrase,
server_host_certs)
server_host_certs, loop=loop)

self.server_host_keys = OrderedDict()

Expand Down
15 changes: 13 additions & 2 deletions asyncssh/public_key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2013-2023 by Ron Frederick <[email protected]> and others.
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
Expand All @@ -20,7 +20,9 @@

"""SSH asymmetric encryption handlers"""

import asyncio
import binascii
import inspect
import os
import re
import time
Expand Down Expand Up @@ -3472,7 +3474,8 @@ def load_keypairs(
keylist: KeyPairListArg, passphrase: Optional[BytesOrStr] = None,
certlist: CertListArg = (), skip_public: bool = False,
ignore_encrypted: bool = False,
unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \
unsafe_skip_rsa_key_validation: Optional[bool] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> \
Sequence[SSHKeyPair]:
"""Load SSH private keys and optional matching certificates
Expand Down Expand Up @@ -3521,6 +3524,10 @@ def load_keypairs(
else:
resolved_passphrase = passphrase

if loop and inspect.isawaitable(resolved_passphrase):
resolved_passphrase = asyncio.run_coroutine_threadsafe(
resolved_passphrase, loop).result()

priv_keys = read_private_key_list(keylist, resolved_passphrase,
unsafe_skip_rsa_key_validation)

Expand Down Expand Up @@ -3559,6 +3566,10 @@ def load_keypairs(
else:
resolved_passphrase = passphrase

if loop and inspect.isawaitable(resolved_passphrase):
resolved_passphrase = asyncio.run_coroutine_threadsafe(
resolved_passphrase, loop).result()

if allow_certs:
key, certs_to_load = read_private_key_and_certs(
key_to_load, resolved_passphrase,
Expand Down
50 changes: 50 additions & 0 deletions tests/test_connection_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,56 @@ async def test_encrypted_client_key(self):
passphrase='passphrase'):
pass

@asynctest
async def test_encrypted_client_key_callable(self):
"""Test public key auth with callable passphrase"""

def _passphrase(filename):
self.assertEqual(filename, 'ckey_encrypted')
return 'passphrase'

async with self.connect(username='ckey', client_keys='ckey_encrypted',
passphrase=_passphrase):
pass

@asynctest
async def test_encrypted_client_key_awaitable(self):
"""Test public key auth with awaitable passphrase"""

async def _passphrase(filename):
self.assertEqual(filename, 'ckey_encrypted')
return 'passphrase'

async with self.connect(username='ckey', client_keys='ckey_encrypted',
passphrase=_passphrase):
pass

@asynctest
async def test_encrypted_client_key_list_callable(self):
"""Test public key auth with callable passphrase"""

def _passphrase(filename):
self.assertEqual(filename, 'ckey_encrypted')
return 'passphrase'

async with self.connect(username='ckey',
client_keys=['ckey_encrypted'],
passphrase=_passphrase):
pass

@asynctest
async def test_encrypted_client_key_list_awaitable(self):
"""Test public key auth with awaitable passphrase"""

async def _passphrase(filename):
self.assertEqual(filename, 'ckey_encrypted')
return 'passphrase'

async with self.connect(username='ckey',
client_keys=['ckey_encrypted'],
passphrase=_passphrase):
pass

@asynctest
async def test_encrypted_client_key_bad_passphrase(self):
"""Test wrong passphrase for encrypted client key"""
Expand Down
13 changes: 0 additions & 13 deletions tests/test_public_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,6 @@ def validate_x509(self, cert, user_principal=None):
def check_private(self, format_name, passphrase=None):
"""Check for a private key match"""

def _passphrase(filename):
self.assertEqual(filename, 'new')
return passphrase

newkey = asyncssh.read_private_key('new', passphrase)
algorithm = newkey.get_algorithm()
keydata = newkey.export_private_key()
Expand All @@ -279,9 +275,6 @@ def _passphrase(filename):
keypair = asyncssh.load_keypairs('new', passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs('new', _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([newkey])[0]
self.assertEqual(keypair.public_data, pubdata)

Expand All @@ -297,15 +290,9 @@ def _passphrase(filename):
keypair = asyncssh.load_keypairs(['new'], passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs(['new'], _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([('new', None)], passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([('new', None)], _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs(Path('new'), passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

Expand Down

0 comments on commit a93224f

Please sign in to comment.