Files
sol/src/brain/responder.rs

180 lines
5.9 KiB
Rust
Raw Normal View History

use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ChatResponse, ChatResponseChoiceFinishReason},
constants::Model,
error::ApiError,
tool::ToolChoice,
};
use rand::Rng;
use tokio::time::{sleep, Duration};
use tracing::{debug, error, info, warn};
use crate::brain::conversation::ContextMessage;
use crate::brain::personality::Personality;
use crate::config::Config;
use crate::tools::ToolRegistry;
/// Run a Mistral chat completion on a blocking thread.
///
/// The mistral client's `chat_async` holds a `std::sync::MutexGuard` across an
/// `.await` point, making the future !Send. We use the synchronous `chat()`
/// method via `spawn_blocking` instead.
pub(crate) async fn chat_blocking(
client: &Arc<mistralai_client::v1::client::Client>,
model: Model,
messages: Vec<ChatMessage>,
params: ChatParams,
) -> Result<ChatResponse, ApiError> {
let client = Arc::clone(client);
tokio::task::spawn_blocking(move || client.chat(model, messages, Some(params)))
.await
.map_err(|e| ApiError {
message: format!("spawn_blocking join error: {e}"),
})?
}
pub struct Responder {
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
}
impl Responder {
pub fn new(
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
) -> Self {
Self {
config,
personality,
tools,
}
}
pub async fn generate_response(
&self,
context: &[ContextMessage],
trigger_body: &str,
trigger_sender: &str,
room_name: &str,
members: &[String],
is_spontaneous: bool,
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Option<String> {
// Apply response delay
let delay = if is_spontaneous {
rand::thread_rng().gen_range(
self.config.behavior.spontaneous_delay_min_ms
..=self.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
self.config.behavior.response_delay_min_ms
..=self.config.behavior.response_delay_max_ms,
)
};
sleep(Duration::from_millis(delay)).await;
let system_prompt = self.personality.build_system_prompt(room_name, members);
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
// Add context messages
for msg in context {
if msg.sender == self.config.matrix.user_id {
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
} else {
let user_msg = format!("{}: {}", msg.sender, msg.content);
messages.push(ChatMessage::new_user_message(&user_msg));
}
}
// Add the triggering message
let trigger = format!("{trigger_sender}: {trigger_body}");
messages.push(ChatMessage::new_user_message(&trigger));
let tool_defs = ToolRegistry::tool_definitions();
let model = Model::new(&self.config.mistral.default_model);
let max_iterations = self.config.mistral.max_tool_iterations;
for iteration in 0..=max_iterations {
let params = ChatParams {
tools: if iteration < max_iterations {
Some(tool_defs.clone())
} else {
None
},
tool_choice: if iteration < max_iterations {
Some(ToolChoice::Auto)
} else {
None
},
..Default::default()
};
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
Ok(r) => r,
Err(e) => {
error!("Mistral chat failed: {e}");
return None;
}
};
let choice = &response.choices[0];
if choice.finish_reason == ChatResponseChoiceFinishReason::ToolCalls {
if let Some(tool_calls) = &choice.message.tool_calls {
// Add assistant message with tool calls
messages.push(ChatMessage::new_assistant_message(
&choice.message.content,
Some(tool_calls.clone()),
));
for tc in tool_calls {
let call_id = tc.id.as_deref().unwrap_or("unknown");
info!(
tool = tc.function.name.as_str(),
id = call_id,
"Executing tool call"
);
let result = self
.tools
.execute(&tc.function.name, &tc.function.arguments)
.await;
let result_str = match result {
Ok(s) => s,
Err(e) => {
warn!(tool = tc.function.name.as_str(), "Tool failed: {e}");
format!("Error: {e}")
}
};
messages.push(ChatMessage::new_tool_message(
&result_str,
call_id,
Some(&tc.function.name),
));
}
debug!(iteration, "Tool iteration complete, continuing");
continue;
}
}
// Final text response
let text = choice.message.content.trim().to_string();
if text.is_empty() {
return None;
}
return Some(text);
}
warn!("Exceeded max tool iterations");
None
}
}