From b244bb33af6ccb7e8e15b26fa4d23cea1b4489ac Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Tue, 17 Sep 2024 06:22:54 -0700 Subject: [PATCH] Allow additional flexibility in specifying environment values This commit allows the 'env' option to accept sequences of either 'key=value' strings or (key, value) tuples, in addition to accepting a dict. The 'name=value' syntax was previously required to be sent as a list. As part of a previous commit, keys and values can now be either Unicode strings or byte strings, where Unicode strings will be encoded as UTF-8 in channel requests. --- asyncssh/channel.py | 6 +++--- asyncssh/connection.py | 12 ++++++------ asyncssh/misc.py | 26 ++++++++++++++------------ tests/test_channel.py | 13 +++++++++++++ 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/asyncssh/channel.py b/asyncssh/channel.py index fb98862..57b40ed 100644 --- a/asyncssh/channel.py +++ b/asyncssh/channel.py @@ -45,7 +45,7 @@ from .logging import SSHLogger -from .misc import ChannelOpenError, EnvIter, MaybeAwait, ProtocolError +from .misc import ChannelOpenError, EnvMap, MaybeAwait, ProtocolError from .misc import TermModes, TermSize, TermSizeArg from .misc import decode_env, encode_env, get_symbol_names, map_handler_name @@ -1497,8 +1497,8 @@ def __init__(self, conn: 'SSHServerConnection', super().__init__(conn, loop, encoding, errors, window, max_pktsize) - env_option = cast(EnvIter, conn.get_key_option('environment', {})) - self._env = dict(encode_env(env_option)) + env_opt = cast(EnvMap, conn.get_key_option('environment', {})) + self._env = dict(encode_env(env_opt)) self._allow_pty = allow_pty self._line_editor = line_editor diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 4596f93..326e0e0 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -105,7 +105,7 @@ from .mac import get_mac_algs, get_default_mac_algs -from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvList, FilePath +from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvSeq, FilePath from .misc import HostPort, IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr from .misc import ChannelListenError, ChannelOpenError, CompressionError from .misc import DisconnectError, ConnectionLost, HostKeyNotVerifiable @@ -4071,7 +4071,7 @@ async def create_session(self, session_factory: SSHClientSessionFactory, command: DefTuple[Optional[str]] = (), *, subsystem: DefTuple[Optional[str]]= (), env: DefTuple[Env] = (), - send_env: DefTuple[Optional[EnvList]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[Union[bool, str]] = (), term_type: DefTuple[Optional[str]] = (), term_size: DefTuple[TermSizeArg] = (), @@ -5590,7 +5590,7 @@ def session_factory() -> SSHTunTapSession: @async_context_manager async def start_sftp_client(self, env: DefTuple[Env] = (), - send_env: DefTuple[Optional[EnvList]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), path_encoding: Optional[str] = 'utf-8', path_errors = 'strict', sftp_version = MIN_SFTP_VERSION) -> SFTPClient: @@ -7770,7 +7770,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions): command: Optional[str] subsystem: Optional[str] env: Env - send_env: Optional[EnvList] + send_env: Optional[EnvSeq] request_pty: _RequestPTY term_type: Optional[str] term_size: TermSizeArg @@ -7837,7 +7837,7 @@ def prepare(self, # type: ignore pkcs11_pin: Optional[str] = None, command: DefTuple[Optional[str]] = (), subsystem: Optional[str] = None, env: DefTuple[Env] = (), - send_env: DefTuple[Optional[EnvList]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[_RequestPTY] = (), term_type: Optional[str] = None, term_size: TermSizeArg = None, @@ -8058,7 +8058,7 @@ def prepare(self, # type: ignore self.env = cast(Env, env if env != () else config.get('SetEnv')) - self.send_env = cast(Optional[EnvList], send_env if send_env != () else + self.send_env = cast(Optional[EnvSeq], send_env if send_env != () else config.get('SendEnv')) self.request_pty = cast(_RequestPTY, request_pty if request_pty != () diff --git a/asyncssh/misc.py b/asyncssh/misc.py index 42810a5..9e6edee 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -111,10 +111,10 @@ async def wait_closed(self) -> None: IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]] -EnvDict = Mapping[BytesOrStr, BytesOrStr] -EnvIter = Iterator[Tuple[BytesOrStr, BytesOrStr]] -EnvList = Sequence[BytesOrStr] -Env = Optional[Union[EnvDict, EnvIter, EnvList]] +EnvMap = Mapping[BytesOrStr, BytesOrStr] +EnvItems = Sequence[Tuple[BytesOrStr, BytesOrStr]] +EnvSeq = Sequence[BytesOrStr] +Env = Optional[Union[EnvMap, EnvItems, EnvSeq]] # Define a version of randrange which is based on SystemRandom(), so that # we get back numbers suitable for cryptographic use. @@ -130,29 +130,31 @@ async def wait_closed(self) -> None: def encode_env(env: Env) -> Iterator[Tuple[bytes, bytes]]: """Convert environemnt dict or list to bytes-based dictionary""" + env = cast(Sequence[Tuple[BytesOrStr, BytesOrStr]], + env.items() if isinstance(env, dict) else env) + try: - if isinstance(env, list): - for item in env: + for item in env: + if isinstance(item, (bytes, str)): if isinstance(item, str): item = item.encode('utf-8') - yield item.split(b'=', 1) - else: - env = cast(EnvIter, env.items() if isinstance(env, dict) else env) + key_bytes, value_bytes = item.split(b'=', 1) + else: + key, value = item - for key, value in env: key_bytes = key.encode('utf-8') \ if isinstance(key, str) else key value_bytes = value.encode('utf-8') \ if isinstance(value, str) else value - yield key_bytes, value_bytes + yield key_bytes, value_bytes except (TypeError, ValueError) as exc: raise ValueError('Invalid environment value: %s' % exc) from None -def lookup_env(patterns: EnvList) -> Iterator[Tuple[bytes, bytes]]: +def lookup_env(patterns: EnvSeq) -> Iterator[Tuple[bytes, bytes]]: """Look up environemnt variables with wildcard matches""" for pattern in patterns: diff --git a/tests/test_channel.py b/tests/test_channel.py index e2195ff..6850de6 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1291,6 +1291,19 @@ async def test_env_list_binary(self): result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\\xff\n') + @asynctest + async def test_env_tuple(self): + """Test setting environment using a tuple of name=value strings""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env', + env=('TEST=test',)) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\n') + @asynctest async def test_invalid_env_list(self): """Test setting environment using an invalid string"""