From fe55a47dd65256ea69795246834a4551e86c9647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Zwoli=C5=84ski?= Date: Thu, 27 Apr 2023 14:11:37 +0200 Subject: [PATCH] Add implementation for mbedtls - Initial work by @MabezDev Co-authored-by: Scott Mabin --- Cargo.toml | 5 +- src/imp/mbedtls.rs | 532 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 536 insertions(+), 1 deletion(-) create mode 100644 src/imp/mbedtls.rs diff --git a/Cargo.toml b/Cargo.toml index f63b0223..57840e7e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,15 @@ tempfile = "3.1.0" [target.'cfg(target_os = "windows")'.dependencies] schannel = "0.1.17" -[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] +[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios", target_os = "espidf")))'.dependencies] log = "0.4.5" openssl = "0.10.29" openssl-sys = "0.9.55" openssl-probe = "0.1" +[target.'cfg(target_os = "espidf")'.dependencies] +mbedtls = { version = "0.8.1", features = ["pkcs12", "std"], path = "/home/mabez/development/rust/embedded/util/rust-mbedtls/mbedtls" } + [dev-dependencies] tempfile = "3.0" test-cert-gen = "0.9" diff --git a/src/imp/mbedtls.rs b/src/imp/mbedtls.rs new file mode 100644 index 00000000..26f045dc --- /dev/null +++ b/src/imp/mbedtls.rs @@ -0,0 +1,532 @@ +extern crate mbedtls; + +use self::mbedtls::ssl::context::IoCallback; + +use self::mbedtls::alloc::{Box as MbedtlsBox, List as MbedtlsList}; +use self::mbedtls::hash::{Md, Type as MdType}; +use self::mbedtls::pk::Pk; +use self::mbedtls::pkcs12::{Pfx, Pkcs12Error}; +use self::mbedtls::rng::{CtrDrbg, OsEntropy}; +use self::mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use self::mbedtls::ssl::{Config, Context, Version}; +use self::mbedtls::x509::certificate::Certificate as MbedtlsCert; +use self::mbedtls::Error as TlsError; +use self::mbedtls::Result as TlsResult; + +use std::error; +use std::fmt::{self, Debug}; +use std::fs; +use std::io::{self, Read}; +use std::sync::Arc; + +use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; + +fn load_ca_certs(dir: &str) -> TlsResult> { + let paths = fs::read_dir(dir).map_err(|_| TlsError::X509FileIoError)?; + + let mut certs = Vec::new(); + + for path in paths { + if let Ok(mut file) = fs::File::open(path.unwrap().path()) { + let mut contents = Vec::new(); + if let Ok(_) = file.read_to_end(&mut contents) { + contents.push(0); // needs NULL terminator + if let Ok(cert) = ::Certificate::from_pem(&contents) { + certs.push(cert); + } + } + } + } + + Ok(certs) +} + +fn load_system_trust_roots() -> Result, Error> { + let paths = [ + "/etc/pki/CA/certs", // Fedora, RHEL + "/usr/share/ca-certificates/mozilla", // Ubuntu, Debian, Arch, Gentoo + ]; + + for path in paths.iter() { + if let Ok(certs) = load_ca_certs(path) { + return Ok(certs); + } + } + + Err(Error::Custom( + "Could not load system default trust roots".to_owned(), + )) +} + +#[derive(Debug)] +pub enum Error { + Normal(TlsError), + Pkcs12(Pkcs12Error), + Custom(String), +} + +#[derive(Debug, Copy, Clone)] +enum ProtocolRole { + Client, + Server, +} + +impl From for Error { + fn from(err: TlsError) -> Error { + Error::Normal(err) + } +} + +impl From for Error { + fn from(err: Pkcs12Error) -> Error { + Error::Pkcs12(err) + } +} + +impl From for HandshakeError { + fn from(e: TlsError) -> HandshakeError { + HandshakeError::Failure(e.into()) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + // error::Error::source(&self) + todo!() + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Normal(ref e) => fmt::Display::fmt(e, fmt), + Error::Pkcs12(ref e) => fmt::Display::fmt(e, fmt), + Error::Custom(ref e) => fmt::Display::fmt(e, fmt), + } + } +} + +fn map_version(protocol: Option) -> Option { + if let Some(protocol) = protocol { + match protocol { + Protocol::Sslv3 => Some(Version::Ssl3), + Protocol::Tlsv10 => Some(Version::Tls1_0), + Protocol::Tlsv11 => Some(Version::Tls1_1), + Protocol::Tlsv12 => Some(Version::Tls1_2), + _ => None, + } + } else { + None + } +} + +pub struct Identity(Pfx); + +impl Identity { + pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result { + let pkcs12 = Pfx::parse(buf).map_err(Error::Pkcs12)?; + let decrypted = pkcs12.decrypt(&pass, None).map_err(Error::Pkcs12)?; + Ok(Identity(decrypted)) + } +} + +impl Clone for Identity { + fn clone(&self) -> Self { + Identity(self.0.clone()) + } +} + +#[derive(Clone)] +pub struct Certificate(MbedtlsBox); +unsafe impl Sync for Certificate {} + +impl Certificate { + pub fn from_der(buf: &[u8]) -> Result { + let cert = MbedtlsCert::from_der(buf).map_err(Error::Normal)?; + Ok(Certificate(cert)) + } + + pub fn from_pem(buf: &[u8]) -> Result { + // Mbedtls needs there to be a trailing NULL byte ... + let mut pem = buf.to_vec(); + pem.push(0); + let cert = MbedtlsCert::from_pem(&pem).map_err(Error::Normal)?; + Ok(Certificate(cert)) + } + + pub fn to_der(&self) -> Result, Error> { + let der = self.0.as_der().to_vec(); + Ok(der) + } +} + +fn cert_to_vec(certs_in: &[::Certificate]) -> Vec> { + certs_in.iter().map(|cert| (cert.0).0.clone()).collect() +} + +#[allow(unused)] +pub struct TlsStream { + role: ProtocolRole, + ca_certs: Vec>, + ca_cert_list: Arc>, + cred_pk: Option>, + cred_certs: Vec>, + cred_cert_list: Arc>, + entropy: Arc, + rng: Arc, + config: Arc, + ctx: Context, +} + +unsafe impl Sync for TlsStream {} +unsafe impl Send for TlsStream {} + +impl Debug for TlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsStream") + .field("role", &self.role) + .field("ca_certs", &self.ca_certs) + // .field("ca_cert_list", &self.ca_cert_list) + // .field("cred_pk", &self.cred_pk) + // .field("cred_certs", &self.cred_certs) + // .field("cred_cert_list", &self.cred_cert_list) + // .field("entropy", &self.entropy) + // .field("rng", &self.rng) + // .field("config", &self.config) + // .field("ctx", &self.ctx) + // .field("stream", &self.stream) + .finish() + } +} + +#[derive(Debug)] +pub struct MidHandshakeTlsStream { + stream: TlsStream, + error: Error, +} + +pub enum HandshakeError { + Failure(Error), + WouldBlock(MidHandshakeTlsStream), +} + +impl MidHandshakeTlsStream { + pub fn get_ref(&self) -> &S { + self.stream.get_ref() + } + + pub fn get_mut(&mut self) -> &mut S { + self.stream.get_mut() + } +} + +impl MidHandshakeTlsStream +where + S: io::Read + io::Write, +{ + pub fn handshake(self) -> Result, HandshakeError> { + Ok(self.stream) + } +} + +#[derive(Clone)] +pub struct TlsConnector { + min_protocol: Option, + max_protocol: Option, + root_certificates: Vec<::Certificate>, + identity: Option<::Identity>, + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + use_sni: bool, +} + +impl Debug for TlsConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsConnector") + .field("min_protocol", &self.min_protocol) + .field("max_protocol", &self.max_protocol) + // .field("root_certificates", &self.root_certificates) + // .field("identity", &self.identity) + .field("accept_invalid_certs", &self.accept_invalid_certs) + .field("accept_invalid_hostnames", &self.accept_invalid_hostnames) + .field("use_sni", &self.use_sni) + .finish() + } +} + +impl TlsConnector { + pub fn new(builder: &TlsConnectorBuilder) -> Result { + let trust_roots = if builder.root_certificates.len() > 0 { + builder.root_certificates.clone() + } else { + load_system_trust_roots()? + }; + + Ok(TlsConnector { + min_protocol: builder.min_protocol, + max_protocol: builder.max_protocol, + root_certificates: trust_roots, + identity: builder.identity.clone(), + accept_invalid_certs: builder.accept_invalid_certs, + accept_invalid_hostnames: builder.accept_invalid_hostnames, + use_sni: builder.use_sni, + }) + } + + pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> + where + S: IoCallback, + { + println!("CONNECTING IN MBETLS"); + let identity = if let Some(identity) = &self.identity { + let mut keys = (identity.0).0.private_keys().collect::>(); + let certificates = (identity.0).0.certificates().collect::>(); + + if keys.len() != 1 { + return Err(HandshakeError::Failure(Error::Custom( + "Unexpected number of keys in PKCS12 file".to_owned(), + ))); + } + if certificates.len() == 0 { + return Err(HandshakeError::Failure(Error::Custom( + "PKCS12 file is missing certificate chain".to_owned(), + ))); + } + + let mut cert_chain = vec![]; + for cert in certificates { + cert_chain.push(cert.0?); + } + + fn pk_clone(pk: &mut Pk) -> TlsResult { + let der = pk.write_private_der_vec()?; + Pk::from_private_key(&der, None) + } + let key = Box::new(keys.pop().unwrap().0.map_err(|_| TlsError::PkInvalidAlg)?); + + Some((cert_chain, key)) + } else { + None + }; + + let ca_vec = cert_to_vec(&self.root_certificates); + let mut ca_list = MbedtlsList::new(); + ca_vec.clone().into_iter().for_each(|c| ca_list.push(c)); + let ca_list = Arc::new(ca_list); + + let entropy = Arc::new(OsEntropy::new()); + let rng = Arc::new(CtrDrbg::new(entropy.clone(), None)?); + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(rng.clone()); + config.set_ca_list(ca_list.clone(), None); + + let mut cred_certs = Default::default(); + let mut cred_cert_list = Arc::new(MbedtlsList::new()); + let mut cred_pk = None; + + if let Some((certificates, mut pk)) = identity { + cred_certs = certificates.to_vec(); + let mut tmp = MbedtlsList::new(); + cred_certs.clone().into_iter().for_each(|c| tmp.push(c)); + cred_cert_list = Arc::new(tmp); + + let cpk = Arc::new(Pk::from_private_key(&pk.write_private_der_vec()?, None)?); + cred_pk = Some(cpk.clone()); + + config.push_cert(cred_cert_list.clone(), cpk.clone())?; + } + + if self.accept_invalid_certs { + config.set_authmode(mbedtls::ssl::config::AuthMode::None); + } + + if let Some(min_version) = map_version(self.min_protocol) { + config.set_min_version(min_version)?; + } + if let Some(max_version) = map_version(self.max_protocol) { + config.set_max_version(max_version)?; + } + + let config = Arc::new(config); + let mut ctx = Context::new(config.clone()); + + let hostname = if self.accept_invalid_hostnames { + None + } else { + Some(domain) + }; + + ctx.establish(stream, hostname)?; + + Ok(TlsStream { + role: ProtocolRole::Client, + ca_certs: ca_vec, + ca_cert_list: ca_list, + cred_pk: cred_pk, + cred_certs: cred_certs, + cred_cert_list: cred_cert_list, + entropy, + rng, + config, + ctx, + }) + } +} + +#[derive(Clone)] +pub struct TlsAcceptor { + identity: Pfx, + min_protocol: Option, + max_protocol: Option, +} + +impl TlsAcceptor { + pub fn new(builder: &TlsAcceptorBuilder) -> Result { + Ok(TlsAcceptor { + identity: (builder.identity.0).0.clone(), + min_protocol: builder.min_protocol, + max_protocol: builder.max_protocol, + }) + } + + pub fn accept(&self, stream: S) -> Result, HandshakeError> + where + S: IoCallback, + { + println!("ACCEPTING IN MBETLS"); + let mut keys = self.identity.private_keys().collect::>(); + let certificates = self.identity.certificates().collect::>(); + + if keys.len() != 1 { + return Err(HandshakeError::Failure(Error::Custom( + "Unexpected number of keys in PKCS12 file".to_owned(), + ))); + } + if certificates.len() == 0 { + return Err(HandshakeError::Failure(Error::Custom( + "PKCS12 file is missing certificate chain".to_owned(), + ))); + } + + let mut cert_chain = vec![]; + for cert in certificates { + cert_chain.push(cert.0?); + } + + let key: &mut Pk = &mut keys.pop().unwrap().0.map_err(|_| TlsError::PkInvalidAlg)?; + + let pk = Arc::new(Pk::from_private_key(&key.write_private_der_vec()?, None)?); + let mut cert_list = MbedtlsList::new(); + cert_chain + .to_vec() + .into_iter() + .for_each(|c| cert_list.push(c)); + let cert_list = Arc::new(cert_list); + + let entropy = Arc::new(OsEntropy::new()); + let rng = Arc::new(CtrDrbg::new(entropy.clone(), None)?); + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(rng.clone()); + config.push_cert(cert_list.clone(), pk.clone())?; + + if let Some(min_version) = map_version(self.min_protocol) { + config.set_min_version(min_version)?; + } + if let Some(max_version) = map_version(self.max_protocol) { + config.set_max_version(max_version)?; + } + + let config = Arc::new(config); + + let mut ctx = Context::new(config.clone()); + + ctx.establish(stream, None)?; + + Ok(TlsStream { + role: ProtocolRole::Server, + ca_certs: Vec::new(), + ca_cert_list: Arc::new(MbedtlsList::new()), + cred_pk: Some(pk), + cred_certs: cert_chain, + cred_cert_list: cert_list, + entropy, + rng, + config, + ctx, + }) + } +} + +impl TlsStream { + pub fn get_ref(&self) -> &S { + self.ctx.io().unwrap() + } + + pub fn get_mut(&mut self) -> &mut S { + self.ctx.io_mut().unwrap() + } + + pub fn buffered_read_size(&self) -> Result { + Ok(self.ctx.bytes_available()) + } + + pub fn peer_certificate(&self) -> Result, Error> { + match self.ctx.peer_cert()? { + None => Ok(None), + Some(certs) => match certs.iter().next() { + None => Ok(None), + Some(c) => Ok(Some(Certificate::from_der(c.as_der())?)), + }, + } + } + + fn server_certificate(&self) -> Result, Error> { + match self.role { + ProtocolRole::Client => self.peer_certificate(), + ProtocolRole::Server => match self.cred_certs.first() { + None => Ok(None), + Some(c) => Ok(Some(Certificate::from_der(c.as_der())?)), + }, + } + } + + pub fn tls_server_end_point(&self) -> Result>, Error> { + let cert = match self.server_certificate()? { + Some(cert) => cert, + None => return Ok(None), + }; + + let md = match cert.0.digest_type() { + MdType::Md5 | MdType::Sha1 => MdType::Sha256, + md => md, + }; + + let der = cert.to_der()?; + let mut digest = vec![0; 64]; + let len = Md::hash(md, &der, &mut digest).map_err(Error::Normal)?; + digest.truncate(len); + + Ok(Some(digest)) + } + + pub fn shutdown(&mut self) -> io::Result<()> { + // Shutdown happens as a result of drop ... + Ok(()) + } +} + +impl io::Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.ctx.read(buf) + } +} + +impl io::Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.ctx.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.ctx.flush() + } +}