Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for original_dst for windows #529

Merged
merged 8 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,48 @@ impl Socket {
)
}
}

/// Get the value for the `SO_ORIGINAL_DST` option on this socket.
#[cfg(all(
feature = "all",
any(
target_os = "android",
target_os = "fuchsia",
target_os = "linux",
target_os = "windows",
)
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "all",
any(
target_os = "android",
target_os = "fuchsia",
target_os = "linux",
target_os = "windows",
)
)))
)]
pub fn original_dst(&self) -> io::Result<SockAddr> {
sys::original_dst(self.as_raw())
}

/// Get the value for the `IP6T_SO_ORIGINAL_DST` option on this socket.
#[cfg(all(
feature = "all",
any(target_os = "android", target_os = "linux", target_os = "windows")
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "all",
any(target_os = "android", target_os = "linux", target_os = "windows")
)))
)]
pub fn original_dst_ipv6(&self) -> io::Result<SockAddr> {
sys::original_dst_ipv6(self.as_raw())
}
keithmattix marked this conversation as resolved.
Show resolved Hide resolved
}

impl Read for Socket {
Expand Down
108 changes: 52 additions & 56 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,58 @@ pub(crate) const fn to_mreqn(
}
}

#[cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
)))
)]
Thomasdezeeuw marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn original_dst(fd: Socket) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(getsockopt(
fd,
libc::SOL_IP,
libc::SO_ORIGINAL_DST,
storage.cast(),
len
))
})
}
.map(|(_, addr)| addr)
}

/// Get the value for the `IP6T_SO_ORIGINAL_DST` option on this socket.
///
/// This value contains the original destination IPv6 address of the connection
/// redirected using `ip6tables` `REDIRECT` or `TPROXY`.
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
#[cfg_attr(
docsrs,
doc(cfg(all(feature = "all", any(target_os = "android", target_os = "linux"))))
)]
Thomasdezeeuw marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn original_dst_ipv6(fd: Socket) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(getsockopt(
fd,
libc::SOL_IPV6,
libc::IP6T_SO_ORIGINAL_DST,
storage.cast(),
len
))
})
}
.map(|(_, addr)| addr)
}

/// Unix only API.
impl crate::Socket {
/// Accept a new incoming connection from this listener.
Expand Down Expand Up @@ -2402,62 +2454,6 @@ impl crate::Socket {
}
}

/// Get the value for the `SO_ORIGINAL_DST` option on this socket.
///
/// This value contains the original destination IPv4 address of the connection
/// redirected using `iptables` `REDIRECT` or `TPROXY`.
#[cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
))]
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
)))
)]
pub fn original_dst(&self) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(getsockopt(
self.as_raw(),
libc::SOL_IP,
libc::SO_ORIGINAL_DST,
storage.cast(),
len
))
})
}
.map(|(_, addr)| addr)
}

/// Get the value for the `IP6T_SO_ORIGINAL_DST` option on this socket.
///
/// This value contains the original destination IPv6 address of the connection
/// redirected using `ip6tables` `REDIRECT` or `TPROXY`.
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
#[cfg_attr(
docsrs,
doc(cfg(all(feature = "all", any(target_os = "android", target_os = "linux"))))
)]
pub fn original_dst_ipv6(&self) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(getsockopt(
self.as_raw(),
libc::SOL_IPV6,
libc::IP6T_SO_ORIGINAL_DST,
storage.cast(),
len
))
})
}
.map(|(_, addr)| addr)
}

/// Copies data between a `file` and this socket using the `sendfile(2)`
/// system call. Because this copying is done within the kernel,
/// `sendfile()` is more efficient than the combination of `read(2)` and
Expand Down
64 changes: 62 additions & 2 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ use std::time::{Duration, Instant};
use std::{process, ptr, slice};

use windows_sys::Win32::Foundation::{SetHandleInformation, HANDLE, HANDLE_FLAG_INHERIT};
#[cfg(feature = "all")]
use windows_sys::Win32::Networking::WinSock::SO_PROTOCOL_INFOW;
use windows_sys::Win32::Networking::WinSock::{
self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0,
POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS,
SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW,
WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED,
};
#[cfg(feature = "all")]
use windows_sys::Win32::Networking::WinSock::{
IP6T_SO_ORIGINAL_DST, SOL_IP, SO_ORIGINAL_DST, SO_PROTOCOL_INFOW,
};
use windows_sys::Win32::System::Threading::INFINITE;

use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
Expand Down Expand Up @@ -857,6 +859,64 @@ pub(crate) fn to_mreqn(
}
}

/// Get the value for the `SO_ORIGINAL_DST` option on this socket.
/// Only valid for sockets in accepting mode.
///
/// Note: if using this function in a proxy context, you must query the
/// redirect records for this socket and set them on the outbound socket
/// created by your proxy in order for any OS level firewall rules to be
/// applied. Read more in the Windows bind and connect redirection
/// [documentation](https://learn.microsoft.com/en-us/windows-hardware/drivers/network/using-bind-or-connect-redirection).
#[cfg(feature = "all")]
#[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
Thomasdezeeuw marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn original_dst(socket: Socket) -> io::Result<SockAddr> {
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(
getsockopt(
socket,
SOL_IP as i32,
SO_ORIGINAL_DST as i32,
storage.cast(),
len,
),
PartialEq::eq,
SOCKET_ERROR
)
})
}
.map(|(_, addr)| addr)
}

/// Get the value for the `IP6T_SO_ORIGINAL_DST` option on this socket.
/// Only valid for sockets in accepting mode.
///
/// Note: if using this function in a proxy context, you must query the
/// redirect records for this socket and set them on the outbound socket
/// created by your proxy in order for any OS level firewall rules to be
/// applied. Read more in the Windows bind and connect redirection
/// [documentation](https://learn.microsoft.com/en-us/windows-hardware/drivers/network/using-bind-or-connect-redirection).
#[cfg(feature = "all")]
#[cfg_attr(docsrs, doc(cfg(all(windows, feature = "all"))))]
Thomasdezeeuw marked this conversation as resolved.
Show resolved Hide resolved
pub(crate) fn original_dst_ipv6(socket: Socket) -> io::Result<SockAddr> {
unsafe {
SockAddr::try_init(|storage, len| {
syscall!(
getsockopt(
socket,
SOL_IP as i32,
IP6T_SO_ORIGINAL_DST as i32,
storage.cast(),
len,
),
PartialEq::eq,
SOCKET_ERROR
)
})
}
.map(|(_, addr)| addr)
}

#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
// SAFETY: a `sockaddr_storage` of all zeros is valid.
Expand Down
34 changes: 28 additions & 6 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use std::num::NonZeroUsize;
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;

Thomasdezeeuw marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(unix)]
use std::path::Path;
use std::str;
Expand Down Expand Up @@ -1600,36 +1601,57 @@ fn header_included_ipv6() {
#[test]
#[cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
any(
target_os = "android",
target_os = "fuchsia",
target_os = "linux",
target_os = "windows"
)
))]
fn original_dst() {
let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
#[cfg(not(target_os = "windows"))]
let expected = Some(libc::ENOENT);
#[cfg(target_os = "windows")]
let expected = Some(windows_sys::Win32::Networking::WinSock::WSAEINVAL);

match socket.original_dst() {
Ok(_) => panic!("original_dst on non-redirected socket should fail"),
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::ENOENT)),
Err(err) => assert_eq!(err.raw_os_error(), expected),
}

let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap();
match socket.original_dst() {
Ok(_) => panic!("original_dst on non-redirected socket should fail"),
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::ENOENT)),
Err(err) => assert_eq!(err.raw_os_error(), expected),
}
}

#[test]
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
#[cfg(all(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
))]
fn original_dst_ipv6() {
let socket = Socket::new(Domain::IPV6, Type::STREAM, None).unwrap();
#[cfg(not(target_os = "windows"))]
let expected = Some(libc::ENOENT);
#[cfg(target_os = "windows")]
let expected = Some(windows_sys::Win32::Networking::WinSock::WSAEINVAL);
#[cfg(not(target_os = "windows"))]
let expected_v4 = Some(libc::EOPNOTSUPP);
#[cfg(target_os = "windows")]
let expected_v4 = Some(windows_sys::Win32::Networking::WinSock::WSAEINVAL);
match socket.original_dst_ipv6() {
Ok(_) => panic!("original_dst_ipv6 on non-redirected socket should fail"),
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::ENOENT)),
Err(err) => assert_eq!(err.raw_os_error(), expected),
}

// Not supported on IPv4 socket.
let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
match socket.original_dst_ipv6() {
Ok(_) => panic!("original_dst_ipv6 on non-redirected socket should fail"),
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::EOPNOTSUPP)),
Err(err) => assert_eq!(err.raw_os_error(), expected_v4),
}
}

Expand Down
Loading