diff --git a/wfe-sqlite/Cargo.toml b/wfe-sqlite/Cargo.toml new file mode 100644 index 0000000..cc638cf --- /dev/null +++ b/wfe-sqlite/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "wfe-sqlite" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "SQLite persistence provider for WFE" + +[dependencies] +wfe-core = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +wfe-core = { workspace = true, features = ["test-support"] } +pretty_assertions = { workspace = true } +rstest = { workspace = true } +tokio = { workspace = true, features = ["test-util"] } diff --git a/wfe-sqlite/src/lib.rs b/wfe-sqlite/src/lib.rs new file mode 100644 index 0000000..c038fc7 --- /dev/null +++ b/wfe-sqlite/src/lib.rs @@ -0,0 +1,1120 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use sqlx::{Row, SqlitePool}; + +use wfe_core::models::{ + CommandName, Event, EventSubscription, ExecutionError, ExecutionPointer, ScheduledCommand, + WorkflowInstance, WorkflowStatus, +}; +use wfe_core::traits::{ + EventRepository, PersistenceProvider, ScheduledCommandRepository, SubscriptionRepository, + WorkflowRepository, +}; +use wfe_core::{Result, WfeError}; + +/// SQLite-backed persistence provider for the WFE workflow engine. +pub struct SqlitePersistenceProvider { + pool: SqlitePool, +} + +impl SqlitePersistenceProvider { + /// Create a new provider connected to the given SQLite database URL. + /// + /// For in-memory databases, pass `":memory:"`. + /// The schema tables are created automatically. + pub async fn new(database_url: &str) -> std::result::Result> { + let options: SqliteConnectOptions = database_url + .parse::()? + .create_if_missing(true) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); + + let max_connections = if database_url.contains(":memory:") { + 1 + } else { + 4 + }; + + let pool = SqlitePoolOptions::new() + .max_connections(max_connections) + .connect_with(options) + .await?; + + // Enable WAL mode and foreign keys + sqlx::query("PRAGMA foreign_keys = ON") + .execute(&pool) + .await?; + + let provider = Self { pool }; + provider.ensure_store_exists().await?; + Ok(provider) + } + + /// Run the DDL statements to create tables and indexes. + async fn create_tables(&self) -> std::result::Result<(), sqlx::Error> { + sqlx::query( + "CREATE TABLE IF NOT EXISTS workflows ( + id TEXT PRIMARY KEY, + definition_id TEXT NOT NULL, + version INTEGER NOT NULL, + description TEXT, + reference TEXT, + status TEXT NOT NULL, + data TEXT NOT NULL, + next_execution INTEGER, + create_time TEXT NOT NULL, + complete_time TEXT + )", + ) + .execute(&self.pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS execution_pointers ( + id TEXT PRIMARY KEY, + workflow_id TEXT NOT NULL, + step_id INTEGER NOT NULL, + active INTEGER NOT NULL DEFAULT 1, + status TEXT NOT NULL, + sleep_until TEXT, + persistence_data TEXT, + start_time TEXT, + end_time TEXT, + event_name TEXT, + event_key TEXT, + event_published INTEGER NOT NULL DEFAULT 0, + event_data TEXT, + step_name TEXT, + retry_count INTEGER NOT NULL DEFAULT 0, + children TEXT NOT NULL DEFAULT '[]', + context_item TEXT, + predecessor_id TEXT, + outcome TEXT, + scope TEXT NOT NULL DEFAULT '[]', + extension_attributes TEXT NOT NULL DEFAULT '{}', + FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE CASCADE + )", + ) + .execute(&self.pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS events ( + id TEXT PRIMARY KEY, + event_name TEXT NOT NULL, + event_key TEXT NOT NULL, + event_data TEXT NOT NULL, + event_time TEXT NOT NULL, + is_processed INTEGER NOT NULL DEFAULT 0 + )", + ) + .execute(&self.pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS event_subscriptions ( + id TEXT PRIMARY KEY, + workflow_id TEXT NOT NULL, + step_id INTEGER NOT NULL, + execution_pointer_id TEXT NOT NULL, + event_name TEXT NOT NULL, + event_key TEXT NOT NULL, + subscribe_as_of TEXT NOT NULL, + subscription_data TEXT, + external_token TEXT, + external_worker_id TEXT, + external_token_expiry TEXT, + terminated INTEGER NOT NULL DEFAULT 0 + )", + ) + .execute(&self.pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS execution_errors ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + error_time TEXT NOT NULL, + workflow_id TEXT NOT NULL, + execution_pointer_id TEXT NOT NULL, + message TEXT NOT NULL + )", + ) + .execute(&self.pool) + .await?; + + sqlx::query( + "CREATE TABLE IF NOT EXISTS scheduled_commands ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + command_name TEXT NOT NULL, + data TEXT NOT NULL, + execute_time INTEGER NOT NULL, + UNIQUE(command_name, data) + )", + ) + .execute(&self.pool) + .await?; + + // Indexes + sqlx::query("CREATE INDEX IF NOT EXISTS idx_workflows_next_execution ON workflows(next_execution)") + .execute(&self.pool) + .await?; + sqlx::query( + "CREATE INDEX IF NOT EXISTS idx_workflows_status ON workflows(status)", + ) + .execute(&self.pool) + .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_execution_pointers_workflow_id ON execution_pointers(workflow_id)") + .execute(&self.pool) + .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_events_name_key ON events(event_name, event_key)") + .execute(&self.pool) + .await?; + sqlx::query( + "CREATE INDEX IF NOT EXISTS idx_events_is_processed ON events(is_processed)", + ) + .execute(&self.pool) + .await?; + sqlx::query( + "CREATE INDEX IF NOT EXISTS idx_events_event_time ON events(event_time)", + ) + .execute(&self.pool) + .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_event_subscriptions_name_key ON event_subscriptions(event_name, event_key)") + .execute(&self.pool) + .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_event_subscriptions_workflow_id ON event_subscriptions(workflow_id)") + .execute(&self.pool) + .await?; + sqlx::query("CREATE INDEX IF NOT EXISTS idx_scheduled_commands_execute_time ON scheduled_commands(execute_time)") + .execute(&self.pool) + .await?; + + Ok(()) + } +} + +// ─── Helpers ─────────────────────────────────────────────────────────────── + +fn to_persistence_err(e: sqlx::Error) -> WfeError { + WfeError::Persistence(e.to_string()) +} + +fn dt_to_string(dt: &DateTime) -> String { + dt.to_rfc3339() +} + +fn opt_dt_to_string(dt: &Option>) -> Option { + dt.as_ref().map(dt_to_string) +} + +fn string_to_dt(s: &str) -> std::result::Result, WfeError> { + s.parse::>() + .map_err(|e| WfeError::Persistence(format!("Failed to parse datetime '{s}': {e}"))) +} + +fn string_to_opt_dt(s: &Option) -> std::result::Result>, WfeError> { + match s { + Some(s) => Ok(Some(string_to_dt(s)?)), + None => Ok(None), + } +} + +fn row_to_workflow( + row: &sqlx::sqlite::SqliteRow, + pointers: Vec, +) -> std::result::Result { + let status_str: String = row.try_get("status").map_err(to_persistence_err)?; + let status: WorkflowStatus = + serde_json::from_str(&format!("\"{status_str}\"")).map_err(|e| { + WfeError::Persistence(format!("Failed to deserialize WorkflowStatus: {e}")) + })?; + + let data_str: String = row.try_get("data").map_err(to_persistence_err)?; + let data: serde_json::Value = serde_json::from_str(&data_str) + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize data: {e}")))?; + + let create_time_str: String = row.try_get("create_time").map_err(to_persistence_err)?; + let complete_time_str: Option = + row.try_get("complete_time").map_err(to_persistence_err)?; + + Ok(WorkflowInstance { + id: row.try_get("id").map_err(to_persistence_err)?, + workflow_definition_id: row.try_get("definition_id").map_err(to_persistence_err)?, + version: row + .try_get::("version") + .map_err(to_persistence_err)? as u32, + description: row.try_get("description").map_err(to_persistence_err)?, + reference: row.try_get("reference").map_err(to_persistence_err)?, + execution_pointers: pointers, + next_execution: row.try_get("next_execution").map_err(to_persistence_err)?, + status, + data, + create_time: string_to_dt(&create_time_str)?, + complete_time: string_to_opt_dt(&complete_time_str)?, + }) +} + +fn row_to_pointer( + row: &sqlx::sqlite::SqliteRow, +) -> std::result::Result { + let status_str: String = row.try_get("status").map_err(to_persistence_err)?; + let status: wfe_core::models::PointerStatus = + serde_json::from_str(&format!("\"{status_str}\"")).map_err(|e| { + WfeError::Persistence(format!("Failed to deserialize PointerStatus: {e}")) + })?; + + let persistence_data_str: Option = row + .try_get("persistence_data") + .map_err(to_persistence_err)?; + let persistence_data: Option = persistence_data_str + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize persistence_data: {e}")))?; + + let event_data_str: Option = + row.try_get("event_data").map_err(to_persistence_err)?; + let event_data: Option = event_data_str + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize event_data: {e}")))?; + + let context_item_str: Option = + row.try_get("context_item").map_err(to_persistence_err)?; + let context_item: Option = context_item_str + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize context_item: {e}")))?; + + let outcome_str: Option = row.try_get("outcome").map_err(to_persistence_err)?; + let outcome: Option = outcome_str + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize outcome: {e}")))?; + + let children_str: String = row.try_get("children").map_err(to_persistence_err)?; + let children: Vec = serde_json::from_str(&children_str) + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize children: {e}")))?; + + let scope_str: String = row.try_get("scope").map_err(to_persistence_err)?; + let scope: Vec = serde_json::from_str(&scope_str) + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize scope: {e}")))?; + + let ext_str: String = row + .try_get("extension_attributes") + .map_err(to_persistence_err)?; + let extension_attributes: HashMap = + serde_json::from_str(&ext_str).map_err(|e| { + WfeError::Persistence(format!("Failed to deserialize extension_attributes: {e}")) + })?; + + let sleep_until_str: Option = + row.try_get("sleep_until").map_err(to_persistence_err)?; + let start_time_str: Option = + row.try_get("start_time").map_err(to_persistence_err)?; + let end_time_str: Option = row.try_get("end_time").map_err(to_persistence_err)?; + + Ok(ExecutionPointer { + id: row.try_get("id").map_err(to_persistence_err)?, + step_id: row + .try_get::("step_id") + .map_err(to_persistence_err)? as usize, + active: row + .try_get::("active") + .map_err(to_persistence_err)?, + status, + sleep_until: string_to_opt_dt(&sleep_until_str)?, + persistence_data, + start_time: string_to_opt_dt(&start_time_str)?, + end_time: string_to_opt_dt(&end_time_str)?, + event_name: row.try_get("event_name").map_err(to_persistence_err)?, + event_key: row.try_get("event_key").map_err(to_persistence_err)?, + event_published: row + .try_get::("event_published") + .map_err(to_persistence_err)?, + event_data, + step_name: row.try_get("step_name").map_err(to_persistence_err)?, + retry_count: row + .try_get::("retry_count") + .map_err(to_persistence_err)? as u32, + children, + context_item, + predecessor_id: row.try_get("predecessor_id").map_err(to_persistence_err)?, + outcome, + scope, + extension_attributes, + }) +} + +fn row_to_event(row: &sqlx::sqlite::SqliteRow) -> std::result::Result { + let event_data_str: String = row.try_get("event_data").map_err(to_persistence_err)?; + let event_data: serde_json::Value = serde_json::from_str(&event_data_str) + .map_err(|e| WfeError::Persistence(format!("Failed to deserialize event_data: {e}")))?; + + let event_time_str: String = row.try_get("event_time").map_err(to_persistence_err)?; + + Ok(Event { + id: row.try_get("id").map_err(to_persistence_err)?, + event_name: row.try_get("event_name").map_err(to_persistence_err)?, + event_key: row.try_get("event_key").map_err(to_persistence_err)?, + event_data, + event_time: string_to_dt(&event_time_str)?, + is_processed: row + .try_get::("is_processed") + .map_err(to_persistence_err)?, + }) +} + +fn row_to_subscription( + row: &sqlx::sqlite::SqliteRow, +) -> std::result::Result { + let subscribe_as_of_str: String = + row.try_get("subscribe_as_of").map_err(to_persistence_err)?; + + let subscription_data_str: Option = row + .try_get("subscription_data") + .map_err(to_persistence_err)?; + let subscription_data: Option = subscription_data_str + .as_deref() + .map(serde_json::from_str) + .transpose() + .map_err(|e| { + WfeError::Persistence(format!("Failed to deserialize subscription_data: {e}")) + })?; + + let external_token_expiry_str: Option = row + .try_get("external_token_expiry") + .map_err(to_persistence_err)?; + + Ok(EventSubscription { + id: row.try_get("id").map_err(to_persistence_err)?, + workflow_id: row.try_get("workflow_id").map_err(to_persistence_err)?, + step_id: row + .try_get::("step_id") + .map_err(to_persistence_err)? as usize, + execution_pointer_id: row + .try_get("execution_pointer_id") + .map_err(to_persistence_err)?, + event_name: row.try_get("event_name").map_err(to_persistence_err)?, + event_key: row.try_get("event_key").map_err(to_persistence_err)?, + subscribe_as_of: string_to_dt(&subscribe_as_of_str)?, + subscription_data, + external_token: row.try_get("external_token").map_err(to_persistence_err)?, + external_worker_id: row + .try_get("external_worker_id") + .map_err(to_persistence_err)?, + external_token_expiry: string_to_opt_dt(&external_token_expiry_str)?, + }) +} + +// ─── Trait implementations ───────────────────────────────────────────────── + +#[async_trait] +impl WorkflowRepository for SqlitePersistenceProvider { + async fn create_new_workflow(&self, instance: &WorkflowInstance) -> Result { + let id = if instance.id.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + instance.id.clone() + }; + + let status_str = serde_json::to_value(instance.status) + .map_err(|e| WfeError::Persistence(e.to_string()))? + .as_str() + .unwrap_or("Runnable") + .to_string(); + let data_str = serde_json::to_string(&instance.data) + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let create_time_str = dt_to_string(&instance.create_time); + let complete_time_str = opt_dt_to_string(&instance.complete_time); + + let mut tx = self.pool.begin().await.map_err(to_persistence_err)?; + + sqlx::query( + "INSERT INTO workflows (id, definition_id, version, description, reference, status, data, next_execution, create_time, complete_time) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", + ) + .bind(&id) + .bind(&instance.workflow_definition_id) + .bind(instance.version as i64) + .bind(&instance.description) + .bind(&instance.reference) + .bind(&status_str) + .bind(&data_str) + .bind(instance.next_execution) + .bind(&create_time_str) + .bind(&complete_time_str) + .execute(&mut *tx) + .await + .map_err(to_persistence_err)?; + + for ptr in &instance.execution_pointers { + insert_pointer(&mut tx, &id, ptr).await?; + } + + tx.commit().await.map_err(to_persistence_err)?; + Ok(id) + } + + async fn persist_workflow(&self, instance: &WorkflowInstance) -> Result<()> { + let status_str = serde_json::to_value(instance.status) + .map_err(|e| WfeError::Persistence(e.to_string()))? + .as_str() + .unwrap_or("Runnable") + .to_string(); + let data_str = serde_json::to_string(&instance.data) + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let complete_time_str = opt_dt_to_string(&instance.complete_time); + + let mut tx = self.pool.begin().await.map_err(to_persistence_err)?; + + sqlx::query( + "UPDATE workflows SET definition_id = ?1, version = ?2, description = ?3, reference = ?4, + status = ?5, data = ?6, next_execution = ?7, complete_time = ?8 + WHERE id = ?9", + ) + .bind(&instance.workflow_definition_id) + .bind(instance.version as i64) + .bind(&instance.description) + .bind(&instance.reference) + .bind(&status_str) + .bind(&data_str) + .bind(instance.next_execution) + .bind(&complete_time_str) + .bind(&instance.id) + .execute(&mut *tx) + .await + .map_err(to_persistence_err)?; + + // Replace all pointers + sqlx::query("DELETE FROM execution_pointers WHERE workflow_id = ?1") + .bind(&instance.id) + .execute(&mut *tx) + .await + .map_err(to_persistence_err)?; + + for ptr in &instance.execution_pointers { + insert_pointer(&mut tx, &instance.id, ptr).await?; + } + + tx.commit().await.map_err(to_persistence_err)?; + Ok(()) + } + + async fn persist_workflow_with_subscriptions( + &self, + instance: &WorkflowInstance, + subscriptions: &[EventSubscription], + ) -> Result<()> { + let status_str = serde_json::to_value(instance.status) + .map_err(|e| WfeError::Persistence(e.to_string()))? + .as_str() + .unwrap_or("Runnable") + .to_string(); + let data_str = serde_json::to_string(&instance.data) + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let complete_time_str = opt_dt_to_string(&instance.complete_time); + + let mut tx = self.pool.begin().await.map_err(to_persistence_err)?; + + sqlx::query( + "UPDATE workflows SET definition_id = ?1, version = ?2, description = ?3, reference = ?4, + status = ?5, data = ?6, next_execution = ?7, complete_time = ?8 + WHERE id = ?9", + ) + .bind(&instance.workflow_definition_id) + .bind(instance.version as i64) + .bind(&instance.description) + .bind(&instance.reference) + .bind(&status_str) + .bind(&data_str) + .bind(instance.next_execution) + .bind(&complete_time_str) + .bind(&instance.id) + .execute(&mut *tx) + .await + .map_err(to_persistence_err)?; + + sqlx::query("DELETE FROM execution_pointers WHERE workflow_id = ?1") + .bind(&instance.id) + .execute(&mut *tx) + .await + .map_err(to_persistence_err)?; + + for ptr in &instance.execution_pointers { + insert_pointer(&mut tx, &instance.id, ptr).await?; + } + + for sub in subscriptions { + insert_subscription(&mut tx, sub).await?; + } + + tx.commit().await.map_err(to_persistence_err)?; + Ok(()) + } + + async fn get_runnable_instances(&self, as_at: DateTime) -> Result> { + let as_at_millis = as_at.timestamp_millis(); + let rows = sqlx::query( + "SELECT id FROM workflows WHERE status = 'Runnable' AND next_execution IS NOT NULL AND next_execution <= ?1", + ) + .bind(as_at_millis) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + let ids = rows + .iter() + .map(|r| r.try_get("id").map_err(to_persistence_err)) + .collect::>>()?; + Ok(ids) + } + + async fn get_workflow_instance(&self, id: &str) -> Result { + let row = sqlx::query("SELECT * FROM workflows WHERE id = ?1") + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(to_persistence_err)? + .ok_or_else(|| WfeError::WorkflowNotFound(id.to_string()))?; + + let pointer_rows = + sqlx::query("SELECT * FROM execution_pointers WHERE workflow_id = ?1") + .bind(id) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + let pointers = pointer_rows + .iter() + .map(row_to_pointer) + .collect::>>()?; + + row_to_workflow(&row, pointers) + } + + async fn get_workflow_instances(&self, ids: &[String]) -> Result> { + if ids.is_empty() { + return Ok(Vec::new()); + } + + let mut result = Vec::with_capacity(ids.len()); + for id in ids { + match self.get_workflow_instance(id).await { + Ok(w) => result.push(w), + Err(WfeError::WorkflowNotFound(_)) => continue, + Err(e) => return Err(e), + } + } + Ok(result) + } +} + +async fn insert_pointer( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + workflow_id: &str, + ptr: &ExecutionPointer, +) -> Result<()> { + let status_str = serde_json::to_value(ptr.status) + .map_err(|e| WfeError::Persistence(e.to_string()))? + .as_str() + .unwrap_or("Pending") + .to_string(); + let persistence_data_str = ptr + .persistence_data + .as_ref() + .map(serde_json::to_string) + .transpose() + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let event_data_str = ptr + .event_data + .as_ref() + .map(serde_json::to_string) + .transpose() + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let context_item_str = ptr + .context_item + .as_ref() + .map(serde_json::to_string) + .transpose() + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let outcome_str = ptr + .outcome + .as_ref() + .map(serde_json::to_string) + .transpose() + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let children_str = + serde_json::to_string(&ptr.children).map_err(|e| WfeError::Persistence(e.to_string()))?; + let scope_str = + serde_json::to_string(&ptr.scope).map_err(|e| WfeError::Persistence(e.to_string()))?; + let ext_str = serde_json::to_string(&ptr.extension_attributes) + .map_err(|e| WfeError::Persistence(e.to_string()))?; + + let sleep_until_str = opt_dt_to_string(&ptr.sleep_until); + let start_time_str = opt_dt_to_string(&ptr.start_time); + let end_time_str = opt_dt_to_string(&ptr.end_time); + + sqlx::query( + "INSERT INTO execution_pointers + (id, workflow_id, step_id, active, status, sleep_until, persistence_data, start_time, + end_time, event_name, event_key, event_published, event_data, step_name, retry_count, + children, context_item, predecessor_id, outcome, scope, extension_attributes) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21)", + ) + .bind(&ptr.id) + .bind(workflow_id) + .bind(ptr.step_id as i64) + .bind(ptr.active) + .bind(&status_str) + .bind(&sleep_until_str) + .bind(&persistence_data_str) + .bind(&start_time_str) + .bind(&end_time_str) + .bind(&ptr.event_name) + .bind(&ptr.event_key) + .bind(ptr.event_published) + .bind(&event_data_str) + .bind(&ptr.step_name) + .bind(ptr.retry_count as i64) + .bind(&children_str) + .bind(&context_item_str) + .bind(&ptr.predecessor_id) + .bind(&outcome_str) + .bind(&scope_str) + .bind(&ext_str) + .execute(&mut **tx) + .await + .map_err(to_persistence_err)?; + + Ok(()) +} + +async fn insert_subscription( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + sub: &EventSubscription, +) -> Result<()> { + let subscribe_as_of_str = dt_to_string(&sub.subscribe_as_of); + let subscription_data_str = sub + .subscription_data + .as_ref() + .map(serde_json::to_string) + .transpose() + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let external_token_expiry_str = opt_dt_to_string(&sub.external_token_expiry); + + sqlx::query( + "INSERT INTO event_subscriptions + (id, workflow_id, step_id, execution_pointer_id, event_name, event_key, + subscribe_as_of, subscription_data, external_token, external_worker_id, + external_token_expiry, terminated) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, 0)", + ) + .bind(&sub.id) + .bind(&sub.workflow_id) + .bind(sub.step_id as i64) + .bind(&sub.execution_pointer_id) + .bind(&sub.event_name) + .bind(&sub.event_key) + .bind(&subscribe_as_of_str) + .bind(&subscription_data_str) + .bind(&sub.external_token) + .bind(&sub.external_worker_id) + .bind(&external_token_expiry_str) + .execute(&mut **tx) + .await + .map_err(to_persistence_err)?; + + Ok(()) +} + +#[async_trait] +impl SubscriptionRepository for SqlitePersistenceProvider { + async fn create_event_subscription( + &self, + subscription: &EventSubscription, + ) -> Result { + let id = if subscription.id.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + subscription.id.clone() + }; + + let mut stored = subscription.clone(); + stored.id = id.clone(); + + let mut tx = self.pool.begin().await.map_err(to_persistence_err)?; + insert_subscription(&mut tx, &stored).await?; + tx.commit().await.map_err(to_persistence_err)?; + Ok(id) + } + + async fn get_subscriptions( + &self, + event_name: &str, + event_key: &str, + as_of: DateTime, + ) -> Result> { + let as_of_str = dt_to_string(&as_of); + let rows = sqlx::query( + "SELECT * FROM event_subscriptions + WHERE event_name = ?1 AND event_key = ?2 AND subscribe_as_of <= ?3 AND terminated = 0", + ) + .bind(event_name) + .bind(event_key) + .bind(&as_of_str) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + rows.iter().map(row_to_subscription).collect() + } + + async fn terminate_subscription(&self, subscription_id: &str) -> Result<()> { + let result = sqlx::query( + "UPDATE event_subscriptions SET terminated = 1 WHERE id = ?1", + ) + .bind(subscription_id) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + if result.rows_affected() == 0 { + return Err(WfeError::SubscriptionNotFound( + subscription_id.to_string(), + )); + } + Ok(()) + } + + async fn get_subscription(&self, subscription_id: &str) -> Result { + let row = sqlx::query("SELECT * FROM event_subscriptions WHERE id = ?1") + .bind(subscription_id) + .fetch_optional(&self.pool) + .await + .map_err(to_persistence_err)? + .ok_or_else(|| WfeError::SubscriptionNotFound(subscription_id.to_string()))?; + + row_to_subscription(&row) + } + + async fn get_first_open_subscription( + &self, + event_name: &str, + event_key: &str, + as_of: DateTime, + ) -> Result> { + let as_of_str = dt_to_string(&as_of); + let row = sqlx::query( + "SELECT * FROM event_subscriptions + WHERE event_name = ?1 AND event_key = ?2 AND subscribe_as_of <= ?3 + AND terminated = 0 AND external_token IS NULL + LIMIT 1", + ) + .bind(event_name) + .bind(event_key) + .bind(&as_of_str) + .fetch_optional(&self.pool) + .await + .map_err(to_persistence_err)?; + + match row { + Some(r) => Ok(Some(row_to_subscription(&r)?)), + None => Ok(None), + } + } + + async fn set_subscription_token( + &self, + subscription_id: &str, + token: &str, + worker_id: &str, + expiry: DateTime, + ) -> Result { + let expiry_str = dt_to_string(&expiry); + + // Only set token if external_token is currently NULL + let result = sqlx::query( + "UPDATE event_subscriptions + SET external_token = ?1, external_worker_id = ?2, external_token_expiry = ?3 + WHERE id = ?4 AND external_token IS NULL", + ) + .bind(token) + .bind(worker_id) + .bind(&expiry_str) + .bind(subscription_id) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + if result.rows_affected() == 0 { + // Check if the subscription exists at all + let exists = sqlx::query("SELECT 1 FROM event_subscriptions WHERE id = ?1") + .bind(subscription_id) + .fetch_optional(&self.pool) + .await + .map_err(to_persistence_err)?; + if exists.is_none() { + return Err(WfeError::SubscriptionNotFound( + subscription_id.to_string(), + )); + } + return Ok(false); + } + Ok(true) + } + + async fn clear_subscription_token( + &self, + subscription_id: &str, + token: &str, + ) -> Result<()> { + let result = sqlx::query( + "UPDATE event_subscriptions + SET external_token = NULL, external_worker_id = NULL, external_token_expiry = NULL + WHERE id = ?1 AND external_token = ?2", + ) + .bind(subscription_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + if result.rows_affected() == 0 { + return Err(WfeError::SubscriptionNotFound( + subscription_id.to_string(), + )); + } + Ok(()) + } +} + +#[async_trait] +impl EventRepository for SqlitePersistenceProvider { + async fn create_event(&self, event: &Event) -> Result { + let id = if event.id.is_empty() { + uuid::Uuid::new_v4().to_string() + } else { + event.id.clone() + }; + + let event_data_str = serde_json::to_string(&event.event_data) + .map_err(|e| WfeError::Persistence(e.to_string()))?; + let event_time_str = dt_to_string(&event.event_time); + + sqlx::query( + "INSERT INTO events (id, event_name, event_key, event_data, event_time, is_processed) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + ) + .bind(&id) + .bind(&event.event_name) + .bind(&event.event_key) + .bind(&event_data_str) + .bind(&event_time_str) + .bind(event.is_processed) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + Ok(id) + } + + async fn get_event(&self, id: &str) -> Result { + let row = sqlx::query("SELECT * FROM events WHERE id = ?1") + .bind(id) + .fetch_optional(&self.pool) + .await + .map_err(to_persistence_err)? + .ok_or_else(|| WfeError::EventNotFound(id.to_string()))?; + + row_to_event(&row) + } + + async fn get_runnable_events(&self, as_at: DateTime) -> Result> { + let as_at_str = dt_to_string(&as_at); + let rows = sqlx::query( + "SELECT id FROM events WHERE is_processed = 0 AND event_time <= ?1", + ) + .bind(&as_at_str) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + rows.iter() + .map(|r| r.try_get("id").map_err(to_persistence_err)) + .collect() + } + + async fn get_events( + &self, + event_name: &str, + event_key: &str, + as_of: DateTime, + ) -> Result> { + let as_of_str = dt_to_string(&as_of); + let rows = sqlx::query( + "SELECT id FROM events WHERE event_name = ?1 AND event_key = ?2 AND event_time <= ?3", + ) + .bind(event_name) + .bind(event_key) + .bind(&as_of_str) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + rows.iter() + .map(|r| r.try_get("id").map_err(to_persistence_err)) + .collect() + } + + async fn mark_event_processed(&self, id: &str) -> Result<()> { + let result = sqlx::query("UPDATE events SET is_processed = 1 WHERE id = ?1") + .bind(id) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + if result.rows_affected() == 0 { + return Err(WfeError::EventNotFound(id.to_string())); + } + Ok(()) + } + + async fn mark_event_unprocessed(&self, id: &str) -> Result<()> { + let result = sqlx::query("UPDATE events SET is_processed = 0 WHERE id = ?1") + .bind(id) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + if result.rows_affected() == 0 { + return Err(WfeError::EventNotFound(id.to_string())); + } + Ok(()) + } +} + +#[async_trait] +impl ScheduledCommandRepository for SqlitePersistenceProvider { + fn supports_scheduled_commands(&self) -> bool { + true + } + + async fn schedule_command(&self, command: &ScheduledCommand) -> Result<()> { + let command_name_str = serde_json::to_value(&command.command_name) + .map_err(|e| WfeError::Persistence(e.to_string()))? + .as_str() + .unwrap_or("") + .to_string(); + + sqlx::query( + "INSERT OR IGNORE INTO scheduled_commands (command_name, data, execute_time) + VALUES (?1, ?2, ?3)", + ) + .bind(&command_name_str) + .bind(&command.data) + .bind(command.execute_time) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + + Ok(()) + } + + async fn process_commands( + &self, + as_of: DateTime, + handler: &(dyn Fn(ScheduledCommand) -> std::pin::Pin> + Send>> + + Send + + Sync), + ) -> Result<()> { + let as_of_millis = as_of.timestamp_millis(); + + // Fetch due commands + let rows = sqlx::query( + "SELECT id, command_name, data, execute_time FROM scheduled_commands WHERE execute_time <= ?1", + ) + .bind(as_of_millis) + .fetch_all(&self.pool) + .await + .map_err(to_persistence_err)?; + + let mut commands: Vec<(i64, ScheduledCommand)> = Vec::new(); + for row in &rows { + let db_id: i64 = row.try_get("id").map_err(to_persistence_err)?; + let command_name_str: String = + row.try_get("command_name").map_err(to_persistence_err)?; + let command_name: CommandName = + serde_json::from_str(&format!("\"{command_name_str}\"")).map_err(|e| { + WfeError::Persistence(format!("Failed to deserialize CommandName: {e}")) + })?; + let data: String = row.try_get("data").map_err(to_persistence_err)?; + let execute_time: i64 = row.try_get("execute_time").map_err(to_persistence_err)?; + + commands.push(( + db_id, + ScheduledCommand { + command_name, + data, + execute_time, + }, + )); + } + + // Process each command then delete it + for (db_id, cmd) in commands { + handler(cmd).await?; + sqlx::query("DELETE FROM scheduled_commands WHERE id = ?1") + .bind(db_id) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + } + Ok(()) + } +} + +#[async_trait] +impl PersistenceProvider for SqlitePersistenceProvider { + async fn persist_errors(&self, errors: &[ExecutionError]) -> Result<()> { + for error in errors { + let error_time_str = dt_to_string(&error.error_time); + sqlx::query( + "INSERT INTO execution_errors (error_time, workflow_id, execution_pointer_id, message) + VALUES (?1, ?2, ?3, ?4)", + ) + .bind(&error_time_str) + .bind(&error.workflow_id) + .bind(&error.execution_pointer_id) + .bind(&error.message) + .execute(&self.pool) + .await + .map_err(to_persistence_err)?; + } + Ok(()) + } + + async fn ensure_store_exists(&self) -> Result<()> { + self.create_tables() + .await + .map_err(|e| WfeError::Persistence(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn schema_creation_idempotent() { + let provider = SqlitePersistenceProvider::new(":memory:").await.unwrap(); + // Call ensure_store_exists again — should not fail + provider.ensure_store_exists().await.unwrap(); + provider.ensure_store_exists().await.unwrap(); + } +} diff --git a/wfe-sqlite/tests/persistence.rs b/wfe-sqlite/tests/persistence.rs new file mode 100644 index 0000000..f5f40ca --- /dev/null +++ b/wfe-sqlite/tests/persistence.rs @@ -0,0 +1,9 @@ +use wfe_core::persistence_suite; + +async fn make_provider() -> wfe_sqlite::SqlitePersistenceProvider { + wfe_sqlite::SqlitePersistenceProvider::new(":memory:") + .await + .unwrap() +} + +persistence_suite!(|| async { make_provider().await });