Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Happy Eyeballs connection RFC #718

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/eyeballs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
//! A Happy Eyeballs RFC implementation
//!
//! Races interleaved IPv4 and IPv6 connections to provide the fastest connection
//! in cases where certain addresses or address families might be blocked, broken, or slow.
//! (See <https://datatracker.ietf.org/doc/html/rfc8305>)
//!
//! ureq strives for simplicity, and avoids spawning threads where it can,
//! but - like with SOCKS - there's no way around it here.
//! Some mini internal async executor
//! (discussed in <https://github.com/algesten/ureq/issues/535#issuecomment-1229433311>)
//! wouldn't help - `connect()` is a blocking syscall with no non-blocking alternative.
//! (Big async runtimes like Tokio "solve" this problem by keeping a pool of OS threads
//! around for just these sorts of blocking calls.)
//! We _could_ have some thread pool (a la rayon) to avoid spawning threads
//! on each connection attempt, but spawning a few threads is a cheap operation
//! compared to everything else going on here.
//! (DNS resolution, handshaking across the Internet...)
//!
//! Much of this implementation was inspired by attohttpc's:
//! <https://github.com/sbstp/attohttpc/blob/master/src/happy.rs>

use std::{
io,
iter::FusedIterator,
net::{SocketAddr, TcpStream},
sync::mpsc::{channel, RecvTimeoutError},
thread,
time::Instant,
};

use log::debug;

use crate::timeout::{io_err_timeout, time_until_deadline};

const TIMEOUT_MSG: &str = "timed out connecting";

pub fn connect(
netloc: String,
addrs: &[SocketAddr],
deadline: Option<Instant>,
) -> io::Result<(TcpStream, SocketAddr)> {
assert!(!addrs.is_empty());

// No racing needed if there's a single address.
if let [single] = addrs {
return single_connection(&netloc, *single, deadline);
}

// Interleave IPV4 and IPV6 addresses
let fours = addrs.iter().filter(|a| matches!(a, SocketAddr::V4(_)));
let sixes = addrs.iter().filter(|a| matches!(a, SocketAddr::V6(_)));
let sorted = interleave(fours, sixes);

let (tx, rx) = channel();
let mut first_error = None;

// Race connections!
// The RFC says:
//
// 1. Not to start connections "simultaneously", but since `connect()`
// syscalls don't return until they've connected or timed out,
// we don't have a way to start an attempt without blocking until it finishes.
// (And if we did that, we wouldn't be racing!)
//
// 2. Once we have a successful connection, all other attempts should be cancelled.
// Doing so would require a lot of nasty (and platform-specific) signal handling,
// as it's the only way to interrupt `connect()`.
for s in sorted {
// Instead, make a best effort to not start new connections if we've got one already.
if let Ok(resp) = rx.try_recv() {
match resp {
Ok(c) => return Ok(c),
Err(e) => {
let _ = first_error.get_or_insert(e);
}
}
}

let tx2 = tx.clone();
let nl2 = netloc.clone();
let s2 = *s;
thread::spawn(move || {
// If the receiver was dropped, someone else already won the race.
let _ = tx2.send(single_connection(&nl2, s2, deadline));
});
}
drop(tx);

const UNREACHABLE_MSG: &str =
"Unreachable: All Happy Eyeballs connections failed, but no error";

if let Some(d) = deadline {
// Wait for a successful connection, or for us to run out of time
loop {
let timeout = time_until_deadline(d, TIMEOUT_MSG)?;
match rx.recv_timeout(timeout) {
Ok(Ok(c)) => return Ok(c),
Ok(Err(e)) => {
let _ = first_error.get_or_insert(e);
}
Err(RecvTimeoutError::Timeout) => {
return Err(io_err_timeout(TIMEOUT_MSG.to_string()))
}
// If all the connecting threads hung up and none succeeded,
// return the first error.
Err(RecvTimeoutError::Disconnected) => {
return Err(first_error.expect(UNREACHABLE_MSG))
}
};
}
} else {
// If there's no deadline, just wait around.
let connections = rx.iter();
for c in connections {
match c {
Ok(c) => return Ok(c),
Err(e) => {
let _ = first_error.get_or_insert(e);
}
}
}
// If we got here, everyone failed. Return the first error.
Err(first_error.expect(UNREACHABLE_MSG))
}
}

fn single_connection(
netloc: &str,
addr: SocketAddr,
deadline: Option<Instant>,
) -> io::Result<(TcpStream, SocketAddr)> {
debug!("connecting to {} at {}", netloc, addr);
if let Some(d) = deadline {
let timeout = time_until_deadline(d, TIMEOUT_MSG)?;
Ok((TcpStream::connect_timeout(&addr, timeout)?, addr))
} else {
Ok((TcpStream::connect(addr)?, addr))
}
}

fn interleave<T, A, B>(mut left: A, mut right: B) -> impl Iterator<Item = T>
where
A: FusedIterator<Item = T>,
B: FusedIterator<Item = T>,
{
let mut last_right = None;

std::iter::from_fn(move || {
if let Some(r) = last_right.take() {
return Some(r);
}

match (left.next(), right.next()) {
(Some(l), Some(r)) => {
last_right = Some(r);
Some(l)
}
(Some(l), None) => Some(l),
(None, Some(r)) => Some(r),
(None, None) => None,
}
})
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ mod agent;
mod body;
mod chunked;
mod error;
mod eyeballs;
mod header;
mod middleware;
mod pool;
Expand All @@ -366,6 +367,7 @@ mod request;
mod resolve;
mod response;
mod stream;
mod timeout;
mod unit;

// rustls is our default tls engine. If the feature is on, it will be
Expand Down
2 changes: 1 addition & 1 deletion src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ impl Response {
/// ```
#[cfg(feature = "json")]
pub fn into_json<T: DeserializeOwned>(self) -> io::Result<T> {
use crate::stream::io_err_timeout;
use crate::timeout::io_err_timeout;

let reader = self.into_reader();
serde_json::from_reader(reader).map_err(|e| {
Expand Down
94 changes: 27 additions & 67 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use log::debug;
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::ops::Div;
use std::time::Duration;
use std::time::Instant;
use std::{fmt, io::Cursor};
Expand All @@ -12,8 +11,10 @@ use socks::{TargetAddr, ToTargetAddr};

use crate::chunked::Decoder as ChunkDecoder;
use crate::error::ErrorKind;
use crate::eyeballs;
use crate::pool::{PoolKey, PoolReturner};
use crate::proxy::Proxy;
use crate::timeout::{io_err_timeout, time_until_deadline};
use crate::unit::Unit;
use crate::Response;
use crate::{error::Error, proxy::Proto};
Expand Down Expand Up @@ -83,7 +84,7 @@ impl From<DeadlineStream> for Stream {
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if let Some(deadline) = self.deadline {
let timeout = time_until_deadline(deadline)?;
let timeout = time_until_deadline(deadline, "timed out reading response")?;
if let Some(socket) = self.stream.socket() {
socket.set_read_timeout(Some(timeout))?;
socket.set_write_timeout(Some(timeout))?;
Expand Down Expand Up @@ -130,20 +131,6 @@ impl Read for DeadlineStream {
}
}

// If the deadline is in the future, return the remaining time until
// then. Otherwise return a TimedOut error.
fn time_until_deadline(deadline: Instant) -> io::Result<Duration> {
let now = Instant::now();
match deadline.checked_duration_since(now) {
None => Err(io_err_timeout("timed out reading response".to_string())),
Some(duration) => Ok(duration),
}
}

pub(crate) fn io_err_timeout(error: String) -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, error)
}

#[derive(Debug)]
pub(crate) struct ReadOnlyStream(Cursor<Vec<u8>>);

Expand Down Expand Up @@ -348,6 +335,7 @@ pub(crate) fn connect_host(
hostname: &str,
port: u16,
) -> Result<(TcpStream, SocketAddr), Error> {
const TIMEOUT_MSG: &str = "timed out connecting";
let connect_deadline: Option<Instant> =
if let Some(timeout_connect) = unit.agent.config.timeout_connect {
Instant::now().checked_add(timeout_connect)
Expand All @@ -373,71 +361,43 @@ pub(crate) fn connect_host(

let proto = proxy.as_ref().map(|proxy| proxy.proto);

let mut any_err = None;
let mut any_stream_and_addr = None;
// Find the first sock_addr that accepts a connection
let multiple_addrs = sock_addrs.len() > 1;

for sock_addr in sock_addrs {
// ensure connect timeout or overall timeout aren't yet hit.
let timeout = match connect_deadline {
Some(deadline) => {
let mut deadline = time_until_deadline(deadline)?;
if multiple_addrs {
deadline = deadline.div(2);
}
Some(deadline)
}
None => None,
};

debug!("connecting to {} at {}", netloc, &sock_addr);

// connect with a configured timeout.
#[allow(clippy::unnecessary_unwrap)]
let stream = if proto.is_some() && Some(Proto::HTTP) != proto {
connect_socks(
let (mut stream, remote_addr) = if proto.is_some() && Some(Proto::HTTP) != proto {
// SOCKS proxy connections.
// Don't mix that with happy eyeballs
// (where we race multiple connections and take the fastest)
// since we'd be repeatedly connecting to the same proxy server.
let mut stream_and_addr_result = None;
// Find the first sock_addr that accepts a connection
for sock_addr in sock_addrs {
// ensure connect timeout or overall timeout aren't yet hit.
debug!("connecting to {} at {}", netloc, &sock_addr);

// connect with a configured timeout.
#[allow(clippy::unnecessary_unwrap)]
let stream = connect_socks(
unit,
proxy.clone().unwrap(),
connect_deadline,
sock_addr,
hostname,
port,
proto.unwrap(),
)
} else if let Some(timeout) = timeout {
TcpStream::connect_timeout(&sock_addr, timeout)
} else {
TcpStream::connect(sock_addr)
};

if let Ok(stream) = stream {
any_stream_and_addr = Some((stream, sock_addr));
break;
} else if let Err(err) = stream {
any_err = Some(err);
);
stream_and_addr_result = Some(stream.map(|s| (s, sock_addr)));
}
}

let (mut stream, remote_addr) = if let Some(stream_and_addr) = any_stream_and_addr {
stream_and_addr
} else if let Some(e) = any_err {
return Err(ErrorKind::ConnectionFailed.msg("Connect error").src(e));
stream_and_addr_result.expect("unreachable: connected to IPs, but no result")
} else {
panic!("shouldn't happen: failed to connect to all IPs, but no error");
};
eyeballs::connect(netloc, &sock_addrs, connect_deadline)
}
.map_err(|e| ErrorKind::ConnectionFailed.msg("Connect error").src(e))?;

stream.set_nodelay(unit.agent.config.no_delay)?;

if let Some(deadline) = unit.deadline {
stream.set_read_timeout(Some(time_until_deadline(deadline)?))?;
stream.set_read_timeout(Some(time_until_deadline(deadline, TIMEOUT_MSG)?))?;
stream.set_write_timeout(Some(time_until_deadline(deadline, TIMEOUT_MSG)?))?;
} else {
stream.set_read_timeout(unit.agent.config.timeout_read)?;
}

if let Some(deadline) = unit.deadline {
stream.set_write_timeout(Some(time_until_deadline(deadline)?))?;
} else {
stream.set_write_timeout(unit.agent.config.timeout_write)?;
}

Expand Down Expand Up @@ -562,7 +522,7 @@ fn connect_socks(
let (lock, cvar) = &*master_signal;
let done = lock.lock().unwrap();

let timeout_connect = time_until_deadline(deadline)?;
let timeout_connect = time_until_deadline(deadline, "SOCKS proxy timed out connecting")?;
let done_result = cvar.wait_timeout(done, timeout_connect).unwrap();
let done = done_result.0;
if *done {
Expand Down
18 changes: 18 additions & 0 deletions src/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//! Timeout utilities, mostly used during connecting.

use std::io;
use std::time::{Duration, Instant};

/// If the deadline is in the future, return the remaining time until
/// then. Otherwise return a TimedOut error.
pub fn time_until_deadline<S: Into<String>>(deadline: Instant, error: S) -> io::Result<Duration> {
let now = Instant::now();
match deadline.checked_duration_since(now) {
None => Err(io_err_timeout(error.into())),
Some(duration) => Ok(duration),
}
}

pub fn io_err_timeout(error: String) -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, error)
}
Loading