From 33e3dae0abf877d81c7228a49245e420c06fd361 Mon Sep 17 00:00:00 2001 From: Evan Pratten Date: Wed, 24 Apr 2024 11:26:53 -0400 Subject: [PATCH] Implement Multi-Queue support for `easy-tun` --- libs/easy-tun/examples/print_traffic.rs | 7 +- libs/easy-tun/examples/print_traffic_mq.rs | 29 ++++ libs/easy-tun/src/tun.rs | 105 ++++++-------- src/args/protomask.rs | 5 + src/args/protomask_clat.rs | 5 + src/protomask-clat.rs | 115 ++++++++------- src/protomask.rs | 159 ++++++++++++--------- 7 files changed, 244 insertions(+), 181 deletions(-) create mode 100644 libs/easy-tun/examples/print_traffic_mq.rs diff --git a/libs/easy-tun/examples/print_traffic.rs b/libs/easy-tun/examples/print_traffic.rs index ac44ebb..885a05c 100644 --- a/libs/easy-tun/examples/print_traffic.rs +++ b/libs/easy-tun/examples/print_traffic.rs @@ -1,17 +1,18 @@ -use easy_tun::Tun; use std::io::Read; +use easy_tun::Tun; + fn main() { // Enable logs env_logger::init(); // Bring up a TUN interface - let mut tun = Tun::new("tun%d").unwrap(); + let mut tun = Tun::new("tun%d", 1).unwrap(); // Loop and read from the interface let mut buffer = [0u8; 1500]; loop { - let length = tun.read(&mut buffer).unwrap(); + let length = tun.fd(0).unwrap().read(&mut buffer).unwrap(); println!("{:?}", &buffer[..length]); } } diff --git a/libs/easy-tun/examples/print_traffic_mq.rs b/libs/easy-tun/examples/print_traffic_mq.rs new file mode 100644 index 0000000..69e8954 --- /dev/null +++ b/libs/easy-tun/examples/print_traffic_mq.rs @@ -0,0 +1,29 @@ +use std::{io::Read, sync::Arc}; + +use easy_tun::Tun; + +fn main() { + // Enable logs + env_logger::init(); + + // Bring up a TUN interface + let tun = Arc::new(Tun::new("tun%d", 5).unwrap()); + + // Spawn 5 threads to read from the interface + let mut threads = Vec::new(); + for i in 0..5 { + let tun = Arc::clone(&tun); + threads.push(std::thread::spawn(move || { + let mut buffer = [0u8; 1500]; + loop { + let length = tun.fd(i).unwrap().read(&mut buffer).unwrap(); + println!("Queue #{}: {:?}", i, &buffer[..length]); + } + })); + } + + // Wait for all threads to finish + for thread in threads { + thread.join().unwrap(); + } +} diff --git a/libs/easy-tun/src/tun.rs b/libs/easy-tun/src/tun.rs index 74efd79..bd1825f 100644 --- a/libs/easy-tun/src/tun.rs +++ b/libs/easy-tun/src/tun.rs @@ -6,7 +6,9 @@ use std::{ }; use ioctl_gen::{ioc, iow}; -use libc::{__c_anonymous_ifr_ifru, ifreq, ioctl, IFF_NO_PI, IFF_TUN, IF_NAMESIZE}; +use libc::{ + __c_anonymous_ifr_ifru, ifreq, ioctl, IFF_MULTI_QUEUE, IFF_NO_PI, IFF_TUN, IF_NAMESIZE, +}; /// Architecture / target environment specific definitions mod arch { @@ -22,8 +24,8 @@ mod arch { /// A TUN device pub struct Tun { - /// Internal file descriptor for the TUN device - fd: File, + /// All internal file descriptors + fds: Vec>, /// Device name name: String, } @@ -35,15 +37,23 @@ impl Tun { /// and may contain a `%d` format specifier to allow for multiple devices with the same name. #[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_lossless)] - pub fn new(dev: &str) -> Result { - log::debug!("Creating new TUN device with requested name:{}", dev); - - // Get a file descriptor for `/dev/net/tun` + pub fn new(dev: &str, queues: usize) -> Result { + log::debug!( + "Creating new TUN device with requested name: {} ({} queues)", + dev, + queues + ); + + // Create all needed file descriptors for `/dev/net/tun` log::trace!("Opening /dev/net/tun"); - let fd = OpenOptions::new() - .read(true) - .write(true) - .open("/dev/net/tun")?; + let mut fds = Vec::with_capacity(queues); + for _ in 0..queues { + let fd = OpenOptions::new() + .read(true) + .write(true) + .open("/dev/net/tun")?; + fds.push(Box::new(fd)); + } // Copy the device name into a C string with padding // NOTE: No zero padding is needed because we pre-init the array to all 0s @@ -57,25 +67,28 @@ impl Tun { let mut ifr = ifreq { ifr_name: dev_cstr, ifr_ifru: __c_anonymous_ifr_ifru { - ifru_flags: (IFF_TUN | IFF_NO_PI) as i16, + ifru_flags: (IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE) as i16, }, }; - // Make an ioctl call to create the TUN device - log::trace!("Calling ioctl to create TUN device"); - let err = unsafe { - ioctl( - fd.as_raw_fd(), - iow!('T', 202, size_of::()) as arch::IoctlRequestType, - &mut ifr, - ) - }; - log::trace!("ioctl returned: {}", err); - - // Check for errors - if err < 0 { - log::error!("ioctl failed: {}", err); - return Err(std::io::Error::last_os_error()); + // Each FD needs to be configured separately + for fd in fds.iter_mut() { + // Make an ioctl call to create the TUN device + log::trace!("Calling ioctl to create TUN device"); + let err = unsafe { + ioctl( + fd.as_raw_fd(), + iow!('T', 202, size_of::()) as arch::IoctlRequestType, + &mut ifr, + ) + }; + log::trace!("ioctl returned: {}", err); + + // Check for errors + if err < 0 { + log::error!("ioctl failed: {}", err); + return Err(std::io::Error::last_os_error()); + } } // Get the name of the device @@ -88,7 +101,7 @@ impl Tun { log::debug!("Created TUN device: {}", name); // Build the TUN struct - Ok(Self { fd, name }) + Ok(Self { fds, name }) } /// Get the name of the TUN device @@ -99,38 +112,14 @@ impl Tun { /// Get the underlying file descriptor #[must_use] - pub fn fd(&self) -> &File { - &self.fd + pub fn fd(&self, queue_id: usize) -> Option<&File> { + self.fds.get(queue_id).map(|fd| &**fd) } -} - -impl AsRawFd for Tun { - fn as_raw_fd(&self) -> RawFd { - self.fd.as_raw_fd() - } -} -impl IntoRawFd for Tun { - fn into_raw_fd(self) -> RawFd { - self.fd.into_raw_fd() - } -} - -impl Read for Tun { - #[profiling::function] - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - self.fd.read(buf) - } -} - -impl Write for Tun { - #[profiling::function] - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.fd.write(buf) + /// Get mutable access to the underlying file descriptor + #[must_use] + pub fn fd_mut(&mut self, queue_id: usize) -> Option<&mut File> { + self.fds.get_mut(queue_id).map(|fd| &mut **fd) } - #[profiling::function] - fn flush(&mut self) -> std::io::Result<()> { - self.fd.flush() - } } diff --git a/src/args/protomask.rs b/src/args/protomask.rs index 8353392..1383491 100644 --- a/src/args/protomask.rs +++ b/src/args/protomask.rs @@ -94,6 +94,11 @@ pub struct Config { /// NAT reservation timeout in seconds #[clap(long, default_value = "7200")] pub reservation_timeout: u64, + + /// Number of queues to create on the TUN device + #[clap(long, default_value = "1")] + #[serde(rename = "queues")] + pub num_queues: usize, } #[derive(Debug, serde::Deserialize, Clone)] diff --git a/src/args/protomask_clat.rs b/src/args/protomask_clat.rs index e3aaba7..ba3357e 100644 --- a/src/args/protomask_clat.rs +++ b/src/args/protomask_clat.rs @@ -82,4 +82,9 @@ pub struct Config { serialize_with = "crate::common::rfc6052::serialize_network_specific_prefix" )] pub embed_prefix: Ipv6Net, + + /// Number of queues to create on the TUN device + #[clap(long, default_value = "1")] + #[serde(rename = "queues")] + pub num_queues: usize, } diff --git a/src/protomask-clat.rs b/src/protomask-clat.rs index 42f785c..c2ca56a 100644 --- a/src/protomask-clat.rs +++ b/src/protomask-clat.rs @@ -16,6 +16,7 @@ use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked}; use std::io::{Read, Write}; +use std::sync::Arc; mod args; mod common; @@ -39,7 +40,7 @@ pub async fn main() { let _server = start_puffin_server(&args.profiler_args); // Bring up a TUN interface - let mut tun = Tun::new(&args.interface).unwrap(); + let tun = Arc::new(Tun::new(&args.interface, config.num_queues).unwrap()); // Get the interface index let rt_handle = rtnl::new_handle().unwrap(); @@ -87,54 +88,70 @@ pub async fn main() { // Translate all incoming packets log::info!("Translating packets on {}", tun.name()); - let mut buffer = vec![0u8; 1500]; - loop { - // Indicate to the profiler that we are starting a new packet - profiling::finish_frame!(); - profiling::scope!("packet"); - - // Read a packet - let len = tun.read(&mut buffer).unwrap(); - - // Translate it based on the Layer 3 protocol number - let translation_result: Result>, PacketHandlingError> = - match get_layer_3_proto(&buffer[..len]) { - Some(4) => { - let (source, dest) = get_ipv4_src_dst(&buffer[..len]); - translate_ipv4_to_ipv6( - &buffer[..len], - unsafe { embed_ipv4_addr_unchecked(source, config.embed_prefix) }, - unsafe { embed_ipv4_addr_unchecked(dest, config.embed_prefix) }, - ) - .map(Some) - .map_err(PacketHandlingError::from) + let mut worker_threads = Vec::new(); + for queue_id in 0..config.num_queues { + let tun = Arc::clone(&tun); + worker_threads.push(std::thread::spawn(move || { + log::debug!("Starting worker thread for queue {}", queue_id); + let mut buffer = vec![0u8; 1500]; + loop { + // Indicate to the profiler that we are starting a new packet + profiling::finish_frame!(); + profiling::scope!("packet"); + + // Read a packet + let len = tun.fd(queue_id).unwrap().read(&mut buffer).unwrap(); + + // Translate it based on the Layer 3 protocol number + let translation_result: Result>, PacketHandlingError> = + match get_layer_3_proto(&buffer[..len]) { + Some(4) => { + let (source, dest) = get_ipv4_src_dst(&buffer[..len]); + translate_ipv4_to_ipv6( + &buffer[..len], + unsafe { embed_ipv4_addr_unchecked(source, config.embed_prefix) }, + unsafe { embed_ipv4_addr_unchecked(dest, config.embed_prefix) }, + ) + .map(Some) + .map_err(PacketHandlingError::from) + } + Some(6) => { + let (source, dest) = get_ipv6_src_dst(&buffer[..len]); + translate_ipv6_to_ipv4( + &buffer[..len], + unsafe { + extract_ipv4_addr_unchecked( + source, + config.embed_prefix.prefix_len(), + ) + }, + unsafe { + extract_ipv4_addr_unchecked( + dest, + config.embed_prefix.prefix_len(), + ) + }, + ) + .map(Some) + .map_err(PacketHandlingError::from) + } + Some(proto) => { + log::warn!("Unknown Layer 3 protocol: {}", proto); + continue; + } + None => { + continue; + } + }; + + // Handle any errors and write + if let Some(output) = handle_translation_error(translation_result) { + tun.fd(queue_id).unwrap().write_all(&output).unwrap(); } - Some(6) => { - let (source, dest) = get_ipv6_src_dst(&buffer[..len]); - translate_ipv6_to_ipv4( - &buffer[..len], - unsafe { - extract_ipv4_addr_unchecked(source, config.embed_prefix.prefix_len()) - }, - unsafe { - extract_ipv4_addr_unchecked(dest, config.embed_prefix.prefix_len()) - }, - ) - .map(Some) - .map_err(PacketHandlingError::from) - } - Some(proto) => { - log::warn!("Unknown Layer 3 protocol: {}", proto); - continue; - } - None => { - continue; - } - }; - - // Handle any errors and write - if let Some(output) = handle_translation_error(translation_result) { - tun.write_all(&output).unwrap(); - } + } + })); + } + for worker in worker_threads { + worker.join().unwrap(); } } diff --git a/src/protomask.rs b/src/protomask.rs index dd1493c..4ffe108 100644 --- a/src/protomask.rs +++ b/src/protomask.rs @@ -14,8 +14,8 @@ use interproto::protocols::ip::{translate_ipv4_to_ipv6, translate_ipv6_to_ipv4}; use ipnet::IpNet; use rfc6052::{embed_ipv4_addr_unchecked, extract_ipv4_addr_unchecked}; use std::{ - cell::RefCell, io::{Read, Write}, + sync::{Arc, Mutex}, time::Duration, }; @@ -42,7 +42,7 @@ pub async fn main() { // Bring up a TUN interface log::debug!("Creating new TUN interface"); - let mut tun = Tun::new(&args.interface).unwrap(); + let tun = Arc::new(Tun::new(&args.interface, config.num_queues).unwrap()); log::debug!("Created TUN interface: {}", tun.name()); // Get the interface index @@ -78,13 +78,16 @@ pub async fn main() { } // Set up the address table - let mut addr_table = RefCell::new(CrossProtocolNetworkAddressTableWithIpv4Pool::new( - &config.pool_prefixes, - Duration::from_secs(config.reservation_timeout), + let addr_table = Arc::new(Mutex::new( + CrossProtocolNetworkAddressTableWithIpv4Pool::new( + &config.pool_prefixes, + Duration::from_secs(config.reservation_timeout), + ), )); for (v4_addr, v6_addr) in &config.static_map { addr_table - .get_mut() + .lock() + .unwrap() .insert_static(*v4_addr, *v6_addr) .unwrap(); } @@ -97,74 +100,88 @@ pub async fn main() { // Translate all incoming packets log::info!("Translating packets on {}", tun.name()); - let mut buffer = vec![0u8; 1500]; - loop { - // Indicate to the profiler that we are starting a new packet - profiling::finish_frame!(); - profiling::scope!("packet"); - - // Read a packet - let len = tun.read(&mut buffer).unwrap(); - - // Translate it based on the Layer 3 protocol number - let translation_result: Result>, PacketHandlingError> = - match get_layer_3_proto(&buffer[..len]) { - Some(4) => { - let (source, dest) = get_ipv4_src_dst(&buffer[..len]); - match addr_table.borrow().get_ipv6(&dest) { - Some(new_destination) => translate_ipv4_to_ipv6( - &buffer[..len], - unsafe { embed_ipv4_addr_unchecked(source, config.translation_prefix) }, - new_destination, - ) - .map(Some) - .map_err(PacketHandlingError::from), - None => { - protomask_metrics::metric!( - PACKET_COUNTER, - PROTOCOL_IPV4, - STATUS_DROPPED - ); - Ok(None) - } - } - } - Some(6) => { - let (source, dest) = get_ipv6_src_dst(&buffer[..len]); - match addr_table.borrow_mut().get_or_create_ipv4(&source) { - Ok(new_source) => { - translate_ipv6_to_ipv4(&buffer[..len], new_source, unsafe { - extract_ipv4_addr_unchecked( - dest, - config.translation_prefix.prefix_len(), + let mut worker_threads = Vec::new(); + for queue_id in 0..config.num_queues { + let tun = Arc::clone(&tun); + let addr_table = Arc::clone(&addr_table); + worker_threads.push(std::thread::spawn(move || { + log::debug!("Starting worker thread for queue {}", queue_id); + + let mut buffer = vec![0u8; 1500]; + loop { + // Indicate to the profiler that we are starting a new packet + profiling::finish_frame!(); + profiling::scope!("packet"); + + // Read a packet + let len = tun.fd(queue_id).unwrap().read(&mut buffer).unwrap(); + + // Translate it based on the Layer 3 protocol number + let translation_result: Result>, PacketHandlingError> = + match get_layer_3_proto(&buffer[..len]) { + Some(4) => { + let (source, dest) = get_ipv4_src_dst(&buffer[..len]); + match addr_table.lock().unwrap().get_ipv6(&dest) { + Some(new_destination) => translate_ipv4_to_ipv6( + &buffer[..len], + unsafe { + embed_ipv4_addr_unchecked(source, config.translation_prefix) + }, + new_destination, ) - }) - .map(Some) - .map_err(PacketHandlingError::from) + .map(Some) + .map_err(PacketHandlingError::from), + None => { + protomask_metrics::metric!( + PACKET_COUNTER, + PROTOCOL_IPV4, + STATUS_DROPPED + ); + Ok(None) + } + } } - Err(error) => { - log::error!("Error getting IPv4 address: {}", error); - protomask_metrics::metric!( - PACKET_COUNTER, - PROTOCOL_IPV6, - STATUS_DROPPED - ); - Ok(None) + Some(6) => { + let (source, dest) = get_ipv6_src_dst(&buffer[..len]); + match addr_table.lock().unwrap().get_or_create_ipv4(&source) { + Ok(new_source) => { + translate_ipv6_to_ipv4(&buffer[..len], new_source, unsafe { + extract_ipv4_addr_unchecked( + dest, + config.translation_prefix.prefix_len(), + ) + }) + .map(Some) + .map_err(PacketHandlingError::from) + } + Err(error) => { + log::error!("Error getting IPv4 address: {}", error); + protomask_metrics::metric!( + PACKET_COUNTER, + PROTOCOL_IPV6, + STATUS_DROPPED + ); + Ok(None) + } + } } - } - } - Some(proto) => { - log::warn!("Unknown Layer 3 protocol: {}", proto); - continue; - } - None => { - continue; - } - }; + Some(proto) => { + log::warn!("Unknown Layer 3 protocol: {}", proto); + continue; + } + None => { + continue; + } + }; - // Handle any errors and write - if let Some(output) = handle_translation_error(translation_result) { - tun.write_all(&output).unwrap(); - } + // Handle any errors and write + if let Some(output) = handle_translation_error(translation_result) { + tun.fd(queue_id).unwrap().write_all(&output).unwrap(); + } + } + })); + } + for worker in worker_threads { + worker.join().unwrap(); } }