Skip to content

Commit

Permalink
Add async feature implemented with tokio and rustls, make it default
Browse files Browse the repository at this point in the history
  • Loading branch information
moparisthebest committed Jul 23, 2020
1 parent 225d71d commit f3a78b2
Show file tree
Hide file tree
Showing 10 changed files with 992 additions and 229 deletions.
437 changes: 429 additions & 8 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,18 @@ include = [
]

[features]
default = []
default = ["async"]
tls = ["openssl"]
openssl_vendored = ["openssl/vendored"]
verbose = []
async = ["tokio", "tokio-rustls", "ring", "base64"]

[dependencies]
# only for non-async build with TLS support
openssl = { version = "0.10.26", optional = true }
# the rest of these are only required for async build
tokio = { version = "0.2", features = [ "macros", "net", "udp", "io-std", "io-util", "rt-threaded" ], optional = true }
tokio-rustls = { version = "0.14", features = ["dangerous_configuration"], optional = true }
# probably should try to keep ring the exact same version as rustls, same features too
ring = { version = "0.16.11", optional = true }
base64 = { version = "0.12.3", optional = true }
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ Binaries:

Building:

- `cargo build --release` - minimal build without TLS support, no dependencies
- `cargo build --release --feature tls` - links to system openssl
- `cargo build --release --feature openssl_vendored` - compiles vendored openssl and link to it
- `cargo build --release` - async build with TLS support supplied by rustls
- `cargo build --release --no-default-features ` - minimal build without TLS support, no dependencies
- `cargo build --release --no-default-features --feature tls` - links to system openssl
- `cargo build --release --no-default-features --feature openssl_vendored` - compiles vendored openssl and link to it

Testing:

Expand Down
304 changes: 304 additions & 0 deletions src/asyncmod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@

use tokio::net::UdpSocket;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::runtime::Runtime;
use tokio_rustls::webpki::DNSNameRef;

use crate::error;
use crate::error::Result;
use crate::*;

pub struct TcpUdpPipe<T: AsyncReadExt + AsyncWriteExt + std::marker::Unpin + std::marker::Send + 'static> {
buf: [u8; 2050], // 2048 + 2 for len
tcp_stream: T,
udp_socket: UdpSocket,
}

impl<T: AsyncReadExt + AsyncWriteExt + std::marker::Unpin + std::marker::Send + 'static> TcpUdpPipe<T> {

pub fn new(tcp_stream: T, udp_socket: UdpSocket) -> TcpUdpPipe<T> {
TcpUdpPipe {
tcp_stream,
udp_socket,
buf: [0u8; 2050],
}
}

pub async fn shuffle_after_first_udp(mut self) -> Result<usize> {
let (len, src_addr) = self.udp_socket.recv_from(&mut self.buf[2..]).await?;

println!("first packet from {}, connecting to that", src_addr);
self.udp_socket.connect(src_addr).await?;

send_udp(&mut self.buf, &mut self.tcp_stream, len).await?;

self.shuffle().await
}

pub async fn shuffle(self) -> Result<usize> {
// todo: investigate https://docs.rs/tokio/0.2.22/tokio/net/struct.TcpStream.html#method.into_split
let (mut tcp_rd, mut tcp_wr) = tokio::io::split(self.tcp_stream);
let (mut udp_rd, mut udp_wr) = self.udp_socket.split();
let mut recv_buf = self.buf.clone(); // or zeroed or?

tokio::spawn(async move {
loop {
let len = udp_rd.recv(&mut recv_buf[2..]).await?;
send_udp(&mut recv_buf, &mut tcp_wr, len).await?;
}

// Sometimes, the rust type inferencer needs
// a little help
#[allow(unreachable_code)]
{
unsafe { std::hint::unreachable_unchecked(); }
Ok::<_, error::Error>(())
}
});

let mut send_buf = self.buf.clone(); // or zeroed or?

loop {
tcp_rd.read_exact(&mut send_buf[..2]).await?;
let len = ((send_buf[0] as usize) << 8) + send_buf[1] as usize;
#[cfg(feature = "verbose")]
println!("tcp expecting len: {}", len);
tcp_rd.read_exact(&mut send_buf[..len]).await?;
#[cfg(feature = "verbose")]
println!("tcp got len: {}", len);
udp_wr.send(&send_buf[..len]).await?;
}

#[allow(unreachable_code)]
{
unsafe { std::hint::unreachable_unchecked(); }
Ok(0)
}
}
}

async fn send_udp<T: AsyncWriteExt + std::marker::Unpin + 'static>(buf: &mut [u8; 2050], tcp_stream: &mut T, len: usize) -> Result<()> {
#[cfg(feature = "verbose")]
println!("udp got len: {}", len);

buf[0] = ((len >> 8) & 0xFF) as u8;
buf[1] = (len & 0xFF) as u8;

// todo: tcp_stream.write_all(&buf[..len + 2]).await
Ok(tcp_stream.write_all(&buf[..len + 2]).await?)
// todo: do this? self.tcp_stream.flush()
}

impl ProxyClient {

pub async fn start_async(&self) -> Result<usize> {
let tcp_stream = self.tcp_connect()?;

let udp_socket = self.udp_connect()?;

TcpUdpPipe::new(tokio::net::TcpStream::from_std(tcp_stream).expect("how could this tokio tcp fail?"), UdpSocket::from_std(udp_socket).expect("how could this tokio udp fail?"))
.shuffle_after_first_udp().await
}

pub fn start(&self) -> Result<usize> {
let mut rt = Runtime::new()?;

rt.block_on(async {
self.start_async().await
})
}

pub async fn start_tls_async(&self, hostname: Option<&str>, pinnedpubkey: Option<&str>) -> Result<usize> {
let tcp_stream = self.tcp_connect()?;
let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).expect("how could this tokio tcp fail?");

use tokio_rustls::{ TlsConnector, rustls::ClientConfig };

let mut config = ClientConfig::new();
config.dangerous().set_certificate_verifier(match pinnedpubkey {
Some(pinnedpubkey) => Arc::new(PinnedpubkeyCertVerifier { pinnedpubkey: pinnedpubkey.to_owned() }),
None => Arc::new(DummyCertVerifier{}),
});

let hostname = match hostname {
Some(hostname) => match DNSNameRef::try_from_ascii_str(hostname) {
Ok(hostname) => hostname,
Err(_) => {
config.enable_sni = false;
DNSNameRef::try_from_ascii_str(&"dummy.hostname").unwrap() // why does rustls ABSOLUTELY REQUIRE this ????
}
},
None => {
config.enable_sni = false;
DNSNameRef::try_from_ascii_str(&"dummy.hostname").unwrap() // why does rustls ABSOLUTELY REQUIRE this ????
}
};
//println!("hostname: {:?}", hostname);

let connector = TlsConnector::from(Arc::new(config));

let tcp_stream= connector.connect(hostname, tcp_stream).await?;

let udp_socket = self.udp_connect()?;

// we want to wait for first udp packet from client first, to set the target to respond to
TcpUdpPipe::new(tcp_stream, UdpSocket::from_std(udp_socket).expect("how could this tokio udp fail?"))
.shuffle_after_first_udp().await
}

pub fn start_tls(&self, hostname: Option<&str>, pinnedpubkey: Option<&str>) -> Result<usize> {
let mut rt = Runtime::new()?;

rt.block_on(async {
self.start_tls_async(hostname, pinnedpubkey).await
})
}
}

use tokio_rustls::rustls;
use tokio_rustls::webpki;

struct DummyCertVerifier;

impl rustls::ServerCertVerifier for DummyCertVerifier {
fn verify_server_cert(&self,
_roots: &rustls::RootCertStore,
_certs: &[rustls::Certificate],
_hostname: webpki::DNSNameRef<'_>,
_ocsp: &[u8]) -> core::result::Result<rustls::ServerCertVerified, rustls::TLSError> {
// verify nothing, subject to MITM
Ok(rustls::ServerCertVerified::assertion())
}
}

struct PinnedpubkeyCertVerifier {
pinnedpubkey: String,
}

impl rustls::ServerCertVerifier for PinnedpubkeyCertVerifier {
fn verify_server_cert(&self,
_roots: &rustls::RootCertStore,
certs: &[rustls::Certificate],
_hostname: webpki::DNSNameRef<'_>,
_ocsp: &[u8]) -> core::result::Result<rustls::ServerCertVerified, rustls::TLSError> {
if certs.is_empty() {
return Err(rustls::TLSError::NoCertificatesPresented);
}
let cert = webpki::trust_anchor_util::cert_der_as_trust_anchor(&certs[0].0)
.map_err(rustls::TLSError::WebPKIError)?;

//println!("spki.len(): {}", cert.spki.len());
//println!("spki: {:?}", cert.spki);
// todo: what is wrong with webpki? it returns *almost* the right answer but missing these leading bytes:
// guess I'll open an issue... (I assume this is some type of algorithm identifying header or something)
let mut pubkey: Vec<u8> = vec![48, 130, 1, 34];
pubkey.extend(cert.spki);

let pubkey = ring::digest::digest(&ring::digest::SHA256, &pubkey);
let pubkey = base64::encode(pubkey);
let pubkey = ["sha256//", &pubkey].join("");

for key in self.pinnedpubkey.split(";") {
if key == pubkey {
return Ok(rustls::ServerCertVerified::assertion());
}
}

Err(rustls::TLSError::General(format!("pubkey '{}' not found in allowed list '{}'", pubkey, self.pinnedpubkey)))
}
}

impl ProxyServer {

pub async fn start_async(&self) -> Result<()> {
let mut listener = tokio::net::TcpListener::bind(&self.tcp_host).await?;
println!("Listening for connections on {}", &self.tcp_host);

loop {
let (stream, _) = listener.accept().await?;
let client_handler = self.client_handler.clone();
tokio::spawn(async move {
client_handler
.handle_client_async(stream).await
.expect("error handling connection");
});
}

#[allow(unreachable_code)]
{
unsafe { std::hint::unreachable_unchecked(); }
Ok(())
}
}

pub fn start(&self) -> Result<()> {
let mut rt = Runtime::new()?;

rt.block_on(async {
self.start_async().await
})
}

pub async fn start_tls_async(&self, tls_key: &str, tls_cert: &str) -> Result<()> {

use std::fs::File;
use std::io::BufReader;
use std::io;
use tokio_rustls::rustls::internal::pemfile::{ certs, pkcs8_private_keys };

let mut tls_key = pkcs8_private_keys(&mut BufReader::new(File::open(tls_key)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
if tls_key.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
}
let tls_key = tls_key.remove(0);

let tls_cert = certs(&mut BufReader::new(File::open(tls_cert)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?;

let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(tls_cert, tls_key)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));

let mut listener = tokio::net::TcpListener::bind(&self.tcp_host).await?;
println!("Listening for TLS connections on {}", &self.tcp_host);

loop {
let (stream, _) = listener.accept().await?;
let client_handler = self.client_handler.clone();
let acceptor = acceptor.clone();

tokio::spawn(async move {
let stream = acceptor.accept(stream).await.expect("failed to wrap with TLS?");

client_handler
.handle_client_async(stream).await
.expect("error handling connection");
});
}

#[allow(unreachable_code)]
{
unsafe { std::hint::unreachable_unchecked(); }
Ok(())
}
}

pub fn start_tls(&self, tls_key: &str, tls_cert: &str) -> Result<()> {
let mut rt = Runtime::new()?;

rt.block_on(async {
self.start_tls_async(tls_key, tls_cert).await
})
}
}

impl ProxyServerClientHandler {

pub async fn handle_client_async<T: AsyncReadExt + AsyncWriteExt + std::marker::Unpin + std::marker::Send + 'static>(&self, tcp_stream: T) -> Result<usize> {
TcpUdpPipe::new(tcp_stream,
UdpSocket::from_std(self.udp_bind()?).expect("how could this tokio udp fail?")
).shuffle().await
}
}
8 changes: 5 additions & 3 deletions src/bin/wireguard-proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ fn main() {

if args.flag("-V") || args.flag("--version") {
print!("wireguard-proxy {} ", env!("CARGO_PKG_VERSION"));
#[cfg(not(any(feature = "tls", feature = "openssl_vendored")))]
#[cfg(not(any(feature = "tls", feature = "openssl_vendored", feature = "async")))]
println!("TLS support: None");
#[cfg(feature = "openssl_vendored")]
#[cfg(feature = "async")]
println!("TLS support: tokio-rustls");
#[cfg(all(feature = "openssl_vendored", not(feature = "async")))]
println!("TLS support: Static/Vendored OpenSSL");
#[cfg(feature = "tls")]
#[cfg(all(feature = "tls", not(feature = "openssl_vendored"), not(feature = "async")))]
println!("TLS support: System OpenSSL");
return;
}
Expand Down
5 changes: 2 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use core::result;

use std::error::Error as StdError;

#[cfg(not(any(feature = "async")))]
pub type IoResult<T> = result::Result<T, std::io::Error>;

pub type Result<T> = result::Result<T, Error>;
Expand Down Expand Up @@ -32,7 +31,7 @@ impl std::error::Error for Error {

impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::new(value.description())
Error::new(&format!("{}", value))
}
}

Expand Down
Loading

0 comments on commit f3a78b2

Please sign in to comment.