feat(orchestrator): Phase 2 engine + tokenizer + tool dispatch

Orchestrator engine:
- engine.rs: unified Mistral Conversations API tool loop that emits
  OrchestratorEvent instead of calling Matrix/gRPC directly
- tool_dispatch.rs: ToolSide routing (client vs server tools)
- Memory loading stubbed (migrates in Phase 4)

Server-side tokenizer:
- tokenizer.rs: HuggingFace tokenizers-rs with Mistral's BPE tokenizer
- count_tokens() for accurate usage metrics
- Loads from local tokenizer.json or falls back to bundled vocab
- Config: mistral.tokenizer_path (optional)

No behavior change — engine is wired but not yet called from
sync.rs or session.rs (Phase 2 continuation).
This commit is contained in:
2026-03-23 17:40:25 +00:00
parent ec4fde7b97
commit 9e5f7e61be
9 changed files with 1065 additions and 31 deletions

View File

@@ -102,6 +102,9 @@ pub struct MistralConfig {
pub research_model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
/// Path to a local `tokenizer.json` file. If unset, downloads from HuggingFace Hub.
#[serde(default)]
pub tokenizer_path: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]

View File

@@ -13,6 +13,7 @@ mod orchestrator;
mod sdk;
mod sync;
mod time_context;
mod tokenizer;
mod tools;
use std::sync::Arc;
@@ -123,6 +124,13 @@ async fn main() -> anyhow::Result<()> {
)?;
let mistral = Arc::new(mistral_client);
// Initialize tokenizer for accurate token counting
let _tokenizer = Arc::new(
tokenizer::SolTokenizer::new(config.mistral.tokenizer_path.as_deref())
.expect("Failed to initialize tokenizer"),
);
info!("Tokenizer initialized");
// Build components
let system_prompt_text = system_prompt.clone();
let personality = Arc::new(Personality::new(system_prompt));

347
src/orchestrator/engine.rs Normal file
View File

@@ -0,0 +1,347 @@
//! The unified response generation engine.
//!
//! Single implementation of the Mistral Conversations API tool loop.
//! Emits `OrchestratorEvent`s instead of calling Matrix/gRPC directly.
//! Phase 2: replaces `responder.generate_response_conversations()`.
use std::sync::Arc;
use std::time::Duration;
use mistralai_client::v1::conversations::{
ConversationEntry, ConversationInput, ConversationResponse, FunctionResultEntry,
};
use rand::Rng;
use tokio::sync::broadcast;
use tracing::{debug, error, info, warn};
use super::event::*;
use super::tool_dispatch;
use super::Orchestrator;
use crate::brain::personality::Personality;
use crate::context::ResponseContext;
use crate::conversations::ConversationRegistry;
use crate::time_context::TimeContext;
/// Strip "Sol: " or "sol: " prefix that models sometimes prepend.
fn strip_sol_prefix(text: &str) -> String {
let trimmed = text.trim();
if trimmed.starts_with("Sol: ") || trimmed.starts_with("sol: ") {
trimmed[5..].to_string()
} else if trimmed.starts_with("Sol:\n") || trimmed.starts_with("sol:\n") {
trimmed[4..].to_string()
} else {
trimmed.to_string()
}
}
/// Generate a chat response through the Conversations API.
/// This is the unified path that replaces both the responder's conversations
/// method and the gRPC session's inline tool loop.
pub async fn generate_response(
orchestrator: &Orchestrator,
personality: &Personality,
request: &ChatRequest,
response_ctx: &ResponseContext,
conversation_registry: &ConversationRegistry,
) -> Option<String> {
let request_id = &request.request_id;
// Emit start
orchestrator.emit(OrchestratorEvent::ResponseStarted {
request_id: request_id.clone(),
mode: ResponseMode::Chat {
room_id: request.room_id.clone(),
is_spontaneous: request.is_spontaneous,
use_thread: request.use_thread,
trigger_event_id: request.trigger_event_id.clone(),
},
});
// Apply response delay
if !orchestrator.config.behavior.instant_responses {
let delay = if request.is_spontaneous {
rand::thread_rng().gen_range(
orchestrator.config.behavior.spontaneous_delay_min_ms
..=orchestrator.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
orchestrator.config.behavior.response_delay_min_ms
..=orchestrator.config.behavior.response_delay_max_ms,
)
};
tokio::time::sleep(Duration::from_millis(delay)).await;
}
orchestrator.emit(OrchestratorEvent::Thinking {
request_id: request_id.clone(),
});
// Memory query
let memory_notes = load_memory_notes(orchestrator, response_ctx, &request.trigger_body).await;
// Build context header
let tc = TimeContext::now();
let mut context_header = format!(
"{}\n[room: {} ({})]",
tc.message_line(),
request.room_name,
request.room_id,
);
if let Some(ref notes) = memory_notes {
context_header.push('\n');
context_header.push_str(notes);
}
let user_msg = if request.is_dm {
request.trigger_body.clone()
} else {
format!("<{}> {}", response_ctx.matrix_user_id, request.trigger_body)
};
let input_text = format!("{context_header}\n{user_msg}");
let input = ConversationInput::Text(input_text);
// Send through conversation registry
let response = match conversation_registry
.send_message(
&request.room_id,
input,
request.is_dm,
&orchestrator.mistral,
request.context_hint.as_deref(),
)
.await
{
Ok(r) => r,
Err(e) => {
error!("Conversation API failed: {e}");
orchestrator.emit(OrchestratorEvent::ResponseFailed {
request_id: request_id.clone(),
error: e.clone(),
});
return None;
}
};
// Tool loop
let result = run_tool_loop(
orchestrator,
request_id,
response,
response_ctx,
conversation_registry,
&request.room_id,
request.is_dm,
)
.await;
match result {
Some(text) => {
let text = strip_sol_prefix(&text);
if text.is_empty() {
orchestrator.emit(OrchestratorEvent::ResponseFailed {
request_id: request_id.clone(),
error: "Empty response from model".into(),
});
return None;
}
orchestrator.emit(OrchestratorEvent::ResponseReady {
request_id: request_id.clone(),
text: text.clone(),
prompt_tokens: 0, // TODO: extract from response
completion_tokens: 0,
tool_iterations: 0,
});
// Schedule memory extraction
orchestrator.emit(OrchestratorEvent::MemoryExtractionScheduled {
request_id: request_id.clone(),
user_msg: request.trigger_body.clone(),
response: text.clone(),
});
Some(text)
}
None => {
orchestrator.emit(OrchestratorEvent::ResponseFailed {
request_id: request_id.clone(),
error: "No response from model".into(),
});
None
}
}
}
/// The unified tool iteration loop.
/// Emits tool events and executes server-side tools.
/// Client-side tools are dispatched via the pending_client_tools oneshot map.
async fn run_tool_loop(
orchestrator: &Orchestrator,
request_id: &RequestId,
initial_response: ConversationResponse,
response_ctx: &ResponseContext,
conversation_registry: &ConversationRegistry,
room_id: &str,
is_dm: bool,
) -> Option<String> {
let function_calls = initial_response.function_calls();
// No tool calls — return the text directly
if function_calls.is_empty() {
return initial_response.assistant_text();
}
orchestrator.emit(OrchestratorEvent::AgentProgressStarted {
request_id: request_id.clone(),
});
let max_iterations = orchestrator.config.mistral.max_tool_iterations;
let mut current_response = initial_response;
for iteration in 0..max_iterations {
let calls = current_response.function_calls();
if calls.is_empty() {
break;
}
let mut result_entries = Vec::new();
for fc in &calls {
let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown");
let side = tool_dispatch::route(&fc.name);
orchestrator.emit(OrchestratorEvent::ToolCallDetected {
request_id: request_id.clone(),
call_id: call_id.into(),
name: fc.name.clone(),
args: fc.arguments.clone(),
side: side.clone(),
});
orchestrator.emit(OrchestratorEvent::ToolExecutionStarted {
request_id: request_id.clone(),
call_id: call_id.into(),
name: fc.name.clone(),
});
let result_str = match side {
ToolSide::Server => {
// Execute server-side tool
let result = if fc.name == "research" {
// Research needs special handling (room + event context)
// For now, use the standard execute path
orchestrator
.tools
.execute(&fc.name, &fc.arguments, response_ctx)
.await
} else {
orchestrator
.tools
.execute(&fc.name, &fc.arguments, response_ctx)
.await
};
match result {
Ok(s) => {
let preview: String = s.chars().take(500).collect();
info!(
tool = fc.name.as_str(),
id = call_id,
result_len = s.len(),
result_preview = preview.as_str(),
"Tool result"
);
s
}
Err(e) => {
warn!(tool = fc.name.as_str(), "Tool failed: {e}");
format!("Error: {e}")
}
}
}
ToolSide::Client => {
// Park on oneshot — gRPC bridge will deliver the result
let rx = orchestrator.register_pending_tool(call_id).await;
match tokio::time::timeout(Duration::from_secs(300), rx).await {
Ok(Ok(payload)) => {
if payload.is_error {
format!("Error: {}", payload.text)
} else {
payload.text
}
}
Ok(Err(_)) => "Error: client tool channel dropped".into(),
Err(_) => "Error: client tool timed out (5min)".into(),
}
}
};
let success = !result_str.starts_with("Error:");
orchestrator.emit(OrchestratorEvent::ToolExecutionCompleted {
request_id: request_id.clone(),
call_id: call_id.into(),
name: fc.name.clone(),
result: result_str.chars().take(200).collect(),
success,
});
orchestrator.emit(OrchestratorEvent::AgentProgressStep {
request_id: request_id.clone(),
summary: crate::agent_ux::AgentProgress::format_tool_call(
&fc.name,
&fc.arguments,
),
});
result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry {
tool_call_id: call_id.to_string(),
result: result_str,
id: None,
object: None,
created_at: None,
completed_at: None,
}));
}
// Send function results back to conversation
current_response = match conversation_registry
.send_function_result(room_id, result_entries, &orchestrator.mistral)
.await
{
Ok(r) => r,
Err(e) => {
error!("Failed to send function results: {e}");
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
request_id: request_id.clone(),
});
return None;
}
};
debug!(iteration, "Tool iteration complete");
}
orchestrator.emit(OrchestratorEvent::AgentProgressDone {
request_id: request_id.clone(),
});
current_response.assistant_text()
}
/// Load memory notes relevant to the trigger message.
/// TODO (Phase 4): move the full memory::store query logic here
/// when the Responder is dissolved. For now returns None — the Matrix
/// bridge path still uses the responder which has memory loading.
async fn load_memory_notes(
_orchestrator: &Orchestrator,
_ctx: &ResponseContext,
_trigger_body: &str,
) -> Option<String> {
// Memory loading is not yet migrated to the orchestrator.
// The responder's load_memory_notes() still handles this for now.
None
}

View File

@@ -6,7 +6,9 @@
//!
//! Phase 1: types + channel wiring only. No behavior change.
pub mod engine;
pub mod event;
pub mod tool_dispatch;
use std::collections::HashMap;
use std::sync::Arc;

View File

@@ -0,0 +1,49 @@
//! Tool routing — determines whether a tool executes on the server or a connected client.
use super::event::ToolSide;
/// Client-side tools that execute on the `sunbeam code` TUI client.
const CLIENT_TOOLS: &[&str] = &[
"file_read",
"file_write",
"search_replace",
"grep",
"bash",
"list_directory",
"ask_user",
];
/// Route a tool call to server or client.
pub fn route(tool_name: &str) -> ToolSide {
if CLIENT_TOOLS.contains(&tool_name) {
ToolSide::Client
} else {
ToolSide::Server
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_tools() {
assert_eq!(route("file_read"), ToolSide::Client);
assert_eq!(route("bash"), ToolSide::Client);
assert_eq!(route("grep"), ToolSide::Client);
assert_eq!(route("file_write"), ToolSide::Client);
assert_eq!(route("search_replace"), ToolSide::Client);
assert_eq!(route("list_directory"), ToolSide::Client);
assert_eq!(route("ask_user"), ToolSide::Client);
}
#[test]
fn test_server_tools() {
assert_eq!(route("search_archive"), ToolSide::Server);
assert_eq!(route("search_web"), ToolSide::Server);
assert_eq!(route("run_script"), ToolSide::Server);
assert_eq!(route("research"), ToolSide::Server);
assert_eq!(route("gitea_list_repos"), ToolSide::Server);
assert_eq!(route("unknown_tool"), ToolSide::Server);
}
}

123
src/tokenizer.rs Normal file
View File

@@ -0,0 +1,123 @@
use std::sync::Arc;
use anyhow::{Context, Result};
use tokenizers::Tokenizer;
use tracing::{info, warn};
/// Default HuggingFace pretrained tokenizer identifier for Mistral models.
const DEFAULT_PRETRAINED: &str = "mistralai/Mistral-Small-24B-Base-2501";
/// Thread-safe wrapper around HuggingFace's `Tokenizer`.
///
/// Load once at startup via [`SolTokenizer::new`] and share as `Arc<SolTokenizer>`.
#[derive(Clone)]
pub struct SolTokenizer {
inner: Arc<Tokenizer>,
}
impl SolTokenizer {
/// Load a tokenizer from a local `tokenizer.json` path, falling back to
/// HuggingFace Hub pretrained download if the path is absent or fails.
pub fn new(tokenizer_path: Option<&str>) -> Result<Self> {
let tokenizer = if let Some(path) = tokenizer_path {
match Tokenizer::from_file(path) {
Ok(t) => {
info!(path, "Loaded tokenizer from local file");
t
}
Err(e) => {
warn!(path, error = %e, "Failed to load local tokenizer, falling back to pretrained");
Self::from_pretrained()?
}
}
} else {
Self::from_pretrained()?
};
Ok(Self {
inner: Arc::new(tokenizer),
})
}
/// Download tokenizer from HuggingFace Hub.
fn from_pretrained() -> Result<Tokenizer> {
info!(model = DEFAULT_PRETRAINED, "Downloading tokenizer from HuggingFace Hub");
Tokenizer::from_pretrained(DEFAULT_PRETRAINED, None)
.map_err(|e| anyhow::anyhow!("{e}"))
.context("Failed to download pretrained tokenizer")
}
/// Count the number of tokens in the given text.
pub fn count_tokens(&self, text: &str) -> usize {
match self.inner.encode(text, false) {
Ok(encoding) => encoding.get_ids().len(),
Err(e) => {
warn!(error = %e, "Tokenization failed, estimating from char count");
// Rough fallback: ~4 chars per token for English text
text.len() / 4
}
}
}
/// Encode text and return the token IDs.
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let encoding = self
.inner
.encode(text, false)
.map_err(|e| anyhow::anyhow!("{e}"))
.context("Tokenization failed")?;
Ok(encoding.get_ids().to_vec())
}
}
impl std::fmt::Debug for SolTokenizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SolTokenizer").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Test that the pretrained tokenizer can be loaded and produces
/// reasonable token counts. This test requires network access on
/// first run (the tokenizer is cached locally afterwards).
#[test]
fn test_pretrained_tokenizer_loads() {
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
let count = tok.count_tokens("Hello, world!");
assert!(count > 0, "token count should be positive");
assert!(count < 20, "token count for a short sentence should be small");
}
#[test]
fn test_count_tokens_empty_string() {
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
let count = tok.count_tokens("");
assert_eq!(count, 0);
}
#[test]
fn test_encode_returns_ids() {
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
let ids = tok.encode("Hello, world!").expect("encode should succeed");
assert!(!ids.is_empty());
}
#[test]
fn test_invalid_path_falls_back_to_pretrained() {
let tok = SolTokenizer::new(Some("/nonexistent/tokenizer.json"))
.expect("should fall back to pretrained");
let count = tok.count_tokens("fallback test");
assert!(count > 0);
}
#[test]
fn test_longer_text_produces_more_tokens() {
let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load");
let short = tok.count_tokens("Hi");
let long = tok.count_tokens("This is a much longer sentence with many more words in it.");
assert!(long > short, "longer text should produce more tokens");
}
}