Skip to content

Commit

Permalink
feat: support server-side keep-alive for mysql and pg protocols (#5496)
Browse files Browse the repository at this point in the history
* feat: support server-side keep-alive for mysql and pg protocols

Signed-off-by: Ruihang Xia <[email protected]>

* update config.md

Signed-off-by: Ruihang Xia <[email protected]>

* update config to use humantime for keep-alive configuration

Signed-off-by: Ruihang Xia <[email protected]>

* chore: Update socket2 dependency

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Feb 11, 2025
1 parent beb9c0a commit e22aa81
Show file tree
Hide file tree
Showing 15 changed files with 70 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions config/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
| `mysql.enable` | Bool | `true` | Whether to enable. |
| `mysql.addr` | String | `127.0.0.1:4002` | The addr to bind the MySQL server. |
| `mysql.runtime_size` | Integer | `2` | The number of server worker threads. |
| `mysql.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `mysql.tls` | -- | -- | -- |
| `mysql.tls.mode` | String | `disable` | TLS mode, refer to https://www.postgresql.org/docs/current/libpq-ssl.html<br/>- `disable` (default value)<br/>- `prefer`<br/>- `require`<br/>- `verify-ca`<br/>- `verify-full` |
| `mysql.tls.cert_path` | String | Unset | Certificate file path. |
Expand All @@ -49,6 +50,7 @@
| `postgres.enable` | Bool | `true` | Whether to enable |
| `postgres.addr` | String | `127.0.0.1:4003` | The addr to bind the PostgresSQL server. |
| `postgres.runtime_size` | Integer | `2` | The number of server worker threads. |
| `postgres.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `postgres.tls` | -- | -- | PostgresSQL server TLS options, see `mysql.tls` section. |
| `postgres.tls.mode` | String | `disable` | TLS mode. |
| `postgres.tls.cert_path` | String | Unset | Certificate file path. |
Expand Down Expand Up @@ -234,6 +236,7 @@
| `mysql.enable` | Bool | `true` | Whether to enable. |
| `mysql.addr` | String | `127.0.0.1:4002` | The addr to bind the MySQL server. |
| `mysql.runtime_size` | Integer | `2` | The number of server worker threads. |
| `mysql.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `mysql.tls` | -- | -- | -- |
| `mysql.tls.mode` | String | `disable` | TLS mode, refer to https://www.postgresql.org/docs/current/libpq-ssl.html<br/>- `disable` (default value)<br/>- `prefer`<br/>- `require`<br/>- `verify-ca`<br/>- `verify-full` |
| `mysql.tls.cert_path` | String | Unset | Certificate file path. |
Expand All @@ -243,6 +246,7 @@
| `postgres.enable` | Bool | `true` | Whether to enable |
| `postgres.addr` | String | `127.0.0.1:4003` | The addr to bind the PostgresSQL server. |
| `postgres.runtime_size` | Integer | `2` | The number of server worker threads. |
| `postgres.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `postgres.tls` | -- | -- | PostgresSQL server TLS options, see `mysql.tls` section. |
| `postgres.tls.mode` | String | `disable` | TLS mode. |
| `postgres.tls.cert_path` | String | Unset | Certificate file path. |
Expand Down
6 changes: 6 additions & 0 deletions config/frontend.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ enable = true
addr = "127.0.0.1:4002"
## The number of server worker threads.
runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"

# MySQL server TLS options.
[mysql.tls]
Expand Down Expand Up @@ -105,6 +108,9 @@ enable = true
addr = "127.0.0.1:4003"
## The number of server worker threads.
runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"

## PostgresSQL server TLS options, see `mysql.tls` section.
[postgres.tls]
Expand Down
6 changes: 6 additions & 0 deletions config/standalone.example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ enable = true
addr = "127.0.0.1:4002"
## The number of server worker threads.
runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"

# MySQL server TLS options.
[mysql.tls]
Expand Down Expand Up @@ -109,6 +112,9 @@ enable = true
addr = "127.0.0.1:4003"
## The number of server worker threads.
runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"

## PostgresSQL server TLS options, see `mysql.tls` section.
[postgres.tls]
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ where
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
tls_server_config,
opts.keep_alive.as_secs(),
opts.reject_no_database.unwrap_or(false),
)),
);
Expand All @@ -248,6 +249,7 @@ where
ServerSqlQueryHandlerAdapter::arc(instance.clone()),
opts.tls.should_force_tls(),
tls_server_config,
opts.keep_alive.as_secs(),
common_runtime::global_runtime(),
user_provider.clone(),
)) as Box<dyn Server>;
Expand Down
7 changes: 7 additions & 0 deletions src/frontend/src/service_config/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ pub struct MysqlOptions {
#[serde(default = "Default::default")]
pub tls: TlsOption,
pub reject_no_database: Option<bool>,
/// Server-side keep-alive time.
///
/// Set to 0 (default) to disable.
#[serde(default = "Default::default")]
#[serde(with = "humantime_serde")]
pub keep_alive: std::time::Duration,
}

impl Default for MysqlOptions {
Expand All @@ -33,6 +39,7 @@ impl Default for MysqlOptions {
runtime_size: 2,
tls: TlsOption::default(),
reject_no_database: None,
keep_alive: std::time::Duration::from_secs(0),
}
}
}
7 changes: 7 additions & 0 deletions src/frontend/src/service_config/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ pub struct PostgresOptions {
pub runtime_size: usize,
#[serde(default = "Default::default")]
pub tls: TlsOption,
/// Server-side keep-alive time.
///
/// Set to 0 (default) to disable.
#[serde(default = "Default::default")]
#[serde(with = "humantime_serde")]
pub keep_alive: std::time::Duration,
}

impl Default for PostgresOptions {
Expand All @@ -31,6 +37,7 @@ impl Default for PostgresOptions {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
tls: Default::default(),
keep_alive: std::time::Duration::from_secs(0),
}
}
}
1 change: 1 addition & 0 deletions src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ notify.workspace = true
object-pool = "0.5"
once_cell.workspace = true
openmetrics-parser = "0.4"
socket2 = "0.5"
# use crates.io version after current revision is merged in next release
# opensrv-mysql = "0.7.0"
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "6bbc3b65e6b19212c4f7fc4f40c20daf6f452deb" }
Expand Down
9 changes: 8 additions & 1 deletion src/servers/src/mysql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ pub struct MysqlSpawnConfig {
// tls config
force_tls: bool,
tls: Arc<ReloadableTlsServerConfig>,
// keep-alive config
keep_alive_secs: u64,
// other shim config
reject_no_database: bool,
}
Expand All @@ -80,11 +82,13 @@ impl MysqlSpawnConfig {
pub fn new(
force_tls: bool,
tls: Arc<ReloadableTlsServerConfig>,
keep_alive_secs: u64,
reject_no_database: bool,
) -> MysqlSpawnConfig {
MysqlSpawnConfig {
force_tls,
tls,
keep_alive_secs,
reject_no_database,
}
}
Expand Down Expand Up @@ -218,7 +222,10 @@ impl Server for MysqlServer {
}

async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
let (stream, addr) = self
.base_server
.bind(listening, self.spawn_config.keep_alive_secs)
.await?;
let io_runtime = self.base_server.io_runtime();

let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
Expand Down
8 changes: 7 additions & 1 deletion src/servers/src/postgres/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct PostgresServer {
base_server: BaseTcpServer,
make_handler: Arc<MakePostgresServerHandler>,
tls_server_config: Arc<ReloadableTlsServerConfig>,
keep_alive_secs: u64,
}

impl PostgresServer {
Expand All @@ -43,6 +44,7 @@ impl PostgresServer {
query_handler: ServerSqlQueryHandlerRef,
force_tls: bool,
tls_server_config: Arc<ReloadableTlsServerConfig>,
keep_alive_secs: u64,
io_runtime: Runtime,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
Expand All @@ -58,6 +60,7 @@ impl PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
make_handler,
tls_server_config,
keep_alive_secs,
}
}

Expand Down Expand Up @@ -116,7 +119,10 @@ impl Server for PostgresServer {
}

async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
let (stream, addr) = self
.base_server
.bind(listening, self.keep_alive_secs)
.await?;

let io_runtime = self.base_server.io_runtime();
let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
Expand Down
16 changes: 15 additions & 1 deletion src/servers/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ impl AcceptTask {
&mut self,
addr: SocketAddr,
name: &str,
keep_alive_secs: u64,
) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
match self.abort_registration.take() {
Some(registration) => {
Expand All @@ -157,6 +158,15 @@ impl AcceptTask {
let addr = listener.local_addr()?;
info!("{name} server started at {addr}");

// set keep-alive
if keep_alive_secs > 0 {
let socket_ref = socket2::SockRef::from(&listener);
let keep_alive = socket2::TcpKeepalive::new()
.with_time(std::time::Duration::from_secs(keep_alive_secs))
.with_interval(std::time::Duration::from_secs(keep_alive_secs));
socket_ref.set_tcp_keepalive(&keep_alive)?;
}

let stream = TcpListenerStream::new(listener);
let stream = Abortable::new(stream, registration);
Ok((stream, addr))
Expand Down Expand Up @@ -205,12 +215,16 @@ impl BaseTcpServer {
task.shutdown(&self.name).await
}

/// Bind the server to the given address and set the keep-alive time.
///
/// If `keep_alive_secs` is 0, the keep-alive will not be set.
pub(crate) async fn bind(
&self,
addr: SocketAddr,
keep_alive_secs: u64,
) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
let mut task = self.accept_task.lock().await;
task.bind(addr, &self.name).await
task.bind(addr, &self.name, keep_alive_secs).await
}

pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions src/servers/tests/mysql/mysql_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result<Box<dyn S
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
tls_server_config,
0,
opts.reject_no_database,
)),
))
Expand Down
1 change: 1 addition & 0 deletions src/servers/tests/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ fn create_postgres_server(
instance,
tls.should_force_tls(),
tls_server_config,
0,
io_runtime,
user_provider,
)))
Expand Down
2 changes: 2 additions & 0 deletions tests-integration/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ pub async fn setup_mysql_server_with_user_provider(
ReloadableTlsServerConfig::try_new(opts.tls.clone())
.expect("Failed to load certificates and keys"),
),
0,
opts.reject_no_database.unwrap_or(false),
)),
));
Expand Down Expand Up @@ -641,6 +642,7 @@ pub async fn setup_pg_server_with_user_provider(
ServerSqlQueryHandlerAdapter::arc(fe_instance_ref),
opts.tls.should_force_tls(),
tls_server_config,
0,
runtime,
user_provider,
)) as Box<dyn Server>);
Expand Down
2 changes: 2 additions & 0 deletions tests-integration/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ watch = false
enable = true
addr = "127.0.0.1:4002"
runtime_size = 2
keep_alive = "0s"
[mysql.tls]
mode = "disable"
Expand All @@ -934,6 +935,7 @@ watch = false
enable = true
addr = "127.0.0.1:4003"
runtime_size = 2
keep_alive = "0s"
[postgres.tls]
mode = "disable"
Expand Down

0 comments on commit e22aa81

Please sign in to comment.