From 8ebb10ef04ba3f2ac2e896a46531be76cce7f894 Mon Sep 17 00:00:00 2001 From: Isaac Date: Wed, 6 Nov 2024 21:20:32 -0800 Subject: [PATCH 1/2] add sqlite worker_queue implementation --- Cargo.toml | 59 ++++--- src/bgworker/mod.rs | 67 ++++++- src/bgworker/sqlt.rs | 405 +++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 35 ++++ 4 files changed, 541 insertions(+), 25 deletions(-) create mode 100644 src/bgworker/sqlt.rs diff --git a/Cargo.toml b/Cargo.toml index ea8c63ca5..ad57098a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,15 @@ rust-version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["auth_jwt", "cli", "with-db", "cache_inmem", "bg_redis", "bg_pg"] +default = [ + "auth_jwt", + "cli", + "with-db", + "cache_inmem", + "bg_redis", + "bg_pg", + "bg_sqlt", +] auth_jwt = ["dep:jsonwebtoken"] cli = ["dep:clap"] testing = ["dep:axum-test"] @@ -37,6 +45,7 @@ storage_gcp = ["object_store/gcp"] cache_inmem = ["dep:moka"] bg_redis = ["dep:rusty-sidekiq", "dep:bb8"] bg_pg = ["dep:sqlx", "dep:ulid"] +bg_sqlt = ["dep:sqlx", "dep:ulid"] [dependencies] loco-gen = { version = "0.12.0", path = "./loco-gen" } @@ -48,10 +57,10 @@ colored = "2" sea-orm = { version = "1.1.0", features = [ - "sqlx-postgres", # `DATABASE_DRIVER` feature - "sqlx-sqlite", - "runtime-tokio-rustls", - "macros", + "sqlx-postgres", # `DATABASE_DRIVER` feature + "sqlx-sqlite", + "runtime-tokio-rustls", + "macros", ], optional = true } tokio = { version = "1.33.0", default-features = false } @@ -75,10 +84,10 @@ fs-err = "2.11.0" tera = "1.19.1" heck = "0.4.0" lettre = { version = "0.11.4", default-features = false, features = [ - "builder", - "hostname", - "smtp-transport", - "tokio1-rustls-tls", + "builder", + "hostname", + "smtp-transport", + "tokio1-rustls-tls", ] } include_dir = "0.7.3" thiserror = { workspace = true } @@ -125,9 +134,11 @@ moka = { version = "0.12.7", features = ["sync"], optional = true } tokio-cron-scheduler = { version = "0.11.0", features = ["signal"] } english-to-cron = { version = "0.1.2" } +# bg_sqlt: sqlite workers # bg_pg: postgres workers sqlx = { version = "0.8.2", default-features = false, features = [ - "postgres", + "postgres", + "sqlite", ], optional = true } ulid = { version = "1", optional = true } @@ -147,26 +158,26 @@ async-trait = { version = "0.1.74" } axum = { version = "0.7.5", features = ["macros"] } tower = "0.4" tower-http = { version = "0.6.1", features = [ - "trace", - "catch-panic", - "timeout", - "add-extension", - "cors", - "fs", - "set-header", - "compression-full", + "trace", + "catch-panic", + "timeout", + "add-extension", + "cors", + "fs", + "set-header", + "compression-full", ] } [dependencies.sea-orm-migration] optional = true version = "1.0.0" features = [ - # Enable at least one `ASYNC_RUNTIME` and `DATABASE_DRIVER` feature if you want to run migration via CLI. - # View the list of supported features at https://www.sea-ql.org/SeaORM/docs/install-and-config/database-and-async-runtime. - # e.g. - "runtime-tokio-rustls", # `ASYNC_RUNTIME` feature - "sqlx-postgres", # `DATABASE_DRIVER` feature - "sqlx-sqlite", + # Enable at least one `ASYNC_RUNTIME` and `DATABASE_DRIVER` feature if you want to run migration via CLI. + # View the list of supported features at https://www.sea-ql.org/SeaORM/docs/install-and-config/database-and-async-runtime. + # e.g. + "runtime-tokio-rustls", # `ASYNC_RUNTIME` feature + "sqlx-postgres", # `DATABASE_DRIVER` feature + "sqlx-sqlite", ] [package.metadata.docs.rs] diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 0e5754d42..85b6edfc3 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -8,10 +8,15 @@ use tracing::{debug, error}; pub mod pg; #[cfg(feature = "bg_redis")] pub mod skq; +#[cfg(feature = "bg_sqlt")] +pub mod sqlt; use crate::{ app::AppContext, - config::{self, Config, PostgresQueueConfig, QueueConfig, RedisQueueConfig, WorkerMode}, + config::{ + self, Config, PostgresQueueConfig, QueueConfig, RedisQueueConfig, SqliteQueueConfig, + WorkerMode, + }, Error, Result, }; @@ -29,6 +34,12 @@ pub enum Queue { std::sync::Arc>, pg::RunOpts, ), + #[cfg(feature = "bg_sqlt")] + Sqlite( + sqlt::SqlitePool, + std::sync::Arc>, + sqlt::RunOpts, + ), None, } @@ -62,6 +73,18 @@ impl Queue { .await .map_err(Box::from)?; } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(pool, _, _) => { + sqlt::enqueue( + pool, + &class, + serde_json::to_value(args)?, + chrono::Utc::now(), + None, + ) + .await + .map_err(Box::from)?; + } _ => {} } Ok(()) @@ -91,6 +114,11 @@ impl Queue { let mut r = registry.lock().await; r.register_worker(W::class_name(), worker)?; } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(_, registry, _) => { + let mut r = registry.lock().await; + r.register_worker(W::class_name(), worker)?; + } _ => {} } Ok(()) @@ -116,6 +144,14 @@ impl Queue { handle.await?; } } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(pool, registry, run_opts) => { + //TODOQ: num workers to config + let handles = registry.lock().await.run(pool, run_opts); + for handle in handles { + handle.await?; + } + } _ => { error!( "no queue provider is configured: compile with at least one queue provider \ @@ -140,6 +176,10 @@ impl Queue { Self::Postgres(pool, _, _) => { pg::initialize_database(pool).await.map_err(Box::from)?; } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(pool, _, _) => { + sqlt::initialize_database(pool).await.map_err(Box::from)?; + } _ => {} } Ok(()) @@ -161,6 +201,10 @@ impl Queue { Self::Postgres(pool, _, _) => { pg::clear(pool).await.map_err(Box::from)?; } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(pool, _, _) => { + sqlt::clear(pool).await.map_err(Box::from)?; + } _ => {} } Ok(()) @@ -182,6 +226,10 @@ impl Queue { Self::Postgres(pool, _, _) => { pg::ping(pool).await.map_err(Box::from)?; } + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(pool, _, _) => { + sqlt::ping(pool).await.map_err(Box::from)?; + } _ => {} } Ok(()) @@ -194,6 +242,8 @@ impl Queue { Self::Redis(_, _, _) => "redis queue".to_string(), #[cfg(feature = "bg_pg")] Self::Postgres(_, _, _) => "postgres queue".to_string(), + #[cfg(feature = "bg_sqlt")] + Self::Sqlite(_, _, _) => "sqlite queue".to_string(), _ => "no queue".to_string(), } } @@ -286,6 +336,17 @@ pub async fn converge(queue: &Queue, config: &QueueConfig) -> Result<()> { num_workers: _, min_connections: _, }) + | QueueConfig::Sqlite(SqliteQueueConfig { + dangerously_flush, + uri: _, + max_connections: _, + enable_logging: _, + connect_timeout: _, + idle_timeout: _, + poll_interval_sec: _, + num_workers: _, + min_connections: _, + }) | QueueConfig::Redis(RedisQueueConfig { dangerously_flush, uri: _, @@ -319,6 +380,10 @@ pub async fn create_queue_provider(config: &Config) -> Result> config::QueueConfig::Postgres(qcfg) => { Ok(Some(Arc::new(pg::create_provider(qcfg).await?))) } + #[cfg(feature = "bg_sqlt")] + config::QueueConfig::Sqlite(qcfg) => { + Ok(Some(Arc::new(sqlt::create_provider(qcfg).await?))) + } #[allow(unreachable_patterns)] _ => Err(Error::string( diff --git a/src/bgworker/sqlt.rs b/src/bgworker/sqlt.rs new file mode 100644 index 000000000..1e656be51 --- /dev/null +++ b/src/bgworker/sqlt.rs @@ -0,0 +1,405 @@ +/// SQLite based background job queue provider +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Duration}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +pub use sqlx::SqlitePool; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteRow}, + ConnectOptions, Row, +}; +use tokio::{task::JoinHandle, time::sleep}; +use tracing::{debug, error, trace}; +use ulid::Ulid; + +use super::{BackgroundWorker, Queue}; +use crate::{config::SqliteQueueConfig, Error, Result}; +type TaskId = String; +type TaskData = JsonValue; +type TaskStatus = String; + +type TaskHandler = Box< + dyn Fn( + TaskId, + TaskData, + ) -> Pin> + Send>> + + Send + + Sync, +>; + +#[derive(Debug, Deserialize, Serialize)] +struct Task { + pub id: TaskId, + pub name: String, + #[allow(clippy::struct_field_names)] + pub task_data: TaskData, + pub status: TaskStatus, + pub run_at: DateTime, + pub interval: Option, +} + +pub struct TaskRegistry { + handlers: Arc>, +} + +impl TaskRegistry { + /// Creates a new `TaskRegistry`. + #[must_use] + pub fn new() -> Self { + Self { + handlers: Arc::new(HashMap::new()), + } + } + + /// Registers a task handler with the provided name. + /// # Errors + /// Fails if cannot register worker + pub fn register_worker(&mut self, name: String, worker: W) -> Result<()> + where + Args: Send + Serialize + Sync + 'static, + W: BackgroundWorker + 'static, + for<'de> Args: Deserialize<'de>, + { + let worker = Arc::new(worker); + let wrapped_handler = move |_task_id: String, task_data: TaskData| { + let w = worker.clone(); + + Box::pin(async move { + let args = serde_json::from_value::(task_data); + match args { + Ok(args) => w.perform(args).await, + Err(err) => Err(err.into()), + } + }) as Pin> + Send>> + }; + + Arc::get_mut(&mut self.handlers) + .ok_or_else(|| Error::string("cannot register worker"))? + .insert(name, Box::new(wrapped_handler)); + Ok(()) + } + + /// Returns a reference to the task handlers. + #[must_use] + pub fn handlers(&self) -> &Arc> { + &self.handlers + } + + /// Runs the task handlers with the provided number of workers. + #[must_use] + pub fn run(&self, pool: &SqlitePool, opts: &RunOpts) -> Vec> { + let mut tasks = Vec::new(); + + let interval = opts.poll_interval_sec; + for idx in 0..opts.num_workers { + let handlers = self.handlers.clone(); + + let pool = pool.clone(); + let task = tokio::spawn(async move { + loop { + trace!( + pool_conns = pool.num_idle(), + worker_num = idx, + "sqlite workers stats" + ); + let task_opt = match dequeue(&pool).await { + Ok(t) => t, + Err(err) => { + error!(err = err.to_string(), "cannot fetch from queue"); + None + } + }; + + if let Some(task) = task_opt { + debug!(task_id = task.id, name = task.name, "working on task"); + if let Some(handler) = handlers.get(&task.name) { + match handler(task.id.clone(), task.task_data.clone()).await { + Ok(()) => { + if let Err(err) = + complete_task(&pool, &task.id, task.interval).await + { + error!( + err = err.to_string(), + task = ?task, + "cannot complete task" + ); + } + } + Err(err) => { + if let Err(err) = fail_task(&pool, &task.id, &err).await { + error!( + err = err.to_string(), + task = ?task, + "cannot fail task" + ); + } + } + } + } else { + error!(task = task.name, "no handler found for task"); + } + } else { + sleep(Duration::from_secs(interval.into())).await; + } + } + }); + + tasks.push(task); + } + + tasks + } +} + +impl Default for TaskRegistry { + fn default() -> Self { + Self::new() + } +} + +async fn connect(cfg: &SqliteQueueConfig) -> Result { + let mut conn_opts: SqliteConnectOptions = cfg.uri.parse()?; + if !cfg.enable_logging { + conn_opts = conn_opts.disable_statement_logging(); + } + let pool = SqlitePoolOptions::new() + .min_connections(cfg.min_connections) + .max_connections(cfg.max_connections) + .idle_timeout(Duration::from_millis(cfg.idle_timeout)) + .acquire_timeout(Duration::from_millis(cfg.connect_timeout)) + .connect_with(conn_opts) + .await?; + Ok(pool) +} + +/// Initialize task tables +/// +/// # Errors +/// +/// This function will return an error if it fails +pub async fn initialize_database(pool: &SqlitePool) -> Result<()> { + debug!("sqlite worker: initialize database"); + sqlx::query( + r" + CREATE TABLE IF NOT EXISTS sqlt_loco_queue ( + id TEXT NOT NULL, + name TEXT NOT NULL, + task_data JSON NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', + run_at TIMESTAMP NOT NULL, + interval INTEGER, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS aquire_queue_write_lock ( + id INTEGER PRIMARY KEY CHECK (id = 1), + is_locked BOOLEAN NOT NULL DEFAULT FALSE, + locked_at TIMESTAMP NULL + ); + + INSERT OR IGNORE INTO aquire_queue_write_lock (id, is_locked) VALUES (1, FALSE); + + CREATE INDEX IF NOT EXISTS idx_sqlt_queue_status_run_at ON sqlt_loco_queue(status, run_at); + ", + ) + .execute(pool) + .await?; + Ok(()) +} + +/// Add a task +/// +/// # Errors +/// +/// This function will return an error if it fails +pub async fn enqueue( + pool: &SqlitePool, + name: &str, + task_data: TaskData, + run_at: DateTime, + interval: Option, +) -> Result { + let task_data_json = serde_json::to_value(task_data)?; + + #[allow(clippy::cast_possible_truncation)] + let interval_ms: Option = interval.map(|i| i.as_millis() as i64); + + let id = Ulid::new().to_string(); + sqlx::query( + "INSERT INTO sqlt_loco_queue (id, task_data, name, run_at, interval) VALUES ($1, $2, $3, \ + DATETIME($4), $5)", + ) + .bind(id.clone()) + .bind(task_data_json) + .bind(name) + .bind(run_at) + .bind(interval_ms) + .execute(pool) + .await?; + Ok(id) +} + +async fn dequeue(client: &SqlitePool) -> Result> { + let mut tx = client.begin().await?; + + let acquired_write_lock = sqlx::query( + "UPDATE aquire_queue_write_lock SET + is_locked = TRUE, + locked_at = CURRENT_TIMESTAMP + WHERE id = 1 AND is_locked = FALSE", + ) + .execute(&mut *tx) + .await?; + + // Couldn't aquire the write lock + if acquired_write_lock.rows_affected() == 0 { + tx.rollback().await?; + return Ok(None); + } + + let row = sqlx::query( + "SELECT id, name, task_data, status, run_at, interval + FROM sqlt_loco_queue + WHERE + status = 'queued' AND + run_at <= CURRENT_TIMESTAMP + ORDER BY run_at LIMIT 1", + ) + // avoid using FromRow because it requires the 'macros' feature, which nothing + // in our dep tree uses, so it'll create smaller, faster builds if we do this manually + .map(|row: SqliteRow| Task { + id: row.get("id"), + name: row.get("name"), + task_data: row.get("task_data"), + status: row.get("status"), + run_at: row.get("run_at"), + interval: row.get("interval"), + }) + .fetch_optional(&mut *tx) + .await?; + + if let Some(task) = row { + sqlx::query( + "UPDATE sqlt_loco_queue SET status = 'processing', updated_at = CURRENT_TIMESTAMP WHERE id = $1", + ) + .bind(&task.id) + .execute(&mut *tx) + .await?; + + // Release the write lock + sqlx::query( + "UPDATE aquire_queue_write_lock + SET is_locked = FALSE, + locked_at = NULL + WHERE id = 1", + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(Some(task)) + } else { + // Release the write lock, no task found + sqlx::query( + "UPDATE aquire_queue_write_lock + SET is_locked = FALSE, + locked_at = NULL + WHERE id = 1", + ) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(None) + } +} + +async fn complete_task( + pool: &SqlitePool, + task_id: &TaskId, + interval_ms: Option, +) -> Result<()> { + if let Some(interval_ms) = interval_ms { + let next_run_at = Utc::now() + chrono::Duration::milliseconds(interval_ms); + sqlx::query( + "UPDATE sqlt_loco_queue SET status = 'queued', updated_at = CURRENT_TIMESTAMP, run_at = DATETIME($1) WHERE id = $2", + ) + .bind(next_run_at) + .bind(task_id) + .execute(pool) + .await?; + } else { + sqlx::query( + "UPDATE sqlt_loco_queue SET status = 'completed', updated_at = CURRENT_TIMESTAMP WHERE id = $1", + ) + .bind(task_id) + .execute(pool) + .await?; + } + Ok(()) +} + +async fn fail_task(pool: &SqlitePool, task_id: &TaskId, error: &crate::Error) -> Result<()> { + let msg = error.to_string(); + error!(err = msg, "failed task"); + let error_json = serde_json::json!({ "error": msg }); + sqlx::query( + "UPDATE sqlt_loco_queue SET status = 'failed', updated_at = CURRENT_TIMESTAMP, task_data = json_patch(task_data, $1) WHERE id = $2", + ) + .bind(error_json) + .bind(task_id) + .execute(pool) + .await?; + Ok(()) +} + +/// Clear all tasks +/// +/// # Errors +/// +/// This function will return an error if it fails +pub async fn clear(pool: &SqlitePool) -> Result<()> { + sqlx::query("DELETE from sqlt_loco_queue") + .execute(pool) + .await?; + Ok(()) +} + +/// Ping system +/// +/// # Errors +/// +/// This function will return an error if it fails +pub async fn ping(pool: &SqlitePool) -> Result<()> { + sqlx::query("SELECT id from sqlt_loco_queue LIMIT 1") + .execute(pool) + .await?; + Ok(()) +} + +#[derive(Debug)] +pub struct RunOpts { + pub num_workers: u32, + pub poll_interval_sec: u32, +} + +/// Create this provider +/// +/// # Errors +/// +/// This function will return an error if it fails +pub async fn create_provider(qcfg: &SqliteQueueConfig) -> Result { + let pool = connect(qcfg).await.map_err(Box::from)?; + let registry = TaskRegistry::new(); + Ok(Queue::Sqlite( + pool, + Arc::new(tokio::sync::Mutex::new(registry)), + RunOpts { + num_workers: qcfg.num_workers, + poll_interval_sec: qcfg.poll_interval_sec, + }, + )) +} diff --git a/src/config.rs b/src/config.rs index e3faec3e4..31eeccb1a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -227,6 +227,8 @@ pub enum QueueConfig { Redis(RedisQueueConfig), /// Postgres queue Postgres(PostgresQueueConfig), + /// Sqlite queue + Sqlite(SqliteQueueConfig), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -272,6 +274,35 @@ pub struct PostgresQueueConfig { pub num_workers: u32, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct SqliteQueueConfig { + pub uri: String, + + #[serde(default)] + pub dangerously_flush: bool, + + #[serde(default)] + pub enable_logging: bool, + + #[serde(default = "db_max_conn")] + pub max_connections: u32, + + #[serde(default = "db_min_conn")] + pub min_connections: u32, + + #[serde(default = "db_connect_timeout")] + pub connect_timeout: u64, + + #[serde(default = "db_idle_timeout")] + pub idle_timeout: u64, + + #[serde(default = "sqlt_poll_interval")] + pub poll_interval_sec: u32, + + #[serde(default = "num_workers")] + pub num_workers: u32, +} + fn db_min_conn() -> u32 { 1 } @@ -292,6 +323,10 @@ fn pgq_poll_interval() -> u32 { 1 } +fn sqlt_poll_interval() -> u32 { + 1 +} + fn num_workers() -> u32 { 2 } From 4a7e6f00b13cf1289b48210371fd96113315abc9 Mon Sep 17 00:00:00 2001 From: Isaac Date: Wed, 6 Nov 2024 21:59:29 -0800 Subject: [PATCH 2/2] add backticks for clippy --- src/bgworker/sqlt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bgworker/sqlt.rs b/src/bgworker/sqlt.rs index 1e656be51..5d8c4b1d7 100644 --- a/src/bgworker/sqlt.rs +++ b/src/bgworker/sqlt.rs @@ -1,4 +1,4 @@ -/// SQLite based background job queue provider +/// `SQLite` based background job queue provider use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Duration}; use chrono::{DateTime, Utc};