Skip to content

Commit

Permalink
Merge branch 'set_mtu_on_wireguard'
Browse files Browse the repository at this point in the history
  • Loading branch information
Jontified committed Jun 8, 2022
2 parents b51a6de + a036cbf commit e80115d
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 9 deletions.
70 changes: 70 additions & 0 deletions talpid-core/src/routing/linux.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ pub enum Error {
#[error(display = "No netlink response for route query")]
NoRouteError,

#[error(display = "No link found")]
LinkNotFoundError,

/// Unable to create routing table for tagged connections and packets.
#[error(display = "Cannot find a free routing table ID")]
NoFreeRoutingTableId,
Expand Down Expand Up @@ -363,6 +366,9 @@ impl RouteManagerImpl {
RouteManagerCommand::GetDestinationRoute(destination, set_mark, result_tx) => {
let _ = result_tx.send(self.get_destination_route(&destination, set_mark).await);
}
RouteManagerCommand::GetMtuForRoute(ip, result_tx) => {
let _ = result_tx.send(self.get_mtu_for_route(ip).await);
}
RouteManagerCommand::ClearRoutes => {
log::debug!("Clearing routes");
self.cleanup_routes().await;
Expand Down Expand Up @@ -720,6 +726,70 @@ impl RouteManagerImpl {
}
}

async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16> {
// RECURSION_LIMIT controls how many times we recurse to find the device name by looking up
// an IP with `get_destination_route`.
const RECURSION_LIMIT: usize = 10;
const STANDARD_MTU: u16 = 1500;
let mut attempted_ip = ip;
for _ in 0..RECURSION_LIMIT {
let route = self.get_destination_route(&attempted_ip, false).await?;
match route {
Some(route) => {
let node = route.get_node();
match (node.get_device(), node.get_address()) {
(Some(device), None) => {
let mtu = self.get_device_mtu(device.to_string()).await?;
if mtu != STANDARD_MTU {
log::info!(
"Found MTU: {} on device {} which is different from the standard {}",
mtu,
device,
STANDARD_MTU
);
}
return Ok(mtu);
}
(None, Some(address)) => attempted_ip = address,
_ => {
panic!("Route must contain either an IP or a device.");
}
}
}
None => {
log::error!("No route detected when assigning the mtu to the Wireguard tunnel");
return Err(Error::NoRouteError);
}
}
}
log::error!(
"Retried {} times looking for the correct device and could not find it",
RECURSION_LIMIT
);
Err(Error::NoRouteError)
}

async fn get_device_mtu(&self, device: String) -> Result<u16> {
let mut links = self.handle.link().get().execute();
let target_device = LinkNla::IfName(device);
while let Some(msg) = links
.try_next()
.await
.map_err(|_| Error::LinkNotFoundError)?
{
let found = msg.nlas.iter().any(|e| *e == target_device);
if found {
if let Some(LinkNla::Mtu(mtu)) =
msg.nlas.iter().find(|e| matches!(e, LinkNla::Mtu(_)))
{
return Ok(u16::try_from(*mtu)
.expect("MTU returned by device does not fit into a u16"));
}
}
}
Err(Error::LinkNotFoundError)
}

async fn get_destination_route(
&self,
destination: &IpAddr,
Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub use imp::{Error, RouteManager};

pub use imp::RouteManagerHandle;

/// A netowrk route with a specific network node, destinaiton and an optional metric.
/// A network route with a specific network node, destinaiton and an optional metric.
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
pub struct Route {
node: Node,
Expand Down
15 changes: 15 additions & 0 deletions talpid-core/src/routing/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ impl RouteManagerHandle {
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}

/// Listen for route changes.
#[cfg(target_os = "linux")]
pub async fn get_mtu_for_route(&self, ip: IpAddr) -> Result<u16, Error> {
let (response_tx, response_rx) = oneshot::channel();
self.tx
.unbounded_send(RouteManagerCommand::GetMtuForRoute(ip, response_tx))
.map_err(|_| Error::RouteManagerDown)?;
response_rx
.await
.map_err(|_| Error::ManagerChannelDown)?
.map_err(Error::PlatformError)
}
}

/// Commands for the underlying route manager object.
Expand All @@ -151,6 +164,8 @@ pub(crate) enum RouteManagerCommand {
#[cfg(target_os = "linux")]
NewChangeListener(oneshot::Sender<mpsc::UnboundedReceiver<CallbackMessage>>),
#[cfg(target_os = "linux")]
GetMtuForRoute(IpAddr, oneshot::Sender<Result<u16, PlatformError>>),
#[cfg(target_os = "linux")]
GetDestinationRoute(
IpAddr,
bool,
Expand Down
53 changes: 47 additions & 6 deletions talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pub enum Error {
/// There was an error listening for events from the Wireguard tunnel
#[error(display = "Failed while listening for events from the Wireguard tunnel")]
WireguardTunnelMonitoringError(#[error(source)] wireguard::Error),

/// Could not detect and assign the correct mtu
#[error(display = "Could not detect and assign a correct MTU for the Wireguard tunnel")]
AssignMtuError,
}

/// Possible events from the VPN tunnel and the child process managing it.
Expand Down Expand Up @@ -101,7 +105,7 @@ impl TunnelMonitor {
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
pub fn start<L>(
runtime: tokio::runtime::Handle,
tunnel_parameters: &TunnelParameters,
tunnel_parameters: &mut TunnelParameters,
log_dir: &Option<PathBuf>,
resource_dir: &Path,
on_event: L,
Expand Down Expand Up @@ -134,9 +138,9 @@ impl TunnelMonitor {
#[cfg(target_os = "android")]
TunnelParameters::OpenVpn(_) => Err(Error::UnsupportedPlatform),

TunnelParameters::Wireguard(config) => Self::start_wireguard_tunnel(
TunnelParameters::Wireguard(ref mut config) => Self::start_wireguard_tunnel(
runtime,
&config,
config,
log_file,
resource_dir,
on_event,
Expand Down Expand Up @@ -172,7 +176,7 @@ impl TunnelMonitor {

fn start_wireguard_tunnel<L>(
runtime: tokio::runtime::Handle,
params: &wireguard_types::TunnelParameters,
params: &mut wireguard_types::TunnelParameters,
log: Option<PathBuf>,
resource_dir: &Path,
on_event: L,
Expand All @@ -188,11 +192,13 @@ impl TunnelMonitor {
+ Clone
+ 'static,
{
let config = wireguard::config::Config::from_parameters(&params)?;
#[cfg(target_os = "linux")]
runtime.block_on(Self::assign_mtu(&route_manager, params));
let config = wireguard::config::Config::from_parameters(params)?;
let monitor = wireguard::WireguardMonitor::start(
runtime,
config,
log.as_ref().map(|p| p.as_path()),
log.as_deref(),
resource_dir,
on_event,
tun_provider,
Expand All @@ -205,6 +211,41 @@ impl TunnelMonitor {
})
}

#[cfg(target_os = "linux")]
fn set_mtu(params: &mut wireguard_types::TunnelParameters, mtu: u16) {
const WIREGUARD_HEADER_SIZE: u16 = 80;
// The largest tunnel MTU that we allow. Standard MTU - Wireguard header
const MAX_TUNNEL_MTU: u16 = 1420;
// The minimum allowed MTU size for our tunnel in IPv6 is 1280
const MIN_IPV6_MTU: u16 = 1280;
const MIN_IPV4_MTU: u16 = 576;
let min_mtu = match params.generic_options.enable_ipv6 {
true => MIN_IPV6_MTU,
false => MIN_IPV4_MTU,
};
let mtu = std::cmp::max(
mtu.checked_sub(WIREGUARD_HEADER_SIZE).unwrap_or(min_mtu),
min_mtu,
);
let upstream_mtu = std::cmp::min(MAX_TUNNEL_MTU, mtu);
params.options.mtu = Some(upstream_mtu);
}

#[cfg(target_os = "linux")]
async fn assign_mtu(
route_manager: &RouteManagerHandle,
params: &mut wireguard_types::TunnelParameters,
) {
// It is fine to leave the params untouched if getting the mtu for the route fails. In that
// case we will do our regular default.
if let Ok(mtu) = route_manager
.get_mtu_for_route(params.connection.peer.endpoint.ip())
.await
{
Self::set_mtu(params, mtu);
}
}

#[cfg(not(target_os = "android"))]
async fn start_openvpn_tunnel<L>(
config: &openvpn_types::TunnelParameters,
Expand Down
4 changes: 2 additions & 2 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl ConnectingState {
let (tunnel_close_tx, tunnel_close_rx) = oneshot::channel();
let (tunnel_close_event_tx, tunnel_close_event_rx) = oneshot::channel();

let tunnel_parameters = parameters.clone();
let mut tunnel_parameters = parameters.clone();

tokio::task::spawn_blocking(move || {
let start = Instant::now();
Expand All @@ -141,7 +141,7 @@ impl ConnectingState {

let block_reason = match TunnelMonitor::start(
runtime,
&tunnel_parameters,
&mut tunnel_parameters,
&log_dir,
&resource_dir,
on_tunnel_event,
Expand Down

0 comments on commit e80115d

Please sign in to comment.