forked from rust-lang/socket2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
540 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
use crate::sys; | ||
use std::borrow::Borrow; | ||
use std::convert::TryInto as _; | ||
use std::io::IoSlice; | ||
use std::iter::FromIterator; | ||
|
||
#[derive(Debug, Clone)] | ||
struct MsgHdrWalker<B> { | ||
buffer: B, | ||
position: Option<usize>, | ||
} | ||
|
||
impl<B: AsRef<[u8]>> MsgHdrWalker<B> { | ||
fn next_ptr(&mut self) -> Option<*const libc::cmsghdr> { | ||
// Build a msghdr so we can use the functionality in libc. | ||
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() }; | ||
let buffer = self.buffer.as_ref(); | ||
// SAFETY: We're giving msghdr a mutable pointer to comply with the C | ||
// API. We'll only allow mutation of `cmsghdr`, however if `B` is | ||
// AsMut<[u8]>. | ||
msghdr.msg_control = buffer.as_ptr() as *mut _; | ||
msghdr.msg_controllen = buffer.len().try_into().expect("buffer is too long"); | ||
|
||
let nxt_hdr = if let Some(position) = self.position { | ||
if position >= buffer.len() { | ||
return None; | ||
} | ||
let cur_hdr = &buffer[position] as *const u8 as *const _; | ||
// Safety: msghdr is a valid pointer and cur_hdr is not null. | ||
unsafe { libc::CMSG_NXTHDR(&msghdr, cur_hdr) } | ||
} else { | ||
// Safety: msghdr is a valid pointer. | ||
unsafe { libc::CMSG_FIRSTHDR(&msghdr) } | ||
}; | ||
|
||
if nxt_hdr.is_null() { | ||
self.position = Some(buffer.len()); | ||
return None; | ||
} | ||
|
||
// SAFETY: nxt_hdr always points to data within the buffer, they must be | ||
// part of the same allocation. | ||
let distance = unsafe { (nxt_hdr as *const u8).offset_from(buffer.as_ptr()) }; | ||
// nxt_hdr is always ahead of the buffer and not null if we're here, | ||
// meaning the distance is always positive. | ||
self.position = Some(distance.try_into().unwrap()); | ||
Some(nxt_hdr) | ||
} | ||
|
||
fn next(&mut self) -> Option<(&libc::cmsghdr, &[u8])> { | ||
self.next_ptr().map(|cmsghdr| { | ||
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`. | ||
let data = unsafe { libc::CMSG_DATA(cmsghdr) }; | ||
let cmsghdr = unsafe { &*cmsghdr }; | ||
// SAFETY: data points to buffer and is controlled by control | ||
// message length. | ||
let data = unsafe { | ||
std::slice::from_raw_parts( | ||
data, | ||
(cmsghdr.cmsg_len as usize) | ||
.saturating_sub(std::mem::size_of::<libc::cmsghdr>()), | ||
) | ||
}; | ||
(cmsghdr, data) | ||
}) | ||
} | ||
} | ||
|
||
impl<B: AsRef<[u8]> + AsMut<[u8]>> MsgHdrWalker<B> { | ||
fn next_mut(&mut self) -> Option<(&mut libc::cmsghdr, &mut [u8])> { | ||
match self.next_ptr() { | ||
Some(cmsghdr) => { | ||
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`. | ||
let data = unsafe { libc::CMSG_DATA(cmsghdr) }; | ||
// SAFETY: The mutable pointer is safe because we're not going to | ||
// vend any concurrent access to the same memory region and B is | ||
// AsMut<[u8]> guaranteeing we have exclusive access to the buffer. | ||
let cmsghdr = cmsghdr as *mut libc::cmsghdr; | ||
let cmsghdr = unsafe { &mut *cmsghdr }; | ||
|
||
// We'll always yield the entirety of the rest of the buffer. | ||
let distance = unsafe { data.offset_from(self.buffer.as_ref().as_ptr()) }; | ||
// The data pointer is always part of the buffer, can't be before | ||
// it. | ||
let distance: usize = distance.try_into().unwrap(); | ||
Some((cmsghdr, &mut self.buffer.as_mut()[distance..])) | ||
} | ||
None => None, | ||
} | ||
} | ||
} | ||
|
||
/// A wrapper around a buffer that can be used to write ancillary control | ||
/// messages. | ||
#[derive(Debug)] | ||
pub struct CmsgWriter<B> { | ||
walker: MsgHdrWalker<B>, | ||
last_push: usize, | ||
} | ||
|
||
impl<B: AsMut<[u8]> + AsRef<[u8]>> CmsgWriter<B> { | ||
/// Creates a new [`CmsgBuffer`] backed by the bytes in `buffer`. | ||
pub fn new(buffer: B) -> Self { | ||
Self { | ||
walker: MsgHdrWalker { | ||
buffer, | ||
position: None, | ||
}, | ||
last_push: 0, | ||
} | ||
} | ||
|
||
/// Pushes a new control message `m` to the buffer. | ||
/// | ||
/// # Panics | ||
/// | ||
/// Panics if the contained buffer does not have enough space to fit `m`. | ||
pub fn push(&mut self, m: &Cmsg) { | ||
let (cmsg_level, cmsg_type, size) = m.level_type_size(); | ||
let (nxt_hdr, data) = self | ||
.walker | ||
.next_mut() | ||
.unwrap_or_else(|| panic!("can't fit message {:?}", m)); | ||
// Safety: All values are passed by copy. | ||
let cmsg_len = unsafe { libc::CMSG_LEN(size) }.try_into().unwrap(); | ||
*nxt_hdr = libc::cmsghdr { | ||
cmsg_len, | ||
cmsg_level, | ||
cmsg_type, | ||
}; | ||
m.write(&mut data[..size as usize]); | ||
// Always store the space required for the last push because the walker | ||
// maintains its position cursor at the currently written option, we | ||
// must always add the space for the last control message when returning | ||
// the consolidated buffer. | ||
self.last_push = unsafe { libc::CMSG_SPACE(size) } as usize; | ||
} | ||
} | ||
|
||
impl<B: AsMut<[u8]> + AsRef<[u8]>> Extend<Cmsg> for CmsgWriter<B> { | ||
fn extend<T: IntoIterator<Item = Cmsg>>(&mut self, iter: T) { | ||
for cmsg in iter { | ||
self.push(&cmsg) | ||
} | ||
} | ||
} | ||
|
||
impl<C: Borrow<Cmsg>> FromIterator<C> for CmsgWriter<Vec<u8>> { | ||
fn from_iter<T: IntoIterator<Item = C>>(iter: T) -> Self { | ||
let mut buff = CmsgWriter::new(vec![]); | ||
for cmsg in iter { | ||
let cmsg = cmsg.borrow(); | ||
buff.walker | ||
.buffer | ||
.resize(buff.walker.buffer.len() + cmsg.space(), 0); | ||
buff.push(&cmsg) | ||
} | ||
buff | ||
} | ||
} | ||
|
||
impl<B: AsRef<[u8]>> CmsgWriter<B> { | ||
pub(crate) fn io_slice(&self) -> IoSlice<'_> { | ||
IoSlice::new(self.buffer()) | ||
} | ||
|
||
pub(crate) fn buffer(&self) -> &[u8] { | ||
if let Some(position) = self.walker.position { | ||
&self.walker.buffer.as_ref()[..position + self.last_push] | ||
} else { | ||
&[] | ||
} | ||
} | ||
} | ||
|
||
/// An iterator over received control messages. | ||
#[derive(Debug, Clone)] | ||
pub struct CmsgIter<'a> { | ||
walker: MsgHdrWalker<&'a [u8]>, | ||
} | ||
|
||
impl<'a> CmsgIter<'a> { | ||
pub(crate) fn new(buffer: &'a [u8]) -> Self { | ||
Self { | ||
walker: MsgHdrWalker { | ||
buffer, | ||
position: None, | ||
}, | ||
} | ||
} | ||
} | ||
|
||
impl<'a> Iterator for CmsgIter<'a> { | ||
type Item = Cmsg; | ||
|
||
fn next(&mut self) -> Option<Self::Item> { | ||
self.walker.next().map( | ||
|( | ||
libc::cmsghdr { | ||
cmsg_len: _, | ||
cmsg_level, | ||
cmsg_type, | ||
}, | ||
data, | ||
)| Cmsg::from_raw(*cmsg_level, *cmsg_type, data), | ||
) | ||
} | ||
} | ||
|
||
/// An unknown control message. | ||
#[derive(Debug, Eq, PartialEq)] | ||
pub struct UnknownCmsg { | ||
cmsg_level: libc::c_int, | ||
cmsg_type: libc::c_int, | ||
} | ||
|
||
/// Control messages. | ||
#[derive(Debug, Eq, PartialEq)] | ||
pub enum Cmsg { | ||
/// The `IP_TTL` control message. | ||
IpTtl(u8), | ||
/// The `IPV6_PKTINFO` control message. | ||
Ipv6PktInfo { | ||
/// The address the packet is destined to/received from. Equivalent to | ||
/// `in6_pktinfo.ipi6_addr`. | ||
addr: std::net::Ipv6Addr, | ||
/// The interface index the packet is destined to/received from. | ||
/// Equivalent to `in6_pktinfo.ipi6_ifindex`. | ||
ifindex: u32, | ||
}, | ||
/// An unrecognized control message. | ||
Unknown(UnknownCmsg), | ||
} | ||
|
||
impl Cmsg { | ||
/// Returns the amount of buffer space required to hold this option. | ||
pub fn space(&self) -> usize { | ||
let (_, _, size) = self.level_type_size(); | ||
// Safety: All values are passed by copy. | ||
let size = unsafe { libc::CMSG_SPACE(size) }; | ||
size as usize | ||
} | ||
|
||
fn level_type_size(&self) -> (libc::c_int, libc::c_int, libc::c_uint) { | ||
match self { | ||
Cmsg::IpTtl(_) => ( | ||
libc::IPPROTO_IP, | ||
libc::IP_TTL, | ||
// TTL is encoded as a u32. | ||
std::mem::size_of::<u32>() as libc::c_uint, | ||
), | ||
Cmsg::Ipv6PktInfo { .. } => ( | ||
libc::IPPROTO_IPV6, | ||
libc::IPV6_PKTINFO, | ||
std::mem::size_of::<libc::in6_pktinfo>() as libc::c_uint, | ||
), | ||
Cmsg::Unknown(UnknownCmsg { | ||
cmsg_level, | ||
cmsg_type, | ||
}) => (*cmsg_level, *cmsg_type, 0), | ||
} | ||
} | ||
|
||
fn write(&self, buffer: &mut [u8]) { | ||
match self { | ||
Cmsg::IpTtl(ttl) => { | ||
let value: u32 = (*ttl).into(); | ||
let value = value.to_ne_bytes(); | ||
(&mut buffer[..value.len()]).copy_from_slice(&value[..]); | ||
} | ||
Cmsg::Ipv6PktInfo { addr, ifindex } => { | ||
let pktinfo = libc::in6_pktinfo { | ||
ipi6_addr: sys::to_in6_addr(addr), | ||
ipi6_ifindex: *ifindex, | ||
}; | ||
let size = std::mem::size_of::<libc::in6_pktinfo>(); | ||
assert_eq!(buffer.len(), size); | ||
// Safety: `pktinfo` is valid for reads for its size in bytes. | ||
// `buffer` is valid for write for the same length, as | ||
// guaranteed by the assertion above. Copy unit is byte, so | ||
// alignment is okay. The two regions do not overlap. | ||
unsafe { | ||
std::ptr::copy_nonoverlapping( | ||
&pktinfo as *const libc::in6_pktinfo as *const _, | ||
buffer.as_mut_ptr(), | ||
size, | ||
) | ||
} | ||
} | ||
Cmsg::Unknown(_) => { | ||
// NOTE: We don't actually allow users of the public API | ||
// serialize unknown control messages, but we use this code path | ||
// for testing. | ||
} | ||
} | ||
} | ||
|
||
fn from_raw(cmsg_level: libc::c_int, cmsg_type: libc::c_int, bytes: &[u8]) -> Self { | ||
match (cmsg_level, cmsg_type) { | ||
(libc::IPPROTO_IP, libc::IP_TTL) => { | ||
assert!(bytes.len() >= std::mem::size_of::<u32>(), "{:?}", bytes); | ||
Cmsg::IpTtl(bytes[0]) | ||
} | ||
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { | ||
let mut pktinfo = unsafe { std::mem::zeroed::<libc::in6_pktinfo>() }; | ||
let size = std::mem::size_of::<libc::in6_pktinfo>(); | ||
assert!(bytes.len() >= size, "{:?}", bytes); | ||
// Safety: `pktinfo` is valid for writes for its size in bytes. | ||
// `buffer` is valid for read for the same length, as | ||
// guaranteed by the assertion above. Copy unit is byte, so | ||
// alignment is okay. The two regions do not overlap. | ||
unsafe { | ||
std::ptr::copy_nonoverlapping( | ||
bytes.as_ptr(), | ||
&mut pktinfo as *mut libc::in6_pktinfo as *mut _, | ||
size, | ||
) | ||
} | ||
Cmsg::Ipv6PktInfo { | ||
addr: sys::from_in6_addr(pktinfo.ipi6_addr), | ||
ifindex: pktinfo.ipi6_ifindex, | ||
} | ||
} | ||
(cmsg_level, cmsg_type) => Cmsg::Unknown(UnknownCmsg { | ||
cmsg_level, | ||
cmsg_type, | ||
}), | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn ser_deser() { | ||
let cmsgs = [ | ||
Cmsg::IpTtl(2), | ||
Cmsg::Ipv6PktInfo { | ||
addr: std::net::Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), | ||
ifindex: 13, | ||
}, | ||
Cmsg::Unknown(UnknownCmsg { | ||
cmsg_level: 12345678, | ||
cmsg_type: 87654321, | ||
}), | ||
]; | ||
let buffer: CmsgWriter<_> = cmsgs.iter().collect(); | ||
let deser = CmsgIter::new(buffer.buffer()).collect::<Vec<_>>(); | ||
assert_eq!(&cmsgs[..], &deser[..]); | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
fn ser_insufficient_space_panics() { | ||
let mut buffer = CmsgWriter::new([0; 3]); | ||
buffer.push(&Cmsg::IpTtl(2)); | ||
} | ||
|
||
#[test] | ||
fn empty_deser() { | ||
assert_eq!(CmsgIter::new(&[]).next(), None); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.