Skip to content

Commit

Permalink
Avoid fixed buffer for agent
Browse files Browse the repository at this point in the history
Add sshwire::ssh_write_vec() helper
  • Loading branch information
mkj committed Aug 2, 2023
1 parent bd5f6ea commit 05acaa3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
22 changes: 11 additions & 11 deletions async/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use sshwire::{SSHEncodeEnum, SSHDecodeEnum};
use sunset::sign::{OwnedSig, SignKey};
use sunset::sshnames::*;

/* Must be sufficient for the list of all public keys */
const BUF_SIZE: usize = 10240;
// Must be sufficient for the list of all public keys
const MAX_RESPONSE: usize = 200_000;

#[derive(Debug, SSHEncode)]
struct AgentSignRequest<'a> {
Expand Down Expand Up @@ -115,30 +115,30 @@ impl AgentClient {
let conn = UnixStream::connect(path).await?;
Ok(Self {
conn,
buf: vec![0u8; BUF_SIZE],
buf: vec![],
})
}

async fn request(&mut self, r: AgentRequest<'_>) -> Result<AgentResponse> {
let l = sshwire::write_ssh(&mut self.buf, &Blob(r))?;
let b = &self.buf[..l];
let b = sshwire::write_ssh_vec(&Blob(r))?;

trace!("agent request {:?}", b.hex_dump());

self.conn.write_all(b).await?;
self.conn.write_all(&b).await?;
self.response().await
}

async fn response(&mut self) -> Result<AgentResponse> {
let mut l = [0u8; 4];
self.conn.read_exact(&mut l).await?;
let l = u32::from_be_bytes(l) as usize;
if l > BUF_SIZE {
return Err(Error::msg("Bad buffer size"));
if l > MAX_RESPONSE {
error!("Response is {l} bytes long");
return Err(Error::msg("Too large response"));
}
let b = &mut self.buf[..l];
self.conn.read_exact(b).await?;
let r: AgentResponse = sshwire::read_ssh(b, None)?;
self.buf.resize(l, 0);
self.conn.read_exact(&mut self.buf).await?;
let r: AgentResponse = sshwire::read_ssh(&self.buf, None)?;
Ok(r)
}

Expand Down
9 changes: 9 additions & 0 deletions src/sshwire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ pub fn write_ssh(target: &mut [u8], value: &dyn SSHEncode) -> Result<usize>
Ok(s.pos)
}

#[cfg(feature = "std")]
pub fn write_ssh_vec(value: &dyn SSHEncode) -> Result<Vec<u8>> {
let l = length_enc(value)? as usize;
let mut v = vec![0u8; l];
let l = write_ssh(&mut v, value)?;
debug_assert_eq!(l, v.len());
Ok(v)
}

/// Hashes the SSH wire format representation of `value`, with a `u32` length prefix.
pub fn hash_ser_length(hash_ctx: &mut impl SSHWireDigestUpdate,
value: &dyn SSHEncode) -> Result<()>
Expand Down

0 comments on commit 05acaa3

Please sign in to comment.