Skip to content

Commit

Permalink
Allow additional flexibility in specifying environment values
Browse files Browse the repository at this point in the history
This commit allows the 'env' option to accept sequences of either
'key=value' strings or (key, value) tuples, in addition to accepting
a dict. The 'name=value' syntax was previously required to be sent
as a list.

As part of a previous commit, keys and values can now be either
Unicode strings or byte strings, where Unicode strings will be
encoded as UTF-8 in channel requests.
  • Loading branch information
ronf committed Sep 17, 2024
1 parent 2ea5a4e commit b244bb3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
6 changes: 3 additions & 3 deletions asyncssh/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

from .logging import SSHLogger

from .misc import ChannelOpenError, EnvIter, MaybeAwait, ProtocolError
from .misc import ChannelOpenError, EnvMap, MaybeAwait, ProtocolError
from .misc import TermModes, TermSize, TermSizeArg
from .misc import decode_env, encode_env, get_symbol_names, map_handler_name

Expand Down Expand Up @@ -1497,8 +1497,8 @@ def __init__(self, conn: 'SSHServerConnection',

super().__init__(conn, loop, encoding, errors, window, max_pktsize)

env_option = cast(EnvIter, conn.get_key_option('environment', {}))
self._env = dict(encode_env(env_option))
env_opt = cast(EnvMap, conn.get_key_option('environment', {}))
self._env = dict(encode_env(env_opt))

self._allow_pty = allow_pty
self._line_editor = line_editor
Expand Down
12 changes: 6 additions & 6 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@

from .mac import get_mac_algs, get_default_mac_algs

from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvList, FilePath
from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvSeq, FilePath
from .misc import HostPort, IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr
from .misc import ChannelListenError, ChannelOpenError, CompressionError
from .misc import DisconnectError, ConnectionLost, HostKeyNotVerifiable
Expand Down Expand Up @@ -4071,7 +4071,7 @@ async def create_session(self, session_factory: SSHClientSessionFactory,
command: DefTuple[Optional[str]] = (), *,
subsystem: DefTuple[Optional[str]]= (),
env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
send_env: DefTuple[Optional[EnvSeq]] = (),
request_pty: DefTuple[Union[bool, str]] = (),
term_type: DefTuple[Optional[str]] = (),
term_size: DefTuple[TermSizeArg] = (),
Expand Down Expand Up @@ -5590,7 +5590,7 @@ def session_factory() -> SSHTunTapSession:

@async_context_manager
async def start_sftp_client(self, env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
send_env: DefTuple[Optional[EnvSeq]] = (),
path_encoding: Optional[str] = 'utf-8',
path_errors = 'strict',
sftp_version = MIN_SFTP_VERSION) -> SFTPClient:
Expand Down Expand Up @@ -7770,7 +7770,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
command: Optional[str]
subsystem: Optional[str]
env: Env
send_env: Optional[EnvList]
send_env: Optional[EnvSeq]
request_pty: _RequestPTY
term_type: Optional[str]
term_size: TermSizeArg
Expand Down Expand Up @@ -7837,7 +7837,7 @@ def prepare(self, # type: ignore
pkcs11_pin: Optional[str] = None,
command: DefTuple[Optional[str]] = (),
subsystem: Optional[str] = None, env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
send_env: DefTuple[Optional[EnvSeq]] = (),
request_pty: DefTuple[_RequestPTY] = (),
term_type: Optional[str] = None,
term_size: TermSizeArg = None,
Expand Down Expand Up @@ -8058,7 +8058,7 @@ def prepare(self, # type: ignore

self.env = cast(Env, env if env != () else config.get('SetEnv'))

self.send_env = cast(Optional[EnvList], send_env if send_env != () else
self.send_env = cast(Optional[EnvSeq], send_env if send_env != () else
config.get('SendEnv'))

self.request_pty = cast(_RequestPTY, request_pty if request_pty != ()
Expand Down
26 changes: 14 additions & 12 deletions asyncssh/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ async def wait_closed(self) -> None:
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]]

EnvDict = Mapping[BytesOrStr, BytesOrStr]
EnvIter = Iterator[Tuple[BytesOrStr, BytesOrStr]]
EnvList = Sequence[BytesOrStr]
Env = Optional[Union[EnvDict, EnvIter, EnvList]]
EnvMap = Mapping[BytesOrStr, BytesOrStr]
EnvItems = Sequence[Tuple[BytesOrStr, BytesOrStr]]
EnvSeq = Sequence[BytesOrStr]
Env = Optional[Union[EnvMap, EnvItems, EnvSeq]]

# Define a version of randrange which is based on SystemRandom(), so that
# we get back numbers suitable for cryptographic use.
Expand All @@ -130,29 +130,31 @@ async def wait_closed(self) -> None:
def encode_env(env: Env) -> Iterator[Tuple[bytes, bytes]]:
"""Convert environemnt dict or list to bytes-based dictionary"""

env = cast(Sequence[Tuple[BytesOrStr, BytesOrStr]],
env.items() if isinstance(env, dict) else env)

try:
if isinstance(env, list):
for item in env:
for item in env:
if isinstance(item, (bytes, str)):
if isinstance(item, str):
item = item.encode('utf-8')

yield item.split(b'=', 1)
else:
env = cast(EnvIter, env.items() if isinstance(env, dict) else env)
key_bytes, value_bytes = item.split(b'=', 1)
else:
key, value = item

for key, value in env:
key_bytes = key.encode('utf-8') \
if isinstance(key, str) else key

value_bytes = value.encode('utf-8') \
if isinstance(value, str) else value

yield key_bytes, value_bytes
yield key_bytes, value_bytes
except (TypeError, ValueError) as exc:
raise ValueError('Invalid environment value: %s' % exc) from None


def lookup_env(patterns: EnvList) -> Iterator[Tuple[bytes, bytes]]:
def lookup_env(patterns: EnvSeq) -> Iterator[Tuple[bytes, bytes]]:
"""Look up environemnt variables with wildcard matches"""

for pattern in patterns:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,19 @@ async def test_env_list_binary(self):
result = ''.join(session.recv_buf[None])
self.assertEqual(result, 'test\\xff\n')

@asynctest
async def test_env_tuple(self):
"""Test setting environment using a tuple of name=value strings"""

async with self.connect() as conn:
chan, session = await _create_session(conn, 'env',
env=('TEST=test',))

await chan.wait_closed()

result = ''.join(session.recv_buf[None])
self.assertEqual(result, 'test\n')

@asynctest
async def test_invalid_env_list(self):
"""Test setting environment using an invalid string"""
Expand Down

0 comments on commit b244bb3

Please sign in to comment.