diff --git a/wfe-core/src/executor/workflow_executor.rs b/wfe-core/src/executor/workflow_executor.rs index 11378d2..a6a3b41 100644 --- a/wfe-core/src/executor/workflow_executor.rs +++ b/wfe-core/src/executor/workflow_executor.rs @@ -7,7 +7,7 @@ use super::error_handler; use super::result_processor; use super::step_registry::StepRegistry; use crate::models::{ - ExecutionError, PointerStatus, QueueType, WorkflowDefinition, WorkflowStatus, + Event, ExecutionError, PointerStatus, QueueType, WorkflowDefinition, WorkflowStatus, }; use crate::traits::{ DistributedLockProvider, LifecyclePublisher, PersistenceProvider, QueueProvider, SearchIndex, @@ -61,7 +61,7 @@ impl WorkflowExecutor { /// 8. Release lock #[tracing::instrument( name = "workflow.execute", - skip(self, definition, step_registry), + skip(self, definition, step_registry, host_context), fields( workflow.id = %workflow_id, workflow.definition_id, @@ -73,6 +73,7 @@ impl WorkflowExecutor { workflow_id: &str, definition: &WorkflowDefinition, step_registry: &StepRegistry, + host_context: Option<&dyn crate::traits::HostContext>, ) -> Result<()> { // 1. Acquire distributed lock. let acquired = self.lock_provider.acquire_lock(workflow_id).await?; @@ -82,7 +83,7 @@ impl WorkflowExecutor { } let result = self - .execute_inner(workflow_id, definition, step_registry) + .execute_inner(workflow_id, definition, step_registry, host_context) .await; // 7. Release lock (always). @@ -98,6 +99,7 @@ impl WorkflowExecutor { workflow_id: &str, definition: &WorkflowDefinition, step_registry: &StepRegistry, + host_context: Option<&dyn crate::traits::HostContext>, ) -> Result<()> { // 2. Load workflow instance. let mut workflow = self @@ -173,6 +175,7 @@ impl WorkflowExecutor { step, workflow: &workflow, cancellation_token, + host_context, }; // d. Call step.run(context). @@ -280,6 +283,18 @@ impl WorkflowExecutor { info!(workflow_id, "All pointers complete, workflow finished"); workflow.status = WorkflowStatus::Complete; workflow.complete_time = Some(Utc::now()); + + // Publish completion event for SubWorkflow parents. + let completion_event = Event::new( + "wfe.workflow.completed", + workflow_id, + serde_json::json!({ "status": "Complete", "data": workflow.data }), + ); + let _ = self.persistence.create_event(&completion_event).await; + let _ = self + .queue_provider + .queue_work(&completion_event.id, QueueType::Event) + .await; } tracing::Span::current().record("workflow.status", tracing::field::debug(&workflow.status)); @@ -573,7 +588,7 @@ mod tests { persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Complete); @@ -604,7 +619,7 @@ mod tests { persistence.create_new_workflow(&instance).await.unwrap(); // First execution: step 0 completes, step 1 pointer created. - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers.len(), 2); @@ -613,7 +628,7 @@ mod tests { assert_eq!(updated.execution_pointers[1].step_id, 1); // Second execution: step 1 completes. - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Complete); @@ -644,7 +659,7 @@ mod tests { // Execute three times for three steps. for _ in 0..3 { - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); } let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); @@ -684,7 +699,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers.len(), 2); @@ -707,7 +722,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Runnable); @@ -733,7 +748,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].status, PointerStatus::Sleeping); @@ -756,7 +771,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!( @@ -796,7 +811,7 @@ mod tests { instance.execution_pointers.push(pointer); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].status, PointerStatus::Complete); @@ -822,7 +837,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); // 1 original + 3 children. @@ -858,7 +873,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].retry_count, 1); @@ -884,7 +899,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Suspended); @@ -908,7 +923,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Terminated); @@ -936,7 +951,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].status, PointerStatus::Failed); @@ -964,7 +979,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(1)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Complete); @@ -999,7 +1014,7 @@ mod tests { persistence.create_new_workflow(&instance).await.unwrap(); // Should not error on a completed workflow. - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Complete); @@ -1024,7 +1039,7 @@ mod tests { instance.execution_pointers.push(pointer); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); // Should still be sleeping since sleep_until is in the future. @@ -1048,7 +1063,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let errors = persistence.get_errors().await; assert_eq!(errors.len(), 1); @@ -1072,7 +1087,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); // Executor itself doesn't publish lifecycle events in the current implementation, // but the with_lifecycle builder works correctly. @@ -1097,7 +1112,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Terminated); @@ -1118,7 +1133,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert!(updated.execution_pointers[0].start_time.is_some()); @@ -1148,7 +1163,7 @@ mod tests { instance.execution_pointers.push(ExecutionPointer::new(0)); persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!( @@ -1203,13 +1218,13 @@ mod tests { persistence.create_new_workflow(&instance).await.unwrap(); // First execution: fails, retry scheduled. - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].retry_count, 1); assert_eq!(updated.execution_pointers[0].status, PointerStatus::Sleeping); // Second execution: succeeds (sleep_until is in the past with 0ms interval). - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.execution_pointers[0].status, PointerStatus::Complete); assert_eq!(updated.status, WorkflowStatus::Complete); @@ -1227,7 +1242,7 @@ mod tests { // No execution pointers at all. persistence.create_new_workflow(&instance).await.unwrap(); - executor.execute(&instance.id, &def, ®istry).await.unwrap(); + executor.execute(&instance.id, &def, ®istry, None).await.unwrap(); let updated = persistence.get_workflow_instance(&instance.id).await.unwrap(); assert_eq!(updated.status, WorkflowStatus::Runnable); diff --git a/wfe-core/src/models/mod.rs b/wfe-core/src/models/mod.rs index 169099f..070f61c 100644 --- a/wfe-core/src/models/mod.rs +++ b/wfe-core/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod lifecycle; pub mod poll_config; pub mod queue_type; pub mod scheduled_command; +pub mod schema; pub mod status; pub mod workflow_definition; pub mod workflow_instance; @@ -20,6 +21,7 @@ pub use lifecycle::{LifecycleEvent, LifecycleEventType}; pub use poll_config::{HttpMethod, PollCondition, PollEndpointConfig}; pub use queue_type::QueueType; pub use scheduled_command::{CommandName, ScheduledCommand}; +pub use schema::{SchemaType, WorkflowSchema}; pub use status::{PointerStatus, WorkflowStatus}; pub use workflow_definition::{StepOutcome, WorkflowDefinition, WorkflowStep}; pub use workflow_instance::WorkflowInstance; diff --git a/wfe-core/src/models/schema.rs b/wfe-core/src/models/schema.rs new file mode 100644 index 0000000..c3e0010 --- /dev/null +++ b/wfe-core/src/models/schema.rs @@ -0,0 +1,483 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +/// Describes a single type in the workflow schema type system. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SchemaType { + String, + Number, + Integer, + Bool, + Optional(Box), + List(Box), + Map(Box), + Any, +} + +/// Defines expected input and output schemas for a workflow. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct WorkflowSchema { + #[serde(default)] + pub inputs: HashMap, + #[serde(default)] + pub outputs: HashMap, +} + +/// Parse a type string into a [`SchemaType`]. +/// +/// Supported formats: +/// - `"string"`, `"number"`, `"integer"`, `"bool"`, `"any"` +/// - `"string?"` (optional) +/// - `"list"`, `"map"` (generic containers) +/// - Nested: `"list>"` +pub fn parse_type(s: &str) -> crate::Result { + let s = s.trim(); + + // Handle optional suffix. + if let Some(inner) = s.strip_suffix('?') { + let inner_type = parse_type(inner)?; + return Ok(SchemaType::Optional(Box::new(inner_type))); + } + + // Handle generic containers: list<...> and map<...>. + if let Some(rest) = s.strip_prefix("list<") { + let inner = rest + .strip_suffix('>') + .ok_or_else(|| crate::WfeError::StepExecution(format!("Invalid type syntax: {s}")))?; + let inner_type = parse_type(inner)?; + return Ok(SchemaType::List(Box::new(inner_type))); + } + if let Some(rest) = s.strip_prefix("map<") { + let inner = rest + .strip_suffix('>') + .ok_or_else(|| crate::WfeError::StepExecution(format!("Invalid type syntax: {s}")))?; + let inner_type = parse_type(inner)?; + return Ok(SchemaType::Map(Box::new(inner_type))); + } + + // Primitive types. + match s { + "string" => Ok(SchemaType::String), + "number" => Ok(SchemaType::Number), + "integer" => Ok(SchemaType::Integer), + "bool" => Ok(SchemaType::Bool), + "any" => Ok(SchemaType::Any), + _ => Err(crate::WfeError::StepExecution(format!( + "Unknown type: {s}" + ))), + } +} + +/// Validate that a JSON value matches the expected [`SchemaType`]. +pub fn validate_value(value: &serde_json::Value, expected: &SchemaType) -> Result<(), String> { + match expected { + SchemaType::String => { + if value.is_string() { + Ok(()) + } else { + Err(format!("expected string, got {}", value_type_name(value))) + } + } + SchemaType::Number => { + if value.is_number() { + Ok(()) + } else { + Err(format!("expected number, got {}", value_type_name(value))) + } + } + SchemaType::Integer => { + if value.is_i64() || value.is_u64() { + Ok(()) + } else { + Err(format!("expected integer, got {}", value_type_name(value))) + } + } + SchemaType::Bool => { + if value.is_boolean() { + Ok(()) + } else { + Err(format!("expected bool, got {}", value_type_name(value))) + } + } + SchemaType::Optional(inner) => { + if value.is_null() { + Ok(()) + } else { + validate_value(value, inner) + } + } + SchemaType::List(inner) => { + if let Some(arr) = value.as_array() { + for (i, item) in arr.iter().enumerate() { + validate_value(item, inner) + .map_err(|e| format!("list element [{i}]: {e}"))?; + } + Ok(()) + } else { + Err(format!("expected list, got {}", value_type_name(value))) + } + } + SchemaType::Map(inner) => { + if let Some(obj) = value.as_object() { + for (key, val) in obj { + validate_value(val, inner) + .map_err(|e| format!("map key \"{key}\": {e}"))?; + } + Ok(()) + } else { + Err(format!("expected map, got {}", value_type_name(value))) + } + } + SchemaType::Any => Ok(()), + } +} + +fn value_type_name(value: &serde_json::Value) -> &'static str { + match value { + serde_json::Value::Null => "null", + serde_json::Value::Bool(_) => "bool", + serde_json::Value::Number(_) => "number", + serde_json::Value::String(_) => "string", + serde_json::Value::Array(_) => "array", + serde_json::Value::Object(_) => "object", + } +} + +impl WorkflowSchema { + /// Validate that the given data satisfies all input field requirements. + pub fn validate_inputs(&self, data: &serde_json::Value) -> Result<(), Vec> { + self.validate_fields(&self.inputs, data) + } + + /// Validate that the given data satisfies all output field requirements. + pub fn validate_outputs(&self, data: &serde_json::Value) -> Result<(), Vec> { + self.validate_fields(&self.outputs, data) + } + + fn validate_fields( + &self, + fields: &HashMap, + data: &serde_json::Value, + ) -> Result<(), Vec> { + let obj = match data.as_object() { + Some(o) => o, + None => { + return Err(vec!["expected an object".to_string()]); + } + }; + + let mut errors = Vec::new(); + + for (name, schema_type) in fields { + match obj.get(name) { + Some(value) => { + if let Err(e) = validate_value(value, schema_type) { + errors.push(format!("field \"{name}\": {e}")); + } + } + None => { + // Missing field is OK for optional types (null is acceptable). + if !matches!(schema_type, SchemaType::Optional(_)) { + errors.push(format!("missing required field: \"{name}\"")); + } + } + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // -- parse_type tests -- + + #[test] + fn parse_type_string() { + assert_eq!(parse_type("string").unwrap(), SchemaType::String); + } + + #[test] + fn parse_type_number() { + assert_eq!(parse_type("number").unwrap(), SchemaType::Number); + } + + #[test] + fn parse_type_integer() { + assert_eq!(parse_type("integer").unwrap(), SchemaType::Integer); + } + + #[test] + fn parse_type_bool() { + assert_eq!(parse_type("bool").unwrap(), SchemaType::Bool); + } + + #[test] + fn parse_type_any() { + assert_eq!(parse_type("any").unwrap(), SchemaType::Any); + } + + #[test] + fn parse_type_optional_string() { + assert_eq!( + parse_type("string?").unwrap(), + SchemaType::Optional(Box::new(SchemaType::String)) + ); + } + + #[test] + fn parse_type_optional_number() { + assert_eq!( + parse_type("number?").unwrap(), + SchemaType::Optional(Box::new(SchemaType::Number)) + ); + } + + #[test] + fn parse_type_list_string() { + assert_eq!( + parse_type("list").unwrap(), + SchemaType::List(Box::new(SchemaType::String)) + ); + } + + #[test] + fn parse_type_list_number() { + assert_eq!( + parse_type("list").unwrap(), + SchemaType::List(Box::new(SchemaType::Number)) + ); + } + + #[test] + fn parse_type_map_string() { + assert_eq!( + parse_type("map").unwrap(), + SchemaType::Map(Box::new(SchemaType::String)) + ); + } + + #[test] + fn parse_type_map_number() { + assert_eq!( + parse_type("map").unwrap(), + SchemaType::Map(Box::new(SchemaType::Number)) + ); + } + + #[test] + fn parse_type_nested_list() { + assert_eq!( + parse_type("list>").unwrap(), + SchemaType::List(Box::new(SchemaType::List(Box::new(SchemaType::String)))) + ); + } + + #[test] + fn parse_type_unknown_errors() { + assert!(parse_type("foobar").is_err()); + } + + #[test] + fn parse_type_trims_whitespace() { + assert_eq!(parse_type(" string ").unwrap(), SchemaType::String); + } + + // -- validate_value tests -- + + #[test] + fn validate_string_match() { + assert!(validate_value(&json!("hello"), &SchemaType::String).is_ok()); + } + + #[test] + fn validate_string_mismatch() { + assert!(validate_value(&json!(42), &SchemaType::String).is_err()); + } + + #[test] + fn validate_number_match() { + assert!(validate_value(&json!(2.78), &SchemaType::Number).is_ok()); + } + + #[test] + fn validate_number_mismatch() { + assert!(validate_value(&json!("not a number"), &SchemaType::Number).is_err()); + } + + #[test] + fn validate_integer_match() { + assert!(validate_value(&json!(42), &SchemaType::Integer).is_ok()); + } + + #[test] + fn validate_integer_mismatch_float() { + assert!(validate_value(&json!(2.78), &SchemaType::Integer).is_err()); + } + + #[test] + fn validate_bool_match() { + assert!(validate_value(&json!(true), &SchemaType::Bool).is_ok()); + } + + #[test] + fn validate_bool_mismatch() { + assert!(validate_value(&json!(1), &SchemaType::Bool).is_err()); + } + + #[test] + fn validate_optional_null_passes() { + let ty = SchemaType::Optional(Box::new(SchemaType::String)); + assert!(validate_value(&json!(null), &ty).is_ok()); + } + + #[test] + fn validate_optional_correct_inner_passes() { + let ty = SchemaType::Optional(Box::new(SchemaType::String)); + assert!(validate_value(&json!("hello"), &ty).is_ok()); + } + + #[test] + fn validate_optional_wrong_inner_fails() { + let ty = SchemaType::Optional(Box::new(SchemaType::String)); + assert!(validate_value(&json!(42), &ty).is_err()); + } + + #[test] + fn validate_list_match() { + let ty = SchemaType::List(Box::new(SchemaType::Number)); + assert!(validate_value(&json!([1, 2, 3]), &ty).is_ok()); + } + + #[test] + fn validate_list_mismatch_element() { + let ty = SchemaType::List(Box::new(SchemaType::Number)); + assert!(validate_value(&json!([1, "two", 3]), &ty).is_err()); + } + + #[test] + fn validate_list_not_array() { + let ty = SchemaType::List(Box::new(SchemaType::Number)); + assert!(validate_value(&json!("not a list"), &ty).is_err()); + } + + #[test] + fn validate_map_match() { + let ty = SchemaType::Map(Box::new(SchemaType::Number)); + assert!(validate_value(&json!({"a": 1, "b": 2}), &ty).is_ok()); + } + + #[test] + fn validate_map_mismatch_value() { + let ty = SchemaType::Map(Box::new(SchemaType::Number)); + assert!(validate_value(&json!({"a": 1, "b": "two"}), &ty).is_err()); + } + + #[test] + fn validate_map_not_object() { + let ty = SchemaType::Map(Box::new(SchemaType::Number)); + assert!(validate_value(&json!([1, 2]), &ty).is_err()); + } + + #[test] + fn validate_any_always_passes() { + assert!(validate_value(&json!(null), &SchemaType::Any).is_ok()); + assert!(validate_value(&json!("str"), &SchemaType::Any).is_ok()); + assert!(validate_value(&json!(42), &SchemaType::Any).is_ok()); + assert!(validate_value(&json!([1, 2]), &SchemaType::Any).is_ok()); + } + + // -- WorkflowSchema validate_inputs / validate_outputs tests -- + + #[test] + fn validate_inputs_all_present() { + let schema = WorkflowSchema { + inputs: HashMap::from([ + ("name".into(), SchemaType::String), + ("age".into(), SchemaType::Integer), + ]), + outputs: HashMap::new(), + }; + let data = json!({"name": "Alice", "age": 30}); + assert!(schema.validate_inputs(&data).is_ok()); + } + + #[test] + fn validate_inputs_missing_required_field() { + let schema = WorkflowSchema { + inputs: HashMap::from([ + ("name".into(), SchemaType::String), + ("age".into(), SchemaType::Integer), + ]), + outputs: HashMap::new(), + }; + let data = json!({"name": "Alice"}); + let errs = schema.validate_inputs(&data).unwrap_err(); + assert!(errs.iter().any(|e| e.contains("age"))); + } + + #[test] + fn validate_inputs_wrong_type() { + let schema = WorkflowSchema { + inputs: HashMap::from([("count".into(), SchemaType::Integer)]), + outputs: HashMap::new(), + }; + let data = json!({"count": "not-a-number"}); + let errs = schema.validate_inputs(&data).unwrap_err(); + assert!(!errs.is_empty()); + } + + #[test] + fn validate_outputs_missing_field() { + let schema = WorkflowSchema { + inputs: HashMap::new(), + outputs: HashMap::from([("result".into(), SchemaType::String)]), + }; + let data = json!({}); + let errs = schema.validate_outputs(&data).unwrap_err(); + assert!(errs.iter().any(|e| e.contains("result"))); + } + + #[test] + fn validate_inputs_optional_field_missing_is_ok() { + let schema = WorkflowSchema { + inputs: HashMap::from([( + "nickname".into(), + SchemaType::Optional(Box::new(SchemaType::String)), + )]), + outputs: HashMap::new(), + }; + let data = json!({}); + assert!(schema.validate_inputs(&data).is_ok()); + } + + #[test] + fn validate_not_object_errors() { + let schema = WorkflowSchema { + inputs: HashMap::from([("x".into(), SchemaType::String)]), + outputs: HashMap::new(), + }; + let errs = schema.validate_inputs(&json!("not an object")).unwrap_err(); + assert!(errs[0].contains("expected an object")); + } + + #[test] + fn schema_serde_round_trip() { + let schema = WorkflowSchema { + inputs: HashMap::from([("name".into(), SchemaType::String)]), + outputs: HashMap::from([("result".into(), SchemaType::Bool)]), + }; + let json_str = serde_json::to_string(&schema).unwrap(); + let deserialized: WorkflowSchema = serde_json::from_str(&json_str).unwrap(); + assert_eq!(deserialized.inputs["name"], SchemaType::String); + assert_eq!(deserialized.outputs["result"], SchemaType::Bool); + } +} diff --git a/wfe-core/src/primitives/mod.rs b/wfe-core/src/primitives/mod.rs index 6e6e9ea..32262e4 100644 --- a/wfe-core/src/primitives/mod.rs +++ b/wfe-core/src/primitives/mod.rs @@ -8,6 +8,7 @@ pub mod recur; pub mod saga_container; pub mod schedule; pub mod sequence; +pub mod sub_workflow; pub mod wait_for; pub mod while_step; @@ -21,6 +22,7 @@ pub use recur::RecurStep; pub use saga_container::SagaContainerStep; pub use schedule::ScheduleStep; pub use sequence::SequenceStep; +pub use sub_workflow::SubWorkflowStep; pub use wait_for::WaitForStep; pub use while_step::WhileStep; @@ -42,6 +44,7 @@ mod test_helpers { step, workflow, cancellation_token: CancellationToken::new(), + host_context: None, } } diff --git a/wfe-core/src/primitives/sub_workflow.rs b/wfe-core/src/primitives/sub_workflow.rs new file mode 100644 index 0000000..e5bdfac --- /dev/null +++ b/wfe-core/src/primitives/sub_workflow.rs @@ -0,0 +1,437 @@ +use async_trait::async_trait; +use chrono::Utc; + +use crate::models::schema::WorkflowSchema; +use crate::models::ExecutionResult; +use crate::traits::step::{StepBody, StepExecutionContext}; + +/// A step that starts a child workflow and waits for its completion. +/// +/// On first invocation, it validates inputs against `input_schema`, starts the +/// child workflow via the host context, and returns a "wait for event" result. +/// +/// When the child workflow completes, the event data arrives, output keys are +/// extracted, and the step proceeds. +#[derive(Default)] +pub struct SubWorkflowStep { + /// The definition ID of the child workflow to start. + pub workflow_id: String, + /// The version of the child workflow definition. + pub version: u32, + /// Input data to pass to the child workflow. + pub inputs: serde_json::Value, + /// Keys to extract from the child workflow's completion event data. + pub output_keys: Vec, + /// Optional schema to validate inputs before starting the child. + pub input_schema: Option, + /// Optional schema to validate outputs from the child. + pub output_schema: Option, +} + +#[async_trait] +impl StepBody for SubWorkflowStep { + async fn run(&mut self, context: &StepExecutionContext<'_>) -> crate::Result { + // If event data has arrived, the child workflow completed. + if let Some(event_data) = &context.execution_pointer.event_data { + // Extract output_keys from event data. + let mut output = serde_json::Map::new(); + + // The event data contains { "status": "...", "data": { ... } }. + let child_data = event_data + .get("data") + .cloned() + .unwrap_or(serde_json::Value::Null); + + if self.output_keys.is_empty() { + // If no specific keys requested, pass all child data through. + if let serde_json::Value::Object(map) = child_data { + output = map; + } + } else { + // Extract only the requested keys. + for key in &self.output_keys { + if let Some(val) = child_data.get(key) { + output.insert(key.clone(), val.clone()); + } + } + } + + let output_value = serde_json::Value::Object(output); + + // Validate against output schema if present. + if let Some(ref schema) = self.output_schema + && let Err(errors) = schema.validate_outputs(&output_value) + { + return Err(crate::WfeError::StepExecution(format!( + "SubWorkflow output validation failed: {}", + errors.join("; ") + ))); + } + + let mut result = ExecutionResult::next(); + result.output_data = Some(output_value); + return Ok(result); + } + + // Hydrate from step_config if our fields are empty (created via Default). + if self.workflow_id.is_empty() + && let Some(config) = &context.step.step_config + { + if let Some(wf_id) = config.get("workflow_id").and_then(|v| v.as_str()) { + self.workflow_id = wf_id.to_string(); + } + if let Some(ver) = config.get("version").and_then(|v| v.as_u64()) { + self.version = ver as u32; + } + if let Some(inputs) = config.get("inputs") { + self.inputs = inputs.clone(); + } + if let Some(keys) = config.get("output_keys").and_then(|v| v.as_array()) { + self.output_keys = keys + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + } + } + + // First call: validate inputs and start child workflow. + if let Some(ref schema) = self.input_schema + && let Err(errors) = schema.validate_inputs(&self.inputs) + { + return Err(crate::WfeError::StepExecution(format!( + "SubWorkflow input validation failed: {}", + errors.join("; ") + ))); + } + + let host = context.host_context.ok_or_else(|| { + crate::WfeError::StepExecution( + "SubWorkflowStep requires a host context to start child workflows".to_string(), + ) + })?; + + let child_instance_id = host + .start_workflow(&self.workflow_id, self.version, self.inputs.clone()) + .await?; + + Ok(ExecutionResult::wait_for_event( + "wfe.workflow.completed", + child_instance_id, + Utc::now(), + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::schema::SchemaType; + use crate::models::ExecutionPointer; + use crate::primitives::test_helpers::*; + use crate::traits::step::HostContext; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Mutex; + + /// A mock HostContext that records calls and returns a fixed instance ID. + struct MockHostContext { + started: Mutex>, + result_id: String, + } + + impl MockHostContext { + fn new(result_id: &str) -> Self { + Self { + started: Mutex::new(Vec::new()), + result_id: result_id.to_string(), + } + } + + fn calls(&self) -> Vec<(String, u32, serde_json::Value)> { + self.started.lock().unwrap().clone() + } + } + + impl HostContext for MockHostContext { + fn start_workflow( + &self, + definition_id: &str, + version: u32, + data: serde_json::Value, + ) -> std::pin::Pin> + Send + '_>> + { + let def_id = definition_id.to_string(); + let result_id = self.result_id.clone(); + Box::pin(async move { + self.started + .lock() + .unwrap() + .push((def_id, version, data)); + Ok(result_id) + }) + } + } + + /// A mock HostContext that returns an error. + struct FailingHostContext; + + impl HostContext for FailingHostContext { + fn start_workflow( + &self, + _definition_id: &str, + _version: u32, + _data: serde_json::Value, + ) -> std::pin::Pin> + Send + '_>> + { + Box::pin(async { + Err(crate::WfeError::StepExecution( + "failed to start child".to_string(), + )) + }) + } + } + + fn make_context_with_host<'a>( + pointer: &'a ExecutionPointer, + step: &'a crate::models::WorkflowStep, + workflow: &'a crate::models::WorkflowInstance, + host: &'a dyn HostContext, + ) -> StepExecutionContext<'a> { + StepExecutionContext { + item: None, + execution_pointer: pointer, + persistence_data: pointer.persistence_data.as_ref(), + step, + workflow, + cancellation_token: tokio_util::sync::CancellationToken::new(), + host_context: Some(host), + } + } + + #[tokio::test] + async fn first_call_starts_child_and_waits() { + let host = MockHostContext::new("child-123"); + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({"x": 10}), + ..Default::default() + }; + + let pointer = ExecutionPointer::new(0); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host); + + let result = step.run(&ctx).await.unwrap(); + assert!(!result.proceed); + assert_eq!(result.event_name.as_deref(), Some("wfe.workflow.completed")); + assert_eq!(result.event_key.as_deref(), Some("child-123")); + assert!(result.event_as_of.is_some()); + + let calls = host.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].0, "child-def"); + assert_eq!(calls[0].1, 1); + assert_eq!(calls[0].2, json!({"x": 10})); + } + + #[tokio::test] + async fn child_completed_proceeds_with_output() { + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + output_keys: vec!["result".into()], + ..Default::default() + }; + + let mut pointer = ExecutionPointer::new(0); + pointer.event_data = Some(json!({ + "status": "Complete", + "data": {"result": "success", "extra": "ignored"} + })); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context(&pointer, &wf_step, &workflow); + + let result = step.run(&ctx).await.unwrap(); + assert!(result.proceed); + assert_eq!( + result.output_data, + Some(json!({"result": "success"})) + ); + } + + #[tokio::test] + async fn child_completed_no_output_keys_passes_all() { + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + output_keys: vec![], + ..Default::default() + }; + + let mut pointer = ExecutionPointer::new(0); + pointer.event_data = Some(json!({ + "status": "Complete", + "data": {"a": 1, "b": 2} + })); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context(&pointer, &wf_step, &workflow); + + let result = step.run(&ctx).await.unwrap(); + assert!(result.proceed); + assert_eq!( + result.output_data, + Some(json!({"a": 1, "b": 2})) + ); + } + + #[tokio::test] + async fn no_host_context_errors() { + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + ..Default::default() + }; + + let pointer = ExecutionPointer::new(0); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context(&pointer, &wf_step, &workflow); + + let err = step.run(&ctx).await.unwrap_err(); + assert!(err.to_string().contains("host context")); + } + + #[tokio::test] + async fn input_validation_failure() { + let host = MockHostContext::new("child-123"); + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({"name": 42}), // wrong type + input_schema: Some(WorkflowSchema { + inputs: HashMap::from([("name".into(), SchemaType::String)]), + outputs: HashMap::new(), + }), + ..Default::default() + }; + + let pointer = ExecutionPointer::new(0); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host); + + let err = step.run(&ctx).await.unwrap_err(); + assert!(err.to_string().contains("input validation failed")); + assert!(host.calls().is_empty()); + } + + #[tokio::test] + async fn output_validation_failure() { + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + output_keys: vec![], + output_schema: Some(WorkflowSchema { + inputs: HashMap::new(), + outputs: HashMap::from([("result".into(), SchemaType::String)]), + }), + ..Default::default() + }; + + let mut pointer = ExecutionPointer::new(0); + pointer.event_data = Some(json!({ + "status": "Complete", + "data": {"result": 42} + })); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context(&pointer, &wf_step, &workflow); + + let err = step.run(&ctx).await.unwrap_err(); + assert!(err.to_string().contains("output validation failed")); + } + + #[tokio::test] + async fn input_validation_passes_then_starts_child() { + let host = MockHostContext::new("child-456"); + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 2, + inputs: json!({"name": "Alice"}), + input_schema: Some(WorkflowSchema { + inputs: HashMap::from([("name".into(), SchemaType::String)]), + outputs: HashMap::new(), + }), + ..Default::default() + }; + + let pointer = ExecutionPointer::new(0); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host); + + let result = step.run(&ctx).await.unwrap(); + assert!(!result.proceed); + assert_eq!(result.event_key.as_deref(), Some("child-456")); + assert_eq!(host.calls().len(), 1); + } + + #[tokio::test] + async fn host_start_workflow_error_propagates() { + let host = FailingHostContext; + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + ..Default::default() + }; + + let pointer = ExecutionPointer::new(0); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context_with_host(&pointer, &wf_step, &workflow, &host); + + let err = step.run(&ctx).await.unwrap_err(); + assert!(err.to_string().contains("failed to start child")); + } + + #[tokio::test] + async fn event_data_without_data_field_returns_empty_output() { + let mut step = SubWorkflowStep { + workflow_id: "child-def".into(), + version: 1, + inputs: json!({}), + output_keys: vec!["foo".into()], + ..Default::default() + }; + + let mut pointer = ExecutionPointer::new(0); + pointer.event_data = Some(json!({"status": "Complete"})); + let wf_step = default_step(); + let workflow = default_workflow(); + let ctx = make_context(&pointer, &wf_step, &workflow); + + let result = step.run(&ctx).await.unwrap(); + assert!(result.proceed); + assert_eq!(result.output_data, Some(json!({}))); + } + + #[tokio::test] + async fn default_step_has_empty_fields() { + let step = SubWorkflowStep::default(); + assert!(step.workflow_id.is_empty()); + assert_eq!(step.version, 0); + assert_eq!(step.inputs, json!(null)); + assert!(step.output_keys.is_empty()); + assert!(step.input_schema.is_none()); + assert!(step.output_schema.is_none()); + } +} diff --git a/wfe-core/src/traits/middleware.rs b/wfe-core/src/traits/middleware.rs index 7bb4152..dd6b9a7 100644 --- a/wfe-core/src/traits/middleware.rs +++ b/wfe-core/src/traits/middleware.rs @@ -68,6 +68,7 @@ mod tests { step: &step, workflow: &instance, cancellation_token: tokio_util::sync::CancellationToken::new(), + host_context: None, }; mw.pre_step(&ctx).await.unwrap(); } @@ -86,6 +87,7 @@ mod tests { step: &step, workflow: &instance, cancellation_token: tokio_util::sync::CancellationToken::new(), + host_context: None, }; let result = ExecutionResult::next(); mw.post_step(&ctx, &result).await.unwrap(); diff --git a/wfe-core/src/traits/mod.rs b/wfe-core/src/traits/mod.rs index 85151f2..1a5d79f 100644 --- a/wfe-core/src/traits/mod.rs +++ b/wfe-core/src/traits/mod.rs @@ -17,4 +17,4 @@ pub use persistence::{ pub use queue::QueueProvider; pub use registry::WorkflowRegistry; pub use search::{Page, SearchFilter, SearchIndex, WorkflowSearchResult}; -pub use step::{StepBody, StepExecutionContext, WorkflowData}; +pub use step::{HostContext, StepBody, StepExecutionContext, WorkflowData}; diff --git a/wfe-core/src/traits/step.rs b/wfe-core/src/traits/step.rs index 47b4b14..d6967e3 100644 --- a/wfe-core/src/traits/step.rs +++ b/wfe-core/src/traits/step.rs @@ -11,8 +11,18 @@ pub trait WorkflowData: Serialize + DeserializeOwned + Send + Sync + Clone + 'st /// Blanket implementation: any type satisfying the bounds is WorkflowData. impl WorkflowData for T where T: Serialize + DeserializeOwned + Send + Sync + Clone + 'static {} +/// Context for steps that need to interact with the workflow host. +/// Implemented by WorkflowHost to allow steps like SubWorkflow to start child workflows. +pub trait HostContext: Send + Sync { + fn start_workflow( + &self, + definition_id: &str, + version: u32, + data: serde_json::Value, + ) -> std::pin::Pin> + Send + '_>>; +} + /// Context available to a step during execution. -#[derive(Debug)] pub struct StepExecutionContext<'a> { /// The current item when iterating (ForEach). pub item: Option<&'a serde_json::Value>, @@ -26,6 +36,22 @@ pub struct StepExecutionContext<'a> { pub workflow: &'a WorkflowInstance, /// Cancellation token. pub cancellation_token: tokio_util::sync::CancellationToken, + /// Host context for starting child workflows. None if not available. + pub host_context: Option<&'a dyn HostContext>, +} + +// Manual Debug impl since dyn HostContext is not Debug. +impl<'a> std::fmt::Debug for StepExecutionContext<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StepExecutionContext") + .field("item", &self.item) + .field("execution_pointer", &self.execution_pointer) + .field("persistence_data", &self.persistence_data) + .field("step", &self.step) + .field("workflow", &self.workflow) + .field("host_context", &self.host_context.is_some()) + .finish() + } } /// The core unit of work in a workflow. Each step implements this trait.