feat(orchestrator): Phase 2 engine + tokenizer + tool dispatch
Orchestrator engine: - engine.rs: unified Mistral Conversations API tool loop that emits OrchestratorEvent instead of calling Matrix/gRPC directly - tool_dispatch.rs: ToolSide routing (client vs server tools) - Memory loading stubbed (migrates in Phase 4) Server-side tokenizer: - tokenizer.rs: HuggingFace tokenizers-rs with Mistral's BPE tokenizer - count_tokens() for accurate usage metrics - Loads from local tokenizer.json or falls back to bundled vocab - Config: mistral.tokenizer_path (optional) No behavior change — engine is wired but not yet called from sync.rs or session.rs (Phase 2 continuation).
This commit is contained in:
@@ -102,6 +102,9 @@ pub struct MistralConfig {
|
||||
pub research_model: String,
|
||||
#[serde(default = "default_max_tool_iterations")]
|
||||
pub max_tool_iterations: usize,
|
||||
/// Path to a local `tokenizer.json` file. If unset, downloads from HuggingFace Hub.
|
||||
#[serde(default)]
|
||||
pub tokenizer_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
|
||||
@@ -13,6 +13,7 @@ mod orchestrator;
|
||||
mod sdk;
|
||||
mod sync;
|
||||
mod time_context;
|
||||
mod tokenizer;
|
||||
mod tools;
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -123,6 +124,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
)?;
|
||||
let mistral = Arc::new(mistral_client);
|
||||
|
||||
// Initialize tokenizer for accurate token counting
|
||||
let _tokenizer = Arc::new(
|
||||
tokenizer::SolTokenizer::new(config.mistral.tokenizer_path.as_deref())
|
||||
.expect("Failed to initialize tokenizer"),
|
||||
);
|
||||
info!("Tokenizer initialized");
|
||||
|
||||
// Build components
|
||||
let system_prompt_text = system_prompt.clone();
|
||||
let personality = Arc::new(Personality::new(system_prompt));
|
||||
|
||||
347
src/orchestrator/engine.rs
Normal file
347
src/orchestrator/engine.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! The unified response generation engine.
|
||||
//!
|
||||
//! Single implementation of the Mistral Conversations API tool loop.
|
||||
//! Emits `OrchestratorEvent`s instead of calling Matrix/gRPC directly.
|
||||
//! Phase 2: replaces `responder.generate_response_conversations()`.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use mistralai_client::v1::conversations::{
|
||||
ConversationEntry, ConversationInput, ConversationResponse, FunctionResultEntry,
|
||||
};
|
||||
use rand::Rng;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::event::*;
|
||||
use super::tool_dispatch;
|
||||
use super::Orchestrator;
|
||||
use crate::brain::personality::Personality;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::conversations::ConversationRegistry;
|
||||
use crate::time_context::TimeContext;
|
||||
|
||||
/// Strip "Sol: " or "sol: " prefix that models sometimes prepend.
|
||||
fn strip_sol_prefix(text: &str) -> String {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.starts_with("Sol: ") || trimmed.starts_with("sol: ") {
|
||||
trimmed[5..].to_string()
|
||||
} else if trimmed.starts_with("Sol:\n") || trimmed.starts_with("sol:\n") {
|
||||
trimmed[4..].to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a chat response through the Conversations API.
|
||||
/// This is the unified path that replaces both the responder's conversations
|
||||
/// method and the gRPC session's inline tool loop.
|
||||
pub async fn generate_response(
|
||||
orchestrator: &Orchestrator,
|
||||
personality: &Personality,
|
||||
request: &ChatRequest,
|
||||
response_ctx: &ResponseContext,
|
||||
conversation_registry: &ConversationRegistry,
|
||||
) -> Option<String> {
|
||||
let request_id = &request.request_id;
|
||||
|
||||
// Emit start
|
||||
orchestrator.emit(OrchestratorEvent::ResponseStarted {
|
||||
request_id: request_id.clone(),
|
||||
mode: ResponseMode::Chat {
|
||||
room_id: request.room_id.clone(),
|
||||
is_spontaneous: request.is_spontaneous,
|
||||
use_thread: request.use_thread,
|
||||
trigger_event_id: request.trigger_event_id.clone(),
|
||||
},
|
||||
});
|
||||
|
||||
// Apply response delay
|
||||
if !orchestrator.config.behavior.instant_responses {
|
||||
let delay = if request.is_spontaneous {
|
||||
rand::thread_rng().gen_range(
|
||||
orchestrator.config.behavior.spontaneous_delay_min_ms
|
||||
..=orchestrator.config.behavior.spontaneous_delay_max_ms,
|
||||
)
|
||||
} else {
|
||||
rand::thread_rng().gen_range(
|
||||
orchestrator.config.behavior.response_delay_min_ms
|
||||
..=orchestrator.config.behavior.response_delay_max_ms,
|
||||
)
|
||||
};
|
||||
tokio::time::sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::Thinking {
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
|
||||
// Memory query
|
||||
let memory_notes = load_memory_notes(orchestrator, response_ctx, &request.trigger_body).await;
|
||||
|
||||
// Build context header
|
||||
let tc = TimeContext::now();
|
||||
let mut context_header = format!(
|
||||
"{}\n[room: {} ({})]",
|
||||
tc.message_line(),
|
||||
request.room_name,
|
||||
request.room_id,
|
||||
);
|
||||
|
||||
if let Some(ref notes) = memory_notes {
|
||||
context_header.push('\n');
|
||||
context_header.push_str(notes);
|
||||
}
|
||||
|
||||
let user_msg = if request.is_dm {
|
||||
request.trigger_body.clone()
|
||||
} else {
|
||||
format!("<{}> {}", response_ctx.matrix_user_id, request.trigger_body)
|
||||
};
|
||||
|
||||
let input_text = format!("{context_header}\n{user_msg}");
|
||||
let input = ConversationInput::Text(input_text);
|
||||
|
||||
// Send through conversation registry
|
||||
let response = match conversation_registry
|
||||
.send_message(
|
||||
&request.room_id,
|
||||
input,
|
||||
request.is_dm,
|
||||
&orchestrator.mistral,
|
||||
request.context_hint.as_deref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Conversation API failed: {e}");
|
||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
||||
request_id: request_id.clone(),
|
||||
error: e.clone(),
|
||||
});
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Tool loop
|
||||
let result = run_tool_loop(
|
||||
orchestrator,
|
||||
request_id,
|
||||
response,
|
||||
response_ctx,
|
||||
conversation_registry,
|
||||
&request.room_id,
|
||||
request.is_dm,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Some(text) => {
|
||||
let text = strip_sol_prefix(&text);
|
||||
if text.is_empty() {
|
||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
||||
request_id: request_id.clone(),
|
||||
error: "Empty response from model".into(),
|
||||
});
|
||||
return None;
|
||||
}
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::ResponseReady {
|
||||
request_id: request_id.clone(),
|
||||
text: text.clone(),
|
||||
prompt_tokens: 0, // TODO: extract from response
|
||||
completion_tokens: 0,
|
||||
tool_iterations: 0,
|
||||
});
|
||||
|
||||
// Schedule memory extraction
|
||||
orchestrator.emit(OrchestratorEvent::MemoryExtractionScheduled {
|
||||
request_id: request_id.clone(),
|
||||
user_msg: request.trigger_body.clone(),
|
||||
response: text.clone(),
|
||||
});
|
||||
|
||||
Some(text)
|
||||
}
|
||||
None => {
|
||||
orchestrator.emit(OrchestratorEvent::ResponseFailed {
|
||||
request_id: request_id.clone(),
|
||||
error: "No response from model".into(),
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The unified tool iteration loop.
|
||||
/// Emits tool events and executes server-side tools.
|
||||
/// Client-side tools are dispatched via the pending_client_tools oneshot map.
|
||||
async fn run_tool_loop(
|
||||
orchestrator: &Orchestrator,
|
||||
request_id: &RequestId,
|
||||
initial_response: ConversationResponse,
|
||||
response_ctx: &ResponseContext,
|
||||
conversation_registry: &ConversationRegistry,
|
||||
room_id: &str,
|
||||
is_dm: bool,
|
||||
) -> Option<String> {
|
||||
let function_calls = initial_response.function_calls();
|
||||
|
||||
// No tool calls — return the text directly
|
||||
if function_calls.is_empty() {
|
||||
return initial_response.assistant_text();
|
||||
}
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::AgentProgressStarted {
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
|
||||
let max_iterations = orchestrator.config.mistral.max_tool_iterations;
|
||||
let mut current_response = initial_response;
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
let calls = current_response.function_calls();
|
||||
if calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut result_entries = Vec::new();
|
||||
|
||||
for fc in &calls {
|
||||
let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown");
|
||||
let side = tool_dispatch::route(&fc.name);
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::ToolCallDetected {
|
||||
request_id: request_id.clone(),
|
||||
call_id: call_id.into(),
|
||||
name: fc.name.clone(),
|
||||
args: fc.arguments.clone(),
|
||||
side: side.clone(),
|
||||
});
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::ToolExecutionStarted {
|
||||
request_id: request_id.clone(),
|
||||
call_id: call_id.into(),
|
||||
name: fc.name.clone(),
|
||||
});
|
||||
|
||||
let result_str = match side {
|
||||
ToolSide::Server => {
|
||||
// Execute server-side tool
|
||||
let result = if fc.name == "research" {
|
||||
// Research needs special handling (room + event context)
|
||||
// For now, use the standard execute path
|
||||
orchestrator
|
||||
.tools
|
||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
||||
.await
|
||||
} else {
|
||||
orchestrator
|
||||
.tools
|
||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
||||
.await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(s) => {
|
||||
let preview: String = s.chars().take(500).collect();
|
||||
info!(
|
||||
tool = fc.name.as_str(),
|
||||
id = call_id,
|
||||
result_len = s.len(),
|
||||
result_preview = preview.as_str(),
|
||||
"Tool result"
|
||||
);
|
||||
s
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(tool = fc.name.as_str(), "Tool failed: {e}");
|
||||
format!("Error: {e}")
|
||||
}
|
||||
}
|
||||
}
|
||||
ToolSide::Client => {
|
||||
// Park on oneshot — gRPC bridge will deliver the result
|
||||
let rx = orchestrator.register_pending_tool(call_id).await;
|
||||
match tokio::time::timeout(Duration::from_secs(300), rx).await {
|
||||
Ok(Ok(payload)) => {
|
||||
if payload.is_error {
|
||||
format!("Error: {}", payload.text)
|
||||
} else {
|
||||
payload.text
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => "Error: client tool channel dropped".into(),
|
||||
Err(_) => "Error: client tool timed out (5min)".into(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let success = !result_str.starts_with("Error:");
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::ToolExecutionCompleted {
|
||||
request_id: request_id.clone(),
|
||||
call_id: call_id.into(),
|
||||
name: fc.name.clone(),
|
||||
result: result_str.chars().take(200).collect(),
|
||||
success,
|
||||
});
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::AgentProgressStep {
|
||||
request_id: request_id.clone(),
|
||||
summary: crate::agent_ux::AgentProgress::format_tool_call(
|
||||
&fc.name,
|
||||
&fc.arguments,
|
||||
),
|
||||
});
|
||||
|
||||
result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry {
|
||||
tool_call_id: call_id.to_string(),
|
||||
result: result_str,
|
||||
id: None,
|
||||
object: None,
|
||||
created_at: None,
|
||||
completed_at: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Send function results back to conversation
|
||||
current_response = match conversation_registry
|
||||
.send_function_result(room_id, result_entries, &orchestrator.mistral)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to send function results: {e}");
|
||||
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(iteration, "Tool iteration complete");
|
||||
}
|
||||
|
||||
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
|
||||
request_id: request_id.clone(),
|
||||
});
|
||||
|
||||
current_response.assistant_text()
|
||||
}
|
||||
|
||||
/// Load memory notes relevant to the trigger message.
|
||||
/// TODO (Phase 4): move the full memory::store query logic here
|
||||
/// when the Responder is dissolved. For now returns None — the Matrix
|
||||
/// bridge path still uses the responder which has memory loading.
|
||||
async fn load_memory_notes(
|
||||
_orchestrator: &Orchestrator,
|
||||
_ctx: &ResponseContext,
|
||||
_trigger_body: &str,
|
||||
) -> Option<String> {
|
||||
// Memory loading is not yet migrated to the orchestrator.
|
||||
// The responder's load_memory_notes() still handles this for now.
|
||||
None
|
||||
}
|
||||
@@ -6,7 +6,9 @@
|
||||
//!
|
||||
//! Phase 1: types + channel wiring only. No behavior change.
|
||||
|
||||
pub mod engine;
|
||||
pub mod event;
|
||||
pub mod tool_dispatch;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
49
src/orchestrator/tool_dispatch.rs
Normal file
49
src/orchestrator/tool_dispatch.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Tool routing — determines whether a tool executes on the server or a connected client.
|
||||
|
||||
use super::event::ToolSide;
|
||||
|
||||
/// Client-side tools that execute on the `sunbeam code` TUI client.
|
||||
const CLIENT_TOOLS: &[&str] = &[
|
||||
"file_read",
|
||||
"file_write",
|
||||
"search_replace",
|
||||
"grep",
|
||||
"bash",
|
||||
"list_directory",
|
||||
"ask_user",
|
||||
];
|
||||
|
||||
/// Route a tool call to server or client.
|
||||
pub fn route(tool_name: &str) -> ToolSide {
|
||||
if CLIENT_TOOLS.contains(&tool_name) {
|
||||
ToolSide::Client
|
||||
} else {
|
||||
ToolSide::Server
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_client_tools() {
|
||||
assert_eq!(route("file_read"), ToolSide::Client);
|
||||
assert_eq!(route("bash"), ToolSide::Client);
|
||||
assert_eq!(route("grep"), ToolSide::Client);
|
||||
assert_eq!(route("file_write"), ToolSide::Client);
|
||||
assert_eq!(route("search_replace"), ToolSide::Client);
|
||||
assert_eq!(route("list_directory"), ToolSide::Client);
|
||||
assert_eq!(route("ask_user"), ToolSide::Client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_tools() {
|
||||
assert_eq!(route("search_archive"), ToolSide::Server);
|
||||
assert_eq!(route("search_web"), ToolSide::Server);
|
||||
assert_eq!(route("run_script"), ToolSide::Server);
|
||||
assert_eq!(route("research"), ToolSide::Server);
|
||||
assert_eq!(route("gitea_list_repos"), ToolSide::Server);
|
||||
assert_eq!(route("unknown_tool"), ToolSide::Server);
|
||||
}
|
||||
}
|
||||
123
src/tokenizer.rs
Normal file
123
src/tokenizer.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokenizers::Tokenizer;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Default HuggingFace pretrained tokenizer identifier for Mistral models.
|
||||
const DEFAULT_PRETRAINED: &str = "mistralai/Mistral-Small-24B-Base-2501";
|
||||
|
||||
/// Thread-safe wrapper around HuggingFace's `Tokenizer`.
|
||||
///
|
||||
/// Load once at startup via [`SolTokenizer::new`] and share as `Arc<SolTokenizer>`.
|
||||
#[derive(Clone)]
|
||||
pub struct SolTokenizer {
|
||||
inner: Arc<Tokenizer>,
|
||||
}
|
||||
|
||||
impl SolTokenizer {
|
||||
/// Load a tokenizer from a local `tokenizer.json` path, falling back to
|
||||
/// HuggingFace Hub pretrained download if the path is absent or fails.
|
||||
pub fn new(tokenizer_path: Option<&str>) -> Result<Self> {
|
||||
let tokenizer = if let Some(path) = tokenizer_path {
|
||||
match Tokenizer::from_file(path) {
|
||||
Ok(t) => {
|
||||
info!(path, "Loaded tokenizer from local file");
|
||||
t
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(path, error = %e, "Failed to load local tokenizer, falling back to pretrained");
|
||||
Self::from_pretrained()?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Self::from_pretrained()?
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(tokenizer),
|
||||
})
|
||||
}
|
||||
|
||||
/// Download tokenizer from HuggingFace Hub.
|
||||
fn from_pretrained() -> Result<Tokenizer> {
|
||||
info!(model = DEFAULT_PRETRAINED, "Downloading tokenizer from HuggingFace Hub");
|
||||
Tokenizer::from_pretrained(DEFAULT_PRETRAINED, None)
|
||||
.map_err(|e| anyhow::anyhow!("{e}"))
|
||||
.context("Failed to download pretrained tokenizer")
|
||||
}
|
||||
|
||||
/// Count the number of tokens in the given text.
|
||||
pub fn count_tokens(&self, text: &str) -> usize {
|
||||
match self.inner.encode(text, false) {
|
||||
Ok(encoding) => encoding.get_ids().len(),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Tokenization failed, estimating from char count");
|
||||
// Rough fallback: ~4 chars per token for English text
|
||||
text.len() / 4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode text and return the token IDs.
|
||||
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
||||
let encoding = self
|
||||
.inner
|
||||
.encode(text, false)
|
||||
.map_err(|e| anyhow::anyhow!("{e}"))
|
||||
.context("Tokenization failed")?;
|
||||
Ok(encoding.get_ids().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SolTokenizer {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SolTokenizer").finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Test that the pretrained tokenizer can be loaded and produces
|
||||
/// reasonable token counts. This test requires network access on
|
||||
/// first run (the tokenizer is cached locally afterwards).
|
||||
#[test]
|
||||
fn test_pretrained_tokenizer_loads() {
|
||||
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
|
||||
let count = tok.count_tokens("Hello, world!");
|
||||
assert!(count > 0, "token count should be positive");
|
||||
assert!(count < 20, "token count for a short sentence should be small");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_tokens_empty_string() {
|
||||
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
|
||||
let count = tok.count_tokens("");
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_returns_ids() {
|
||||
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
|
||||
let ids = tok.encode("Hello, world!").expect("encode should succeed");
|
||||
assert!(!ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_path_falls_back_to_pretrained() {
|
||||
let tok = SolTokenizer::new(Some("/nonexistent/tokenizer.json"))
|
||||
.expect("should fall back to pretrained");
|
||||
let count = tok.count_tokens("fallback test");
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_longer_text_produces_more_tokens() {
|
||||
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
|
||||
let short = tok.count_tokens("Hi");
|
||||
let long = tok.count_tokens("This is a much longer sentence with many more words in it.");
|
||||
assert!(long > short, "longer text should produce more tokens");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user