Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TCP keepalive for MySQL and PostgresSQL. #3559

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ hashlink = "0.9.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.14.5"
socket2 = { version = "0.5.7", features = ["all"] }

[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, TcpKeepalive, WithSocket,
WriteBuffer,
};
28 changes: 25 additions & 3 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ use bytes::BufMut;
use futures_core::ready;

pub use buffered::{BufferedSocket, WriteBuffer};
pub use tcp_keepalive::TcpKeepalive;

use crate::io::ReadBuf;

mod buffered;
mod tcp_keepalive;

pub trait Socket: Send + Sync + Unpin + 'static {
fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result<usize>;
Expand Down Expand Up @@ -186,6 +188,7 @@ pub async fn connect_tcp<Ws: WithSocket>(
host: &str,
port: u16,
with_socket: Ws,
keepalive: Option<&TcpKeepalive>,
) -> crate::Result<Ws::Output> {
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);
Expand All @@ -197,6 +200,13 @@ pub async fn connect_tcp<Ws: WithSocket>(
let stream = TcpStream::connect((host, port)).await?;
stream.set_nodelay(true)?;

// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = keepalive.socket2();
let sock_ref = socket2::SockRef::from(&stream);
sock_ref.set_tcp_keepalive(&keepalive)?;
}

return Ok(with_socket.with_socket(stream));
}

Expand All @@ -216,9 +226,21 @@ pub async fn connect_tcp<Ws: WithSocket>(
s.get_ref().set_nodelay(true)?;
Ok(s)
});
match stream {
Ok(stream) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
let stream = match stream {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};
// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = keepalive.socket2();
let sock_ref = socket2::SockRef::from(&stream);
match sock_ref.set_tcp_keepalive(&keepalive) {
Ok(_) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
}
}
}

Expand Down
244 changes: 244 additions & 0 deletions sqlx-core/src/net/socket/tcp_keepalive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use std::time::Duration;

/// Configures a socket's TCP keepalive parameters.
#[derive(Debug, Clone, Copy)]
pub struct TcpKeepalive {
#[cfg_attr(
any(target_os = "openbsd", target_os = "haiku", target_os = "vita"),
allow(dead_code)
)]
time: Option<Duration>,
#[cfg(not(any(
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
target_os = "nto",
target_os = "espidf",
target_os = "vita",
target_os = "haiku",
)))]
interval: Option<Duration>,
#[cfg(not(any(
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
target_os = "windows",
target_os = "nto",
target_os = "espidf",
target_os = "vita",
target_os = "haiku",
)))]
retries: Option<u32>,
}

impl TcpKeepalive {
/// Returns a new, empty set of TCP keepalive parameters.
/// The unset parameters will use OS-defined defaults.
pub const fn new() -> TcpKeepalive {
TcpKeepalive {
time: None,
#[cfg(not(any(
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
target_os = "nto",
target_os = "espidf",
target_os = "vita",
target_os = "haiku",
)))]
interval: None,
#[cfg(not(any(
target_os = "openbsd",
target_os = "redox",
target_os = "solaris",
target_os = "windows",
target_os = "nto",
target_os = "espidf",
target_os = "vita",
target_os = "haiku",
)))]
retries: None,
}
}

/// Set the amount of time after which TCP keepalive probes will be sent on
/// idle connections.
///
/// This will set `TCP_KEEPALIVE` on macOS and iOS, and
/// `TCP_KEEPIDLE` on all other Unix operating systems, except
/// OpenBSD and Haiku which don't support any way to set this
/// option. On Windows, this sets the value of the `tcp_keepalive`
/// struct's `keepalivetime` field.
///
/// Some platforms specify this value in seconds, so sub-second
/// specifications may be omitted.
pub const fn with_time(self, time: Duration) -> Self {
Self {
time: Some(time),
..self
}
}

/// Set the value of the `TCP_KEEPINTVL` option. On Windows, this sets the
/// value of the `tcp_keepalive` struct's `keepaliveinterval` field.
///
/// Sets the time interval between TCP keepalive probes.
///
/// Some platforms specify this value in seconds, so sub-second
/// specifications may be omitted.
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
)))
)]
pub const fn with_interval(self, interval: Duration) -> Self {
Self {
interval: Some(interval),
..self
}
}

/// Set the value of the `TCP_KEEPCNT` option.
///
/// Set the maximum number of TCP keepalive probes that will be sent before
/// dropping a connection, if TCP keepalive is enabled on this socket.
///
/// This setter has no effect on Windows.
#[cfg(all(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
)))]
#[cfg_attr(
docsrs,
doc(cfg(all(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
))))
)]
pub const fn with_retries(self, retries: u32) -> Self {
Self {
retries: Some(retries),
..self
}
}

/// Convert `TcpKeepalive` to `socket2::TcpKeepalive`.
#[doc(hidden)]
pub(super) const fn socket2(self) -> socket2::TcpKeepalive {
let mut ka = socket2::TcpKeepalive::new();
if let Some(time) = self.time {
ka = ka.with_time(time);
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "windows",
)))
)]
if let Some(interval) = self.interval {
ka = ka.with_interval(interval);
}
#[cfg(all(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
)))]
#[cfg_attr(
docsrs,
doc(cfg(all(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
))))
)]
if let Some(retries) = self.retries {
ka = ka.with_retries(retries);
}
ka
}
}
10 changes: 9 additions & 1 deletion sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ impl MySqlConnection {

let handshake = match &options.socket {
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
None => {
crate::net::connect_tcp(
&options.host,
options.port,
do_handshake,
options.tcp_keep_alive.as_ref(),
)
.await?
}
};

let stream = handshake.await?;
Expand Down
Loading