Skip to content

Commit

Permalink
Pass terminal size through a server process when redirecting to a TTY
Browse files Browse the repository at this point in the history
This commit will pass the terminal size received in an SSHServerProcess
through when stdin is redirected to a local TTY. It wil also pass
through terminal size changes recived in the SSHServerProcess.
  • Loading branch information
ronf committed May 28, 2024
1 parent 4ecce1e commit 1bbd845
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 13 deletions.
2 changes: 1 addition & 1 deletion asyncssh/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
from .logging import SSHLogger

from .misc import ChannelOpenError, MaybeAwait, ProtocolError
from .misc import TermModes, TermSize, TermSizeArg
from .misc import get_symbol_names, map_handler_name

from .packet import Boolean, Byte, String, UInt32, SSHPacket, SSHPacketHandler

from .session import TermModes, TermSize, TermSizeArg
from .session import SSHSession, SSHClientSession, SSHServerSession
from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession
from .session import SSHSessionFactory, SSHClientSessionFactory
Expand Down
4 changes: 2 additions & 2 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from .misc import KeyExchangeFailed, IllegalUserName, MACError
from .misc import PasswordChangeRequired, PermissionDenied, ProtocolError
from .misc import ProtocolNotSupported, ServiceNotAvailable
from .misc import TermModesArg, TermSizeArg
from .misc import async_context_manager, construct_disc_error
from .misc import get_symbol_names, ip_address, map_handler_name
from .misc import parse_byte_count, parse_time_interval
Expand Down Expand Up @@ -146,8 +147,7 @@

from .server import SSHServer

from .session import DataType, TermModesArg, TermSizeArg
from .session import SSHClientSession, SSHServerSession
from .session import DataType, SSHClientSession, SSHServerSession
from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession
from .session import SSHClientSessionFactory, SSHTCPSessionFactory
from .session import SSHUNIXSessionFactory, SSHTunTapSessionFactory
Expand Down
25 changes: 25 additions & 0 deletions asyncssh/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import ipaddress
import re
import socket
import sys

from pathlib import Path, PurePath
from random import SystemRandom
Expand All @@ -41,6 +42,16 @@
from .constants import DISC_PROTOCOL_ERROR, DISC_PROTOCOL_VERSION_NOT_SUPPORTED
from .constants import DISC_SERVICE_NOT_AVAILABLE

if sys.platform != 'win32': # pragma: no branch
import fcntl
import struct
import termios

TermModes = Mapping[int, int]
TermModesArg = Optional[TermModes]
TermSize = Tuple[int, int, int, int]
TermSizeArg = Union[None, Tuple[int, int], TermSize]


class _Hash(Protocol):
"""Protocol for hashing data"""
Expand Down Expand Up @@ -331,6 +342,14 @@ async def maybe_wait_closed(writer: '_SupportsWaitClosed') -> None:
pass


def set_terminal_size(tty: IO, width: int, height: int,
pixwidth: int, pixheight: int) -> None:
"""Set the terminal size of a TTY"""

fcntl.ioctl(tty, termios.TIOCSWINSZ,
struct.pack('hhhh', height, width, pixwidth, pixheight))


class Options:
"""Container for configuration options"""

Expand Down Expand Up @@ -764,6 +783,12 @@ def __init__(self, width: int, height: int, pixwidth: int, pixheight: int):
self.pixwidth = pixwidth
self.pixheight = pixheight

@property
def term_size(self) -> TermSize:
"""Return terminal size as a tuple of 4 integers"""

return self.width, self.height, self.pixwidth, self.pixheight


_disc_error_map = {
DISC_PROTOCOL_ERROR: ProtocolError,
Expand Down
20 changes: 16 additions & 4 deletions asyncssh/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,16 @@

from .logging import SSHLogger

from .misc import BytesOrStr, Error, MaybeAwait
from .misc import ProtocolError, Record, open_file
from .misc import BytesOrStr, Error, MaybeAwait, TermModes, TermSize
from .misc import ProtocolError, Record, open_file, set_terminal_size
from .misc import BreakReceived, SignalReceived, TerminalSizeChanged

from .session import DataType, TermModes, TermSize
from .session import DataType

from .stream import SSHReader, SSHWriter, SSHStreamSession
from .stream import SSHClientStreamSession, SSHServerStreamSession
from .stream import SFTPServerFactory


_AnyStrContra = TypeVar('_AnyStrContra', bytes, str, contravariant=True)

_File = Union[IO[bytes], '_AsyncFileProtocol[bytes]']
Expand Down Expand Up @@ -406,13 +405,20 @@ def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType,
self._process: 'SSHProcess[AnyStr]' = process
self._datatype = datatype
self._transport: Optional[asyncio.WriteTransport] = None
self._tty: Optional[IO] = None
self._close_event = asyncio.Event()

def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a newly opened pipe"""

self._transport = cast(asyncio.WriteTransport, transport)

pipe = transport.get_extra_info('pipe')

if isinstance(self._process, SSHServerProcess) and pipe.isatty():
self._tty = pipe
set_terminal_size(pipe, *self._process.term_size)

def connection_lost(self, exc: Optional[Exception]) -> None:
"""Handle closing of the pipe"""

Expand All @@ -434,6 +440,12 @@ def write(self, data: AnyStr) -> None:
assert self._transport is not None
self._transport.write(self.encode(data))

def write_exception(self, exc: Exception) -> None:
"""Write terminal size changes to the pipe if it is a TTY"""

if isinstance(exc, TerminalSizeChanged) and self._tty:
set_terminal_size(self._tty, *exc.term_size)

def write_eof(self) -> None:
"""Write EOF to the pipe"""

Expand Down
7 changes: 1 addition & 6 deletions asyncssh/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""SSH session handlers"""

from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic
from typing import Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Tuple


if TYPE_CHECKING:
Expand All @@ -31,11 +31,6 @@

DataType = Optional[int]

TermModes = Mapping[int, int]
TermModesArg = Optional[TermModes]
TermSize = Tuple[int, int, int, int]
TermSizeArg = Union[None, Tuple[int, int], TermSize]


class SSHSession(Generic[AnyStr]):
"""SSH session handler"""
Expand Down
53 changes: 53 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@
from .server import ServerTestCase
from .util import asynctest, echo

if sys.platform != 'win32': # pragma: no branch
import fcntl
import struct
import termios

try:
import aiofiles
_aiofiles_available = True
except ImportError: # pragma: no cover
_aiofiles_available = False


async def _handle_client(process):
"""Handle a new client request"""

Expand Down Expand Up @@ -100,6 +106,23 @@ async def _handle_client(process):
except asyncssh.TerminalSizeChanged as exc:
process.exit_with_signal('ABRT', False,
'%sx%s' % (exc.width, exc.height))
elif action == 'term_size_tty':
master, slave = os.openpty()
await process.redirect_stdin(master)
process.stdout.write(b'\n')

await process.stdin.readline()
size = fcntl.ioctl(slave, termios.TIOCGWINSZ, 8*b'\0')
height, width, _, _ = struct.unpack('hhhh', size)
process.stdout.write(('%sx%s' % (width, height)).encode())
os.close(slave)
elif action == 'term_size_nontty':
rpipe, wpipe = os.pipe()
await process.redirect_stdin(wpipe)
process.stdout.write(b'\n')

await process.stdin.readline()
os.close(rpipe)
elif action == 'timeout':
process.channel.set_encoding('utf-8')
process.stdout.write('Sleeping')
Expand Down Expand Up @@ -648,6 +671,36 @@ async def test_forward_terminal_size(self):

self.assertEqual(result.exit_signal[2], '80x24')

@unittest.skipIf(sys.platform == 'win32',
'skip fcntl/termios test on Windows')
@asynctest
async def test_forward_terminal_size_tty(self):
"""Test forwarding a terminal size change to a remote tty"""

async with self.connect() as conn:
process = await conn.create_process('term_size_tty',
term_type='ansi')
await process.stdout.readline()
process.change_terminal_size(80, 24)
process.stdin.write_eof()
result = await process.wait()

self.assertEqual(result.stdout, '80x24')

@asynctest
async def test_forward_terminal_size_nontty(self):
"""Test forwarding a terminal size change to a remote non-tty"""

async with self.connect() as conn:
process = await conn.create_process('term_size_nontty',
term_type='ansi')
await process.stdout.readline()
process.change_terminal_size(80, 24)
process.stdin.write_eof()
result = await process.wait()

self.assertEqual(result.stdout, '')

@asynctest
async def test_forward_break(self):
"""Test forwarding a break"""
Expand Down

0 comments on commit 1bbd845

Please sign in to comment.