From 4b167815bda803b0aea51ef56c722a4d19dc17bd Mon Sep 17 00:00:00 2001 From: Alessandro Ghedini Date: Fri, 23 Jun 2023 10:44:20 +0100 Subject: [PATCH 1/4] Introduce ssl::Error::would_block --- boring/src/ssl/error.rs | 4 ++++ boring/src/ssl/mod.rs | 50 +++++++++++++++++------------------------ 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index 5fb91659..a7e6f9fc 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -81,6 +81,10 @@ impl Error { _ => None, } } + + pub fn would_block(&self) -> bool { + matches!(self.code, ErrorCode::WANT_READ | ErrorCode::WANT_WRITE) + } } impl From for Error { diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 67e16b18..b354902a 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -3162,11 +3162,9 @@ impl MidHandshakeSslStream { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - match self.error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(self)) - } - _ => Err(HandshakeError::Failure(self)), + match self.error.would_block() { + true => Err(HandshakeError::WouldBlock(self)), + false => Err(HandshakeError::Failure(self)), } } } @@ -3471,14 +3469,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), @@ -3494,14 +3490,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), @@ -3523,14 +3517,12 @@ where Ok(stream) } else { let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { - Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })) - } - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + match error.would_block() { + true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { + stream, + error, + })), + false => Err(HandshakeError::Failure(MidHandshakeSslStream { stream, error, })), From b27a6e76a1b6f5bfac5ce7cb22db24ed39342c3a Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 3 Aug 2023 18:35:55 +0200 Subject: [PATCH 2/4] Introduce setup_accept and setup_connect These two new kinds of methods immediately return a MidHandshakeSslStream instead of actually initiating a handshake. This greatly simplifies loops around MidHandshakeSslStream::WouldBlock. --- boring/src/ssl/connector.rs | 70 +++++++++++++++-- boring/src/ssl/mod.rs | 146 ++++++++++++++++++++++++------------ tokio-boring/src/lib.rs | 83 +++++--------------- 3 files changed, 180 insertions(+), 119 deletions(-) diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 7be740d3..93da579e 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -9,6 +9,8 @@ use crate::ssl::{ }; use crate::version; +use super::MidHandshakeSslStream; + const FFDHE_2048: &str = " -----BEGIN DH PARAMETERS----- MIIBCAKCAQEA//////////+t+FRYortKmq/cViAnPTzx2LnFg84tNpWp4TZBFGQz @@ -98,11 +100,30 @@ impl SslConnector { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification. + pub fn setup_connect( + &self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> + where + S: Read + Write, + { + self.configure()?.setup_connect(domain, stream) + } + + /// Attempts a client-side TLS session on a stream. + /// + /// The domain is used for SNI and hostname verification. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> where S: Read + Write, { - self.configure()?.connect(domain, stream) + self.setup_connect(domain, stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } /// Returns a structure allowing for configuration of a single TLS session before connection. @@ -192,7 +213,13 @@ impl ConnectConfiguration { /// Initiates a client-side TLS session on a stream. /// /// The domain is used for SNI and hostname verification if enabled. - pub fn connect(mut self, domain: &str, stream: S) -> Result, HandshakeError> + /// + /// See [`Ssl::setup_connect`] for more details. + pub fn setup_connect( + mut self, + domain: &str, + stream: S, + ) -> Result, ErrorStack> where S: Read + Write, { @@ -210,7 +237,22 @@ impl ConnectConfiguration { setup_verify_hostname(&mut self.ssl, domain)?; } - self.ssl.connect(stream) + Ok(self.ssl.setup_connect(stream)) + } + + /// Attempts a client-side TLS session on a stream. + /// + /// The domain is used for SNI and hostname verification if enabled. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. + pub fn connect(self, domain: &str, stream: S) -> Result, HandshakeError> + where + S: Read + Write, + { + self.setup_connect(domain, stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } } @@ -319,13 +361,29 @@ impl SslAcceptor { Ok(SslAcceptorBuilder(ctx)) } - /// Initiates a server-side TLS session on a stream. - pub fn accept(&self, stream: S) -> Result, HandshakeError> + /// Initiates a server-side TLS handshake on a stream. + /// + /// See [`Ssl::setup_accept`] for more details. + pub fn setup_accept(&self, stream: S) -> Result, ErrorStack> where S: Read + Write, { let ssl = Ssl::new(&self.0)?; - ssl.accept(stream) + + Ok(ssl.setup_accept(stream)) + } + + /// Attempts a server-side TLS handshake on a stream. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. + pub fn accept(&self, stream: S) -> Result, HandshakeError> + where + S: Read + Write, + { + self.setup_accept(stream) + .map_err(HandshakeError::SetupFailure)? + .handshake() } /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index b354902a..40941176 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -2306,34 +2306,52 @@ impl Ssl { } } - /// Initiates a client-side TLS handshake. + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This corresponds to [`SSL_connect`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. + /// + /// See [`SslStreamBuilder::setup_connect`] for more details. /// /// # Warning /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslConnector` rather than `Ssl` directly, as it manages that configuration. + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslConnector`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_connect(self, stream: S) -> MidHandshakeSslStream + where + S: Read + Write, + { + SslStreamBuilder::new(self, stream).setup_connect() + } + + /// Attempts a client-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. + /// + /// # Warning /// - /// [`SSL_connect`]: https://www.openssl.org/docs/manmaster/man3/SSL_connect.html + /// OpenSSL's default configuration is insecure. It is highly recommended to use + /// [`SslConnector`] rather than `Ssl` directly, as it manages that configuration. pub fn connect(self, stream: S) -> Result, HandshakeError> where S: Read + Write, { - SslStreamBuilder::new(self, stream).connect() + self.setup_connect(stream).handshake() } /// Initiates a server-side TLS handshake. /// - /// This corresponds to [`SSL_accept`]. + /// This method is guaranteed to return without calling any callback defined + /// in the internal [`Ssl`] or [`SslContext`]. /// - /// # Warning + /// See [`SslStreamBuilder::setup_accept`] for more details. /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. + /// # Warning /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> + /// BoringSSL's default configuration is insecure. It is highly recommended to use + /// [`SslAcceptor`] rather than [`Ssl`] directly, as it manages that configuration. + pub fn setup_accept(self, stream: S) -> MidHandshakeSslStream where S: Read + Write, { @@ -2352,7 +2370,25 @@ impl Ssl { } } - SslStreamBuilder::new(self, stream).accept() + SslStreamBuilder::new(self, stream).setup_accept() + } + + /// Attempts a server-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. + /// + /// # Warning + /// + /// OpenSSL's default configuration is insecure. It is highly recommended to use + /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. + /// + /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html + pub fn accept(self, stream: S) -> Result, HandshakeError> + where + S: Read + Write, + { + self.setup_accept(stream).handshake() } } @@ -3461,46 +3497,60 @@ where unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } - /// See `Ssl::connect` + /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. + /// + /// This method calls [`Self::set_connect_state`] and returns without actually + /// initiating the handshake. The caller is then free to call + /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + pub fn setup_connect(mut self) -> MidHandshakeSslStream { + self.set_connect_state(); + + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_WRITE, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "connect handshake has not started yet", + ))), + }, + } + } + + /// Attempts a client-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_connect`] and + /// [`MidHandshakeSslStream::handshake`]. pub fn connect(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.would_block() { - true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })), - false => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } + self.setup_connect().handshake() + } + + /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. + /// + /// This method calls [`Self::set_accept_state`] and returns without actually + /// initiating the handshake. The caller is then free to call + /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + pub fn setup_accept(mut self) -> MidHandshakeSslStream { + self.set_accept_state(); + + MidHandshakeSslStream { + stream: self.inner, + error: Error { + code: ErrorCode::WANT_READ, + cause: Some(InnerError::Io(io::Error::new( + io::ErrorKind::WouldBlock, + "accept handshake has not started yet", + ))), + }, } } - /// See `Ssl::accept` + /// Attempts a server-side TLS handshake. + /// + /// This is a convenience method which combines [`Self::setup_accept`] and + /// [`MidHandshakeSslStream::handshake`]. pub fn accept(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.would_block() { - true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })), - false => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } - } + self.setup_accept().handshake() } /// Initiates the handshake. diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a0dd58c5..f594231d 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -13,6 +13,7 @@ #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +use boring::error::ErrorStack; use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, @@ -35,7 +36,7 @@ pub async fn connect( where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| config.connect(domain, s), stream).await + handshake(|s| config.setup_connect(domain, s), stream).await } /// Asynchronously performs a server-side TLS handshake over the provided stream. @@ -43,24 +44,22 @@ pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.accept(s), stream).await + handshake(|s| acceptor.setup_accept(s), stream).await } -async fn handshake(f: F, stream: S) -> Result, HandshakeError> +async fn handshake( + f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, + stream: S, +) -> Result, HandshakeError> where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, S: AsyncRead + AsyncWrite + Unpin, { - let start = StartHandshakeFuture(Some(StartHandshakeFutureInner { f, stream })); + let ongoing_handshake = Some( + f(StreamWrapper { stream, context: 0 }) + .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?, + ); - match start.await? { - StartedHandshake::Done(s) => Ok(s), - StartedHandshake::Mid(s) => HandshakeFuture(Some(s)).await, - } + HandshakeFuture(ongoing_handshake).await } struct StreamWrapper { @@ -334,53 +333,6 @@ where } } -enum StartedHandshake { - Done(SslStream), - Mid(MidHandshakeSslStream>), -} - -struct StartHandshakeFuture(Option>); - -struct StartHandshakeFutureInner { - f: F, - stream: S, -} - -impl Future for StartHandshakeFuture -where - F: FnOnce( - StreamWrapper, - ) - -> Result>, ssl::HandshakeError>> - + Unpin, - S: Unpin, -{ - type Output = Result, HandshakeError>; - - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { - let inner = self.0.take().expect("future polled after completion"); - - let stream = StreamWrapper { - stream: inner.stream, - context: ctx as *mut _ as usize, - }; - match (inner.f)(stream) { - Ok(mut s) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Done(SslStream(s)))) - } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; - Poll::Ready(Ok(StartedHandshake::Mid(s))) - } - Err(e) => Poll::Ready(Err(HandshakeError(e))), - } - } -} - struct HandshakeFuture(Option>>); impl Future for HandshakeFuture @@ -389,21 +341,22 @@ where { type Output = Result, HandshakeError>; - fn poll( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - ) -> Poll, HandshakeError>> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut s = self.0.take().expect("future polled after completion"); - s.get_mut().context = ctx as *mut _ as usize; + s.get_mut().context = cx as *mut _ as usize; + match s.handshake() { Ok(mut s) => { s.get_mut().context = 0; + Poll::Ready(Ok(SslStream(s))) } Err(ssl::HandshakeError::WouldBlock(mut s)) => { s.get_mut().context = 0; + self.0 = Some(s); + Poll::Pending } Err(e) => Poll::Ready(Err(HandshakeError(e))), From 252d3448adb69db341557e60156a4b9003c77a49 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Fri, 4 Aug 2023 13:42:09 +0200 Subject: [PATCH 3/4] Introduce helper module in tokio-boring tests --- tokio-boring/tests/client_server.rs | 88 ++++---------------------- tokio-boring/tests/common/mod.rs | 96 +++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 77 deletions(-) create mode 100644 tokio-boring/tests/common/mod.rs diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 72c5a040..925f9875 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -1,11 +1,12 @@ -use boring::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use boring::ssl::{SslConnector, SslMethod}; use futures::future; -use std::future::Future; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::pin::Pin; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream}; -use tokio_boring::{HandshakeError, SslStream}; +use std::net::ToSocketAddrs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +mod common; + +use self::common::{connect, create_server, with_trivial_client_server_exchange}; #[tokio::test] async fn google() { @@ -33,75 +34,14 @@ async fn google() { assert!(response.ends_with("") || response.ends_with("")); } -fn create_server() -> ( - impl Future, HandshakeError>>, - SocketAddr, -) { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - - listener.set_nonblocking(true).unwrap(); - - let listener = TcpListener::from_std(listener).unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - acceptor - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); - let acceptor = acceptor.build(); - - let stream = listener.accept().await.unwrap().0; - - tokio_boring::accept(&acceptor, stream).await - }; - - (server, addr) -} - #[tokio::test] async fn server() { - let (stream, addr) = create_server(); - - let server = async { - let mut stream = stream.await.unwrap(); - let mut buf = [0; 4]; - stream.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, b"asdf"); - - stream.write_all(b"jkl;").await.unwrap(); - - future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) - .await - .unwrap(); - }; - - let client = async { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - connector.set_ca_file("tests/cert.pem").unwrap(); - let config = connector.build().configure().unwrap(); - - let stream = TcpStream::connect(&addr).await.unwrap(); - let mut stream = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap(); - - stream.write_all(b"asdf").await.unwrap(); - - let mut buf = vec![]; - stream.read_to_end(&mut buf).await.unwrap(); - assert_eq!(buf, b"jkl;"); - }; - - future::join(server, client).await; + with_trivial_client_server_exchange(|_| ()).await; } #[tokio::test] async fn handshake_error() { - let (stream, addr) = create_server(); + let (stream, addr) = create_server(|_| ()); let server = async { let err = stream.await.unwrap_err(); @@ -110,13 +50,7 @@ async fn handshake_error() { }; let client = async { - let connector = SslConnector::builder(SslMethod::tls()).unwrap(); - let config = connector.build().configure().unwrap(); - let stream = TcpStream::connect(&addr).await.unwrap(); - - let err = tokio_boring::connect(config, "localhost", stream) - .await - .unwrap_err(); + let err = connect(addr, |_| Ok(())).await.unwrap_err(); assert!(err.into_source_stream().is_some()); }; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs new file mode 100644 index 00000000..6ed394ef --- /dev/null +++ b/tokio-boring/tests/common/mod.rs @@ -0,0 +1,96 @@ +#![allow(dead_code)] + +use boring::error::ErrorStack; +use boring::ssl::{ + SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, SslFiletype, SslMethod, +}; +use futures::future::{self, Future}; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_boring::{HandshakeError, SslStream}; + +pub(crate) fn create_server( + setup: impl FnOnce(&mut SslAcceptorBuilder), +) -> ( + impl Future, HandshakeError>>, + SocketAddr, +) { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + + listener.set_nonblocking(true).unwrap(); + + let listener = TcpListener::from_std(listener).unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = async move { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + + acceptor + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); + + acceptor + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); + + setup(&mut acceptor); + + let acceptor = acceptor.build(); + + let stream = listener.accept().await.unwrap().0; + + tokio_boring::accept(&acceptor, stream).await + }; + + (server, addr) +} + +pub(crate) async fn connect( + addr: SocketAddr, + setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, +) -> Result, HandshakeError> { + let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); + + setup(&mut connector).unwrap(); + + let config = connector.build().configure().unwrap(); + + let stream = TcpStream::connect(&addr).await.unwrap(); + + tokio_boring::connect(config, "localhost", stream).await +} + +pub(crate) async fn with_trivial_client_server_exchange( + server_setup: impl FnOnce(&mut SslAcceptorBuilder), +) { + let (stream, addr) = create_server(server_setup); + + let server = async { + let mut stream = stream.await.unwrap(); + let mut buf = [0; 4]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"asdf"); + + stream.write_all(b"jkl;").await.unwrap(); + + future::poll_fn(|ctx| Pin::new(&mut stream).poll_shutdown(ctx)) + .await + .unwrap(); + }; + + let client = async { + let mut stream = connect(addr, |builder| builder.set_ca_file("tests/cert.pem")) + .await + .unwrap(); + + stream.write_all(b"asdf").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"jkl;"); + }; + + future::join(server, client).await; +} From 7942e71ae8b7882d96236092b629fabc38423935 Mon Sep 17 00:00:00 2001 From: Anthony Ramine Date: Thu, 3 Aug 2023 19:49:35 +0200 Subject: [PATCH 4/4] Remove all connect and accept functions Those functions were blurring the line between setup failures (which are caused by the developer misusing the API) and actual failures encountered on the stream while trying to achieve the TLS handshake, especially on SslConnector and SslAcceptor. Removing them allows for the removal of HandshakeError, as the HandshakeError::SetupFailure variant becomes useless, and there is no real need to distinguish in that error type between Failure and WouldBlock when we can just check the error stored in MidHandshakeSslStream. This then allow us to simplify tokio_boring's own entry points, also making them distinguish between setup failures and failures on the stream. --- boring/src/ssl/connector.rs | 47 +---------- boring/src/ssl/error.rs | 65 --------------- boring/src/ssl/mod.rs | 114 +++++--------------------- boring/src/ssl/test/mod.rs | 70 +++++++++++----- boring/src/ssl/test/server.rs | 6 +- hyper-boring/src/lib.rs | 2 +- hyper-boring/src/test.rs | 10 ++- tokio-boring/examples/simple-async.rs | 2 +- tokio-boring/src/lib.rs | 85 ++++++++++--------- tokio-boring/tests/client_server.rs | 9 +- tokio-boring/tests/common/mod.rs | 6 +- tokio-boring/tests/rpk.rs | 4 +- 12 files changed, 137 insertions(+), 283 deletions(-) diff --git a/boring/src/ssl/connector.rs b/boring/src/ssl/connector.rs index 93da579e..ca7a6bee 100644 --- a/boring/src/ssl/connector.rs +++ b/boring/src/ssl/connector.rs @@ -4,8 +4,8 @@ use std::ops::{Deref, DerefMut}; use crate::dh::Dh; use crate::error::ErrorStack; use crate::ssl::{ - HandshakeError, Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, - SslOptions, SslRef, SslStream, SslVerifyMode, + Ssl, SslContext, SslContextBuilder, SslContextRef, SslMethod, SslMode, SslOptions, SslRef, + SslVerifyMode, }; use crate::version; @@ -111,21 +111,6 @@ impl SslConnector { self.configure()?.setup_connect(domain, stream) } - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI and hostname verification. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Returns a structure allowing for configuration of a single TLS session before connection. pub fn configure(&self) -> Result { Ssl::new(&self.0).map(|ssl| ConnectConfiguration { @@ -239,21 +224,6 @@ impl ConnectConfiguration { Ok(self.ssl.setup_connect(stream)) } - - /// Attempts a client-side TLS session on a stream. - /// - /// The domain is used for SNI and hostname verification if enabled. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(domain, stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } } impl Deref for ConnectConfiguration { @@ -373,19 +343,6 @@ impl SslAcceptor { Ok(ssl.setup_accept(stream)) } - /// Attempts a server-side TLS handshake on a stream. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(&self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream) - .map_err(HandshakeError::SetupFailure)? - .handshake() - } - /// Consumes the `SslAcceptor`, returning the inner raw `SslContext`. pub fn into_context(self) -> SslContext { self.0 diff --git a/boring/src/ssl/error.rs b/boring/src/ssl/error.rs index a7e6f9fc..6f7af219 100644 --- a/boring/src/ssl/error.rs +++ b/boring/src/ssl/error.rs @@ -1,13 +1,10 @@ use crate::ffi; use libc::c_int; use std::error; -use std::error::Error as StdError; use std::fmt; use std::io; use crate::error::ErrorStack; -use crate::ssl::MidHandshakeSslStream; -use crate::x509::X509VerifyResult; /// An error code returned from SSL functions. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -130,65 +127,3 @@ impl error::Error for Error { } } } - -/// An error or intermediate state after a TLS handshake attempt. -// FIXME overhaul -#[derive(Debug)] -pub enum HandshakeError { - /// Setup failed. - SetupFailure(ErrorStack), - /// The handshake failed. - Failure(MidHandshakeSslStream), - /// The handshake encountered a `WouldBlock` error midway through. - /// - /// This error will never be returned for blocking streams. - WouldBlock(MidHandshakeSslStream), -} - -impl StdError for HandshakeError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - match *self { - HandshakeError::SetupFailure(ref e) => Some(e), - HandshakeError::Failure(ref s) | HandshakeError::WouldBlock(ref s) => Some(s.error()), - } - } -} - -impl fmt::Display for HandshakeError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - HandshakeError::SetupFailure(ref e) => { - write!(f, "TLS stream setup failed {}", e) - } - HandshakeError::Failure(ref s) => fmt_mid_handshake_error(s, f, "TLS handshake failed"), - HandshakeError::WouldBlock(ref s) => { - fmt_mid_handshake_error(s, f, "TLS handshake interrupted") - } - } - } -} - -fn fmt_mid_handshake_error( - s: &MidHandshakeSslStream, - f: &mut fmt::Formatter, - prefix: &str, -) -> fmt::Result { - #[cfg(feature = "rpk")] - if s.ssl().ssl_context().is_rpk() { - write!(f, "{}", prefix)?; - return write!(f, " {}", s.error()); - } - - match s.ssl().verify_result() { - X509VerifyResult::OK => write!(f, "{}", prefix)?, - verify => write!(f, "{}: cert verification failed - {}", prefix, verify)?, - } - - write!(f, " {}", s.error()) -} - -impl From for HandshakeError { - fn from(e: ErrorStack) -> HandshakeError { - HandshakeError::SetupFailure(e) - } -} diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 40941176..af8230de 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -15,7 +15,11 @@ //! let connector = SslConnector::builder(SslMethod::tls()).unwrap().build(); //! //! let stream = TcpStream::connect("google.com:443").unwrap(); -//! let mut stream = connector.connect("google.com", stream).unwrap(); +//! let mut stream = connector +//! .setup_connect("google.com", stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); //! //! stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); //! let mut res = vec![]; @@ -49,7 +53,12 @@ //! Ok(stream) => { //! let acceptor = acceptor.clone(); //! thread::spawn(move || { -//! let stream = acceptor.accept(stream).unwrap(); +//! let stream = acceptor +//! .setup_accept(stream) +//! .unwrap() +//! .handshake() +//! .unwrap(); +//! //! handle_client(stream); //! }); //! } @@ -98,7 +107,7 @@ use crate::{cvt, cvt_0i, cvt_n, cvt_p, init}; pub use crate::ssl::connector::{ ConnectConfiguration, SslAcceptor, SslAcceptorBuilder, SslConnector, SslConnectorBuilder, }; -pub use crate::ssl::error::{Error, ErrorCode, HandshakeError}; +pub use crate::ssl::error::{Error, ErrorCode}; mod bio; mod callbacks; @@ -2324,22 +2333,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_connect() } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// [`SslConnector`] rather than `Ssl` directly, as it manages that configuration. - pub fn connect(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_connect(stream).handshake() - } - /// Initiates a server-side TLS handshake. /// /// This method is guaranteed to return without calling any callback defined @@ -2372,24 +2365,6 @@ impl Ssl { SslStreamBuilder::new(self, stream).setup_accept() } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - /// - /// # Warning - /// - /// OpenSSL's default configuration is insecure. It is highly recommended to use - /// `SslAcceptor` rather than `Ssl` directly, as it manages that configuration. - /// - /// [`SSL_accept`]: https://www.openssl.org/docs/manmaster/man3/SSL_accept.html - pub fn accept(self, stream: S) -> Result, HandshakeError> - where - S: Read + Write, - { - self.setup_accept(stream).handshake() - } } impl fmt::Debug for SslRef { @@ -3192,16 +3167,14 @@ impl MidHandshakeSslStream { /// This corresponds to [`SSL_do_handshake`]. /// /// [`SSL_do_handshake`]: https://www.openssl.org/docs/manmaster/man3/SSL_do_handshake.html - pub fn handshake(mut self) -> Result, HandshakeError> { + pub fn handshake(mut self) -> Result, MidHandshakeSslStream> { let ret = unsafe { ffi::SSL_do_handshake(self.stream.ssl.as_ptr()) }; if ret > 0 { Ok(self.stream) } else { self.error = self.stream.make_error(ret); - match self.error.would_block() { - true => Err(HandshakeError::WouldBlock(self)), - false => Err(HandshakeError::Failure(self)), - } + + Err(self) } } } @@ -3484,7 +3457,7 @@ where /// This corresponds to [`SSL_set_connect_state`]. /// /// [`SSL_set_connect_state`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_connect_state.html - pub fn set_connect_state(&mut self) { + fn set_connect_state(&mut self) { unsafe { ffi::SSL_set_connect_state(self.inner.ssl.as_ptr()) } } @@ -3493,15 +3466,14 @@ where /// This corresponds to [`SSL_set_accept_state`]. /// /// [`SSL_set_accept_state`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_accept_state.html - pub fn set_accept_state(&mut self) { + fn set_accept_state(&mut self) { unsafe { ffi::SSL_set_accept_state(self.inner.ssl.as_ptr()) } } /// Initiates a client-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_connect_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_connect(mut self) -> MidHandshakeSslStream { self.set_connect_state(); @@ -3517,19 +3489,10 @@ where } } - /// Attempts a client-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_connect`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn connect(self) -> Result, HandshakeError> { - self.setup_connect().handshake() - } - /// Initiates a server-side TLS handshake, returning a [`MidHandshakeSslStream`]. /// - /// This method calls [`Self::set_accept_state`] and returns without actually - /// initiating the handshake. The caller is then free to call - /// [`MidHandshakeSslStream`] and loop on [`HandshakeError::WouldBlock`]. + /// The caller is then free to call [`MidHandshakeSslStream::handshake`] and retry + /// on blocking errors. pub fn setup_accept(mut self) -> MidHandshakeSslStream { self.set_accept_state(); @@ -3544,41 +3507,6 @@ where }, } } - - /// Attempts a server-side TLS handshake. - /// - /// This is a convenience method which combines [`Self::setup_accept`] and - /// [`MidHandshakeSslStream::handshake`]. - pub fn accept(self) -> Result, HandshakeError> { - self.setup_accept().handshake() - } - - /// Initiates the handshake. - /// - /// This will fail if `set_accept_state` or `set_connect_state` was not called first. - /// - /// This corresponds to [`SSL_do_handshake`]. - /// - /// [`SSL_do_handshake`]: https://www.openssl.org/docs/manmaster/man3/SSL_do_handshake.html - pub fn handshake(self) -> Result, HandshakeError> { - let mut stream = self.inner; - let ret = unsafe { ffi::SSL_do_handshake(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.would_block() { - true => Err(HandshakeError::WouldBlock(MidHandshakeSslStream { - stream, - error, - })), - false => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } - } - } } impl SslStreamBuilder { diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index e72dc086..3d76fe06 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -25,10 +25,9 @@ use crate::ssl; use crate::ssl::test::server::Server; use crate::ssl::SslVersion; use crate::ssl::{ - Error, ExtensionType, HandshakeError, MidHandshakeSslStream, ShutdownResult, ShutdownState, - Ssl, SslAcceptor, SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, - SslMethod, SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, - StatusType, + Error, ExtensionType, MidHandshakeSslStream, ShutdownResult, ShutdownState, Ssl, SslAcceptor, + SslAcceptorBuilder, SslConnector, SslContext, SslContextBuilder, SslFiletype, SslMethod, + SslOptions, SslSessionCacheMode, SslStream, SslStreamBuilder, SslVerifyMode, StatusType, }; use crate::x509::store::X509StoreBuilder; use crate::x509::verify::X509CheckFlags; @@ -317,7 +316,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -336,7 +335,7 @@ fn test_connect_with_srtp_ctx() { .unwrap(); let mut ssl = Ssl::new(&ctx.build()).unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -386,7 +385,7 @@ fn test_connect_with_srtp_ssl() { profilenames ); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.accept(stream).unwrap(); + let mut stream = ssl.setup_accept(stream).handshake().unwrap(); let mut buf = [0; 60]; stream @@ -405,7 +404,7 @@ fn test_connect_with_srtp_ssl() { ssl.set_tlsext_use_srtp("SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32") .unwrap(); ssl.set_mtu(1500).unwrap(); - let mut stream = ssl.connect(stream).unwrap(); + let mut stream = ssl.setup_connect(stream).handshake().unwrap(); let mut buf = [1; 60]; { @@ -584,7 +583,10 @@ fn write_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -615,7 +617,10 @@ fn read_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -646,7 +651,10 @@ fn flush_panic() { let stream = ExplodingStream(server.connect_tcp()); let ctx = SslContext::builder(SslMethod::tls()).unwrap(); - let _ = Ssl::new(&ctx.build()).unwrap().connect(stream); + let _ = Ssl::new(&ctx.build()) + .unwrap() + .setup_connect(stream) + .handshake(); } #[test] @@ -676,7 +684,7 @@ fn default_verify_paths() { }; let mut ssl = Ssl::new(&ctx).unwrap(); ssl.set_hostname("google.com").unwrap(); - let mut socket = ssl.connect(s).unwrap(); + let mut socket = ssl.setup_connect(s).handshake().unwrap(); socket.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap(); let mut result = vec![]; @@ -738,7 +746,12 @@ fn connector_valid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - let mut s = connector.build().connect("foobar.com", s).unwrap(); + let mut s = connector + .build() + .setup_connect("foobar.com", s) + .unwrap() + .handshake() + .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -752,7 +765,12 @@ fn connector_invalid_hostname() { connector.set_ca_file("test/root-ca.pem").unwrap(); let s = server.connect_tcp(); - connector.build().connect("bogus.com", s).unwrap_err(); + connector + .build() + .setup_connect("bogus.com", s) + .unwrap() + .handshake() + .unwrap_err(); } #[test] @@ -768,7 +786,9 @@ fn connector_invalid_no_hostname_verification() { .configure() .unwrap() .verify_hostname(false) - .connect("bogus.com", s) + .setup_connect("bogus.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -786,7 +806,9 @@ fn connector_no_hostname_still_verifies() { .configure() .unwrap() .verify_hostname(false) - .connect("fizzbuzz.com", s) + .setup_connect("fizzbuzz.com", s) + .unwrap() + .handshake() .is_err()); } @@ -803,7 +825,9 @@ fn connector_no_hostname_can_disable_verify() { .configure() .unwrap() .verify_hostname(false) - .connect("foobar.com", s) + .setup_connect("foobar.com", s) + .unwrap() + .handshake() .unwrap(); s.read_exact(&mut [0]).unwrap(); } @@ -820,7 +844,7 @@ fn test_mozilla_server(new: fn(SslMethod) -> Result Result SslStream { let socket = TcpStream::connect(self.addr).unwrap(); - let mut s = self.ssl.connect(socket).unwrap(); + let mut s = self.ssl.setup_connect(socket).handshake().unwrap(); s.read_exact(&mut [0]).unwrap(); s } pub fn connect_err(self) { let socket = TcpStream::connect(self.addr).unwrap(); - self.ssl.connect(socket).unwrap_err(); + self.ssl.setup_connect(socket).handshake().unwrap_err(); } } diff --git a/hyper-boring/src/lib.rs b/hyper-boring/src/lib.rs index d6e82c1b..7d1ad16d 100644 --- a/hyper-boring/src/lib.rs +++ b/hyper-boring/src/lib.rs @@ -245,7 +245,7 @@ where } let config = inner.setup_ssl(&uri, host)?; - let stream = tokio_boring::connect(config, host, conn).await?; + let stream = tokio_boring::connect(config, host, conn)?.await?; Ok(MaybeHttpsStream::Https(stream)) }; diff --git a/hyper-boring/src/test.rs b/hyper-boring/src/test.rs index 226a0487..73c9d664 100644 --- a/hyper-boring/src/test.rs +++ b/hyper-boring/src/test.rs @@ -43,7 +43,10 @@ async fn localhost() { for _ in 0..3 { let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); let service = service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(Body::empty())) }); @@ -105,7 +108,10 @@ async fn alpn_h2() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let stream = tokio_boring::accept(&acceptor, stream).await.unwrap(); + let stream = tokio_boring::accept(&acceptor, stream) + .unwrap() + .await + .unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service = diff --git a/tokio-boring/examples/simple-async.rs b/tokio-boring/examples/simple-async.rs index f4a69a1c..d9b159de 100644 --- a/tokio-boring/examples/simple-async.rs +++ b/tokio-boring/examples/simple-async.rs @@ -11,6 +11,6 @@ async fn main() -> anyhow::Result<()> { ssl_builder.set_default_verify_paths()?; ssl_builder.set_verify(ssl::SslVerifyMode::PEER); let acceptor = ssl_builder.build(); - let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream).await?; + let _ssl_stream = tokio_boring::accept(&acceptor, tcp_stream)?.await?; Ok(()) } diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index f594231d..6cff31bd 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -18,6 +18,7 @@ use boring::ssl::{ self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor, SslRef, }; +use boring::x509::X509VerifyResult; use boring_sys as ffi; use std::error::Error; use std::fmt; @@ -28,38 +29,35 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; /// Asynchronously performs a client-side TLS handshake over the provided stream. -pub async fn connect( +pub fn connect( config: ConnectConfiguration, domain: &str, stream: S, -) -> Result, HandshakeError> +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| config.setup_connect(domain, s), stream).await + handshake(|s| config.setup_connect(domain, s), stream) } /// Asynchronously performs a server-side TLS handshake over the provided stream. -pub async fn accept(acceptor: &SslAcceptor, stream: S) -> Result, HandshakeError> +pub fn accept(acceptor: &SslAcceptor, stream: S) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - handshake(|s| acceptor.setup_accept(s), stream).await + handshake(|s| acceptor.setup_accept(s), stream) } -async fn handshake( +fn handshake( f: impl FnOnce(StreamWrapper) -> Result>, ErrorStack>, stream: S, -) -> Result, HandshakeError> +) -> Result, ErrorStack> where S: AsyncRead + AsyncWrite + Unpin, { - let ongoing_handshake = Some( - f(StreamWrapper { stream, context: 0 }) - .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?, - ); + let ongoing_handshake = Some(f(StreamWrapper { stream, context: 0 })?); - HandshakeFuture(ongoing_handshake).await + Ok(HandshakeFuture(ongoing_handshake)) } struct StreamWrapper { @@ -265,47 +263,32 @@ where } /// The error type returned after a failed handshake. -pub struct HandshakeError(ssl::HandshakeError>); +pub struct HandshakeError(MidHandshakeSslStream>); impl HandshakeError { /// Returns a shared reference to the `Ssl` object associated with this error. - pub fn ssl(&self) -> Option<&SslRef> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.ssl()), - _ => None, - } + pub fn ssl(&self) -> &SslRef { + self.0.ssl() } /// Converts error to the source data stream that was used for the handshake. - pub fn into_source_stream(self) -> Option { - match self.0 { - ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream), - _ => None, - } + pub fn into_source_stream(self) -> S { + self.0.into_source_stream().stream } /// Returns a reference to the source data stream. - pub fn as_source_stream(&self) -> Option<&S> { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream), - _ => None, - } + pub fn as_source_stream(&self) -> &S { + &self.0.get_ref().stream } /// Returns the error code, if any. - pub fn code(&self) -> Option { - match &self.0 { - ssl::HandshakeError::Failure(s) => Some(s.error().code()), - _ => None, - } + pub fn code(&self) -> ErrorCode { + self.0.error().code() } /// Returns a reference to the inner I/O error, if any. pub fn as_io_error(&self) -> Option<&io::Error> { - match &self.0 { - ssl::HandshakeError::Failure(s) => s.error().io_error(), - _ => None, - } + self.0.error().io_error() } } @@ -320,7 +303,20 @@ where impl fmt::Display for HandshakeError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.0, fmt) + fmt.write_str("TLS handshake failed: ")?; + + #[cfg(feature = "rpk")] + if self.0.ssl().ssl_context().is_rpk() { + return write!(fmt, " {}", self.0.error()); + } + + let verify = self.0.ssl().verify_result(); + + if verify != X509VerifyResult::OK { + write!(fmt, "cert verification failed - {verify}")?; + } + + write!(fmt, "TLS handshake failed: {}", self.0.error()) } } @@ -329,11 +325,14 @@ where S: fmt::Debug, { fn source(&self) -> Option<&(dyn Error + 'static)> { - self.0.source() + self.0.error().source() } } -struct HandshakeFuture(Option>>); +/// Future for an ongoing TLS handshake. +/// +/// See [`connect`] and [`accept`]. +pub struct HandshakeFuture(Option>>); impl Future for HandshakeFuture where @@ -352,10 +351,10 @@ where Poll::Ready(Ok(SslStream(s))) } - Err(ssl::HandshakeError::WouldBlock(mut s)) => { - s.get_mut().context = 0; + Err(mut handshake) if handshake.error().would_block() => { + handshake.get_mut().context = 0; - self.0 = Some(s); + self.0 = Some(handshake); Poll::Pending } diff --git a/tokio-boring/tests/client_server.rs b/tokio-boring/tests/client_server.rs index 925f9875..872033f8 100644 --- a/tokio-boring/tests/client_server.rs +++ b/tokio-boring/tests/client_server.rs @@ -19,6 +19,7 @@ async fn google() { .configure() .unwrap(); let mut stream = tokio_boring::connect(config, "google.com", stream) + .unwrap() .await .unwrap(); @@ -44,15 +45,11 @@ async fn handshake_error() { let (stream, addr) = create_server(|_| ()); let server = async { - let err = stream.await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = stream.await.unwrap_err(); }; let client = async { - let err = connect(addr, |_| Ok(())).await.unwrap_err(); - - assert!(err.into_source_stream().is_some()); + let _err = connect(addr, |_| Ok(())).await.unwrap_err(); }; future::join(server, client).await; diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs index 6ed394ef..fdcf86fb 100644 --- a/tokio-boring/tests/common/mod.rs +++ b/tokio-boring/tests/common/mod.rs @@ -41,7 +41,7 @@ pub(crate) fn create_server( let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -59,7 +59,9 @@ pub(crate) async fn connect( let stream = TcpStream::connect(&addr).await.unwrap(); - tokio_boring::connect(config, "localhost", stream).await + tokio_boring::connect(config, "localhost", stream) + .unwrap() + .await } pub(crate) async fn with_trivial_client_server_exchange( diff --git a/tokio-boring/tests/rpk.rs b/tokio-boring/tests/rpk.rs index 5492767a..48f53655 100644 --- a/tokio-boring/tests/rpk.rs +++ b/tokio-boring/tests/rpk.rs @@ -34,7 +34,7 @@ mod test_rpk { let stream = listener.accept().await.unwrap().0; - tokio_boring::accept(&acceptor, stream).await + tokio_boring::accept(&acceptor, stream).unwrap().await }; (server, addr) @@ -66,6 +66,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let mut stream = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap(); @@ -97,6 +98,7 @@ mod test_rpk { let stream = TcpStream::connect(&addr).await.unwrap(); let err = tokio_boring::connect(config, "localhost", stream) + .unwrap() .await .unwrap_err();