From 05acaa3a02a49c49c5646729a4de3bb2fdf480d0 Mon Sep 17 00:00:00 2001 From: Matt Johnston Date: Thu, 3 Aug 2023 00:13:43 +0800 Subject: [PATCH] Avoid fixed buffer for agent Add sshwire::ssh_write_vec() helper --- async/src/agent.rs | 22 +++++++++++----------- src/sshwire.rs | 9 +++++++++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/async/src/agent.rs b/async/src/agent.rs index 74cbea3..de3380b 100644 --- a/async/src/agent.rs +++ b/async/src/agent.rs @@ -20,8 +20,8 @@ use sshwire::{SSHEncodeEnum, SSHDecodeEnum}; use sunset::sign::{OwnedSig, SignKey}; use sunset::sshnames::*; -/* Must be sufficient for the list of all public keys */ -const BUF_SIZE: usize = 10240; +// Must be sufficient for the list of all public keys +const MAX_RESPONSE: usize = 200_000; #[derive(Debug, SSHEncode)] struct AgentSignRequest<'a> { @@ -115,17 +115,16 @@ impl AgentClient { let conn = UnixStream::connect(path).await?; Ok(Self { conn, - buf: vec![0u8; BUF_SIZE], + buf: vec![], }) } async fn request(&mut self, r: AgentRequest<'_>) -> Result { - let l = sshwire::write_ssh(&mut self.buf, &Blob(r))?; - let b = &self.buf[..l]; + let b = sshwire::write_ssh_vec(&Blob(r))?; trace!("agent request {:?}", b.hex_dump()); - self.conn.write_all(b).await?; + self.conn.write_all(&b).await?; self.response().await } @@ -133,12 +132,13 @@ impl AgentClient { let mut l = [0u8; 4]; self.conn.read_exact(&mut l).await?; let l = u32::from_be_bytes(l) as usize; - if l > BUF_SIZE { - return Err(Error::msg("Bad buffer size")); + if l > MAX_RESPONSE { + error!("Response is {l} bytes long"); + return Err(Error::msg("Too large response")); } - let b = &mut self.buf[..l]; - self.conn.read_exact(b).await?; - let r: AgentResponse = sshwire::read_ssh(b, None)?; + self.buf.resize(l, 0); + self.conn.read_exact(&mut self.buf).await?; + let r: AgentResponse = sshwire::read_ssh(&self.buf, None)?; Ok(r) } diff --git a/src/sshwire.rs b/src/sshwire.rs index e195123..1f5498a 100644 --- a/src/sshwire.rs +++ b/src/sshwire.rs @@ -140,6 +140,15 @@ pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result Ok(s.pos) } +#[cfg(feature = "std")] +pub fn write_ssh_vec(value: &dyn SSHEncode) -> Result> { + let l = length_enc(value)? as usize; + let mut v = vec![0u8; l]; + let l = write_ssh(&mut v, value)?; + debug_assert_eq!(l, v.len()); + Ok(v) +} + /// Hashes the SSH wire format representation of `value`, with a `u32` length prefix. pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate, value: &dyn SSHEncode) -> Result<()>