diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 7be740d3..ca7a6bee 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -4,11 +4,13 @@ use std::ops::{Deref, DerefMut}; use crate::dh::Dh; use crate::error::ErrorStack; use crate::ssl::{ - HandshakeError, Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, - SslOptions, SslRef, SslStream, SslVerifyMode, + Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, SslOptions, SslRef, + SslVerifyMode, }; use crate::version; +use super::MidHandshakeSslStream; + const FFDHE_2048: &str = " -----BEGIN DH PARAMETERS----- MIIBCAKCAQEA//////////+t+FRYortKmq/cViAnPTzx2LnFg84tNpWp4TZBFGQz @@ -98,11 +100,15 @@ impl SslConnector { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification. - pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> + pub fn setup_connect( + &self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> where S: Read + Write, { - self.configure()?.connect(domain, stream) + self.configure()?.setup_connect(domain, stream) } /// Returns a structure allowing for configuration of a single TLS session before connection. @@ -192,7 +198,13 @@ impl ConnectConfiguration { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification if enabled. - pub fn connect(mut self, domain: &str, stream: S) -> Result, HandshakeError> + /// + /// See [`Ssl::setup_connect`] for more details. + pub fn setup_connect( + mut self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> where S: Read + Write, { @@ -210,7 +222,7 @@ impl ConnectConfiguration { setup_verify_hostname(&mut self.ssl, domain)?; } - self.ssl.connect(stream) + Ok(self.ssl.setup_connect(stream)) } } @@ -319,13 +331,16 @@ impl SslAcceptor { Ok(SslAcceptorBuilder(ctx)) } - /// Initiates a server-side TLS session on a stream. - pub fn accept(&self, stream: S) -> Result, HandshakeError> + /// Initiates a server-side TLS handshake on a stream. + /// + /// See [`Ssl::setup_accept`] for more details. + pub fn setup_accept(&self, stream: S) -> Result, ErrorStack> where S: Read + Write, { let ssl = Ssl::new(&self.0)?; - ssl.accept(stream) + + Ok(ssl.setup_accept(stream)) } /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index 5fb91659..6f7af219 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -1,13 +1,10 @@ use crate::ffi; use libc::c_int; use std::error; -use std::error::Error as StdError; use std::fmt; use std::io; use crate::error::ErrorStack; -use crate::ssl::MidHandshakeSslStream; -use crate::x509::X509VerifyResult; /// An error code returned from SSL functions. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -81,6 +78,10 @@ impl Error { _ => None, } } + + pub fn would_block(&self) -> bool { + matches!(self.code, ErrorCode::WANT_READ | ErrorCode::WANT_WRITE) + } } impl From for Error { @@ -126,65 +127,3 @@ impl error::Error for Error { } } } - -/// An error or intermediate state after a TLS handshake attempt. -// FIXME overhaul -#[derive(Debug)] -pub enum HandshakeError { - /// Setup failed. - SetupFailure(ErrorStack), - /// The handshake failed. - Failure(MidHandshakeSslStream), - /// The handshake encountered a `WouldBlock` error midway through. - /// - /// This error will never be returned for blocking streams. - WouldBlock(MidHandshakeSslStream), -} - -impl StdError for HandshakeError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - HandshakeError::SetupFailure(ref e) => Some(e), - HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => Some(s.error()), - } - } -} - -impl fmt::Display for HandshakeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - HandshakeError::SetupFailure(ref e) => { - write!(f, "TLS stream setup failed {}", e) - } - HandshakeError::Failure(ref s) => fmt_mid_handshake_error(s, f, "TLS handshake failed"), - HandshakeError::WouldBlock(ref s) => { - fmt_mid_handshake_error(s, f, "TLS handshake interrupted") - } - } - } -} - -fn fmt_mid_handshake_error( - s: &MidHandshakeSslStream, - f: &mut fmt::Formatter, - prefix: &str, -) -> fmt::Result { - #[cfg(feature = "rpk")] - if s.ssl().ssl_context().is_rpk() { - write!(f, "{}", prefix)?; - return write!(f, " {}", s.error()); - } - - match s.ssl().verify_result() { - X509VerifyResult::OK => write!(f, "{}", prefix)?, - verify => write!(f, "{}: cert verification failed - {}", prefix, verify)?, - } - - write!(f, " {}", s.error()) -} - -impl From for HandshakeError { - fn from(e: ErrorStack) -> HandshakeError { - HandshakeError::SetupFailure(e) - } -} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 67e16b18..af8230de 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -15,7 +15,11 @@ //! let connector = SslConnector::builder(SslMethod::tls()).unwrap().build(); //! //! let stream = TcpStream::connect("google.com:443").unwrap(); -//! let mut stream = connector.connect("google.com", stream).unwrap(); +//! let mut stream = connector +//! .setup_connect("google.com", stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); //! //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); //! let mut res = vec![]; @@ -49,7 +53,12 @@ //! Ok(stream) => { //! let acceptor = acceptor.clone(); //! thread::spawn(move || { -//! let stream = acceptor.accept(stream).unwrap(); +//! let stream = acceptor +//! .setup_accept(stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); +//! //! handle_client(stream); //! }); //! } @@ -98,7 +107,7 @@ use crate::{cvt, cvt_0i, cvt_n, cvt_p, init}; pub use crate::ssl::connector::{ ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, }; -pub use crate::ssl::error::{Error, ErrorCode, HandshakeError}; +pub use crate::ssl::error::{Error, ErrorCode}; mod bio; mod callbacks; @@ -2306,34 +2315,36 @@ impl Ssl { } } - /// Initiates a client-side TLS handshake. + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This corresponds to [`SSL_connect`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. /// - /// # Warning + /// See [`SslStreamBuilder::setup_connect`] for more details. /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslConnector` rather than `Ssl` directly, as it manages that configuration. + /// # Warning /// - /// [`SSL_connect`]: https://www.openssl.org/docs/manmaster/man3/SSL_connect.html - pub fn connect(self, stream: S) -> Result, HandshakeError> + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslConnector`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_connect(self, stream: S) -> MidHandshakeSslStream where S: Read + Write, { - SslStreamBuilder::new(self, stream).connect() + SslStreamBuilder::new(self, stream).setup_connect() } /// Initiates a server-side TLS handshake. /// - /// This corresponds to [`SSL_accept`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. /// - /// # Warning + /// See [`SslStreamBuilder::setup_accept`] for more details. /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. + /// # Warning /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslAcceptor`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_accept(self, stream: S) -> MidHandshakeSslStream where S: Read + Write, { @@ -2352,7 +2363,7 @@ impl Ssl { } } - SslStreamBuilder::new(self, stream).accept() + SslStreamBuilder::new(self, stream).setup_accept() } } @@ -3156,18 +3167,14 @@ impl MidHandshakeSslStream { /// This corresponds to [`SSL_do_handshake`]. /// /// [`SSL_do_handshake`]: https://www.openssl.org/docs/manmaster/man3/SSL_do_handshake.html - pub fn handshake(mut self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result, MidHandshakeSslStream> { let ret = unsafe { ffi::SSL_do_handshake(self.stream.ssl.as_ptr()) }; if ret > 0 { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - match self.error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(self)) - } - _ => Err(HandshakeError::Failure(self)), - } + + Err(self) } } } @@ -3450,7 +3457,7 @@ where /// This corresponds to [`SSL_set_connect_state`]. /// /// [`SSL_set_connect_state`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_connect_state.html - pub fn set_connect_state(&mut self) { + fn set_connect_state(&mut self) { unsafe { ffi::SSL_set_connect_state(self.inner.ssl.as_ptr()) } } @@ -3459,82 +3466,45 @@ where /// This corresponds to [`SSL_set_accept_state`]. /// /// [`SSL_set_accept_state`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_accept_state.html - pub fn set_accept_state(&mut self) { + fn set_accept_state(&mut self) { unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } - /// See `Ssl::connect` - pub fn connect(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } - } - } + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. + /// + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. + pub fn setup_connect(mut self) -> MidHandshakeSslStream { + self.set_connect_state(); - /// See `Ssl::accept` - pub fn accept(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_WRITE, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "connect handshake has not started yet", + ))), + }, } } - /// Initiates the handshake. - /// - /// This will fail if `set_accept_state` or `set_connect_state` was not called first. - /// - /// This corresponds to [`SSL_do_handshake`]. + /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// [`SSL_do_handshake`]: https://www.openssl.org/docs/manmaster/man3/SSL_do_handshake.html - pub fn handshake(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_do_handshake(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. + pub fn setup_accept(mut self) -> MidHandshakeSslStream { + self.set_accept_state(); + + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_READ, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "accept handshake has not started yet", + ))), + }, } } } diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index e72dc086..3d76fe06 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -25,10 +25,9 @@ use crate::ssl; use crate::ssl::test::server::Server; use crate::ssl::SslVersion; use crate::ssl::{ - Error, ExtensionType, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, - Ssl, SslAcceptor, SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, - SslMethod, SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, - StatusType, + Error, ExtensionType, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor, + SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, SslMethod, + SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, StatusType, }; use crate::x509::store::X509StoreBuilder; use crate::x509::verify::X509CheckFlags; @@ -317,7 +316,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -336,7 +335,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -386,7 +385,7 @@ fn test_connect_with_srtp_ssl() { profilenames ); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -405,7 +404,7 @@ fn test_connect_with_srtp_ssl() { ssl.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32") .unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -584,7 +583,10 @@ fn write_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -615,7 +617,10 @@ fn read_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -646,7 +651,10 @@ fn flush_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -676,7 +684,7 @@ fn default_verify_paths() { }; let mut ssl = Ssl::new(&ctx).unwrap(); ssl.set_hostname("google.com").unwrap(); - let mut socket = ssl.connect(s).unwrap(); + let mut socket = ssl.setup_connect(s).handshake().unwrap(); socket.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); let mut result = vec![]; @@ -738,7 +746,12 @@ fn connector_valid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - let mut s = connector.build().connect("foobar.com", s).unwrap(); + let mut s = connector + .build() + .setup_connect("foobar.com", s) + .unwrap() + .handshake() + .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -752,7 +765,12 @@ fn connector_invalid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - connector.build().connect("bogus.com", s).unwrap_err(); + connector + .build() + .setup_connect("bogus.com", s) + .unwrap() + .handshake() + .unwrap_err(); } #[test] @@ -768,7 +786,9 @@ fn connector_invalid_no_hostname_verification() { .configure() .unwrap() .verify_hostname(false) - .connect("bogus.com", s) + .setup_connect("bogus.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -786,7 +806,9 @@ fn connector_no_hostname_still_verifies() { .configure() .unwrap() .verify_hostname(false) - .connect("fizzbuzz.com", s) + .setup_connect("fizzbuzz.com", s) + .unwrap() + .handshake() .is_err()); } @@ -803,7 +825,9 @@ fn connector_no_hostname_can_disable_verify() { .configure() .unwrap() .verify_hostname(false) - .connect("foobar.com", s) + .setup_connect("foobar.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -820,7 +844,7 @@ fn test_mozilla_server(new: fn(SslMethod) -> Result Result SslStream { let socket = TcpStream::connect(self.addr).unwrap(); - let mut s = self.ssl.connect(socket).unwrap(); + let mut s = self.ssl.setup_connect(socket).handshake().unwrap(); s.read_exact(&mut [0]).unwrap(); s } pub fn connect_err(self) { let socket = TcpStream::connect(self.addr).unwrap(); - self.ssl.connect(socket).unwrap_err(); + self.ssl.setup_connect(socket).handshake().unwrap_err(); } } diff --git a/hyper-boring/src/lib.rs b/hyper-boring/src/lib.rs index d6e82c1b..7d1ad16d 100644 --- a/hyper-boring/src/lib.rs +++ b/hyper-boring/src/lib.rs @@ -245,7 +245,7 @@ where } let config = inner.setup_ssl(&uri, host)?; - let stream = tokio_boring::connect(config, host, conn).await?; + let stream = tokio_boring::connect(config, host, conn)?.await?; Ok(MaybeHttpsStream::Https(stream)) }; diff --git a/hyper-boring/src/test.rs b/hyper-boring/src/test.rs index 226a0487..73c9d664 100644 --- a/hyper-boring/src/test.rs +++ b/hyper-boring/src/test.rs @@ -43,7 +43,10 @@ async fn localhost() { for _ in 0..3 { let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); let service = service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(Body::empty())) }); @@ -105,7 +108,10 @@ async fn alpn_h2() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service = diff --git a/tokio-boring/examples/simple-async.rs b/tokio-boring/examples/simple-async.rs index f4a69a1c..d9b159de 100644 --- a/tokio-boring/examples/simple-async.rs +++ b/tokio-boring/examples/simple-async.rs @@ -11,6 +11,6 @@ async fn main() -> anyhow::Result<()> { ssl_builder.set_default_verify_paths()?; ssl_builder.set_verify(ssl::SslVerifyMode::PEER); let acceptor = ssl_builder.build(); - let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream).await?; + let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream)?.await?; Ok(()) } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a0dd58c5..6cff31bd 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,10 +13,12 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, }; +use boring::x509::X509VerifyResult; use boring_sys as ffi; use std::error::Error; use std::fmt; @@ -27,40 +29,35 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// Asynchronously performs a client-side TLS handshake over the provided stream. -pub async fn connect( +pub fn connect( config: ConnectConfiguration, domain: &str, stream: S, -) -> Result, HandshakeError> +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| config.connect(domain, s), stream).await + handshake(|s| config.setup_connect(domain, s), stream) } /// Asynchronously performs a server-side TLS handshake over the provided stream. -pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> +pub fn accept(acceptor: &SslAcceptor, stream: S) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.accept(s), stream).await + handshake(|s| acceptor.setup_accept(s), stream) } -async fn handshake(f: F, stream: S) -> Result, HandshakeError> +fn handshake( + f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, + stream: S, +) -> Result, ErrorStack> where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, S: AsyncRead + AsyncWrite + Unpin, { - let start = StartHandshakeFuture(Some(StartHandshakeFutureInner { f, stream })); + let ongoing_handshake = Some(f(StreamWrapper { stream, context: 0 })?); - match start.await? { - StartedHandshake::Done(s) => Ok(s), - StartedHandshake::Mid(s) => HandshakeFuture(Some(s)).await, - } + Ok(HandshakeFuture(ongoing_handshake)) } struct StreamWrapper { @@ -266,47 +263,32 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(MidHandshakeSslStream>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. - pub fn ssl(&self) -> Option<&SslRef> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.ssl()), - _ => None, - } + pub fn ssl(&self) -> &SslRef { + self.0.ssl() } /// Converts error to the source data stream that was used for the handshake. - pub fn into_source_stream(self) -> Option { - match self.0 { - ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), - _ => None, - } + pub fn into_source_stream(self) -> S { + self.0.into_source_stream().stream } /// Returns a reference to the source data stream. - pub fn as_source_stream(&self) -> Option<&S> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream), - _ => None, - } + pub fn as_source_stream(&self) -> &S { + &self.0.get_ref().stream } /// Returns the error code, if any. - pub fn code(&self) -> Option { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.error().code()), - _ => None, - } + pub fn code(&self) -> ErrorCode { + self.0.error().code() } /// Returns a reference to the inner I/O error, if any. pub fn as_io_error(&self) -> Option<&io::Error> { - match &self.0 { - ssl::HandshakeError::Failure(s) => s.error().io_error(), - _ => None, - } + self.0.error().io_error() } } @@ -321,67 +303,36 @@ where impl fmt::Display for HandshakeError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, fmt) - } -} + fmt.write_str("TLS handshake failed: ")?; -impl Error for HandshakeError -where - S: fmt::Debug, -{ - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() - } -} + #[cfg(feature = "rpk")] + if self.0.ssl().ssl_context().is_rpk() { + return write!(fmt, " {}", self.0.error()); + } -enum StartedHandshake { - Done(SslStream), - Mid(MidHandshakeSslStream>), -} + let verify = self.0.ssl().verify_result(); -struct StartHandshakeFuture(Option>); + if verify != X509VerifyResult::OK { + write!(fmt, "cert verification failed - {verify}")?; + } -struct StartHandshakeFutureInner { - f: F, - stream: S, + write!(fmt, "TLS handshake failed: {}", self.0.error()) + } } -impl Future for StartHandshakeFuture +impl Error for HandshakeError where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, - S: Unpin, + S: fmt::Debug, { - type Output = Result, HandshakeError>; - - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { - let inner = self.0.take().expect("future polled after completion"); - - let stream = StreamWrapper { - stream: inner.stream, - context: ctx as *mut _ as usize, - }; - match (inner.f)(stream) { - Ok(mut s) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Done(SslStream(s)))) - } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Mid(s))) - } - Err(e) => Poll::Ready(Err(HandshakeError(e))), - } + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.0.error().source() } } -struct HandshakeFuture(Option>>); +/// Future for an ongoing TLS handshake. +/// +/// See [`connect`] and [`accept`]. +pub struct HandshakeFuture(Option>>); impl Future for HandshakeFuture where @@ -389,21 +340,22 @@ where { type Output = Result, HandshakeError>; - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut s = self.0.take().expect("future polled after completion"); - s.get_mut().context = ctx as *mut _ as usize; + s.get_mut().context = cx as *mut _ as usize; + match s.handshake() { Ok(mut s) => { s.get_mut().context = 0; + Poll::Ready(Ok(SslStream(s))) } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; - self.0 = Some(s); + Err(mut handshake) if handshake.error().would_block() => { + handshake.get_mut().context = 0; + + self.0 = Some(handshake); + Poll::Pending } Err(e) => Poll::Ready(Err(HandshakeError(e))), diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 72c5a040..872033f8 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -1,11 +1,12 @@ -use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use boring::ssl::{SslConnector, SslMethod}; use futures::future; -use std::future::Future; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::pin::Pin; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_boring::{HandshakeError, SslStream}; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; #[tokio::test] async fn google() { @@ -18,6 +19,7 @@ async fn google() { .configure() .unwrap(); let mut stream = tokio_boring::connect(config, "google.com", stream) + .unwrap() .await .unwrap(); @@ -33,92 +35,21 @@ async fn google() { assert!(response.ends_with("") || response.ends_with("")); } -fn create_server() -> ( - impl Future, HandshakeError>>, - SocketAddr, -) { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - - listener.set_nonblocking(true).unwrap(); - - let listener = TcpListener::from_std(listener).unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - acceptor - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); - let acceptor = acceptor.build(); - - let stream = listener.accept().await.unwrap().0; - - tokio_boring::accept(&acceptor, stream).await - }; - - (server, addr) -} - #[tokio::test] async fn server() { - let (stream, addr) = create_server(); - - let server = async { - let mut stream = stream.await.unwrap(); - let mut buf = [0; 4]; - stream.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"asdf"); - - stream.write_all(b"jkl;").await.unwrap(); - - future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) - .await - .unwrap(); - }; - - let client = async { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - connector.set_ca_file("tests/cert.pem").unwrap(); - let config = connector.build().configure().unwrap(); - - let stream = TcpStream::connect(&addr).await.unwrap(); - let mut stream = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap(); - - stream.write_all(b"asdf").await.unwrap(); - - let mut buf = vec![]; - stream.read_to_end(&mut buf).await.unwrap(); - assert_eq!(buf, b"jkl;"); - }; - - future::join(server, client).await; + with_trivial_client_server_exchange(|_| ()).await; } #[tokio::test] async fn handshake_error() { - let (stream, addr) = create_server(); + let (stream, addr) = create_server(|_| ()); let server = async { - let err = stream.await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = stream.await.unwrap_err(); }; let client = async { - let connector = SslConnector::builder(SslMethod::tls()).unwrap(); - let config = connector.build().configure().unwrap(); - let stream = TcpStream::connect(&addr).await.unwrap(); - - let err = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = connect(addr, |_| Ok(())).await.unwrap_err(); }; future::join(server, client).await; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs new file mode 100644 index 00000000..fdcf86fb --- /dev/null +++ b/tokio-boring/tests/common/mod.rs @@ -0,0 +1,98 @@ +#![allow(dead_code)] + +use boring::error::ErrorStack; +use boring::ssl::{ + SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod, +}; +use futures::future::{self, Future}; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_boring::{HandshakeError, SslStream}; + +pub(crate) fn create_server( + setup: impl FnOnce(&mut SslAcceptorBuilder), +) -> ( + impl Future, HandshakeError>>, + SocketAddr, +) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + + listener.set_nonblocking(true).unwrap(); + + let listener = TcpListener::from_std(listener).unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = async move { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + + acceptor + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + + acceptor + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + setup(&mut acceptor); + + let acceptor = acceptor.build(); + + let stream = listener.accept().await.unwrap().0; + + tokio_boring::accept(&acceptor, stream).unwrap().await + }; + + (server, addr) +} + +pub(crate) async fn connect( + addr: SocketAddr, + setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, +) -> Result, HandshakeError> { + let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); + + setup(&mut connector).unwrap(); + + let config = connector.build().configure().unwrap(); + + let stream = TcpStream::connect(&addr).await.unwrap(); + + tokio_boring::connect(config, "localhost", stream) + .unwrap() + .await +} + +pub(crate) async fn with_trivial_client_server_exchange( + server_setup: impl FnOnce(&mut SslAcceptorBuilder), +) { + let (stream, addr) = create_server(server_setup); + + let server = async { + let mut stream = stream.await.unwrap(); + let mut buf = [0; 4]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"asdf"); + + stream.write_all(b"jkl;").await.unwrap(); + + future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) + .await + .unwrap(); + }; + + let client = async { + let mut stream = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap(); + + stream.write_all(b"asdf").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"jkl;"); + }; + + future::join(server, client).await; +} diff --git a/tokio-boring/tests/rpk.rs b/tokio-boring/tests/rpk.rs index 5492767a..48f53655 100644 --- a/tokio-boring/tests/rpk.rs +++ b/tokio-boring/tests/rpk.rs @@ -34,7 +34,7 @@ mod test_rpk { let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -66,6 +66,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let mut stream = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap(); @@ -97,6 +98,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let err = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap_err();