From 0aa1dd66a2e63771a555ba965b302cf97c829366 Mon Sep 17 00:00:00 2001 From: Johnny Graettinger Date: Wed, 2 Oct 2024 10:36:26 -0500 Subject: [PATCH] automations: distributed execution of stateful tasks Introduce a new crate `automations` which offers a programming model for distributed execution of arbitrary, stateful tasks modeled as coroutines. A supporting migration introduces table `internal.tasks` which tracks the states of tasks, including task inner states and pending recieved messages. Implement a Fibonacci executor which is a load test and test-bed for automations behaviors, with various tune-able parameters. --- Cargo.lock | 17 + crates/automations/Cargo.toml | 25 ++ crates/automations/src/executors.rs | 306 ++++++++++++++++++ crates/automations/src/lib.rs | 99 ++++++ crates/automations/src/server.rs | 235 ++++++++++++++ crates/automations/tests/fibonacci.rs | 131 ++++++++ crates/automations/tests/test_fibonacci.rs | 107 ++++++ .../migrations/20241012072256_job_queue.sql | 88 +++++ 8 files changed, 1008 insertions(+) create mode 100644 crates/automations/Cargo.toml create mode 100644 crates/automations/src/executors.rs create mode 100644 crates/automations/src/lib.rs create mode 100644 crates/automations/src/server.rs create mode 100644 crates/automations/tests/fibonacci.rs create mode 100644 crates/automations/tests/test_fibonacci.rs create mode 100644 supabase/migrations/20241012072256_job_queue.sql diff --git a/Cargo.lock b/Cargo.lock index 6153a148c3..5bad5e497c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -635,6 +635,23 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "automations" +version = "0.0.0" +dependencies = [ + "anyhow", + "coroutines", + "futures", + "models", + "rand 0.8.5", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "avro" version = "0.0.0" diff --git a/crates/automations/Cargo.toml b/crates/automations/Cargo.toml new file mode 100644 index 0000000000..70a7f166ad --- /dev/null +++ b/crates/automations/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "automations" +version.workspace = true +rust-version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true + +[dependencies] +coroutines = { path = "../coroutines" } +models = { path = "../models", features = ["sqlx-support"] } + +anyhow = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/crates/automations/src/executors.rs b/crates/automations/src/executors.rs new file mode 100644 index 0000000000..27c67ac718 --- /dev/null +++ b/crates/automations/src/executors.rs @@ -0,0 +1,306 @@ +use super::{server, BoxedRaw, Executor, PollOutcome, TaskType}; +use anyhow::Context; +use futures::future::{BoxFuture, FutureExt}; +use sqlx::types::Json as SqlJson; + +/// ObjSafe is an object-safe and type-erased trait which is implemented for all Executors. +pub trait ObjSafe: Send + Sync + 'static { + fn task_type(&self) -> TaskType; + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Option>, + inbox: &'s mut Option)>>>, + ) -> BoxFuture<'s, anyhow::Result>>; +} + +impl ObjSafe for E { + fn task_type(&self) -> TaskType { + E::TASK_TYPE + } + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Option>, + inbox: &'s mut Option)>>>, + ) -> BoxFuture<'s, anyhow::Result>> { + async move { + let mut state_parsed: E::State = if let Some(state) = state { + serde_json::from_str(state.get()).context("failed to decode task state")? + } else { + E::State::default() + }; + + let mut inbox_parsed: std::collections::VecDeque<(models::Id, Option)> = + inbox + .as_ref() + .into_iter() + .flatten() + .map(|SqlJson((task_id, rx))| { + if let Some(rx) = rx { + anyhow::Result::Ok((*task_id, Some(serde_json::from_str(rx.get())?))) + } else { + anyhow::Result::Ok((*task_id, None)) + } + }) + .collect::>() + .context("failed to decode received message")?; + + let outcome = E::poll( + self, + task_id, + parent_id, + &mut state_parsed, + &mut inbox_parsed, + ) + .await?; + + // Re-encode state for persistence. + // If we're Done, then the output state is NULL which is implicitly Default. + if matches!(outcome, PollOutcome::Done) { + *state = None + } else { + *state = Some(SqlJson( + serde_json::value::to_raw_value(&state_parsed) + .context("failed to encode inner state")?, + )); + } + + // Re-encode the unconsumed portion of the inbox. + if inbox_parsed.is_empty() { + *inbox = None + } else { + *inbox = Some( + inbox_parsed + .into_iter() + .map(|(task_id, msg)| { + Ok(SqlJson(( + task_id, + match msg { + Some(msg) => Some(serde_json::value::to_raw_value(&msg)?), + None => None, + }, + ))) + }) + .collect::>>() + .context("failed to encode unconsumed inbox message")?, + ); + } + + Ok(match outcome { + PollOutcome::Done => PollOutcome::Done, + PollOutcome::Send(task_id, msg) => PollOutcome::Send(task_id, msg), + PollOutcome::Sleep(interval) => PollOutcome::Sleep(interval), + PollOutcome::Spawn(task_id, task_type, msg) => { + PollOutcome::Spawn(task_id, task_type, msg) + } + PollOutcome::Suspend => PollOutcome::Suspend, + PollOutcome::Yield(msg) => PollOutcome::Yield( + serde_json::value::to_raw_value(&msg) + .context("failed to encode yielded message")?, + ), + }) + } + .boxed() + } +} + +pub async fn poll_task( + server::ReadyTask { + executor, + permit: _guard, + pool, + task: + server::DequeuedTask { + id: task_id, + type_: _, + parent_id, + mut inbox, + mut state, + mut last_heartbeat, + }, + }: server::ReadyTask, + heartbeat_timeout: std::time::Duration, +) -> anyhow::Result<()> { + let mut heartbeat_ticks = tokio::time::interval(heartbeat_timeout / 2); + let _instant = heartbeat_ticks.tick().await; // Discard immediate first tick. + + // Build a Future which forever maintains our heartbeat or fails. + let update_heartbeats = async { + loop { + let _instant = heartbeat_ticks.tick().await; + + last_heartbeat = + match update_heartbeat(&pool, task_id, heartbeat_timeout, last_heartbeat).await { + Ok(last_heartbeat) => last_heartbeat, + Err(err) => return err, + } + } + }; + tokio::pin!(update_heartbeats); + + // Poll `executor` and `update_heartbeats` in tandem, so that a failure + // to update our heartbeat also cancels the executor. + let outcome = tokio::select! { + outcome = executor.poll(task_id, parent_id, &mut state, &mut inbox) => { outcome? }, + err = &mut update_heartbeats => return Err(err), + }; + + // The possibly long-lived polling operation is now complete. + // Build a Future that commits a (hopefully) brief transaction of `outcome`. + let persist_outcome = async { + let mut txn = pool.begin().await?; + () = persist_outcome(outcome, &mut *txn, task_id, parent_id, state, inbox).await?; + Ok(txn.commit().await?) + }; + + // Poll `persist_outcome` while continuing to poll `update_heartbeats`, + // to guarantee we cannot commit an outcome after our lease is lost. + tokio::select! { + result = persist_outcome => result, + err = update_heartbeats => Err(err), + } +} + +async fn update_heartbeat( + pool: &sqlx::PgPool, + task_id: models::Id, + heartbeat_timeout: std::time::Duration, + expect_heartbeat: String, +) -> anyhow::Result { + let update = sqlx::query!( + r#" + UPDATE internal.tasks + SET heartbeat = NOW() + WHERE task_id = $1 AND heartbeat::TEXT = $2 + RETURNING heartbeat::TEXT AS "heartbeat!"; + "#, + task_id as models::Id, + expect_heartbeat, + ) + .fetch_optional(pool); + + // We must guard against both explicit errors and also timeouts when updating + // the heartbeat, to ensure we bubble up an error that cancels our paired + // executor prior to `heartbeat_timeout` elapsing. + let updated = match tokio::time::timeout(heartbeat_timeout / 4, update).await { + Ok(Ok(Some(updated))) => updated, + Ok(Ok(None)) => anyhow::bail!("task heartbeat was unexpectedly updated externally"), + Ok(Err(err)) => return Err(anyhow::anyhow!(err).context("failed to update task heartbeat")), + Err(err) => return Err(anyhow::anyhow!(err).context("timed out updating task heartbeat")), + }; + + tracing::info!( + last = expect_heartbeat, + next = updated.heartbeat, + "updated task heartbeat" + ); + + Ok(updated.heartbeat) +} + +async fn persist_outcome( + outcome: PollOutcome, + txn: &mut sqlx::PgConnection, + task_id: models::Id, + parent_id: Option, + state: Option>, + inbox: Option)>>>, +) -> anyhow::Result<()> { + use std::time::Duration; + + if let PollOutcome::Spawn(spawn_id, spawn_type, _msg) = &outcome { + sqlx::query!( + "SELECT internal.create_task($1, $2, $3)", + *spawn_id as models::Id, + *spawn_type as TaskType, + task_id as models::Id, + ) + .execute(&mut *txn) + .await + .context("failed to spawn new task")?; + } + + if let Some((send_id, msg)) = match &outcome { + // When a task is spawned, send its first message. + PollOutcome::Spawn(spawn_id, _spawn_type, msg) => Some((*spawn_id, Some(msg))), + // If we're Done but have a parent, send it an EOF. + PollOutcome::Done => parent_id.map(|parent_id| (parent_id, None)), + // Send an arbitrary message to an identified task. + PollOutcome::Send(task_id, msg) => Some((*task_id, msg.as_ref())), + // Yield is sugar for sending to our parent. + PollOutcome::Yield(msg) => { + let Some(parent_id) = parent_id else { + anyhow::bail!("task yielded illegally, because it does not have a parent"); + }; + Some((parent_id, Some(msg))) + } + _ => None, + } { + sqlx::query!( + "SELECT internal.send_to_task($1, $2, $3::JSON);", + send_id as models::Id, + task_id as models::Id, + SqlJson(msg) as SqlJson<_>, + ) + .execute(&mut *txn) + .await + .with_context(|| format!("failed to send message to {send_id:?}"))?; + } + + let wake_at_interval = if inbox.is_some() { + Some(Duration::ZERO) // Always poll immediately if inbox items remain. + } else { + match &outcome { + PollOutcome::Sleep(interval) => Some(*interval), + // These outcomes do not suspend the task, and it should wake as soon as possible. + PollOutcome::Spawn(..) | PollOutcome::Send(..) | PollOutcome::Yield(..) => { + Some(Duration::ZERO) + } + // Suspend indefinitely (note that NOW() + NULL::INTERVAL is NULL). + PollOutcome::Done | PollOutcome::Suspend => None, + } + }; + + let updated = sqlx::query!( + r#" + UPDATE internal.tasks SET + heartbeat = '0001-01-01T00:00:00Z', + inbox = $3::JSON[] || inbox_next, + inbox_next = NULL, + inner_state = $2::JSON, + wake_at = + CASE WHEN inbox_next IS NOT NULL + THEN NOW() + ELSE NOW() + $4::INTERVAL + END + WHERE task_id = $1 + RETURNING wake_at IS NULL AS "suspended!" + "#, + task_id as models::Id, + state as Option>, + inbox as Option)>>>, + wake_at_interval as Option, + ) + .fetch_one(&mut *txn) + .await + .context("failed to update task row")?; + + // If we're Done and also successfully suspended, then delete ourselves. + // (Otherwise, the task has been left in a like-new state). + if matches!(&outcome, PollOutcome::Done if updated.suspended) { + sqlx::query!( + "DELETE FROM internal.tasks WHERE task_id = $1;", + task_id as models::Id, + ) + .execute(&mut *txn) + .await + .context("failed to delete task row")?; + } + + Ok(()) +} diff --git a/crates/automations/src/lib.rs b/crates/automations/src/lib.rs new file mode 100644 index 0000000000..9a43187396 --- /dev/null +++ b/crates/automations/src/lib.rs @@ -0,0 +1,99 @@ +use anyhow::Context; +use std::sync::Arc; + +mod executors; +mod server; + +/// BoxedRaw is a type-erased raw JSON message. +type BoxedRaw = Box; + +/// TaskType is the type of a task, and maps it to an Executor. +#[derive( + Debug, + serde::Deserialize, + serde::Serialize, + sqlx::Type, + PartialOrd, + PartialEq, + Ord, + Eq, + Clone, + Copy, +)] +#[sqlx(transparent)] +pub struct TaskType(pub i16); + +/// PollOutcome is the outcome of an `Executor::poll()` for a given task. +#[derive(Debug)] +pub enum PollOutcome { + /// Spawn a new TaskId with the given TaskType and send a first message. + /// The TaskId must not exist. + Spawn(models::Id, TaskType, BoxedRaw), + /// Send a message (Some) or EOF (None) to another TaskId, which must exist. + Send(models::Id, Option), + /// Yield to send a message to this task's parent. + Yield(Yield), + /// Sleep for at-most the indicated Duration, then poll again. + /// The task may be woken earlier if it receives a message. + Sleep(std::time::Duration), + /// Suspend the task until it receives a message. + Suspend, + /// Done completes and removes the task. + /// If this task has a parent, that parent is sent an EOF. + Done, +} + +/// Executor is the core trait implemented by executors of various task types. +pub trait Executor: Send + Sync + 'static { + const TASK_TYPE: TaskType; + + type Receive: serde::de::DeserializeOwned + serde::Serialize + Send; + type State: Default + serde::de::DeserializeOwned + serde::Serialize + Send; + type Yield: serde::Serialize; + + fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Self::State, + inbox: &'s mut std::collections::VecDeque<(models::Id, Option)>, + ) -> impl std::future::Future>> + Send + 's; +} + +/// Server holds registered implementations of Executor, +/// and serves them. +pub struct Server(Vec>); + +impl PollOutcome { + pub fn spawn( + spawn_id: models::Id, + task_type: TaskType, + msg: M, + ) -> anyhow::Result { + Ok(Self::Spawn( + spawn_id, + task_type, + serde_json::value::to_raw_value(&msg).context("failed to encode task spawn message")?, + )) + } + + pub fn send(task_id: models::Id, msg: Option) -> anyhow::Result { + Ok(Self::Send( + task_id, + match msg { + Some(msg) => Some( + serde_json::value::to_raw_value(&msg) + .context("failed to encode sent message")?, + ), + None => None, + }, + )) + } +} + +pub fn next_task_id() -> models::Id { + static ID_GENERATOR: std::sync::LazyLock> = + std::sync::LazyLock::new(|| std::sync::Mutex::new(models::IdGenerator::new(1))); + + ID_GENERATOR.lock().unwrap().next() +} diff --git a/crates/automations/src/server.rs b/crates/automations/src/server.rs new file mode 100644 index 0000000000..92f0055156 --- /dev/null +++ b/crates/automations/src/server.rs @@ -0,0 +1,235 @@ +use super::{executors, BoxedRaw, Executor, Server, TaskType}; +use futures::stream::StreamExt; +use sqlx::types::Json as SqlJson; +use std::sync::Arc; + +impl Server { + pub const fn new() -> Self { + Self(Vec::new()) + } + + /// Register an Executor to be served by this Server. + pub fn register(mut self, executor: E) -> Self { + let index = match self + .0 + .binary_search_by_key(&E::TASK_TYPE, |entry| entry.task_type()) + { + Ok(_index) => panic!("an Executor for {:?} is already registered", E::TASK_TYPE), + Err(index) => index, + }; + + self.0.insert(index, Arc::new(executor)); + self + } + + /// Serve this Server until signaled to stop by `shutdown`. + pub async fn serve( + self, + permits: u32, + pool: sqlx::PgPool, + dequeue_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + shutdown: impl std::future::Future, + ) { + serve( + self, + permits, + pool, + dequeue_interval, + heartbeat_timeout, + shutdown, + ) + .await + } +} + +pub struct ReadyTask { + pub executor: Arc, + pub permit: tokio::sync::OwnedSemaphorePermit, + pub pool: sqlx::PgPool, + pub task: DequeuedTask, +} + +pub struct DequeuedTask { + pub id: models::Id, + pub type_: TaskType, + pub parent_id: Option, + pub inbox: Option)>>>, + pub state: Option>, + pub last_heartbeat: String, +} + +pub async fn serve( + executors: Server, + permits: u32, + pool: sqlx::PgPool, + dequeue_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + shutdown: impl std::future::Future, +) { + let semaphore = Arc::new(tokio::sync::Semaphore::new(permits as usize)); + + // Use Box::pin to ensure we can fullly drop `ready_tasks` later, + // as it may hold `semaphore` permits. + let mut ready_tasks = Box::pin(ready_tasks( + executors, + pool.clone(), + dequeue_interval, + heartbeat_timeout, + semaphore.clone(), + )); + tokio::pin!(shutdown); + + // Poll for ready tasks and start them until `shutdown` is signaled. + while let Some(ready_tasks) = tokio::select! { + ready = ready_tasks.next() => ready, + () = &mut shutdown => None, + } { + let ready_tasks: Vec = match ready_tasks { + Ok(tasks) => tasks, + Err(err) => { + tracing::error!(?err, "failed to poll for tasks (will retry)"); + Vec::new() + } + }; + + for ready in ready_tasks { + tokio::spawn(async move { + let (task_id, task_type, parent_id) = + (ready.task.id, ready.task.type_, ready.task.parent_id); + + if let Err(err) = executors::poll_task(ready, heartbeat_timeout).await { + tracing::warn!( + ?task_id, + ?task_type, + ?parent_id, + ?err, + "task executor failed and will be retried after heartbeat timeout" + ); + // The task will be retried once it's heartbeat times out. + } + }); + } + } + tracing::info!("task polling loop signaled to stop and is awaiting running tasks"); + std::mem::drop(ready_tasks); + + // Acquire all permits, when only happens after all running tasks have finished. + let _ = semaphore.acquire_many_owned(permits).await.unwrap(); +} + +pub fn ready_tasks( + executors: Server, + pool: sqlx::PgPool, + dequeue_interval: std::time::Duration, + heartbeat_timeout: std::time::Duration, + semaphore: Arc, +) -> impl futures::stream::Stream>> { + let task_types: Vec<_> = executors.0.iter().map(|e| e.task_type().0).collect(); + + coroutines::coroutine(move |mut co| async move { + loop { + () = ready_tasks_iter( + &mut co, + &executors, + heartbeat_timeout, + dequeue_interval, + &pool, + &semaphore, + &task_types, + ) + .await; + } + }) +} + +async fn ready_tasks_iter( + co: &mut coroutines::Suspend>, ()>, + executors: &Server, + heartbeat_timeout: std::time::Duration, + dequeue_interval: std::time::Duration, + pool: &sqlx::PgPool, + semaphore: &Arc, + task_types: &[i16], +) { + // Block until at least one permit is available. + if semaphore.available_permits() == 0 { + let _ = semaphore.clone().acquire_owned().await.unwrap(); + } + + // Acquire all available permits, and then poll for up to that many tasks. + let mut permits = semaphore + .clone() + .acquire_many_owned(semaphore.available_permits() as u32) + .await + .unwrap(); + + let dequeued = sqlx::query_as!( + DequeuedTask, + r#" + WITH picked AS ( + SELECT task_id + FROM internal.tasks + WHERE + task_type = ANY($1) AND + wake_at < NOW() AND + heartbeat < NOW() - $2::INTERVAL + ORDER BY wake_at DESC + LIMIT $3 + FOR UPDATE SKIP LOCKED + ) + UPDATE internal.tasks + SET heartbeat = NOW() + WHERE task_id in (SELECT task_id FROM picked) + RETURNING + task_id as "id: models::Id", + task_type as "type_: TaskType", + parent_id as "parent_id: models::Id", + inbox as "inbox: Vec)>>", + inner_state as "state: SqlJson", + heartbeat::TEXT as "last_heartbeat!"; + "#, + &task_types as &[i16], + heartbeat_timeout as std::time::Duration, + permits.num_permits() as i64, + ) + .fetch_all(pool) + .await; + + let dequeued = match dequeued { + Ok(dequeued) => { + tracing::debug!(dequeued = dequeued.len(), "completed task dequeue"); + dequeued + } + Err(err) => { + () = co.yield_(Err(err)).await; + Vec::new() // We'll sleep as if it were idle, then retry. + } + }; + + let ready = dequeued + .into_iter() + .map(|task| { + let Ok(index) = task_types.binary_search(&task.type_.0) else { + panic!("polled {:?} with unexpected {:?}", task.id, task.type_); + }; + ReadyTask { + task, + executor: executors.0[index].clone(), + permit: permits.split(1).unwrap(), + pool: pool.clone(), + } + }) + .collect(); + + () = co.yield_(Ok(ready)).await; + + // If permits remain, there were not enough tasks to dequeue. + // Sleep for up-to `dequeue_interval`, cancelling early if a task completes. + if permits.num_permits() != 0 { + tokio::select! { + () = tokio::time::sleep(dequeue_interval) => (), + _ = semaphore.clone().acquire_owned() => (), // Cancel sleep. + } + } +} diff --git a/crates/automations/tests/fibonacci.rs b/crates/automations/tests/fibonacci.rs new file mode 100644 index 0000000000..83d600a42e --- /dev/null +++ b/crates/automations/tests/fibonacci.rs @@ -0,0 +1,131 @@ +use automations::PollOutcome; +use std::collections::VecDeque; + +/// Fibonacci is one of the least-efficient calculators of the Fibonacci +/// sequence on the planet. It solves in exponential time by spawning two +/// sub-tasks in the recursive case, and does not re-use the results of +/// sub-computations. +pub struct Fibonacci { + // Percentage of the time that task polls should randomly fail. + // Value should be in range [0, 1) where 0 never fails. + pub failure_rate: f32, + // Amount of time to wait before allowing a poll to complete. + pub sleep_for: std::time::Duration, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Message { + pub value: i64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub enum State { + // Init spawns the first task and transitions to SpawnOne + Init, + // SpawnOne spawns the second task and transitions to Waiting + SpawnOne(i64), + // Waiting waits for `pending` child tasks to complete, + // accumulating their yielded values, and then transitions to Sleeping. + Waiting { partial: i64, pending: usize }, + Finished, +} + +impl Default for State { + fn default() -> Self { + Self::Init + } +} + +impl automations::Executor for Fibonacci { + const TASK_TYPE: automations::TaskType = automations::TaskType(32767); + + type Receive = Message; + type Yield = Message; + type State = State; + + #[tracing::instrument( + ret, + err(Debug, level = tracing::Level::ERROR), + skip_all, + fields(?task_id, ?parent_id, ?state, ?inbox), + )] + async fn poll<'s>( + &'s self, + task_id: models::Id, + parent_id: Option, + state: &'s mut Self::State, + inbox: &'s mut VecDeque<(models::Id, Option)>, + ) -> anyhow::Result> { + if rand::random::() < self.failure_rate { + return Err( + anyhow::anyhow!("A no good, very bad error!").context("something bad happened") + ); + } + + if let State::SpawnOne(value) = state { + let spawn = PollOutcome::spawn( + automations::next_task_id(), + Self::TASK_TYPE, + Message { value: *value - 2 }, + ); + *state = State::Waiting { + partial: 0, + pending: 2, + }; + + return spawn; + } + + match (std::mem::take(state), inbox.pop_front()) { + // Base case: + (State::Init, Some((_parent_id, Some(Message { value })))) if value <= 2 => { + *state = State::Finished; + Ok(PollOutcome::Yield(Message { value: 1 })) + } + + // Recursive case: + (State::Init, Some((_parent_id, Some(Message { value })))) => { + *state = State::SpawnOne(value); + + PollOutcome::spawn( + automations::next_task_id(), + Self::TASK_TYPE, + Message { value: value - 1 }, + ) + } + + (State::Waiting { partial, pending }, None) => { + *state = State::Waiting { partial, pending }; + // Sleeping at this point in the lifecycle exercises handling of + // messages sent to a task that's currently being polled. + () = tokio::time::sleep(self.sleep_for).await; + Ok(PollOutcome::Suspend) + } + + (State::Waiting { partial, pending }, Some((_child_id, Some(Message { value })))) => { + *state = State::Waiting { + partial: partial + value, + pending, + }; + Ok(PollOutcome::Suspend) + } + + (State::Waiting { partial, pending }, Some((_child_id, None))) => { + if pending != 1 || parent_id.is_none() { + *state = State::Waiting { + partial, + pending: pending - 1, + }; + Ok(PollOutcome::Suspend) + } else { + *state = State::Finished; + Ok(PollOutcome::Yield(Message { value: partial })) + } + } + + (State::Finished, None) => Ok(PollOutcome::Done), + + state => anyhow::bail!("unexpected poll with state {state:?} and inbox {inbox:?}"), + } + } +} diff --git a/crates/automations/tests/test_fibonacci.rs b/crates/automations/tests/test_fibonacci.rs new file mode 100644 index 0000000000..badd743eb6 --- /dev/null +++ b/crates/automations/tests/test_fibonacci.rs @@ -0,0 +1,107 @@ +use std::time::Duration; + +mod fibonacci; + +// Percentage of the time that task polls should randomly fail. +// Value should be in range [0, 1) where 0 never fails. +const FAILURE_RATE: f32 = 0.00; +// Fibonacci sequence index to calculate. +// Larger numbers require exponentially more work. +const SEQUENCE: i64 = 10; +// Expected value at `SEQUENCE` offset. +const EXPECT_VALUE: i64 = 55; +// Number of concurrent polls that may run. +const CONCURRENCY: u32 = 50; +// When idle, the interval between polls for ready-to-run tasks. +// Note that `automations` will also poll after task completions. +const DEQUEUE_INTERVAL: Duration = Duration::from_secs(5); +// The timeout before a task poll is considered to have failed, +// and is eligible for retry. +const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(10); +// Amount of time each poll sleeps before responding. +const SLEEP_FOR: Duration = Duration::from_secs(0); +// Database under test. +const FIXED_DATABASE_URL: &str = "postgresql://postgres:postgres@localhost:5432/postgres"; + +#[tokio::test] +async fn test_fibonacci_bench() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(tracing::level_filters::LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + + let pool = sqlx::postgres::PgPool::connect(&FIXED_DATABASE_URL) + .await + .expect("connect"); + + // This cleanup is not required for correctness, but makes it nicer to + // visually review the internal.tasks table. + sqlx::query!("DELETE FROM internal.tasks WHERE task_type = 32767;") + .execute(&pool) + .await + .unwrap(); + + let root_id = automations::next_task_id(); + + sqlx::query!( + "SELECT internal.create_task($1, 32767::SMALLINT, NULL::public.flowid)", + root_id as models::Id + ) + .execute(&pool) + .await + .unwrap(); + + sqlx::query!( + r#"SELECT internal.send_to_task($1, '00:00:00:00:00:00:00:00'::flowid, $2::JSON)"#, + root_id as models::Id, + sqlx::types::Json(fibonacci::Message { value: SEQUENCE }) + as sqlx::types::Json + ) + .execute(&pool) + .await + .unwrap(); + + let monitor = async { + let mut ticker = tokio::time::interval(Duration::from_millis(500)); + + loop { + let _instant = ticker.tick().await; + + let record = sqlx::query!( + r#"SELECT inner_state as "state: sqlx::types::Json" FROM internal.tasks WHERE task_id = $1"#, + root_id as models::Id + ) + .fetch_one(&pool) + .await + .unwrap(); + + if let Some(sqlx::types::Json(fibonacci::State::Waiting { + partial, + pending: 0, + })) = record.state + { + tracing::info!(value = partial, "completed Fibonacci sequence"); + assert_eq!(partial, EXPECT_VALUE); + break; + } + } + }; + + () = automations::Server::new() + .register(fibonacci::Fibonacci { + failure_rate: FAILURE_RATE, + sleep_for: SLEEP_FOR, + }) + .serve( + CONCURRENCY, + pool.clone(), + DEQUEUE_INTERVAL, + HEARTBEAT_TIMEOUT, + monitor, + ) + .await; +} diff --git a/supabase/migrations/20241012072256_job_queue.sql b/supabase/migrations/20241012072256_job_queue.sql new file mode 100644 index 0000000000..9d164b0215 --- /dev/null +++ b/supabase/migrations/20241012072256_job_queue.sql @@ -0,0 +1,88 @@ +BEGIN; + +CREATE TABLE internal.tasks ( + task_id public.flowid PRIMARY KEY NOT NULL, + task_type SMALLINT NOT NULL, + parent_id public.flowid, + + inner_state JSON, + + wake_at TIMESTAMPTZ, + inbox JSON[], + inbox_next JSON[], + + heartbeat TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01T00:00:00Z' +); + +CREATE INDEX idx_tasks_ready_at ON internal.tasks + USING btree (wake_at) INCLUDE (task_type); + +COMMENT ON TABLE internal.tasks IS ' +The tasks table supports a distributed and asynchronous task execution system +implemented in the Rust "automations" crate. + +Tasks are poll-able coroutines which are identified by task_id and have a task_type. +They may be short-lived and polled just once, or very long-lived and polled +many times over their life-cycle. + +Tasks are polled by executors which dequeue from the tasks table and run +bespoke executors parameterized by the task type. A polling routine may take +an arbitrarily long amount of time to finish, and the executor +is required to periodically update the task heartbeat as it runs. + +A task is polled by at-most one executor at a time. Executor failures are +detected through a failure to update the task heartbeat within a threshold amount +of time, which makes the task re-eligible for dequeue by another executor. + +Tasks are coroutines and may send messages to one another, which is tracked in the +inbox of each task and processed by the task executor. If a task is currently being +polled (its heartbeat is not the DEFAULT), then messages accrue in inbox_next. +'; + + +CREATE FUNCTION internal.create_task( + p_task_id public.flowid, + p_task_type SMALLINT, + p_parent_id public.flowid +) +RETURNS VOID +SET search_path = '' +AS $$ +BEGIN + + INSERT INTO internal.tasks (task_id, task_type, parent_id) + VALUES (p_task_id, p_task_type, p_parent_id); + +END; +$$ LANGUAGE plpgsql; + + +CREATE FUNCTION internal.send_to_task( + p_task_id public.flowid, + p_from_id public.flowid, + p_message JSON +) +RETURNS VOID +SET search_path = '' +AS $$ +BEGIN + + UPDATE internal.tasks SET + wake_at = LEAST(wake_at, NOW()), + inbox = + CASE WHEN heartbeat = '0001-01-01T00:00:00Z' + THEN ARRAY_APPEND(inbox, JSON_BUILD_ARRAY(p_from_id, p_message)) + ELSE inbox + END, + inbox_next = + CASE WHEN heartbeat = '0001-01-01T00:00:00Z' + THEN inbox_next + ELSE ARRAY_APPEND(inbox_next, JSON_BUILD_ARRAY(p_from_id, p_message)) + END + WHERE task_id = p_task_id; + +END; +$$ LANGUAGE plpgsql; + + +COMMIT; \ No newline at end of file