-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add missing where clause to guarentee correct derivation of Send
- Loading branch information
Showing
2 changed files
with
368 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
pub mod context { | ||
use core::any::Any; | ||
use core::result::Result as StdResult; | ||
#[cfg(feature = "std")] | ||
use std::io::{Read, Write, Result as IoResult}; | ||
#[cfg(feature = "std")] | ||
use std::sync::Arc; | ||
use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; | ||
use mbedtls_sys::types::size_t; | ||
use mbedtls_sys::*; | ||
use crate::alloc::{List as MbedtlsList}; | ||
use crate::error::{Error, Result, IntoResult}; | ||
use crate::pk::Pk; | ||
use crate::private::UnsafeFrom; | ||
use crate::ssl::config::{Config, Version, AuthMode}; | ||
use crate::x509::{Certificate, Crl, VerifyError}; | ||
pub trait IoCallback: Any { | ||
unsafe extern "C" fn call_recv( | ||
user_data: *mut c_void, | ||
data: *mut c_uchar, | ||
len: size_t, | ||
) -> c_int | ||
where | ||
Self: Sized; | ||
unsafe extern "C" fn call_send( | ||
user_data: *mut c_void, | ||
data: *const c_uchar, | ||
len: size_t, | ||
) -> c_int | ||
where | ||
Self: Sized; | ||
fn data_ptr(&mut self) -> *mut c_void; | ||
} | ||
impl<IO: Read + Write + 'static> IoCallback for IO { | ||
unsafe extern "C" fn call_recv( | ||
user_data: *mut c_void, | ||
data: *mut c_uchar, | ||
len: size_t, | ||
) -> c_int { | ||
let len = if len > (c_int::max_value() as size_t) { | ||
c_int::max_value() as size_t | ||
} else { | ||
len | ||
}; | ||
match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len)) | ||
{ | ||
Ok(i) => i as c_int, | ||
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, | ||
} | ||
} | ||
unsafe extern "C" fn call_send( | ||
user_data: *mut c_void, | ||
data: *const c_uchar, | ||
len: size_t, | ||
) -> c_int { | ||
let len = if len > (c_int::max_value() as size_t) { | ||
c_int::max_value() as size_t | ||
} else { | ||
len | ||
}; | ||
match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) { | ||
Ok(i) => i as c_int, | ||
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, | ||
} | ||
} | ||
fn data_ptr(&mut self) -> *mut c_void { | ||
self as *mut IO as *mut _ | ||
} | ||
} | ||
#[allow(dead_code)] | ||
#[repr(C)] | ||
pub struct Context<S> { | ||
inner: ::mbedtls_sys::ssl_context, | ||
config: Arc<Config>, | ||
io: Option<Box<S>>, | ||
handshake_ca_cert: Option<Arc<MbedtlsList<Certificate>>>, | ||
handshake_crl: Option<Arc<Crl>>, | ||
handshake_cert: Vec<Arc<MbedtlsList<Certificate>>>, | ||
handshake_pk: Vec<Arc<Pk>>, | ||
} | ||
#[allow(dead_code)] | ||
impl<S> Context<S> { | ||
pub(crate) fn into_inner(self) -> ::mbedtls_sys::ssl_context { | ||
let inner = self.inner; | ||
::core::mem::forget(self); | ||
inner | ||
} | ||
pub(crate) fn handle(&self) -> &::mbedtls_sys::ssl_context { | ||
&self.inner | ||
} | ||
pub(crate) fn handle_mut(&mut self) -> &mut ::mbedtls_sys::ssl_context { | ||
&mut self.inner | ||
} | ||
} | ||
unsafe impl<S> Send for Context<S> where S: Send {} | ||
impl<'a, S> Into<*const ssl_context> for &'a Context<S> { | ||
fn into(self) -> *const ssl_context { | ||
self.handle() | ||
} | ||
} | ||
impl<'a, S> Into<*mut ssl_context> for &'a mut Context<S> { | ||
fn into(self) -> *mut ssl_context { | ||
self.handle_mut() | ||
} | ||
} | ||
impl<S> Context<S> { | ||
#[doc = r" Needed for compatibility with mbedtls - where we could pass"] | ||
#[doc = r" `*const` but function signature requires `*mut`"] | ||
#[allow(dead_code)] | ||
pub(crate) unsafe fn inner_ffi_mut(&self) -> *mut ssl_context { | ||
self.handle() as *const _ as *mut ssl_context | ||
} | ||
} | ||
impl<'a, S> crate::private::UnsafeFrom<*const ssl_context> for &'a Context<S> { | ||
unsafe fn from(ptr: *const ssl_context) -> Option<Self> { | ||
(ptr as *const Context<S>).as_ref() | ||
} | ||
} | ||
impl<'a, S> crate::private::UnsafeFrom<*mut ssl_context> for &'a mut Context<S> { | ||
unsafe fn from(ptr: *mut ssl_context) -> Option<Self> { | ||
(ptr as *mut Context<S>).as_mut() | ||
} | ||
} | ||
impl<S: IoCallback> Context<S> { | ||
pub fn establish(&mut self, io: S, hostname: Option<&str>) -> Result<()> { | ||
unsafe { | ||
let mut io = Box::new(io); | ||
ssl_session_reset(self.into()).into_result()?; | ||
self.set_hostname(hostname)?; | ||
let ptr = &mut *io as *mut _ as *mut c_void; | ||
ssl_set_bio( | ||
self.into(), | ||
ptr, | ||
Some(S::call_send), | ||
Some(S::call_recv), | ||
None, | ||
); | ||
self.io = Some(io); | ||
self.handshake_cert.clear(); | ||
self.handshake_pk.clear(); | ||
self.handshake_ca_cert = None; | ||
self.handshake_crl = None; | ||
match ssl_handshake(self.into()).into_result() { | ||
Err(e) => { | ||
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); | ||
self.io = None; | ||
Err(e) | ||
} | ||
Ok(_) => Ok(()), | ||
} | ||
} | ||
} | ||
} | ||
impl<S> Context<S> { | ||
pub fn new(config: Arc<Config>) -> Self { | ||
let mut inner = ssl_context::default(); | ||
unsafe { | ||
ssl_init(&mut inner); | ||
ssl_setup(&mut inner, (&*config).into()); | ||
}; | ||
Context { | ||
inner, | ||
config: config.clone(), | ||
io: None, | ||
handshake_ca_cert: None, | ||
handshake_crl: None, | ||
handshake_cert: ::alloc::vec::Vec::new(), | ||
handshake_pk: ::alloc::vec::Vec::new(), | ||
} | ||
} | ||
#[cfg(feature = "std")] | ||
fn set_hostname(&mut self, hostname: Option<&str>) -> Result<()> { | ||
if let Some(s) = hostname { | ||
let cstr = ::std::ffi::CString::new(s).map_err(|_| Error::SslBadInputData)?; | ||
unsafe { | ||
ssl_set_hostname(self.into(), cstr.as_ptr()) | ||
.into_result() | ||
.map(|_| ()) | ||
} | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
pub fn verify_result(&self) -> StdResult<(), VerifyError> { | ||
match unsafe { ssl_get_verify_result(self.into()) } { | ||
0 => Ok(()), | ||
flags => Err(VerifyError::from_bits_truncate(flags)), | ||
} | ||
} | ||
pub fn config(&self) -> &Arc<Config> { | ||
&self.config | ||
} | ||
pub fn close(&mut self) { | ||
unsafe { | ||
ssl_close_notify(self.into()); | ||
ssl_set_bio(self.into(), ::core::ptr::null_mut(), None, None, None); | ||
self.io = None; | ||
} | ||
} | ||
pub fn io(&self) -> Option<&Box<S>> { | ||
self.io.as_ref() | ||
} | ||
pub fn io_mut(&mut self) -> Option<&mut Box<S>> { | ||
self.io.as_mut() | ||
} | ||
#[doc = " Return the minor number of the negotiated TLS version"] | ||
pub fn minor_version(&self) -> i32 { | ||
self.inner.minor_ver | ||
} | ||
#[doc = " Return the major number of the negotiated TLS version"] | ||
pub fn major_version(&self) -> i32 { | ||
self.inner.major_ver | ||
} | ||
#[doc = " Return the number of bytes currently available to read that"] | ||
#[doc = " are stored in the Session's internal read buffer"] | ||
pub fn bytes_available(&self) -> usize { | ||
unsafe { ssl_get_bytes_avail(self.into()) } | ||
} | ||
pub fn version(&self) -> Version { | ||
let major = self.major_version(); | ||
{ | ||
match (&major, &3) { | ||
(left_val, right_val) => { | ||
if !(*left_val == *right_val) { | ||
let kind = ::core::panicking::AssertKind::Eq; | ||
::core::panicking::assert_failed( | ||
kind, | ||
&*left_val, | ||
&*right_val, | ||
::core::option::Option::None, | ||
); | ||
} | ||
} | ||
} | ||
}; | ||
let minor = self.minor_version(); | ||
match minor { | ||
0 => Version::Ssl3, | ||
1 => Version::Tls1_0, | ||
2 => Version::Tls1_1, | ||
3 => Version::Tls1_2, | ||
_ => ::core::panicking::panic_fmt(::core::fmt::Arguments::new_v1( | ||
&["internal error: entered unreachable code: "], | ||
&match (&"unexpected TLS version",) { | ||
(arg0,) => [::core::fmt::ArgumentV1::new( | ||
arg0, | ||
::core::fmt::Display::fmt, | ||
)], | ||
}, | ||
)), | ||
} | ||
} | ||
#[doc = " Return the 16-bit ciphersuite identifier."] | ||
#[doc = " All assigned ciphersuites are listed by the IANA in"] | ||
#[doc = " https://www.iana.org/assignments/tls-parameters/tls-parameters.txt"] | ||
pub fn ciphersuite(&self) -> Result<u16> { | ||
if self.inner.session.is_null() { | ||
return Err(Error::SslBadInputData); | ||
} | ||
Ok(unsafe { self.inner.session.as_ref().unwrap().ciphersuite as u16 }) | ||
} | ||
pub fn peer_cert(&self) -> Result<Option<&MbedtlsList<Certificate>>> { | ||
if self.inner.session.is_null() { | ||
return Err(Error::SslBadInputData); | ||
} | ||
unsafe { | ||
let peer_cert: &MbedtlsList<Certificate> = UnsafeFrom::from( | ||
&((*self.inner.session).peer_cert) as *const *mut x509_crt | ||
as *const *const x509_crt, | ||
) | ||
.ok_or(Error::SslBadInputData)?; | ||
Ok(Some(peer_cert)) | ||
} | ||
} | ||
} | ||
impl<S> Drop for Context<S> { | ||
fn drop(&mut self) { | ||
unsafe { | ||
self.close(); | ||
ssl_free(self.into()); | ||
} | ||
} | ||
} | ||
impl<S> Read for Context<S> { | ||
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { | ||
match unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result() } { | ||
Err(Error::SslPeerCloseNotify) => Ok(0), | ||
Err(e) => Err(crate::private::error_to_io_error(e)), | ||
Ok(i) => Ok(i as usize), | ||
} | ||
} | ||
} | ||
impl<S> Write for Context<S> { | ||
fn write(&mut self, buf: &[u8]) -> IoResult<usize> { | ||
match unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result() } { | ||
Err(Error::SslPeerCloseNotify) => Ok(0), | ||
Err(e) => Err(crate::private::error_to_io_error(e)), | ||
Ok(i) => Ok(i as usize), | ||
} | ||
} | ||
fn flush(&mut self) -> IoResult<()> { | ||
Ok(()) | ||
} | ||
} | ||
pub struct HandshakeContext<'ctx> { | ||
pub context: &'ctx mut Context<Box<dyn Any>>, | ||
} | ||
impl<'ctx> HandshakeContext<'ctx> { | ||
pub(crate) fn init(context: &'ctx mut Context<Box<dyn Any>>) -> Self { | ||
HandshakeContext { context } | ||
} | ||
pub fn set_authmode(&mut self, am: AuthMode) -> Result<()> { | ||
if self.context.inner.handshake as *const _ == ::core::ptr::null() { | ||
return Err(Error::SslBadInputData); | ||
} | ||
unsafe { ssl_set_hs_authmode(self.context.into(), am as i32) } | ||
Ok(()) | ||
} | ||
pub fn set_ca_list( | ||
&mut self, | ||
chain: Arc<MbedtlsList<Certificate>>, | ||
crl: Option<Arc<Crl>>, | ||
) -> Result<()> { | ||
if self.context.inner.handshake as *const _ == ::core::ptr::null() { | ||
return Err(Error::SslBadInputData); | ||
} | ||
unsafe { | ||
ssl_set_hs_ca_chain( | ||
self.context.into(), | ||
chain.inner_ffi_mut(), | ||
crl.as_ref() | ||
.map(|crl| crl.inner_ffi_mut()) | ||
.unwrap_or(::core::ptr::null_mut()), | ||
); | ||
} | ||
self.context.handshake_ca_cert = Some(chain); | ||
self.context.handshake_crl = crl; | ||
Ok(()) | ||
} | ||
#[doc = " If this is never called, will use the set of private keys and"] | ||
#[doc = " certificates configured in the `Config` associated with this `Context`."] | ||
#[doc = " If this is called at least once, all those are ignored and the set"] | ||
#[doc = " specified using this function is used."] | ||
pub fn push_cert( | ||
&mut self, | ||
chain: Arc<MbedtlsList<Certificate>>, | ||
key: Arc<Pk>, | ||
) -> Result<()> { | ||
if self.context.inner.handshake as *const _ == ::core::ptr::null() { | ||
return Err(Error::SslBadInputData); | ||
} | ||
unsafe { | ||
ssl_set_hs_own_cert( | ||
self.context.into(), | ||
chain.inner_ffi_mut(), | ||
key.inner_ffi_mut(), | ||
) | ||
.into_result()?; | ||
} | ||
self.context.handshake_cert.push(chain); | ||
self.context.handshake_pk.push(key); | ||
Ok(()) | ||
} | ||
} | ||
} |