From bc42edc5527d2c08482db4337b7e5357f4fe425a Mon Sep 17 00:00:00 2001 From: Kevin Guthrie Date: Fri, 2 Feb 2024 16:17:48 -0500 Subject: [PATCH] Introduce and use read_uninit and write_uninit duplicated from openssl-0.10.61 and tokio-openssl-0.6.4 --- boring/src/ssl/mod.rs | 85 ++++++++++++++++++++++++----------------- tokio-boring/src/lib.rs | 11 ++---- 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 2f27b2ed..ff11cf5c 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -62,7 +62,6 @@ use foreign_types::{ForeignType, ForeignTypeRef, Opaque}; use libc::{c_char, c_int, c_long, c_uchar, c_uint, c_void}; use once_cell::sync::Lazy; use std::any::TypeId; -use std::cmp; use std::collections::HashMap; use std::convert::TryInto; use std::ffi::{CStr, CString}; @@ -70,7 +69,7 @@ use std::fmt; use std::io; use std::io::prelude::*; use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; +use std::mem::{self, ManuallyDrop, MaybeUninit}; use std::ops::{Deref, DerefMut}; use std::panic::resume_unwind; use std::path::Path; @@ -2694,16 +2693,6 @@ impl SslRef { unsafe { ffi::SSL_get_rbio(self.as_ptr()) } } - fn read(&mut self, buf: &mut [u8]) -> c_int { - let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; - unsafe { ffi::SSL_read(self.as_ptr(), buf.as_ptr() as *mut c_void, len) } - } - - fn write(&mut self, buf: &[u8]) -> c_int { - let len = cmp::min(c_int::max_value() as usize, buf.len()) as c_int; - unsafe { ffi::SSL_write(self.as_ptr(), buf.as_ptr() as *const c_void, len) } - } - #[cfg(feature = "kx-safe-default")] fn set_curves_list(&mut self, curves: &str) -> Result<(), ErrorStack> { let curves = CString::new(curves).unwrap(); @@ -3708,6 +3697,30 @@ impl SslStream { Self::new_base(ssl, stream) } + /// Like `read`, but takes a possibly-uninitialized slice. + /// + /// # Safety + /// + /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`, + /// then the first `n` bytes of `buf` are guaranteed to be initialized. + pub fn read_uninit(&mut self, buf: &mut [MaybeUninit]) -> io::Result { + loop { + match self.ssl_read_uninit(buf) { + Ok(n) => return Ok(n), + Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), + Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => { + return Ok(0); + } + Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {} + Err(e) => { + return Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); + } + } + } + } + /// Like `read`, but returns an `ssl::Error` rather than an `io::Error`. /// /// It is particularly useful with a nonblocking socket, where the error value will identify if @@ -3717,16 +3730,28 @@ impl SslStream { /// /// [`SSL_read`]: https://www.openssl.org/docs/manmaster/man3/SSL_read.html pub fn ssl_read(&mut self, buf: &mut [u8]) -> Result { - // The interpretation of the return code here is a little odd with a - // zero-length write. OpenSSL will likely correctly report back to us - // that it read zero bytes, but zero is also the sentinel for "error". - // To avoid that confusion short-circuit that logic and return quickly - // if `buf` has a length of zero. + // SAFETY: `ssl_read_uninit` does not de-initialize the buffer. + unsafe { + self.ssl_read_uninit(slice::from_raw_parts_mut( + buf.as_mut_ptr().cast::>(), + buf.len(), + )) + } + } + + /// Like `read_ssl`, but takes a possibly-uninitialized slice. + /// + /// # Safety + /// + /// No portion of `buf` will be de-initialized by this method. If the method returns `Ok(n)`, + /// then the first `n` bytes of `buf` are guaranteed to be initialized. + pub fn ssl_read_uninit(&mut self, buf: &mut [MaybeUninit]) -> Result { if buf.is_empty() { return Ok(0); } - let ret = self.ssl.read(buf); + let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int; + let ret = unsafe { ffi::SSL_read(self.ssl().as_ptr(), buf.as_mut_ptr().cast(), len) }; if ret > 0 { Ok(ret as usize) } else { @@ -3743,12 +3768,12 @@ impl SslStream { /// /// [`SSL_write`]: https://www.openssl.org/docs/manmaster/man3/SSL_write.html pub fn ssl_write(&mut self, buf: &[u8]) -> Result { - // See above for why we short-circuit on zero-length buffers if buf.is_empty() { return Ok(0); } - let ret = self.ssl.write(buf); + let len = usize::min(c_int::max_value() as usize, buf.len()) as c_int; + let ret = unsafe { ffi::SSL_write(self.ssl().as_ptr(), buf.as_ptr().cast(), len) }; if ret > 0 { Ok(ret as usize) } else { @@ -3919,20 +3944,12 @@ impl SslStream { impl Read for SslStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - loop { - match self.ssl_read(buf) { - Ok(n) => return Ok(n), - Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), - Err(ref e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => { - return Ok(0); - } - Err(ref e) if e.code() == ErrorCode::WANT_READ && e.io_error().is_none() => {} - Err(e) => { - return Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); - } - } + // SAFETY: `read_uninit` does not de-initialize the buffer + unsafe { + self.read_uninit(slice::from_raw_parts_mut( + buf.as_mut_ptr().cast::>(), + buf.len(), + )) } } } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index c8ad4f3b..f1593ed8 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -21,7 +21,7 @@ use boring_sys as ffi; use std::error::Error; use std::fmt; use std::future::Future; -use std::io::{self, Read, Write}; +use std::io::{self, Write}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -201,13 +201,8 @@ where buf: &mut ReadBuf, ) -> Poll> { self.run_in_context(ctx, |s| { - // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though - // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now. - let slice = unsafe { - let buf = buf.unfilled_mut(); - std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast::(), buf.len()) - }; - match cvt(s.read(slice))? { + // SAFETY: read_uninit does not de-initialize the buffer. + match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? { Poll::Ready(nread) => { unsafe { buf.assume_init(nread);