Skip to content

Commit

Permalink
Add missing where clause to guarentee correct derivation of Send
Browse files Browse the repository at this point in the history
  • Loading branch information
MabezDev committed Dec 10, 2021
1 parent 60833da commit e03e46f
Show file tree
Hide file tree
Showing 2 changed files with 368 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mbedtls/src/wrapper_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ macro_rules! define_struct {
);

as_item!(
unsafe impl<$($g)*> Send for $name<$($g)*> {}
unsafe impl<$($g)*> Send for $name<$($g)*>
where $($g: Send)*
{}
);
};

Expand Down
365 changes: 365 additions & 0 deletions mbedtls/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
pub mod context {
use core::any::Any;
use core::result::Result as StdResult;
#[cfg(feature = "std")]
use std::io::{Read, Write, Result as IoResult};
#[cfg(feature = "std")]
use std::sync::Arc;
use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
use mbedtls_sys::types::size_t;
use mbedtls_sys::*;
use crate::alloc::{List as MbedtlsList};
use crate::error::{Error, Result, IntoResult};
use crate::pk::Pk;
use crate::private::UnsafeFrom;
use crate::ssl::config::{Config, Version, AuthMode};
use crate::x509::{Certificate, Crl, VerifyError};
pub trait IoCallback: Any {
unsafe extern "C" fn call_recv(
user_data: *mut c_void,
data: *mut c_uchar,
len: size_t,
) -> c_int
where
Self: Sized;
unsafe extern "C" fn call_send(
user_data: *mut c_void,
data: *const c_uchar,
len: size_t,
) -> c_int
where
Self: Sized;
fn data_ptr(&mut self) -> *mut c_void;
}
impl<IO: Read + Write + 'static> IoCallback for IO {
unsafe extern "C" fn call_recv(
user_data: *mut c_void,
data: *mut c_uchar,
len: size_t,
) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len))
{
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}
unsafe extern "C" fn call_send(
user_data: *mut c_void,
data: *const c_uchar,
len: size_t,
) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}
fn data_ptr(&mut self) -> *mut c_void {
self as *mut IO as *mut _
}
}
#[allow(dead_code)]
#[repr(C)]
pub struct Context<S> {
inner: ::mbedtls_sys::ssl_context,
config: Arc<Config>,
io: Option<Box<S>>,
handshake_ca_cert: Option<Arc<MbedtlsList<Certificate>>>,
handshake_crl: Option<Arc<Crl>>,
handshake_cert: Vec<Arc<MbedtlsList<Certificate>>>,
handshake_pk: Vec<Arc<Pk>>,
}
#[allow(dead_code)]
impl<S> Context<S> {
pub(crate) fn into_inner(self) -> ::mbedtls_sys::ssl_context {
let inner = self.inner;
::core::mem::forget(self);
inner
}
pub(crate) fn handle(&self) -> &::mbedtls_sys::ssl_context {
&self.inner
}
pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::ssl_context {
&mut self.inner
}
}
unsafe impl<S> Send for Context<S> where S: Send {}
impl<'a, S> Into<*const ssl_context> for &'a Context<S> {
fn into(self) -> *const ssl_context {
self.handle()
}
}
impl<'a, S> Into<*mut ssl_context> for &'a mut Context<S> {
fn into(self) -> *mut ssl_context {
self.handle_mut()
}
}
impl<S> Context<S> {
#[doc = r" Needed for compatibility with mbedtls - where we could pass"]
#[doc = r" `*const` but function signature requires `*mut`"]
#[allow(dead_code)]
pub(crate) unsafe fn inner_ffi_mut(&self) -> *mut ssl_context {
self.handle() as *const _ as *mut ssl_context
}
}
impl<'a, S> crate::private::UnsafeFrom<*const ssl_context> for &'a Context<S> {
unsafe fn from(ptr: *const ssl_context) -> Option<Self> {
(ptr as *const Context<S>).as_ref()
}
}
impl<'a, S> crate::private::UnsafeFrom<*mut ssl_context> for &'a mut Context<S> {
unsafe fn from(ptr: *mut ssl_context) -> Option<Self> {
(ptr as *mut Context<S>).as_mut()
}
}
impl<S: IoCallback> Context<S> {
pub fn establish(&mut self, io: S, hostname: Option<&str>) -> Result<()> {
unsafe {
let mut io = Box::new(io);
ssl_session_reset(self.into()).into_result()?;
self.set_hostname(hostname)?;
let ptr = &mut *io as *mut _ as *mut c_void;
ssl_set_bio(
self.into(),
ptr,
Some(S::call_send),
Some(S::call_recv),
None,
);
self.io = Some(io);
self.handshake_cert.clear();
self.handshake_pk.clear();
self.handshake_ca_cert = None;
self.handshake_crl = None;
match ssl_handshake(self.into()).into_result() {
Err(e) => {
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None);
self.io = None;
Err(e)
}
Ok(_) => Ok(()),
}
}
}
}
impl<S> Context<S> {
pub fn new(config: Arc<Config>) -> Self {
let mut inner = ssl_context::default();
unsafe {
ssl_init(&mut inner);
ssl_setup(&mut inner, (&*config).into());
};
Context {
inner,
config: config.clone(),
io: None,
handshake_ca_cert: None,
handshake_crl: None,
handshake_cert: ::alloc::vec::Vec::new(),
handshake_pk: ::alloc::vec::Vec::new(),
}
}
#[cfg(feature = "std")]
fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> {
if let Some(s) = hostname {
let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?;
unsafe {
ssl_set_hostname(self.into(), cstr.as_ptr())
.into_result()
.map(|_| ())
}
} else {
Ok(())
}
}
pub fn verify_result(&self) -> StdResult<(), VerifyError> {
match unsafe { ssl_get_verify_result(self.into()) } {
0 => Ok(()),
flags => Err(VerifyError::from_bits_truncate(flags)),
}
}
pub fn config(&self) -> &Arc<Config> {
&self.config
}
pub fn close(&mut self) {
unsafe {
ssl_close_notify(self.into());
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None);
self.io = None;
}
}
pub fn io(&self) -> Option<&Box<S>> {
self.io.as_ref()
}
pub fn io_mut(&mut self) -> Option<&mut Box<S>> {
self.io.as_mut()
}
#[doc = " Return the minor number of the negotiated TLS version"]
pub fn minor_version(&self) -> i32 {
self.inner.minor_ver
}
#[doc = " Return the major number of the negotiated TLS version"]
pub fn major_version(&self) -> i32 {
self.inner.major_ver
}
#[doc = " Return the number of bytes currently available to read that"]
#[doc = " are stored in the Session's internal read buffer"]
pub fn bytes_available(&self) -> usize {
unsafe { ssl_get_bytes_avail(self.into()) }
}
pub fn version(&self) -> Version {
let major = self.major_version();
{
match (&major, &3) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
}
};
let minor = self.minor_version();
match minor {
0 => Version::Ssl3,
1 => Version::Tls1_0,
2 => Version::Tls1_1,
3 => Version::Tls1_2,
_ => ::core::panicking::panic_fmt(::core::fmt::Arguments::new_v1(
&["internal error: entered unreachable code: "],
&match (&"unexpected TLS version",) {
(arg0,) => [::core::fmt::ArgumentV1::new(
arg0,
::core::fmt::Display::fmt,
)],
},
)),
}
}
#[doc = " Return the 16-bit ciphersuite identifier."]
#[doc = " All assigned ciphersuites are listed by the IANA in"]
#[doc = " https://www.iana.org/assignments/tls-parameters/tls-parameters.txt"]
pub fn ciphersuite(&self) -> Result<u16> {
if self.inner.session.is_null() {
return Err(Error::SslBadInputData);
}
Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 })
}
pub fn peer_cert(&self) -> Result<Option<&MbedtlsList<Certificate>>> {
if self.inner.session.is_null() {
return Err(Error::SslBadInputData);
}
unsafe {
let peer_cert: &MbedtlsList<Certificate> = UnsafeFrom::from(
&((*self.inner.session).peer_cert) as *const *mut x509_crt
as *const *const x509_crt,
)
.ok_or(Error::SslBadInputData)?;
Ok(Some(peer_cert))
}
}
}
impl<S> Drop for Context<S> {
fn drop(&mut self) {
unsafe {
self.close();
ssl_free(self.into());
}
}
}
impl<S> Read for Context<S> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } {
Err(Error::SslPeerCloseNotify) => Ok(0),
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i as usize),
}
}
}
impl<S> Write for Context<S> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } {
Err(Error::SslPeerCloseNotify) => Ok(0),
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i as usize),
}
}
fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}
pub struct HandshakeContext<'ctx> {
pub context: &'ctx mut Context<Box<dyn Any>>,
}
impl<'ctx> HandshakeContext<'ctx> {
pub(crate) fn init(context: &'ctx mut Context<Box<dyn Any>>) -> Self {
HandshakeContext { context }
}
pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> {
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
return Err(Error::SslBadInputData);
}
unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) }
Ok(())
}
pub fn set_ca_list(
&mut self,
chain: Arc<MbedtlsList<Certificate>>,
crl: Option<Arc<Crl>>,
) -> Result<()> {
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
return Err(Error::SslBadInputData);
}
unsafe {
ssl_set_hs_ca_chain(
self.context.into(),
chain.inner_ffi_mut(),
crl.as_ref()
.map(|crl| crl.inner_ffi_mut())
.unwrap_or(::core::ptr::null_mut()),
);
}
self.context.handshake_ca_cert = Some(chain);
self.context.handshake_crl = crl;
Ok(())
}
#[doc = " If this is never called, will use the set of private keys and"]
#[doc = " certificates configured in the `Config` associated with this `Context`."]
#[doc = " If this is called at least once, all those are ignored and the set"]
#[doc = " specified using this function is used."]
pub fn push_cert(
&mut self,
chain: Arc<MbedtlsList<Certificate>>,
key: Arc<Pk>,
) -> Result<()> {
if self.context.inner.handshake as *const _ == ::core::ptr::null() {
return Err(Error::SslBadInputData);
}
unsafe {
ssl_set_hs_own_cert(
self.context.into(),
chain.inner_ffi_mut(),
key.inner_ffi_mut(),
)
.into_result()?;
}
self.context.handshake_cert.push(chain);
self.context.handshake_pk.push(key);
Ok(())
}
}
}

0 comments on commit e03e46f

Please sign in to comment.