Files
sol/src/main.rs

341 lines
11 KiB
Rust
Raw Normal View History

mod archive;
mod brain;
mod config;
mod context;
mod matrix_utils;
mod memory;
mod sync;
mod tools;
use std::sync::Arc;
use matrix_sdk::Client;
use opensearch::http::transport::TransportBuilder;
use opensearch::OpenSearch;
use ruma::{OwnedDeviceId, OwnedUserId};
use tokio::signal;
use tokio::sync::Mutex;
use tracing::{error, info};
use url::Url;
use archive::indexer::Indexer;
use archive::schema::create_index_if_not_exists;
use brain::conversation::{ContextMessage, ConversationManager};
use memory::schema::create_index_if_not_exists as create_memory_index;
use brain::evaluator::Evaluator;
use brain::personality::Personality;
use brain::responder::Responder;
use config::Config;
use sync::AppState;
use tools::ToolRegistry;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("sol=info")),
)
.init();
// Load config
let config_path =
std::env::var("SOL_CONFIG").unwrap_or_else(|_| "/etc/sol/sol.toml".into());
let config = Config::load(&config_path)?;
info!("Loaded config from {config_path}");
// Load system prompt
let prompt_path = std::env::var("SOL_SYSTEM_PROMPT")
.unwrap_or_else(|_| "/etc/sol/system_prompt.md".into());
let system_prompt = std::fs::read_to_string(&prompt_path)?;
info!("Loaded system prompt from {prompt_path}");
// Read secrets from environment
let access_token = std::env::var("SOL_MATRIX_ACCESS_TOKEN")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_ACCESS_TOKEN not set"))?;
let device_id = std::env::var("SOL_MATRIX_DEVICE_ID")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_DEVICE_ID not set"))?;
let mistral_api_key = std::env::var("SOL_MISTRAL_API_KEY")
.map_err(|_| anyhow::anyhow!("SOL_MISTRAL_API_KEY not set"))?;
let config = Arc::new(config);
// Initialize Matrix client with E2EE and sqlite store
let homeserver = Url::parse(&config.matrix.homeserver_url)?;
let matrix_client = Client::builder()
.homeserver_url(homeserver)
.sqlite_store(&config.matrix.state_store_path, None)
.build()
.await?;
// Restore session
let user_id: OwnedUserId = config.matrix.user_id.parse()?;
let device_id: OwnedDeviceId = device_id.into();
let session = matrix_sdk::AuthSession::Matrix(matrix_sdk::matrix_auth::MatrixSession {
meta: matrix_sdk::SessionMeta {
user_id,
device_id,
},
tokens: matrix_sdk::matrix_auth::MatrixSessionTokens {
access_token,
refresh_token: None,
},
});
matrix_client.restore_session(session).await?;
info!(user = %config.matrix.user_id, "Matrix session restored");
// Initialize OpenSearch client
let os_url = Url::parse(&config.opensearch.url)?;
let os_transport = TransportBuilder::new(
opensearch::http::transport::SingleNodeConnectionPool::new(os_url),
)
.build()?;
let os_client = OpenSearch::new(os_transport);
// Ensure indices exist
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
create_memory_index(&os_client, &config.opensearch.memory_index).await?;
// Initialize Mistral client
let mistral_client = mistralai_client::v1::client::Client::new(
Some(mistral_api_key),
None,
None,
None,
)?;
let mistral = Arc::new(mistral_client);
// Build components
let personality = Arc::new(Personality::new(system_prompt));
let conversations = Arc::new(Mutex::new(ConversationManager::new(
config.behavior.room_context_window,
config.behavior.dm_context_window,
)));
// Backfill conversation context from archive before starting
if config.behavior.backfill_on_join {
info!("Backfilling conversation context from archive...");
if let Err(e) = backfill_conversations(&os_client, &config, &conversations).await {
error!("Backfill failed (non-fatal): {e}");
}
}
let tool_registry = Arc::new(ToolRegistry::new(
os_client.clone(),
matrix_client.clone(),
config.clone(),
));
let indexer = Arc::new(Indexer::new(os_client.clone(), config.clone()));
let evaluator = Arc::new(Evaluator::new(config.clone()));
let responder = Arc::new(Responder::new(
config.clone(),
personality,
tool_registry,
os_client.clone(),
));
// Start background flush task
let _flush_handle = indexer.start_flush_task();
// Build shared state
let state = Arc::new(AppState {
config: config.clone(),
indexer,
evaluator,
responder,
conversations,
mistral,
opensearch: os_client,
last_response: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
responding_in: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
});
// Backfill reactions from Matrix room timelines
info!("Backfilling reactions from room timelines...");
if let Err(e) = backfill_reactions(&matrix_client, &state.indexer).await {
error!("Reaction backfill failed (non-fatal): {e}");
}
// Start sync loop in background
let sync_client = matrix_client.clone();
let sync_state = state.clone();
let sync_handle = tokio::spawn(async move {
if let Err(e) = sync::start_sync(sync_client, sync_state).await {
error!("Sync loop error: {e}");
}
});
info!("Sol is running");
// Wait for shutdown signal
signal::ctrl_c().await?;
info!("Shutdown signal received");
// Cancel sync
sync_handle.abort();
info!("Sol has shut down");
Ok(())
}
/// Backfill conversation context from the OpenSearch archive.
///
/// Queries the most recent messages per room and seeds the ConversationManager
/// so Sol has context surviving restarts.
async fn backfill_conversations(
os_client: &OpenSearch,
config: &Config,
conversations: &Arc<Mutex<ConversationManager>>,
) -> anyhow::Result<()> {
use serde_json::json;
let window = config.behavior.room_context_window.max(config.behavior.dm_context_window);
let index = &config.opensearch.index;
// Get all distinct rooms
let agg_body = json!({
"size": 0,
"aggs": {
"rooms": {
"terms": { "field": "room_id", "size": 500 }
}
}
});
let response = os_client
.search(opensearch::SearchParts::Index(&[index]))
.body(agg_body)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let buckets = body["aggregations"]["rooms"]["buckets"]
.as_array()
.cloned()
.unwrap_or_default();
let mut total = 0;
for bucket in &buckets {
let room_id = bucket["key"].as_str().unwrap_or("");
if room_id.is_empty() {
continue;
}
// Fetch recent messages for this room
let query = json!({
"size": window,
"sort": [{ "timestamp": "asc" }],
"query": {
"bool": {
"filter": [
{ "term": { "room_id": room_id } },
{ "term": { "redacted": false } }
]
}
},
"_source": ["sender_name", "sender", "content", "timestamp"]
});
let resp = os_client
.search(opensearch::SearchParts::Index(&[index]))
.body(query)
.send()
.await?;
let data: serde_json::Value = resp.json().await?;
let hits = data["hits"]["hits"].as_array().cloned().unwrap_or_default();
if hits.is_empty() {
continue;
}
let mut convs = conversations.lock().await;
for hit in &hits {
let src = &hit["_source"];
let sender = src["sender_name"]
.as_str()
.or_else(|| src["sender"].as_str())
.unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let timestamp = src["timestamp"].as_i64().unwrap_or(0);
convs.add_message(
room_id,
false, // we don't know if it's a DM from the archive, use group window
ContextMessage {
sender: sender.to_string(),
content: content.to_string(),
timestamp,
},
);
total += 1;
}
}
info!(rooms = buckets.len(), messages = total, "Backfill complete");
Ok(())
}
/// Backfill reactions from Matrix room timelines into the archive.
///
/// For each joined room, fetches recent timeline events and indexes any
/// m.reaction events that aren't already in the archive.
async fn backfill_reactions(
client: &Client,
indexer: &Arc<Indexer>,
) -> anyhow::Result<()> {
use matrix_sdk::room::MessagesOptions;
use ruma::events::AnySyncTimelineEvent;
use ruma::uint;
let rooms = client.joined_rooms();
let mut total = 0;
for room in &rooms {
let room_id = room.room_id().to_string();
// Fetch recent messages (backwards from now)
let mut options = MessagesOptions::backward();
options.limit = uint!(500);
let messages = match room.messages(options).await {
Ok(m) => m,
Err(e) => {
error!(room = room_id.as_str(), "Failed to fetch timeline for reaction backfill: {e}");
continue;
}
};
for event in &messages.chunk {
let Ok(deserialized) = event.raw().deserialize() else {
continue;
};
if let AnySyncTimelineEvent::MessageLike(
ruma::events::AnySyncMessageLikeEvent::Reaction(reaction_event),
) = deserialized
{
let original = match reaction_event {
ruma::events::SyncMessageLikeEvent::Original(ref o) => o,
_ => continue,
};
let target_event_id = original.content.relates_to.event_id.to_string();
let sender = original.sender.to_string();
let emoji = &original.content.relates_to.key;
let timestamp: i64 = original.origin_server_ts.0.into();
indexer.add_reaction(&target_event_id, &sender, emoji, timestamp).await;
total += 1;
}
}
}
info!(reactions = total, rooms = rooms.len(), "Reaction backfill complete");
Ok(())
}