diff --git a/src/net/mod.rs b/src/net/mod.rs index 7ec8bc698..7fd5c20f1 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -25,7 +25,7 @@ pub use crate::maybe_polyfill::net::{ }; pub use send_recv::*; pub use socket::*; -pub use socket_addr_any::{SocketAddrAny, SocketAddrStorage}; +pub use socket_addr_any::{RawSocketAddr, SocketAddrAny, SocketAddrStorage}; #[cfg(not(any(windows, target_os = "wasi")))] pub use socketpair::socketpair; pub use types::*; diff --git a/src/net/socket_addr_any.rs b/src/net/socket_addr_any.rs index b43d09667..86ca29387 100644 --- a/src/net/socket_addr_any.rs +++ b/src/net/socket_addr_any.rs @@ -14,6 +14,7 @@ use crate::net::xdp::SocketAddrXdp; #[cfg(unix)] use crate::net::SocketAddrUnix; use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6}; +use crate::utils::{as_mut_ptr, as_ptr}; use crate::{backend, io}; #[cfg(feature = "std")] use core::fmt; @@ -83,6 +84,23 @@ impl SocketAddrAny { } } + /// Creates a platform-specific encoding of this socket address, + /// and returns it. + pub fn to_raw(&self) -> RawSocketAddr { + let mut raw = RawSocketAddr { + storage: unsafe { std::mem::zeroed() }, + len: 0, + }; + + raw.len = unsafe { self.write(raw.as_mut_ptr()) }; + raw + } + + /// Reads a platform-specific encoding of a socket address. + pub fn from_raw(raw: RawSocketAddr) -> io::Result { + unsafe { Self::read(raw.as_ptr(), raw.len) } + } + /// Writes a platform-specific encoding of this socket address to /// the memory pointed to by `storage`, and returns the number of /// bytes used. @@ -107,6 +125,35 @@ impl SocketAddrAny { } } +/// A raw sockaddr and its length. +#[repr(C)] +pub struct RawSocketAddr { + pub(crate) storage: SocketAddrStorage, + pub(crate) len: usize, +} + +impl RawSocketAddr { + /// Creates a raw encoded sockaddr from the given address. + pub fn new(addr: impl Into) -> Self { + addr.into().to_raw() + } + + /// Returns a raw pointer to the sockaddr. + pub fn as_ptr(&self) -> *const SocketAddrStorage { + as_ptr(&self.storage) + } + + /// Returns a raw mutable pointer to the sockaddr. + pub fn as_mut_ptr(&mut self) -> *mut SocketAddrStorage { + as_mut_ptr(&mut self.storage) + } + + /// Returns the length of the encoded sockaddr. + pub fn namelen(&self) -> usize { + self.len + } +} + #[cfg(feature = "std")] impl fmt::Debug for SocketAddrAny { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/tests/net/addr.rs b/tests/net/addr.rs index 81b86d74b..40542c101 100644 --- a/tests/net/addr.rs +++ b/tests/net/addr.rs @@ -13,12 +13,22 @@ fn encode_decode() { let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::V4(orig)); + let orig = SocketAddrV4::new(Ipv4Addr::new(2, 3, 5, 6), 33); + let encoded = SocketAddrAny::V4(orig).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::V4(orig)); + let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36); let mut encoded = std::mem::MaybeUninit::::uninit(); let len = SocketAddrAny::V6(orig).write(encoded.as_mut_ptr()); let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::V6(orig)); + let orig = SocketAddrV6::new(Ipv6Addr::new(2, 3, 5, 6, 8, 9, 11, 12), 33, 34, 36); + let encoded = SocketAddrAny::V6(orig).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::V6(orig)); + #[cfg(not(windows))] { let orig = SocketAddrUnix::new("/path/to/socket").unwrap(); @@ -26,6 +36,11 @@ fn encode_decode() { let len = SocketAddrAny::Unix(orig.clone()).write(encoded.as_mut_ptr()); let decoded = SocketAddrAny::read(encoded.as_ptr(), len).unwrap(); assert_eq!(decoded, SocketAddrAny::Unix(orig)); + + let orig = SocketAddrUnix::new("/path/to/socket").unwrap(); + let encoded = SocketAddrAny::Unix(orig.clone()).to_raw(); + let decoded = SocketAddrAny::from_raw(encoded).unwrap(); + assert_eq!(decoded, SocketAddrAny::Unix(orig)); } } }