Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TCP keepalive for MySQL and PostgresSQL. #3559

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ hashlink = "0.9.0"
indexmap = "2.0"
event-listener = "5.2.0"
hashbrown = "0.14.5"
socket2 = "0.5.7"

[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
Expand Down
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, TcpKeepalive, WithSocket,
WriteBuffer,
};
106 changes: 103 additions & 3 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use bytes::BufMut;
use futures_core::ready;
Expand Down Expand Up @@ -182,10 +183,87 @@ impl<S: Socket + ?Sized> Socket for Box<S> {
}
}

#[derive(Debug, Clone, Copy)]
pub struct TcpKeepalive {
pub time: Option<Duration>,
pub interval: Option<Duration>,
pub retries: Option<u32>,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to encapsulate instead of using socket2::TcpKeepalive in the public api (that is something that @abonander should decide), we should use a newtype wrapper instead of just copying the other code.

Suggested change
pub struct TcpKeepalive {
pub time: Option<Duration>,
pub interval: Option<Duration>,
pub retries: Option<u32>,
}
pub struct TcpKeepalive(socket2::TcpKeepalive)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's reasonable, thx

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea. But socket2::TcpKeepalive is not Copy (Sorry I made a mistake before).
So doing so we won't have Copy trait.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I don't know how I came up with the idea that it does. Must have misread a part of the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have mixed feelings about using a bespoke struct here.

These builder methods could just be directly on the ConnectOptions structs.

I suppose being able to define TcpKeepalive separately in a const could be nice for reusability, but at the same time, having to import and build a separate type would also be a little annoying if you only use it once.

I think maybe because interval and retries aren't supported on all platforms, we should only expose the tcp_keepalive_time() on the ConnectOptions builders. We can always add the others later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields should also be #[doc(hidden)] or just private.


impl TcpKeepalive {
/// Returns a new, empty set of TCP keepalive parameters.
pub const fn new() -> TcpKeepalive {
xuehaonan27 marked this conversation as resolved.
Show resolved Hide resolved
TcpKeepalive {
time: None,
interval: None,
retries: None,
}
}

/// Set the amount of time after which TCP keepalive probes will be sent on
/// idle connections.
///
/// This will set `TCP_KEEPALIVE` on macOS and iOS, and
/// `TCP_KEEPIDLE` on all other Unix operating systems, except
/// OpenBSD and Haiku which don't support any way to set this
/// option. On Windows, this sets the value of the `tcp_keepalive`
/// struct's `keepalivetime` field.
///
/// Some platforms specify this value in seconds, so sub-second
/// specifications may be omitted.
pub const fn with_time(self, time: Duration) -> Self {
Self {
time: Some(time),
..self
}
}

/// Set the value of the `TCP_KEEPINTVL` option. On Windows, this sets the
/// value of the `tcp_keepalive` struct's `keepaliveinterval` field.
///
/// Sets the time interval between TCP keepalive probes.
///
/// Some platforms specify this value in seconds, so sub-second
/// specifications may be omitted.
pub const fn with_interval(self, interval: Duration) -> Self {
Self {
interval: Some(interval),
..self
}
}

/// Set the value of the `TCP_KEEPCNT` option.
///
/// Set the maximum number of TCP keepalive probes that will be sent before
/// dropping a connection, if TCP keepalive is enabled on this socket.
pub const fn with_retries(self, retries: u32) -> Self {
xuehaonan27 marked this conversation as resolved.
Show resolved Hide resolved
Self {
retries: Some(retries),
..self
}
}

/// Convert `TcpKeepalive` to `socket2::TcpKeepalive`.
pub const fn socket2(self) -> socket2::TcpKeepalive {
xuehaonan27 marked this conversation as resolved.
Show resolved Hide resolved
let mut ka = socket2::TcpKeepalive::new();
if let Some(time) = self.time {
ka = ka.with_time(time);
}
if let Some(interval) = self.interval {
ka = ka.with_interval(interval);
xuehaonan27 marked this conversation as resolved.
Show resolved Hide resolved
}
if let Some(retries) = self.retries {
ka = ka.with_retries(retries);
xuehaonan27 marked this conversation as resolved.
Show resolved Hide resolved
}
ka
}
}

pub async fn connect_tcp<Ws: WithSocket>(
host: &str,
port: u16,
with_socket: Ws,
keepalive: Option<&TcpKeepalive>,
) -> crate::Result<Ws::Output> {
// IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
let host = host.trim_matches(&['[', ']'][..]);
Expand All @@ -197,6 +275,13 @@ pub async fn connect_tcp<Ws: WithSocket>(
let stream = TcpStream::connect((host, port)).await?;
stream.set_nodelay(true)?;

// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = keepalive.socket2();
let sock_ref = socket2::SockRef::from(&stream);
sock_ref.set_tcp_keepalive(&keepalive)?;
}

return Ok(with_socket.with_socket(stream));
}

Expand All @@ -216,9 +301,24 @@ pub async fn connect_tcp<Ws: WithSocket>(
s.get_ref().set_nodelay(true)?;
Ok(s)
});
match stream {
Ok(stream) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
let stream = match stream {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};
// set tcp keepalive
if let Some(keepalive) = keepalive {
let keepalive = socket2::TcpKeepalive::new()
.with_interval(keepalive.interval)
.with_retries(keepalive.retries)
.with_time(keepalive.time);
let sock_ref = socket2::SockRef::from(&stream);
match sock_ref.set_tcp_keepalive(&keepalive) {
Ok(_) => return Ok(with_socket.with_socket(stream)),
Err(e) => last_err = Some(e),
}
}
}

Expand Down
10 changes: 9 additions & 1 deletion sqlx-mysql/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ impl MySqlConnection {

let handshake = match &options.socket {
Some(path) => crate::net::connect_uds(path, do_handshake).await?,
None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
None => {
crate::net::connect_tcp(
&options.host,
options.port,
do_handshake,
options.tcp_keep_alive.as_ref(),
)
.await?
}
};

let stream = handshake.await?;
Expand Down
10 changes: 9 additions & 1 deletion sqlx-mysql/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod connect;
mod parse;
mod ssl_mode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive};
pub use ssl_mode::MySqlSslMode;

/// Options and flags which can be used to configure a MySQL connection.
Expand Down Expand Up @@ -80,6 +80,7 @@ pub struct MySqlConnectOptions {
pub(crate) no_engine_substitution: bool,
pub(crate) timezone: Option<String>,
pub(crate) set_names: bool,
pub(crate) tcp_keep_alive: Option<TcpKeepalive>,
}

impl Default for MySqlConnectOptions {
Expand Down Expand Up @@ -111,6 +112,7 @@ impl MySqlConnectOptions {
no_engine_substitution: true,
timezone: Some(String::from("+00:00")),
set_names: true,
tcp_keep_alive: None,
}
}

Expand Down Expand Up @@ -403,6 +405,12 @@ impl MySqlConnectOptions {
self.set_names = flag_val;
self
}

/// Sets the TCP keepalive configuration for the connection.
pub fn tcp_keep_alive(mut self, tcp_keep_alive: TcpKeepalive) -> Self {
self.tcp_keep_alive = Some(tcp_keep_alive);
self
}
}

impl MySqlConnectOptions {
Expand Down
10 changes: 9 additions & 1 deletion sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_future = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
None => {
net::connect_tcp(
&options.host,
options.port,
MaybeUpgradeTls(options),
options.tcp_keep_alive.as_ref(),
)
.await?
}
};

let socket = socket_future.await?;
Expand Down
10 changes: 9 additions & 1 deletion sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};

pub use ssl_mode::PgSslMode;

use crate::{connection::LogSettings, net::tls::CertificateInput};
use crate::{connection::LogSettings, net::tls::CertificateInput, net::TcpKeepalive};

mod connect;
mod parse;
Expand Down Expand Up @@ -102,6 +102,7 @@ pub struct PgConnectOptions {
pub(crate) application_name: Option<String>,
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) tcp_keep_alive: Option<TcpKeepalive>,
pub(crate) options: Option<String>,
}

Expand Down Expand Up @@ -168,6 +169,7 @@ impl PgConnectOptions {
application_name: var("PGAPPNAME").ok(),
extra_float_digits: Some("2".into()),
log_settings: Default::default(),
tcp_keep_alive: None,
options: var("PGOPTIONS").ok(),
}
}
Expand Down Expand Up @@ -493,6 +495,12 @@ impl PgConnectOptions {
self
}

/// Sets the TCP keepalive configuration for the connection.
pub fn tcp_keep_alive(mut self, tcp_keep_alive: TcpKeepalive) -> Self {
self.tcp_keep_alive = Some(tcp_keep_alive);
self
}

/// Set additional startup options for the connection as a list of key-value pairs.
///
/// # Example
Expand Down