refactor(orchestrator): unified engine + clean public API
- Single run_tool_loop() replaces chat/code-specific variants - generate() for ConversationRegistry path - generate_from_response() for caller-managed conversations - Engine uses ToolContext via execute_with_context(), no ResponseContext - No imports from grpc, sync, matrix, context, or agent_ux modules
This commit is contained in:
@@ -1,206 +1,50 @@
|
|||||||
//! The unified response generation engine.
|
//! The unified response generation engine.
|
||||||
//!
|
//!
|
||||||
//! Single implementation of the Mistral Conversations API tool loop.
|
//! Single implementation of the Mistral Conversations API tool loop.
|
||||||
//! Emits `OrchestratorEvent`s instead of calling Matrix/gRPC directly.
|
//! Emits `OrchestratorEvent`s — no transport knowledge.
|
||||||
//! Phase 2: replaces `responder.generate_response_conversations()`.
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use mistralai_client::v1::conversations::{
|
use mistralai_client::v1::conversations::{
|
||||||
ConversationEntry, ConversationInput, ConversationResponse, FunctionResultEntry,
|
ConversationEntry, ConversationInput, ConversationResponse,
|
||||||
|
AppendConversationRequest, FunctionResultEntry,
|
||||||
};
|
};
|
||||||
use rand::Rng;
|
|
||||||
use tokio::sync::broadcast;
|
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use super::event::*;
|
use super::event::*;
|
||||||
use super::tool_dispatch;
|
use super::tool_dispatch;
|
||||||
use super::Orchestrator;
|
use super::Orchestrator;
|
||||||
use crate::brain::personality::Personality;
|
|
||||||
use crate::context::ResponseContext;
|
|
||||||
use crate::conversations::ConversationRegistry;
|
|
||||||
use crate::time_context::TimeContext;
|
|
||||||
|
|
||||||
/// Strip "Sol: " or "sol: " prefix that models sometimes prepend.
|
/// Run the Mistral tool iteration loop on a conversation response.
|
||||||
fn strip_sol_prefix(text: &str) -> String {
|
/// Emits events for every state transition. Returns the final text + usage.
|
||||||
let trimmed = text.trim();
|
pub async fn run_tool_loop(
|
||||||
if trimmed.starts_with("Sol: ") || trimmed.starts_with("sol: ") {
|
|
||||||
trimmed[5..].to_string()
|
|
||||||
} else if trimmed.starts_with("Sol:\n") || trimmed.starts_with("sol:\n") {
|
|
||||||
trimmed[4..].to_string()
|
|
||||||
} else {
|
|
||||||
trimmed.to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a chat response through the Conversations API.
|
|
||||||
/// This is the unified path that replaces both the responder's conversations
|
|
||||||
/// method and the gRPC session's inline tool loop.
|
|
||||||
pub async fn generate_response(
|
|
||||||
orchestrator: &Orchestrator,
|
orchestrator: &Orchestrator,
|
||||||
personality: &Personality,
|
request: &GenerateRequest,
|
||||||
request: &ChatRequest,
|
|
||||||
response_ctx: &ResponseContext,
|
|
||||||
conversation_registry: &ConversationRegistry,
|
|
||||||
) -> Option<String> {
|
|
||||||
let request_id = &request.request_id;
|
|
||||||
|
|
||||||
// Emit start
|
|
||||||
orchestrator.emit(OrchestratorEvent::ResponseStarted {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
mode: ResponseMode::Chat {
|
|
||||||
room_id: request.room_id.clone(),
|
|
||||||
is_spontaneous: request.is_spontaneous,
|
|
||||||
use_thread: request.use_thread,
|
|
||||||
trigger_event_id: request.trigger_event_id.clone(),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Apply response delay
|
|
||||||
if !orchestrator.config.behavior.instant_responses {
|
|
||||||
let delay = if request.is_spontaneous {
|
|
||||||
rand::thread_rng().gen_range(
|
|
||||||
orchestrator.config.behavior.spontaneous_delay_min_ms
|
|
||||||
..=orchestrator.config.behavior.spontaneous_delay_max_ms,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
rand::thread_rng().gen_range(
|
|
||||||
orchestrator.config.behavior.response_delay_min_ms
|
|
||||||
..=orchestrator.config.behavior.response_delay_max_ms,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
tokio::time::sleep(Duration::from_millis(delay)).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::Thinking {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Memory query
|
|
||||||
let memory_notes = load_memory_notes(orchestrator, response_ctx, &request.trigger_body).await;
|
|
||||||
|
|
||||||
// Build context header
|
|
||||||
let tc = TimeContext::now();
|
|
||||||
let mut context_header = format!(
|
|
||||||
"{}\n[room: {} ({})]",
|
|
||||||
tc.message_line(),
|
|
||||||
request.room_name,
|
|
||||||
request.room_id,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(ref notes) = memory_notes {
|
|
||||||
context_header.push('\n');
|
|
||||||
context_header.push_str(notes);
|
|
||||||
}
|
|
||||||
|
|
||||||
let user_msg = if request.is_dm {
|
|
||||||
request.trigger_body.clone()
|
|
||||||
} else {
|
|
||||||
format!("<{}> {}", response_ctx.matrix_user_id, request.trigger_body)
|
|
||||||
};
|
|
||||||
|
|
||||||
let input_text = format!("{context_header}\n{user_msg}");
|
|
||||||
let input = ConversationInput::Text(input_text);
|
|
||||||
|
|
||||||
// Send through conversation registry
|
|
||||||
let response = match conversation_registry
|
|
||||||
.send_message(
|
|
||||||
&request.room_id,
|
|
||||||
input,
|
|
||||||
request.is_dm,
|
|
||||||
&orchestrator.mistral,
|
|
||||||
request.context_hint.as_deref(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(r) => r,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Conversation API failed: {e}");
|
|
||||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
error: e.clone(),
|
|
||||||
});
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Tool loop
|
|
||||||
let result = run_tool_loop(
|
|
||||||
orchestrator,
|
|
||||||
request_id,
|
|
||||||
response,
|
|
||||||
response_ctx,
|
|
||||||
conversation_registry,
|
|
||||||
&request.room_id,
|
|
||||||
request.is_dm,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Some(text) => {
|
|
||||||
let text = strip_sol_prefix(&text);
|
|
||||||
if text.is_empty() {
|
|
||||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
error: "Empty response from model".into(),
|
|
||||||
});
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::ResponseReady {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
text: text.clone(),
|
|
||||||
prompt_tokens: 0, // TODO: extract from response
|
|
||||||
completion_tokens: 0,
|
|
||||||
tool_iterations: 0,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Schedule memory extraction
|
|
||||||
orchestrator.emit(OrchestratorEvent::MemoryExtractionScheduled {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
user_msg: request.trigger_body.clone(),
|
|
||||||
response: text.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
Some(text)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
error: "No response from model".into(),
|
|
||||||
});
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The unified tool iteration loop.
|
|
||||||
/// Emits tool events and executes server-side tools.
|
|
||||||
/// Client-side tools are dispatched via the pending_client_tools oneshot map.
|
|
||||||
async fn run_tool_loop(
|
|
||||||
orchestrator: &Orchestrator,
|
|
||||||
request_id: &RequestId,
|
|
||||||
initial_response: ConversationResponse,
|
initial_response: ConversationResponse,
|
||||||
response_ctx: &ResponseContext,
|
) -> Option<(String, TokenUsage)> {
|
||||||
conversation_registry: &ConversationRegistry,
|
let request_id = &request.request_id;
|
||||||
room_id: &str,
|
|
||||||
is_dm: bool,
|
|
||||||
) -> Option<String> {
|
|
||||||
let function_calls = initial_response.function_calls();
|
let function_calls = initial_response.function_calls();
|
||||||
|
|
||||||
// No tool calls — return the text directly
|
// No tool calls — return text directly
|
||||||
if function_calls.is_empty() {
|
if function_calls.is_empty() {
|
||||||
return initial_response.assistant_text();
|
return initial_response.assistant_text().map(|text| {
|
||||||
|
(text, TokenUsage {
|
||||||
|
prompt_tokens: initial_response.usage.prompt_tokens,
|
||||||
|
completion_tokens: initial_response.usage.completion_tokens,
|
||||||
|
})
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::AgentProgressStarted {
|
let conv_id = initial_response.conversation_id.clone();
|
||||||
request_id: request_id.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let max_iterations = orchestrator.config.mistral.max_tool_iterations;
|
let max_iterations = orchestrator.config.mistral.max_tool_iterations;
|
||||||
let mut current_response = initial_response;
|
let mut current_response = initial_response;
|
||||||
|
|
||||||
|
let tool_ctx = ToolContext {
|
||||||
|
user_id: request.user_id.clone(),
|
||||||
|
scope_key: request.conversation_key.clone(),
|
||||||
|
is_direct: request.is_direct,
|
||||||
|
};
|
||||||
|
|
||||||
for iteration in 0..max_iterations {
|
for iteration in 0..max_iterations {
|
||||||
let calls = current_response.function_calls();
|
let calls = current_response.function_calls();
|
||||||
if calls.is_empty() {
|
if calls.is_empty() {
|
||||||
@@ -221,7 +65,7 @@ async fn run_tool_loop(
|
|||||||
side: side.clone(),
|
side: side.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::ToolExecutionStarted {
|
orchestrator.emit(OrchestratorEvent::ToolStarted {
|
||||||
request_id: request_id.clone(),
|
request_id: request_id.clone(),
|
||||||
call_id: call_id.into(),
|
call_id: call_id.into(),
|
||||||
name: fc.name.clone(),
|
name: fc.name.clone(),
|
||||||
@@ -229,31 +73,14 @@ async fn run_tool_loop(
|
|||||||
|
|
||||||
let result_str = match side {
|
let result_str = match side {
|
||||||
ToolSide::Server => {
|
ToolSide::Server => {
|
||||||
// Execute server-side tool
|
let result = orchestrator
|
||||||
let result = if fc.name == "research" {
|
.tools
|
||||||
// Research needs special handling (room + event context)
|
.execute_with_context(&fc.name, &fc.arguments, &tool_ctx)
|
||||||
// For now, use the standard execute path
|
.await;
|
||||||
orchestrator
|
|
||||||
.tools
|
|
||||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
orchestrator
|
|
||||||
.tools
|
|
||||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
|
||||||
.await
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(s) => {
|
Ok(s) => {
|
||||||
let preview: String = s.chars().take(500).collect();
|
info!(tool = fc.name.as_str(), id = call_id, result_len = s.len(), "Tool result");
|
||||||
info!(
|
|
||||||
tool = fc.name.as_str(),
|
|
||||||
id = call_id,
|
|
||||||
result_len = s.len(),
|
|
||||||
result_preview = preview.as_str(),
|
|
||||||
"Tool result"
|
|
||||||
);
|
|
||||||
s
|
s
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -263,7 +90,7 @@ async fn run_tool_loop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ToolSide::Client => {
|
ToolSide::Client => {
|
||||||
// Park on oneshot — gRPC bridge will deliver the result
|
// Park on oneshot — transport bridge delivers the result
|
||||||
let rx = orchestrator.register_pending_tool(call_id).await;
|
let rx = orchestrator.register_pending_tool(call_id).await;
|
||||||
match tokio::time::timeout(Duration::from_secs(300), rx).await {
|
match tokio::time::timeout(Duration::from_secs(300), rx).await {
|
||||||
Ok(Ok(payload)) => {
|
Ok(Ok(payload)) => {
|
||||||
@@ -281,22 +108,14 @@ async fn run_tool_loop(
|
|||||||
|
|
||||||
let success = !result_str.starts_with("Error:");
|
let success = !result_str.starts_with("Error:");
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::ToolExecutionCompleted {
|
orchestrator.emit(OrchestratorEvent::ToolCompleted {
|
||||||
request_id: request_id.clone(),
|
request_id: request_id.clone(),
|
||||||
call_id: call_id.into(),
|
call_id: call_id.into(),
|
||||||
name: fc.name.clone(),
|
name: fc.name.clone(),
|
||||||
result: result_str.chars().take(200).collect(),
|
result_preview: result_str.chars().take(200).collect(),
|
||||||
success,
|
success,
|
||||||
});
|
});
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::AgentProgressStep {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
summary: crate::agent_ux::AgentProgress::format_tool_call(
|
|
||||||
&fc.name,
|
|
||||||
&fc.arguments,
|
|
||||||
),
|
|
||||||
});
|
|
||||||
|
|
||||||
result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry {
|
result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry {
|
||||||
tool_call_id: call_id.to_string(),
|
tool_call_id: call_id.to_string(),
|
||||||
result: result_str,
|
result: result_str,
|
||||||
@@ -307,17 +126,24 @@ async fn run_tool_loop(
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send function results back to conversation
|
// Send results back to Mistral conversation
|
||||||
current_response = match conversation_registry
|
let req = AppendConversationRequest {
|
||||||
.send_function_result(room_id, result_entries, &orchestrator.mistral)
|
inputs: ConversationInput::Entries(result_entries),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: Some(true),
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
current_response = match orchestrator
|
||||||
|
.mistral
|
||||||
|
.append_conversation_async(&conv_id, &req)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to send function results: {e}");
|
error!("Failed to send function results: {}", e.message);
|
||||||
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
|
|
||||||
request_id: request_id.clone(),
|
|
||||||
});
|
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -325,23 +151,10 @@ async fn run_tool_loop(
|
|||||||
debug!(iteration, "Tool iteration complete");
|
debug!(iteration, "Tool iteration complete");
|
||||||
}
|
}
|
||||||
|
|
||||||
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
|
current_response.assistant_text().map(|text| {
|
||||||
request_id: request_id.clone(),
|
(text, TokenUsage {
|
||||||
});
|
prompt_tokens: current_response.usage.prompt_tokens,
|
||||||
|
completion_tokens: current_response.usage.completion_tokens,
|
||||||
current_response.assistant_text()
|
})
|
||||||
}
|
})
|
||||||
|
|
||||||
/// Load memory notes relevant to the trigger message.
|
|
||||||
/// TODO (Phase 4): move the full memory::store query logic here
|
|
||||||
/// when the Responder is dissolved. For now returns None — the Matrix
|
|
||||||
/// bridge path still uses the responder which has memory loading.
|
|
||||||
async fn load_memory_notes(
|
|
||||||
_orchestrator: &Orchestrator,
|
|
||||||
_ctx: &ResponseContext,
|
|
||||||
_trigger_body: &str,
|
|
||||||
) -> Option<String> {
|
|
||||||
// Memory loading is not yet migrated to the orchestrator.
|
|
||||||
// The responder's load_memory_notes() still handles this for now.
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
//! Event-driven orchestrator — the single response generation pipeline.
|
//! Event-driven orchestrator — Sol's transport-agnostic response pipeline.
|
||||||
//!
|
//!
|
||||||
//! The orchestrator owns the Mistral tool loop and emits events through
|
//! The orchestrator receives a `GenerateRequest`, runs the Mistral
|
||||||
//! a `tokio::broadcast` channel. Transport bridges (Matrix, gRPC, etc.)
|
//! conversation + tool loop, and emits `OrchestratorEvent`s through a
|
||||||
//! subscribe to these events and translate them to their protocol.
|
//! `tokio::broadcast` channel. It has zero knowledge of Matrix, gRPC,
|
||||||
|
//! or any specific transport.
|
||||||
//!
|
//!
|
||||||
//! Phase 1: types + channel wiring only. No behavior change.
|
//! Transport bridges subscribe externally via `subscribe()` and translate
|
||||||
|
//! events to their protocol.
|
||||||
|
|
||||||
pub mod engine;
|
pub mod engine;
|
||||||
pub mod event;
|
pub mod event;
|
||||||
@@ -20,79 +22,152 @@ pub use event::*;
|
|||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::conversations::ConversationRegistry;
|
use crate::conversations::ConversationRegistry;
|
||||||
use crate::persistence::Store;
|
|
||||||
use crate::tools::ToolRegistry;
|
use crate::tools::ToolRegistry;
|
||||||
|
|
||||||
const EVENT_CHANNEL_CAPACITY: usize = 256;
|
const EVENT_CHANNEL_CAPACITY: usize = 256;
|
||||||
|
|
||||||
/// Result payload from a client-side tool execution.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ToolResultPayload {
|
|
||||||
pub text: String,
|
|
||||||
pub is_error: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The orchestrator — Sol's response generation pipeline.
|
/// The orchestrator — Sol's response generation pipeline.
|
||||||
///
|
///
|
||||||
/// Owns the event broadcast channel. Bridges subscribe via `subscribe()`.
|
/// Owns the event broadcast channel. Transport bridges subscribe via `subscribe()`.
|
||||||
/// Phase 2+ will add `generate_chat_response()` and `generate_code_response()`.
|
/// Call `generate()` or `generate_from_response()` to run the pipeline.
|
||||||
pub struct Orchestrator {
|
pub struct Orchestrator {
|
||||||
pub config: Arc<Config>,
|
pub config: Arc<Config>,
|
||||||
pub tools: Arc<ToolRegistry>,
|
pub tools: Arc<ToolRegistry>,
|
||||||
pub store: Arc<Store>,
|
|
||||||
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
||||||
pub conversation_registry: Arc<ConversationRegistry>,
|
pub conversations: Arc<ConversationRegistry>,
|
||||||
pub system_prompt: String,
|
pub system_prompt: String,
|
||||||
|
|
||||||
/// Broadcast sender — all orchestration events go here.
|
/// Broadcast sender — all orchestration events go here.
|
||||||
event_tx: broadcast::Sender<OrchestratorEvent>,
|
event_tx: broadcast::Sender<OrchestratorEvent>,
|
||||||
|
|
||||||
/// Pending client-side tool calls awaiting results from gRPC clients.
|
/// Pending client-side tool calls awaiting results from external sources.
|
||||||
/// Key: call_id, Value: oneshot sender to unblock the engine.
|
|
||||||
pending_client_tools: Arc<Mutex<HashMap<String, oneshot::Sender<ToolResultPayload>>>>,
|
pending_client_tools: Arc<Mutex<HashMap<String, oneshot::Sender<ToolResultPayload>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Orchestrator {
|
impl Orchestrator {
|
||||||
/// Create a new orchestrator. Returns the orchestrator and an initial
|
|
||||||
/// event receiver (for the first subscriber, typically the Matrix bridge).
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
tools: Arc<ToolRegistry>,
|
tools: Arc<ToolRegistry>,
|
||||||
store: Arc<Store>,
|
|
||||||
mistral: Arc<mistralai_client::v1::client::Client>,
|
mistral: Arc<mistralai_client::v1::client::Client>,
|
||||||
conversation_registry: Arc<ConversationRegistry>,
|
conversations: Arc<ConversationRegistry>,
|
||||||
system_prompt: String,
|
system_prompt: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
let (event_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
|
||||||
|
|
||||||
info!("Orchestrator initialized (event channel capacity: {EVENT_CHANNEL_CAPACITY})");
|
info!("Orchestrator initialized (event channel capacity: {EVENT_CHANNEL_CAPACITY})");
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
tools,
|
tools,
|
||||||
store,
|
|
||||||
mistral,
|
mistral,
|
||||||
conversation_registry,
|
conversations,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
event_tx,
|
event_tx,
|
||||||
pending_client_tools: Arc::new(Mutex::new(HashMap::new())),
|
pending_client_tools: Arc::new(Mutex::new(HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Subscribe to the event stream. Each subscriber gets its own receiver
|
/// Subscribe to the event stream.
|
||||||
/// that independently tracks position in the broadcast buffer.
|
|
||||||
pub fn subscribe(&self) -> broadcast::Receiver<OrchestratorEvent> {
|
pub fn subscribe(&self) -> broadcast::Receiver<OrchestratorEvent> {
|
||||||
self.event_tx.subscribe()
|
self.event_tx.subscribe()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit an event to all subscribers.
|
/// Emit an event to all subscribers.
|
||||||
pub fn emit(&self, event: OrchestratorEvent) {
|
pub fn emit(&self, event: OrchestratorEvent) {
|
||||||
// Ignore send errors (no subscribers is fine during startup).
|
|
||||||
let _ = self.event_tx.send(event);
|
let _ = self.event_tx.send(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Submit a tool result from an external source (e.g., gRPC client).
|
/// Generate a response using the ConversationRegistry.
|
||||||
/// Unblocks the engine's tool loop for the matching call_id.
|
/// Creates or appends to a conversation keyed by `request.conversation_key`.
|
||||||
|
pub async fn generate(&self, request: &GenerateRequest) -> Option<String> {
|
||||||
|
self.emit(OrchestratorEvent::Started {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
metadata: request.metadata.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
self.emit(OrchestratorEvent::Thinking {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let input = mistralai_client::v1::conversations::ConversationInput::Text(
|
||||||
|
request.text.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = match self.conversations
|
||||||
|
.send_message(
|
||||||
|
&request.conversation_key,
|
||||||
|
input,
|
||||||
|
request.is_direct,
|
||||||
|
&self.mistral,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
self.emit(OrchestratorEvent::Failed {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
error: e.clone(),
|
||||||
|
});
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.run_and_emit(request, response).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a response from a pre-built ConversationResponse.
|
||||||
|
/// The caller already created/appended the conversation externally.
|
||||||
|
/// The orchestrator only runs the tool loop and emits events.
|
||||||
|
pub async fn generate_from_response(
|
||||||
|
&self,
|
||||||
|
request: &GenerateRequest,
|
||||||
|
response: mistralai_client::v1::conversations::ConversationResponse,
|
||||||
|
) -> Option<String> {
|
||||||
|
self.emit(OrchestratorEvent::Started {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
metadata: request.metadata.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
self.emit(OrchestratorEvent::Thinking {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
self.run_and_emit(request, response).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the tool loop and emit Done/Failed events.
|
||||||
|
async fn run_and_emit(
|
||||||
|
&self,
|
||||||
|
request: &GenerateRequest,
|
||||||
|
response: mistralai_client::v1::conversations::ConversationResponse,
|
||||||
|
) -> Option<String> {
|
||||||
|
let result = engine::run_tool_loop(self, request, response).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Some((text, usage)) => {
|
||||||
|
info!(
|
||||||
|
prompt_tokens = usage.prompt_tokens,
|
||||||
|
completion_tokens = usage.completion_tokens,
|
||||||
|
"Response ready"
|
||||||
|
);
|
||||||
|
self.emit(OrchestratorEvent::Done {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
text: text.clone(),
|
||||||
|
usage,
|
||||||
|
});
|
||||||
|
Some(text)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.emit(OrchestratorEvent::Failed {
|
||||||
|
request_id: request.request_id.clone(),
|
||||||
|
error: "No response from model".into(),
|
||||||
|
});
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Submit a tool result from an external source.
|
||||||
pub async fn submit_tool_result(
|
pub async fn submit_tool_result(
|
||||||
&self,
|
&self,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
@@ -112,8 +187,7 @@ impl Orchestrator {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a pending client-side tool call. Returns a oneshot receiver
|
/// Register a pending client-side tool call.
|
||||||
/// that the engine awaits for the result.
|
|
||||||
pub async fn register_pending_tool(
|
pub async fn register_pending_tool(
|
||||||
&self,
|
&self,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
@@ -145,35 +219,14 @@ mod tests {
|
|||||||
assert_eq!(received.request_id(), &id);
|
assert_eq!(received.request_id(), &id);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_multiple_subscribers() {
|
|
||||||
let (event_tx, _) = broadcast::channel(16);
|
|
||||||
let mut rx1 = event_tx.subscribe();
|
|
||||||
let mut rx2 = event_tx.subscribe();
|
|
||||||
|
|
||||||
let id = RequestId::new();
|
|
||||||
let _ = event_tx.send(OrchestratorEvent::Thinking {
|
|
||||||
request_id: id.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let r1 = rx1.recv().await.unwrap();
|
|
||||||
let r2 = rx2.recv().await.unwrap();
|
|
||||||
assert_eq!(r1.request_id(), &id);
|
|
||||||
assert_eq!(r2.request_id(), &id);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_submit_tool_result() {
|
async fn test_submit_tool_result() {
|
||||||
let pending: Arc<Mutex<HashMap<String, oneshot::Sender<ToolResultPayload>>>> =
|
let pending: Arc<Mutex<HashMap<String, oneshot::Sender<ToolResultPayload>>>> =
|
||||||
Arc::new(Mutex::new(HashMap::new()));
|
Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = oneshot::channel();
|
||||||
pending
|
pending.lock().await.insert("call-1".into(), tx);
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert("call-1".into(), tx);
|
|
||||||
|
|
||||||
// Simulate submitting result
|
|
||||||
let sender = pending.lock().await.remove("call-1").unwrap();
|
let sender = pending.lock().await.remove("call-1").unwrap();
|
||||||
sender
|
sender
|
||||||
.send(ToolResultPayload {
|
.send(ToolResultPayload {
|
||||||
|
|||||||
Reference in New Issue
Block a user