diff --git a/src/sshwire.rs b/src/sshwire.rs index 74ae14b..5bfd0e2 100644 --- a/src/sshwire.rs +++ b/src/sshwire.rs @@ -33,7 +33,7 @@ pub trait SSHSink { /// A generic source for a packet, used similarly to `serde::Deserializer` pub trait SSHSource<'de> { fn take(&mut self, len: usize) -> WireResult<&'de [u8]>; - fn pos(&self) -> usize; + fn remaining(&self) -> usize; fn ctx(&mut self) -> &mut ParseContext; } @@ -184,7 +184,6 @@ struct EncodeBytes<'a> { } impl<'a> SSHSink for EncodeBytes<'a> { - #[inline] fn push(&mut self, v: &[u8]) -> WireResult<()> { if v.len() > self.target.len() { return Err(WireError::NoRoom); @@ -235,8 +234,8 @@ impl<'de> SSHSource<'de> for DecodeBytes<'de> { Ok(t) } - fn pos(&self) -> usize { - usize::MAX - self.input.len() + fn remaining(&self) -> usize { + self.input.len() } fn ctx(&mut self) -> &mut ParseContext { @@ -414,23 +413,22 @@ impl<'de, B: SSHDecode<'de>> SSHDecode<'de> for Blob { fn dec(s: &mut S) -> WireResult where S: sshwire::SSHSource<'de> { let len = u32::dec(s)? as usize; - let pos1 = s.pos(); + let rem1 = s.remaining(); let inner = SSHDecode::dec(s)?; - let pos2 = s.pos(); + let rem2 = s.remaining(); // Sanity check the length matched - let used_len = pos2 - pos1; + let used_len = rem1 - rem2; if used_len != len { - let extra = len.checked_sub(used_len).ok_or(WireError::SSHProto)?; - if s.ctx().seen_unknown { // Skip over unconsumed bytes in the blob. // This can occur with Unknown variants + let extra = len.checked_sub(used_len).ok_or(WireError::SSHProto)?; s.take(extra)?; } else { trace!("SSH blob length differs. \ - Expected {} bytes, got {} bytes {}..{}", - len, used_len, pos1, pos2); + Expected {} bytes, got {} remaining {}, {}", + len, used_len, rem1, rem2); return Err(WireError::SSHProto) } }