Skip to content

Commit

Permalink
Add mechanism to speedtest the TunnelCommunity
Browse files Browse the repository at this point in the history
  • Loading branch information
egbertbouman committed Mar 25, 2024
1 parent 8f7c2f2 commit c274977
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 1 deletion.
11 changes: 10 additions & 1 deletion ipv8_rust_tunnels/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections import UserDict
from typing import TYPE_CHECKING
from typing import Callable, TYPE_CHECKING

if TYPE_CHECKING:
from ipv8.messaging.anonymization.community import TunnelCommunity, TunnelSettings
Expand Down Expand Up @@ -209,3 +209,12 @@ def reset_byte_counters(self) -> None:
"""
self.bytes_up = 0
self.bytes_down = 0

def run_speedtest(self, target_addr: str, associate_port: int, num_packets: int, request_size: int,
response_size: int, timeout_ms: int, window_size: int, callback: Callable) -> None:
"""
Perform a TunnelCommunity speedtest. Connects to an existing UDP associate
port and sends test messages to a given target address.
"""
self.rust_ep.run_speedtest(target_addr, associate_port, num_packets, request_size,
response_size, timeout_ms, window_size, callback)
26 changes: 26 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod payload;
mod routing;
mod socket;
mod socks5;
mod speedtest;
mod util;
#[macro_use]
extern crate log;
Expand Down Expand Up @@ -186,6 +187,31 @@ impl Endpoint {
Ok(())
}

fn run_speedtest(
&mut self,
server_addr: String,
associate_port: u16,
num_packets: usize,
request_size: u16,
response_size: u16,
timeout_ms: usize,
window_size: usize,
callback: PyObject,
) -> PyResult<()> {
let settings = self.settings.clone().unwrap().clone();
settings.load().handle.spawn(speedtest::run_speedtest(
server_addr,
associate_port,
num_packets,
request_size,
response_size,
timeout_ms,
window_size,
callback,
));
Ok(())
}

fn set_prefix(&mut self, prefix: &PyBytes) -> PyResult<()> {
if let Some(settings) = &self.settings {
let mut new_settings = TunnelSettings::clone(&settings.load_full());
Expand Down
165 changes: 165 additions & 0 deletions src/speedtest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
use std::{
collections::HashMap,
io::Cursor,
sync::{Arc, Mutex},
time::Duration,
};

use pyo3::{types::IntoPyDict, PyObject, Python};
use rand::{Rng, RngCore};
use socks5_proto::{Address, UdpHeader};
use tokio::{
net::UdpSocket,
sync::{broadcast, oneshot, OwnedSemaphorePermit, Semaphore},
};

use crate::util;

pub async fn run_speedtest(
server_addr: String,
associate_port: u16,
num_packets: usize,
request_size: u16,
response_size: u16,
timeout_ms: usize,
window_size: usize,
callback: PyObject,
) {
let (socket_tx, socket_rx) = oneshot::channel();

let msg_tx = tokio::sync::broadcast::Sender::new(window_size);
let recv_task = tokio::spawn(receive_and_broadcast(associate_port, socket_tx, msg_tx.clone()));
let socket = match socket_rx.await {
Ok(socket) => socket,
Err(e) => {
error!("Error while receiving speedtest socket: {}. Aborting test.", e);
return;
}
};

let semaphore = Arc::new(Semaphore::new(window_size));
let results = Arc::new(Mutex::new(HashMap::new()));

debug!("Sending packets with window={}", window_size);

for _ in 0..num_packets {
let permit = semaphore.clone().acquire_owned().await.unwrap();
tokio::spawn(send_and_wait(
Address::SocketAddress(server_addr.parse().unwrap()),
socket.clone(),
request_size,
response_size,
timeout_ms,
msg_tx.subscribe(),
results.clone(),
permit,
));
}

debug!("All {} packets sent!", num_packets);
tokio::time::sleep(Duration::from_millis(timeout_ms.try_into().unwrap())).await;
recv_task.abort();
let _ =
Python::with_gil(|py| callback.call1(py, (results.lock().unwrap().clone().into_py_dict(py),)));
}

pub async fn receive_and_broadcast(
associate_port: u16,
socket_tx: oneshot::Sender<Arc<UdpSocket>>,
tx: broadcast::Sender<(u32, usize)>,
) {
let socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap());
let _ = socket.connect(format!("127.0.0.1:{}", associate_port)).await;
let _ = socket_tx.send(socket.clone());

let mut buf = [0; 2048];
loop {
match socket.recv(&mut buf).await {
Ok(n) => {
let mut packet = &buf[..n];
// Strip SOCKS5 header
let Ok(header) = UdpHeader::read_from(&mut Cursor::new(packet)).await else {
error!("Failed to decode SOCKS5 header address");
continue;
};
packet = &packet[header.serialized_len()..];

// Payload format: 'd' + 4-byte transaction ID + the exit IP + payload + 'e'
if packet.len() < 13 {
error!("Dropping packet (response too small");
continue;
}

// Broadcast transaction ID
let tid = u32::from_be_bytes(packet[1..5].try_into().unwrap());
let _ = tx.send((tid, n));
debug!("Broadcasting response for request {}", tid);
}
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {}
Err(e) => error!("Error while reading socket: {}", e),
}
}
}

pub async fn send_and_wait(
target: Address,
socket: Arc<UdpSocket>,
request_size: u16,
response_size: u16,
timeout_ms: usize,
mut rx: broadcast::Receiver<(u32, usize)>,
results: Arc<Mutex<HashMap<u32, [usize; 4]>>>,
_: OwnedSemaphorePermit,
) {
let mut random_data = [0; 2048];
rand::thread_rng().fill_bytes(&mut random_data);
let payload = &random_data[..request_size as usize];
let tid: u32 = rand::thread_rng().gen();
let header = UdpHeader::new(0, target.clone());
let mut socks5_pkt = Vec::with_capacity(header.serialized_len());
if let Err(e) = header.write_to(&mut socks5_pkt).await {
error!("Error while writing SOCKS5 header: {}", e);
return;
};
socks5_pkt.extend_from_slice(&[b'd']);
socks5_pkt.extend_from_slice(&tid.to_be_bytes());
socks5_pkt.extend_from_slice(&response_size.to_be_bytes());
socks5_pkt.extend_from_slice(&payload);
socks5_pkt.extend_from_slice(&[b'e']);

match socket.send(&socks5_pkt).await {
Ok(size) => {
debug!("Sent request {} ({} bytes)", tid, size);
results
.lock()
.unwrap()
.insert(tid, [util::get_time_ms() as usize, size, 0, 0]);
let wait_for_tid = async {
loop {
match rx.recv().await {
Ok((tid_received, n)) => {
debug!("Received {}, looking for {}", tid_received, tid);
if tid_received == tid {
debug!("Received response for request {}", tid);
if let Some(result) = results.lock().unwrap().get_mut(&tid) {
result[2] = util::get_time_ms() as usize;
result[3] = n;
break;
};
}
}
Err(_) => {}
}
}
};
if let Err(_) =
tokio::time::timeout(Duration::from_millis(timeout_ms.try_into().unwrap()), wait_for_tid)
.await
{
warn!("Request {} timedout", tid);
}
}
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {}
Err(e) => error!("Error while writing to socket: {}", e),
};
}
10 changes: 10 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ pub fn get_time() -> u64 {
}
}

pub fn get_time_ms() -> u128 {
match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(time) => time.as_millis(),
Err(error) => {
error!("Failed to get system time: {}", error);
0
}
}
}

pub fn create_socket(addr: SocketAddr) -> Result<Arc<UdpSocket>> {
let socket_std = match std::net::UdpSocket::bind(addr) {
Ok(socket) => {
Expand Down

0 comments on commit c274977

Please sign in to comment.