Skip to content

Commit

Permalink
Implement Ref::{try_as_ref,try_into_ref,try_into_mut}
Browse files Browse the repository at this point in the history
Only `Ref::try_as_mut` remains missing, probably pending polonius
landing in rustc.

Partially fixes #1865
Supersedes #1184
  • Loading branch information
jswrenn committed Oct 17, 2024
1 parent 0bee231 commit b528e37
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,13 @@ impl<Src, Dst: ?Sized + TryFromBytes> ValidityError<Src, Dst> {
self.src
}

pub(crate) fn with_src<NewSrc>(self, new_src: NewSrc) -> ValidityError<NewSrc, Dst> {
// INVARIANT: `with_src` doesn't change the type of `Dst`, so the
// invariant that `Dst`'s alignment requirement is greater than one is
// preserved.
ValidityError { src: new_src, dst: SendSyncPhantomData::default() }
}

/// Maps the source value associated with the conversion error.
///
/// This can help mitigate [issues with `Send`, `Sync` and `'static`
Expand Down
311 changes: 307 additions & 4 deletions src/ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,85 @@ where
}
}

impl<B, T> Ref<B, T>
where
B: ByteSlice,
T: KnownLayout + ?Sized,
{
/// Attempts to dereference this `Ref<_, T>` into a `&T` without copying.
///
/// If the bytes of `self` are a valid instance of `T`, this method returns
/// a reference to those bytes interpreted as `T`. If those bytes are not a
/// valid instance of `T`, this returns `Err`.
///
/// # Examples
///
/// ```
/// use zerocopy::Ref;
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(C, packed)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(C)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
/// temperature: u8,
/// marshmallows: [[u8; 2]],
/// }
///
/// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..];
///
/// let r = Ref::<_, Packet>::new(bytes).unwrap();
/// let packet = Ref::try_as_ref(r).unwrap();
///
/// assert_eq!(packet.mug_size, 240);
/// assert_eq!(packet.temperature, 77);
/// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]);
/// ```
#[must_use = "has no side effects"]
#[inline(always)]
pub fn try_as_ref(r: &Self) -> Result<&T, ValidityError<&Self, T>>
where
T: TryFromBytes + Immutable,
{
// Presumably unreachable, since we've guarded each constructor of `Ref`.
static_assert_dst_is_not_zst!(T);

// SAFETY: We don't call any methods on `r` other than those provided by
// `ByteSlice`.
let b = unsafe { r.as_byte_slice() };

match Ptr::from_ref(b.deref()).try_cast_into_no_leftover::<T, BecauseImmutable>(None) {
Ok(candidate) => match candidate.try_into_valid() {
Ok(valid) => Ok(valid.as_ref()),
Err(e) => Err(e.map_src(|_| r)),
},
Err(CastError::Validity(i)) => match i {},
Err(CastError::Alignment(_) | CastError::Size(_)) => {
// SAFETY: By invariant on `Ref::0`, the referenced byte slice
// is aligned to `T`'s alignment and its size corresponds to a
// valid size for `T`. Since properties are checked upon
// constructing `Ref`, these failures are unreachable.
unsafe { core::hint::unreachable_unchecked() }
}
}
}
}

impl<'a, B, T> Ref<B, T>
where
B: 'a + IntoByteSlice<'a>,
T: FromBytes + KnownLayout + Immutable + ?Sized,
T: TryFromBytes + KnownLayout + Immutable + ?Sized,
{
/// Converts this `Ref` into a reference.
///
Expand All @@ -609,7 +684,10 @@ where
/// there is no conflict with a method on the inner type.
#[must_use = "has no side effects"]
#[inline(always)]
pub fn into_ref(r: Self) -> &'a T {
pub fn into_ref(r: Self) -> &'a T
where
T: FromBytes,
{
// Presumably unreachable, since we've guarded each constructor of `Ref`.
static_assert_dst_is_not_zst!(T);

Expand All @@ -627,12 +705,91 @@ where
let ptr = ptr.bikeshed_recall_valid();
ptr.as_ref()
}

/// Attempts to convert this `Ref<_, T>` into a `&T` without copying.
///
/// If the bytes of `self` are a valid instance of `T`, this method returns
/// a reference to those bytes interpreted as `T`. If those bytes are not a
/// valid instance of `T`, this returns `Err`.
///
/// # Examples
///
/// ```
/// use zerocopy::Ref;
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(C, packed)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout, Immutable)]
/// #[repr(C)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
/// temperature: u8,
/// marshmallows: [[u8; 2]],
/// }
///
/// let bytes = &[0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..];
///
/// let r = Ref::<_, Packet>::new(bytes).unwrap();
/// let packet = Ref::try_into_ref(r).unwrap();
///
/// assert_eq!(packet.mug_size, 240);
/// assert_eq!(packet.temperature, 77);
/// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]);
/// ```
#[must_use = "has no side effects"]
#[inline(always)]
pub fn try_into_ref(r: Self) -> Result<&'a T, ValidityError<Self, T>> {
// Presumably unreachable, since we've guarded each constructor of `Ref`.
static_assert_dst_is_not_zst!(T);

// SAFETY: We don't call any methods on `b` other than those provided by
// `ByteSlice`.
let bytes = unsafe { r.as_byte_slice() };

let bytes: &'_ [u8] = bytes.deref();

// Extend the lifetime of `bytes` to `'a`. This gives us a reference
// `bytes` with the same lifetime as if we had called
// `r.into_byte_slice()`, but without consuming `r`. This is valuable,
// since we will need to return `r` if validation fails.
//
// SAFETY: This is sound because `bytes` lives for `'a`. `Self` is
// `IntoByteSlice`, whose `.into_byte_slice()` method is guaranteed to
// produce a `&'a [u8]` with the same address and length as the slice
// obtained by `.deref()` (which is how `bytes` is obtained).
let bytes = unsafe { mem::transmute::<&[u8], &'a [u8]>(bytes) };

match Ptr::from_ref(bytes).try_cast_into_no_leftover::<T, BecauseImmutable>(None) {
Ok(candidate) => match candidate.try_into_valid() {
Ok(candidate) => Ok(candidate.as_ref()),
Err(e) => Err(e.with_src(r)),
},
Err(CastError::Validity(i)) => match i {},
Err(CastError::Alignment(_) | CastError::Size(_)) => {
// SAFETY: By invariant on `Ref::0`, the referenced byte slice
// is aligned to `T`'s alignment and its size corresponds to a
// valid size for `T`. Since properties are checked upon
// constructing `Ref`, these failures are unreachable.
unsafe { core::hint::unreachable_unchecked() }
}
}
}
}

impl<'a, B, T> Ref<B, T>
where
B: 'a + IntoByteSliceMut<'a>,
T: FromBytes + IntoBytes + KnownLayout + ?Sized,
T: TryFromBytes + IntoBytes + KnownLayout + ?Sized,
{
/// Converts this `Ref` into a mutable reference.
///
Expand All @@ -643,7 +800,10 @@ where
/// there is no conflict with a method on the inner type.
#[must_use = "has no side effects"]
#[inline(always)]
pub fn into_mut(r: Self) -> &'a mut T {
pub fn into_mut(r: Self) -> &'a mut T
where
T: FromBytes,
{
// Presumably unreachable, since we've guarded each constructor of `Ref`.
static_assert_dst_is_not_zst!(T);

Expand All @@ -661,6 +821,86 @@ where
let ptr = ptr.bikeshed_recall_valid();
ptr.as_mut()
}

/// Attempts to convert this `Ref<_, T>` into a `&mut T` without copying.
///
/// If the bytes of `self` are a valid instance of `T`, this method returns
/// a reference to those bytes interpreted as `T`. If those bytes are not a
/// valid instance of `T`, this returns `Err`.
///
/// # Examples
///
/// ```
/// use zerocopy::Ref;
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)]
/// #[repr(C, packed)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, IntoBytes, KnownLayout, Immutable)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
/// temperature: u8,
/// marshmallows: [[u8; 2]],
/// }
///
/// let bytes = &mut [0xC0, 0xC0, 240, 77, 0, 1, 2, 3, 4, 5][..];
///
/// let r = Ref::<_, Packet>::new(bytes).unwrap();
/// let packet = Ref::try_into_mut(r).unwrap();
///
/// assert_eq!(packet.mug_size, 240);
/// assert_eq!(packet.temperature, 77);
/// assert_eq!(packet.marshmallows, [[0, 1], [2, 3], [4, 5]]);
/// ```
#[must_use = "has no side effects"]
#[inline(always)]
pub fn try_into_mut(mut r: Self) -> Result<&'a mut T, ValidityError<Self, T>> {
// Presumably unreachable, since we've guarded each constructor of `Ref`.
static_assert_dst_is_not_zst!(T);

// SAFETY: We don't call any methods on `b` other than those provided by
// `ByteSliceMut`.
let bytes = unsafe { r.as_byte_slice_mut() };

let bytes: &'_ mut [u8] = bytes.deref_mut();

// Extend the lifetime of `bytes` to `'a`. This gives us a reference
// `bytes` with the same lifetime as if we had called
// `r.into_byte_slice_mut()`, but without consuming `r`. This is
// valuable, since we will need to return `r` if validation fails.
//
// SAFETY: This is sound because `bytes` lives for `'a`. `Self` is
// `IntoByteSliceMut`, whose `.into_byte_slice_mut()` method is
// guaranteed to produce a `&'a [u8]` with the same address and length
// as the slice obtained by `.deref()` (which is how `bytes` is
// obtained).
let bytes = unsafe { mem::transmute::<&mut [u8], &'a mut [u8]>(bytes) };

match Ptr::from_mut(bytes).try_cast_into_no_leftover::<T, BecauseExclusive>(None) {
Ok(candidate) => match candidate.try_into_valid() {
Ok(candidate) => Ok(candidate.as_mut()),
Err(e) => Err(e.with_src(r)),
},
Err(CastError::Validity(i)) => match i {},
Err(CastError::Alignment(_) | CastError::Size(_)) => {
// SAFETY: By invariant on `Ref::0`, the referenced byte slice
// is aligned to `T`'s alignment and its size corresponds to a
// valid size for `T`. Since properties are checked upon
// constructing `Ref`, these failures are unreachable.
unsafe { core::hint::unreachable_unchecked() }
}
}
}
}

impl<B, T> Ref<B, T>
Expand Down Expand Up @@ -1109,6 +1349,33 @@ mod tests {
assert!(Ref::<_, [AU64]>::from_suffix_with_elems(&buf.t[..], unreasonable_len).is_err());
}

#[test]
#[allow(unstable_name_collisions)]
#[allow(clippy::as_conversions)]
fn test_try_as_ref() {
#[allow(unused)]
use crate::util::AsAddress as _;

// valid source

let buf = Align::<[u8; 8], u64>::default();
let buf_addr = (&buf.t as *const [u8; 8]).addr();

let r = Ref::<_, u64>::from_bytes(&buf.t[..]).unwrap();
let rf = Ref::try_as_ref(&r).unwrap();
assert_eq!(rf, &0u64);
assert_eq!((rf as *const u64).addr(), buf_addr);

// invalid source

let buf = Align::<[u8; 1], u64>::new([42]);
let buf_addr = (&buf.t as *const [u8; 1]).addr();

let r = Ref::<_, bool>::from_bytes(&buf.t[..]).unwrap();
let re = Ref::try_as_ref(&r).unwrap_err();
assert_eq!(Ref::bytes(re.into_src()).addr(), buf_addr);
}

#[test]
#[allow(unstable_name_collisions)]
#[allow(clippy::as_conversions)]
Expand All @@ -1132,6 +1399,42 @@ mod tests {
assert_eq!(buf.t, [0xFF; 8]);
}

#[test]
#[allow(unstable_name_collisions)]
#[allow(clippy::as_conversions)]
fn test_try_into_ref_mut() {
#[allow(unused)]
use crate::util::AsAddress as _;

// valid source

let mut buf = Align::<[u8; 8], u64>::default();
let buf_addr = (&buf.t as *const [u8; 8]).addr();

let r = Ref::<_, u64>::from_bytes(&buf.t[..]).unwrap();
let rf = Ref::try_into_ref(r).unwrap();
assert_eq!(rf, &0u64);
assert_eq!((rf as *const u64).addr(), buf_addr);

let r = Ref::<_, u64>::from_bytes(&mut buf.t[..]).unwrap();
let rf = Ref::try_into_mut(r).unwrap();
assert_eq!(rf, &mut 0u64);
assert_eq!((rf as *mut u64).addr(), buf_addr);

// invalid source

let mut buf = Align::<[u8; 1], u64>::new([42]);
let buf_addr = (&buf.t as *const [u8; 1]).addr();

let r = Ref::<_, bool>::from_bytes(&buf.t[..]).unwrap();
let re = Ref::try_into_ref(r).unwrap_err();
assert_eq!(Ref::bytes(&re.into_src()).addr(), buf_addr);

let r = Ref::<_, bool>::from_bytes(&mut buf.t[..]).unwrap();
let re = Ref::try_into_mut(r).unwrap_err();
assert_eq!(Ref::bytes(&re.into_src()).addr(), buf_addr);
}

#[test]
fn test_display_debug() {
let buf = Align::<[u8; 8], u64>::default();
Expand Down

0 comments on commit b528e37

Please sign in to comment.