initial commit
Signed-off-by: Sienna Meridian Satterwhite <sienna@r3t.io>
This commit is contained in:
23
src/api/config.rs
Normal file
23
src/api/config.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use actix_web::web;
|
||||
use super::handlers::{health_check, store_fact, search_facts, list_facts, delete_fact};
|
||||
use super::mcp_http::{mcp_post, mcp_get, mcp_delete};
|
||||
|
||||
pub fn configure_api(cfg: &mut web::ServiceConfig) {
|
||||
// MCP Streamable HTTP transport endpoint
|
||||
cfg.service(
|
||||
web::resource("/mcp")
|
||||
.route(web::post().to(mcp_post))
|
||||
.route(web::get().to(mcp_get))
|
||||
.route(web::delete().to(mcp_delete))
|
||||
);
|
||||
|
||||
// REST convenience API (not MCP — useful for curl / testing)
|
||||
cfg.service(
|
||||
web::scope("/api")
|
||||
.route("/health", web::get().to(health_check))
|
||||
.route("/facts", web::post().to(store_fact))
|
||||
.route("/facts/search", web::get().to(search_facts))
|
||||
.route("/facts/{namespace}", web::get().to(list_facts))
|
||||
.route("/facts/{id}", web::delete().to(delete_fact))
|
||||
);
|
||||
}
|
||||
94
src/api/handlers.rs
Normal file
94
src/api/handlers.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
// API Handlers
|
||||
use actix_web::{web, HttpResponse};
|
||||
use crate::api::types::{
|
||||
FactRequest, FactResponse, SearchParams, ListParams,
|
||||
SearchResponse, ErrorResponse, HealthResponse,
|
||||
};
|
||||
use crate::memory::service::MemoryService;
|
||||
use crate::mcp::server::parse_ts;
|
||||
|
||||
fn fact_to_response(fact: crate::memory::service::MemoryFact, score: Option<f32>) -> FactResponse {
|
||||
FactResponse {
|
||||
id: fact.id,
|
||||
namespace: fact.namespace,
|
||||
content: fact.content,
|
||||
created_at: fact.created_at,
|
||||
score: score.or(if fact.score != 0.0 { Some(fact.score) } else { None }),
|
||||
source: fact.source,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn health_check() -> HttpResponse {
|
||||
HttpResponse::Ok().json(HealthResponse {
|
||||
status: "healthy".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn store_fact(
|
||||
memory: web::Data<MemoryService>,
|
||||
body: web::Json<FactRequest>,
|
||||
) -> HttpResponse {
|
||||
let namespace = body.namespace.as_deref().unwrap_or("default");
|
||||
match memory.add_fact(namespace, &body.content, body.source.as_deref()).await {
|
||||
Ok(fact) => HttpResponse::Created().json(fact_to_response(fact, None)),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.json(ErrorResponse { error: e.to_string() }),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn search_facts(
|
||||
memory: web::Data<MemoryService>,
|
||||
params: web::Query<SearchParams>,
|
||||
) -> HttpResponse {
|
||||
if params.q.trim().is_empty() {
|
||||
return HttpResponse::BadRequest()
|
||||
.json(ErrorResponse { error: "`q` must not be empty".to_string() });
|
||||
}
|
||||
let limit = params.limit.unwrap_or(10);
|
||||
match memory.search_facts(¶ms.q, limit, params.namespace.as_deref()).await {
|
||||
Ok(results) => {
|
||||
let total = results.len();
|
||||
HttpResponse::Ok().json(SearchResponse {
|
||||
results: results.into_iter().map(|f| fact_to_response(f, None)).collect(),
|
||||
total,
|
||||
})
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.json(ErrorResponse { error: e.to_string() }),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_facts(
|
||||
memory: web::Data<MemoryService>,
|
||||
namespace: web::Path<String>,
|
||||
params: web::Query<ListParams>,
|
||||
) -> HttpResponse {
|
||||
let limit = params.limit.unwrap_or(50);
|
||||
let from_ts = params.from.as_deref().and_then(parse_ts);
|
||||
let to_ts = params.to.as_deref().and_then(parse_ts);
|
||||
match memory.list_facts(&namespace, limit, from_ts, to_ts).await {
|
||||
Ok(facts) => {
|
||||
let total = facts.len();
|
||||
HttpResponse::Ok().json(SearchResponse {
|
||||
results: facts.into_iter().map(|f| fact_to_response(f, None)).collect(),
|
||||
total,
|
||||
})
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.json(ErrorResponse { error: e.to_string() }),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_fact(
|
||||
memory: web::Data<MemoryService>,
|
||||
id: web::Path<String>,
|
||||
) -> HttpResponse {
|
||||
match memory.delete_fact(&id).await {
|
||||
Ok(true) => HttpResponse::NoContent().finish(),
|
||||
Ok(false) => HttpResponse::NotFound()
|
||||
.json(ErrorResponse { error: format!("fact {id} not found") }),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.json(ErrorResponse { error: e.to_string() }),
|
||||
}
|
||||
}
|
||||
250
src/api/mcp_http.rs
Normal file
250
src/api/mcp_http.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
// MCP Streamable HTTP transport (protocol version 2025-06-18)
|
||||
//
|
||||
// Spec: https://spec.modelcontextprotocol.io/specification/2025-06-18/basic/transports/
|
||||
//
|
||||
// POST /mcp — client→server JSON-RPC (requests, notifications, responses)
|
||||
// GET /mcp — server→client SSE push (405: we have no server-initiated messages)
|
||||
// DELETE /mcp — explicit session teardown
|
||||
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::api::oidc::OidcVerifier;
|
||||
use crate::memory::service::MemoryService;
|
||||
use crate::mcp::protocol::{Request, Response, PARSE_ERROR, INVALID_PARAMS};
|
||||
use crate::mcp::server::handle;
|
||||
|
||||
/// Auth configuration, chosen at startup based on environment variables.
|
||||
pub enum AuthConfig {
|
||||
/// No auth — bind localhost, check Origin header (default).
|
||||
LocalOnly,
|
||||
/// Simple shared bearer token — bind 0.0.0.0, skip Origin check.
|
||||
Bearer(String),
|
||||
/// OIDC JWT verification — bind 0.0.0.0, skip Origin check.
|
||||
Oidc(OidcVerifier),
|
||||
}
|
||||
|
||||
impl AuthConfig {
|
||||
/// Returns true when the server should bind to all interfaces.
|
||||
pub fn is_remote(&self) -> bool {
|
||||
!matches!(self, Self::LocalOnly)
|
||||
}
|
||||
}
|
||||
|
||||
const PROTOCOL_VERSION: &str = "2025-06-18";
|
||||
|
||||
// ── session store ─────────────────────────────────────────────────────────────
|
||||
|
||||
pub struct SessionStore(Mutex<HashSet<String>>);
|
||||
|
||||
impl SessionStore {
|
||||
pub fn new() -> Self {
|
||||
Self(Mutex::new(HashSet::new()))
|
||||
}
|
||||
|
||||
fn create(&self) -> String {
|
||||
let id = Uuid::new_v4().to_string();
|
||||
self.0.lock().unwrap().insert(id.clone());
|
||||
id
|
||||
}
|
||||
|
||||
fn is_valid(&self, id: &str) -> bool {
|
||||
self.0.lock().unwrap().contains(id)
|
||||
}
|
||||
|
||||
fn remove(&self, id: &str) {
|
||||
self.0.lock().unwrap().remove(id);
|
||||
}
|
||||
}
|
||||
|
||||
// ── POST /mcp ─────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn mcp_post(
|
||||
req: HttpRequest,
|
||||
body: web::Bytes,
|
||||
memory: web::Data<MemoryService>,
|
||||
sessions: web::Data<SessionStore>,
|
||||
auth: web::Data<AuthConfig>,
|
||||
) -> HttpResponse {
|
||||
// 1. Auth / origin check
|
||||
if let Some(err) = check_auth(&req, &auth) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// 2. Protocol version check (SHOULD per spec)
|
||||
if let Some(v) = req.headers().get("MCP-Protocol-Version") {
|
||||
if v.to_str().unwrap_or("") != PROTOCOL_VERSION {
|
||||
return HttpResponse::BadRequest()
|
||||
.body(format!("unsupported MCP-Protocol-Version; expected {PROTOCOL_VERSION}"));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse body
|
||||
let text = match std::str::from_utf8(&body) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return HttpResponse::BadRequest().body("body must be UTF-8"),
|
||||
};
|
||||
let raw: Value = match serde_json::from_str(text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let resp = Response::err(Value::Null, PARSE_ERROR, e.to_string());
|
||||
return HttpResponse::BadRequest().json(resp);
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Classify the message
|
||||
let has_method = raw.get("method").is_some();
|
||||
let has_id = raw.get("id").filter(|v| !v.is_null()).is_some();
|
||||
let is_response = !has_method && (raw.get("result").is_some() || raw.get("error").is_some());
|
||||
let is_notification = has_method && !has_id;
|
||||
|
||||
if is_response || is_notification {
|
||||
return HttpResponse::Accepted().finish();
|
||||
}
|
||||
|
||||
if !has_method {
|
||||
let resp = Response::err(Value::Null, INVALID_PARAMS, "missing 'method' field".to_string());
|
||||
return HttpResponse::BadRequest().json(resp);
|
||||
}
|
||||
|
||||
// 5. Session check
|
||||
let method = raw["method"].as_str().unwrap_or("").to_string();
|
||||
let session_id = header_str(&req, "Mcp-Session-Id");
|
||||
|
||||
if method.as_str() != "initialize" {
|
||||
match session_id.as_deref() {
|
||||
Some(id) if sessions.is_valid(id) => {}
|
||||
Some(_) => return HttpResponse::NotFound().body("session not found or expired"),
|
||||
None => return HttpResponse::BadRequest().body("Mcp-Session-Id required"),
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Deserialise and dispatch
|
||||
let mcp_req: Request = match serde_json::from_value(raw) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = Response::err(Value::Null, PARSE_ERROR, e.to_string());
|
||||
return HttpResponse::BadRequest().json(resp);
|
||||
}
|
||||
};
|
||||
|
||||
let Some(mcp_resp) = handle(&mcp_req, &memory).await else {
|
||||
return HttpResponse::Accepted().finish();
|
||||
};
|
||||
|
||||
// 7. On successful initialize, issue a new session ID
|
||||
let mut builder = HttpResponse::Ok();
|
||||
builder.content_type("application/json");
|
||||
|
||||
if method.as_str() == "initialize" && mcp_resp.error.is_none() {
|
||||
let sid = sessions.create();
|
||||
builder.insert_header(("Mcp-Session-Id", sid));
|
||||
}
|
||||
|
||||
builder.json(mcp_resp)
|
||||
}
|
||||
|
||||
// ── GET /mcp ──────────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn mcp_get(req: HttpRequest, auth: web::Data<AuthConfig>) -> HttpResponse {
|
||||
if let Some(err) = check_auth(&req, &auth) {
|
||||
return err;
|
||||
}
|
||||
HttpResponse::MethodNotAllowed()
|
||||
.insert_header(("Allow", "POST, DELETE"))
|
||||
.finish()
|
||||
}
|
||||
|
||||
// ── DELETE /mcp ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn mcp_delete(
|
||||
req: HttpRequest,
|
||||
sessions: web::Data<SessionStore>,
|
||||
auth: web::Data<AuthConfig>,
|
||||
) -> HttpResponse {
|
||||
if let Some(err) = check_auth(&req, &auth) {
|
||||
return err;
|
||||
}
|
||||
if let Some(id) = header_str(&req, "Mcp-Session-Id") {
|
||||
sessions.remove(&id);
|
||||
}
|
||||
HttpResponse::Ok().finish()
|
||||
}
|
||||
|
||||
// ── auth helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Returns `Some(error response)` if the request fails auth, `None` if it passes.
|
||||
fn check_auth(req: &HttpRequest, auth: &AuthConfig) -> Option<HttpResponse> {
|
||||
match auth {
|
||||
AuthConfig::LocalOnly => check_origin(req),
|
||||
|
||||
AuthConfig::Bearer(expected) => {
|
||||
if check_bearer(req, expected) {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
HttpResponse::Unauthorized()
|
||||
.insert_header(("WWW-Authenticate", "Bearer"))
|
||||
.body("invalid or missing Authorization header"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
AuthConfig::Oidc(verifier) => {
|
||||
let token = match extract_bearer(req) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Some(
|
||||
HttpResponse::Unauthorized()
|
||||
.insert_header(("WWW-Authenticate", "Bearer"))
|
||||
.body("Authorization: Bearer <token> required"),
|
||||
)
|
||||
}
|
||||
};
|
||||
match verifier.verify(&token) {
|
||||
Ok(()) => None,
|
||||
Err(e) => Some(
|
||||
HttpResponse::Unauthorized()
|
||||
.insert_header(("WWW-Authenticate", "Bearer"))
|
||||
.body(format!("token rejected: {e}")),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the raw token string from `Authorization: Bearer <token>`.
|
||||
fn extract_bearer(req: &HttpRequest) -> Option<String> {
|
||||
let header = req.headers().get("Authorization")?.to_str().ok()?;
|
||||
header.strip_prefix("Bearer ").map(|t| t.to_string())
|
||||
}
|
||||
|
||||
/// Check `Authorization: Bearer <token>` using constant-time comparison.
|
||||
fn check_bearer(req: &HttpRequest, expected: &str) -> bool {
|
||||
let Some(token) = extract_bearer(req) else { return false };
|
||||
if token.len() != expected.len() {
|
||||
return false;
|
||||
}
|
||||
token.bytes().zip(expected.bytes()).fold(0u8, |acc, (a, b)| acc | (a ^ b)) == 0
|
||||
}
|
||||
|
||||
/// Reject cross-origin requests (DNS rebinding protection) — used in LocalOnly mode.
|
||||
/// Requests with no Origin header (e.g. curl, server-to-server) are always allowed.
|
||||
fn check_origin(req: &HttpRequest) -> Option<HttpResponse> {
|
||||
let origin = req.headers().get("Origin")?.to_str().unwrap_or("").to_string();
|
||||
if ["localhost", "127.0.0.1", "::1"].iter().any(|h| origin.contains(h)) {
|
||||
None
|
||||
} else {
|
||||
Some(HttpResponse::Forbidden().body(format!("Origin '{origin}' not allowed")))
|
||||
}
|
||||
}
|
||||
|
||||
fn header_str(req: &HttpRequest, name: &str) -> Option<String> {
|
||||
req.headers()
|
||||
.get(name)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
1
src/api/middleware.rs
Normal file
1
src/api/middleware.rs
Normal file
@@ -0,0 +1 @@
|
||||
// Middleware placeholder — auth not yet implemented for HTTP mode.
|
||||
6
src/api/mod.rs
Normal file
6
src/api/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub mod config;
|
||||
pub mod handlers;
|
||||
pub mod mcp_http;
|
||||
pub mod middleware;
|
||||
pub mod oidc;
|
||||
pub mod types;
|
||||
146
src/api/oidc.rs
Normal file
146
src/api/oidc.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
// OIDC JWT verification
|
||||
//
|
||||
// Fetches the provider's JWKS at startup via the OpenID Connect discovery document
|
||||
// and validates RS256/ES256/PS256 bearer tokens on every request.
|
||||
//
|
||||
// Environment variables:
|
||||
// MCP_OIDC_ISSUER — issuer URL, e.g. https://auth.example.com
|
||||
// MCP_OIDC_AUDIENCE — (optional) expected `aud` claim
|
||||
|
||||
use jsonwebtoken::{
|
||||
decode, decode_header,
|
||||
jwk::{AlgorithmParameters, JwkSet},
|
||||
Algorithm, DecodingKey, Validation,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum OidcError {
|
||||
Discovery(String),
|
||||
Jwks(String),
|
||||
Token(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OidcError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Discovery(s) => write!(f, "OIDC discovery: {s}"),
|
||||
Self::Jwks(s) => write!(f, "OIDC JWKS: {s}"),
|
||||
Self::Token(s) => write!(f, "OIDC token: {s}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OidcVerifier {
|
||||
issuer: String,
|
||||
audience: Option<String>,
|
||||
jwks: JwkSet,
|
||||
}
|
||||
|
||||
impl OidcVerifier {
|
||||
/// Fetch discovery doc and JWKS from the issuer. Called once at startup.
|
||||
pub async fn new(issuer: &str, audience: Option<String>) -> Result<Self, OidcError> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// 1. Discovery document
|
||||
let disc_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
issuer.trim_end_matches('/')
|
||||
);
|
||||
let disc: Value = client
|
||||
.get(&disc_url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| OidcError::Discovery(e.to_string()))?
|
||||
.error_for_status()
|
||||
.map_err(|e| OidcError::Discovery(e.to_string()))?
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| OidcError::Discovery(e.to_string()))?;
|
||||
|
||||
let jwks_uri = disc["jwks_uri"]
|
||||
.as_str()
|
||||
.ok_or_else(|| OidcError::Discovery("discovery doc missing jwks_uri".to_string()))?;
|
||||
|
||||
// 2. JWKS
|
||||
let jwks: JwkSet = client
|
||||
.get(jwks_uri)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| OidcError::Jwks(e.to_string()))?
|
||||
.error_for_status()
|
||||
.map_err(|e| OidcError::Jwks(e.to_string()))?
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| OidcError::Jwks(e.to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
issuer: issuer.to_string(),
|
||||
audience,
|
||||
jwks,
|
||||
})
|
||||
}
|
||||
|
||||
/// Verify a raw Bearer token string. Returns `Ok(())` if valid.
|
||||
pub fn verify(&self, token: &str) -> Result<(), OidcError> {
|
||||
// Decode header to get kid + algorithm
|
||||
let header = decode_header(token).map_err(|e| OidcError::Token(e.to_string()))?;
|
||||
|
||||
let kid = header.kid.as_deref().unwrap_or("");
|
||||
|
||||
// Find the matching JWK
|
||||
let jwk = self
|
||||
.jwks
|
||||
.find(kid)
|
||||
.ok_or_else(|| OidcError::Token(format!("no JWK found for kid={kid:?}")))?;
|
||||
|
||||
// Build the decoding key from the JWK
|
||||
let key = match &jwk.algorithm {
|
||||
AlgorithmParameters::RSA(rsa) => {
|
||||
DecodingKey::from_rsa_components(&rsa.n, &rsa.e)
|
||||
.map_err(|e| OidcError::Token(e.to_string()))?
|
||||
}
|
||||
AlgorithmParameters::EllipticCurve(ec) => {
|
||||
DecodingKey::from_ec_components(&ec.x, &ec.y)
|
||||
.map_err(|e| OidcError::Token(e.to_string()))?
|
||||
}
|
||||
other => {
|
||||
return Err(OidcError::Token(format!(
|
||||
"unsupported JWK algorithm: {other:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Determine algorithm from header (fall back to RS256)
|
||||
let alg = header.alg;
|
||||
// Only allow standard OIDC signing algorithms
|
||||
match alg {
|
||||
Algorithm::RS256
|
||||
| Algorithm::RS384
|
||||
| Algorithm::RS512
|
||||
| Algorithm::PS256
|
||||
| Algorithm::PS384
|
||||
| Algorithm::PS512
|
||||
| Algorithm::ES256
|
||||
| Algorithm::ES384 => {}
|
||||
other => {
|
||||
return Err(OidcError::Token(format!(
|
||||
"algorithm {other:?} not permitted"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
let mut validation = Validation::new(alg);
|
||||
validation.set_issuer(&[&self.issuer]);
|
||||
|
||||
if let Some(aud) = &self.audience {
|
||||
validation.set_audience(&[aud]);
|
||||
} else {
|
||||
validation.validate_aud = false;
|
||||
}
|
||||
|
||||
decode::<Value>(token, &key, &validation)
|
||||
.map(|_| ())
|
||||
.map_err(|e| OidcError::Token(e.to_string()))
|
||||
}
|
||||
}
|
||||
15
src/api/rest.rs
Normal file
15
src/api/rest.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use actix_web::{web, HttpResponse, Responder};
|
||||
|
||||
pub fn configure_rest(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(
|
||||
web::resource("/health")
|
||||
.route(web::get().to(health_check))
|
||||
);
|
||||
}
|
||||
|
||||
async fn health_check() -> impl Responder {
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}))
|
||||
}
|
||||
50
src/api/types.rs
Normal file
50
src/api/types.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
// API Request/Response Types
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FactRequest {
|
||||
pub namespace: Option<String>,
|
||||
pub content: String,
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchParams {
|
||||
pub q: String,
|
||||
pub limit: Option<usize>,
|
||||
pub namespace: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListParams {
|
||||
pub limit: Option<usize>,
|
||||
pub from: Option<String>,
|
||||
pub to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FactResponse {
|
||||
pub id: String,
|
||||
pub namespace: String,
|
||||
pub content: String,
|
||||
pub created_at: String,
|
||||
pub score: Option<f32>,
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResponse {
|
||||
pub results: Vec<FactResponse>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct HealthResponse {
|
||||
pub status: String,
|
||||
pub version: String,
|
||||
}
|
||||
119
src/auth.rs
Normal file
119
src/auth.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use crate::error::{Result, ServerError};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub exp: usize,
|
||||
pub iat: usize,
|
||||
pub permissions: Vec<String>,
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn new(sub: String, permissions: Vec<String>, expires_in: usize) -> Self {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as usize;
|
||||
|
||||
Claims {
|
||||
sub,
|
||||
exp: now + expires_in,
|
||||
iat: now,
|
||||
permissions,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_permission(&self, permission: &str) -> bool {
|
||||
self.permissions.iter().any(|p| p == permission)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_jwt(
|
||||
sub: String,
|
||||
permissions: Vec<String>,
|
||||
secret: &str,
|
||||
expires_in: usize,
|
||||
) -> Result<String> {
|
||||
let claims = Claims::new(sub, permissions, expires_in);
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
);
|
||||
|
||||
token.map_err(|e| ServerError::AuthError(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn validate_jwt(token: &str, secret: &str) -> Result<Claims> {
|
||||
let mut validation = Validation::default();
|
||||
validation.validate_exp = true; // Enable expiration validation
|
||||
|
||||
let decoded = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(secret.as_bytes()),
|
||||
&validation,
|
||||
);
|
||||
|
||||
decoded
|
||||
.map(|token_data| token_data.claims)
|
||||
.map_err(|e| ServerError::AuthError(e.to_string()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_jwt_generation_and_validation() {
|
||||
let secret = "test-secret";
|
||||
let sub = "user123".to_string();
|
||||
let permissions = vec!["read".to_string(), "write".to_string()];
|
||||
|
||||
let token = generate_jwt(sub.clone(), permissions.clone(), secret, 3600).unwrap();
|
||||
let claims = validate_jwt(&token, secret).unwrap();
|
||||
|
||||
assert_eq!(claims.sub, sub);
|
||||
assert!(claims.has_permission("read"));
|
||||
assert!(claims.has_permission("write"));
|
||||
assert!(!claims.has_permission("delete"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jwt_with_invalid_secret() {
|
||||
let token = generate_jwt("user123".to_string(), vec![], "secret1", 3600).unwrap();
|
||||
let result = validate_jwt(&token, "wrong-secret");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jwt_expiration() {
|
||||
let secret = "test-secret";
|
||||
let sub = "user123".to_string();
|
||||
let permissions = vec!["read".to_string()];
|
||||
|
||||
// Generate token that expires in the past
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as usize;
|
||||
|
||||
let claims = Claims {
|
||||
sub: sub.clone(),
|
||||
exp: now - 1000, // Expired 1000 seconds ago
|
||||
iat: now,
|
||||
permissions: permissions.clone(),
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
).unwrap();
|
||||
|
||||
let result = validate_jwt(&token, secret);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
37
src/config.rs
Normal file
37
src/config.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryConfig {
|
||||
pub base_dir: String,
|
||||
/// Simple bearer token auth. Set `MCP_AUTH_TOKEN` for single-user remote hosting.
|
||||
/// Mutually exclusive with `oidc_issuer` (OIDC takes priority if both are set).
|
||||
pub auth_token: Option<String>,
|
||||
/// OIDC issuer URL (e.g. `https://auth.example.com`). When set, the server fetches
|
||||
/// the JWKS at startup and validates JWT Bearer tokens on every request.
|
||||
/// Read from `MCP_OIDC_ISSUER`.
|
||||
pub oidc_issuer: Option<String>,
|
||||
/// Optional OIDC audience claim to validate. Read from `MCP_OIDC_AUDIENCE`.
|
||||
/// Leave unset to skip audience validation.
|
||||
pub oidc_audience: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base_dir: "./data/memory".to_string(),
|
||||
auth_token: None,
|
||||
oidc_issuer: None,
|
||||
oidc_audience: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryConfig {
|
||||
pub fn from_env() -> Self {
|
||||
Self {
|
||||
base_dir: std::env::var("MCP_MEMORY_BASE_DIR")
|
||||
.unwrap_or_else(|_| "./data/memory".to_string()),
|
||||
auth_token: std::env::var("MCP_AUTH_TOKEN").ok(),
|
||||
oidc_issuer: std::env::var("MCP_OIDC_ISSUER").ok(),
|
||||
oidc_audience: std::env::var("MCP_OIDC_AUDIENCE").ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
2
src/embedding/mod.rs
Normal file
2
src/embedding/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// Embedding module
|
||||
pub mod service;
|
||||
124
src/embedding/service.rs
Normal file
124
src/embedding/service.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
// Embedding Service — wraps fastembed with a process-wide model cache so each
|
||||
// model is loaded from disk exactly once regardless of how many EmbeddingService
|
||||
// instances are created.
|
||||
|
||||
use fastembed::{EmbeddingModel, TextEmbedding, InitOptions};
|
||||
use thiserror::Error;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum EmbeddingError {
|
||||
#[error("Model {0} not supported")]
|
||||
UnsupportedModel(String),
|
||||
|
||||
#[error("Failed to load model: {0}")]
|
||||
LoadError(String),
|
||||
|
||||
#[error("Embedding generation failed: {0}")]
|
||||
GenerationError(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum EmbeddingModelType {
|
||||
BgeBaseEnglish,
|
||||
CodeBert,
|
||||
GraphCodeBert,
|
||||
}
|
||||
|
||||
impl EmbeddingModelType {
|
||||
pub fn model_name(&self) -> &'static str {
|
||||
match self {
|
||||
EmbeddingModelType::BgeBaseEnglish => "bge-base-en-v1.5",
|
||||
EmbeddingModelType::CodeBert => "codebert",
|
||||
EmbeddingModelType::GraphCodeBert => "graphcodebert",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModelType::BgeBaseEnglish => 768,
|
||||
EmbeddingModelType::CodeBert => 768, // AllMpnetBaseV2
|
||||
EmbeddingModelType::GraphCodeBert => 768, // NomicEmbedTextV1
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_fastembed_model(&self) -> EmbeddingModel {
|
||||
match self {
|
||||
EmbeddingModelType::BgeBaseEnglish => EmbeddingModel::BGEBaseENV15,
|
||||
EmbeddingModelType::CodeBert => EmbeddingModel::AllMpnetBaseV2,
|
||||
EmbeddingModelType::GraphCodeBert => EmbeddingModel::NomicEmbedTextV1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Global model cache ────────────────────────────────────────────────────────
|
||||
// Each model is wrapped in Arc<Mutex<TextEmbedding>> so that:
|
||||
// - Arc → the same TextEmbedding allocation is shared across all
|
||||
// EmbeddingService instances that use the same model.
|
||||
// - Mutex → TextEmbedding::embed takes &mut self so we need exclusive access.
|
||||
// The lock is held only for the duration of the embed call, which
|
||||
// is CPU-bound and returns quickly.
|
||||
|
||||
type CachedModel = Arc<Mutex<TextEmbedding>>;
|
||||
|
||||
struct ModelCache {
|
||||
models: HashMap<EmbeddingModelType, CachedModel>,
|
||||
}
|
||||
|
||||
impl ModelCache {
|
||||
fn new() -> Self {
|
||||
Self { models: HashMap::new() }
|
||||
}
|
||||
|
||||
fn get_or_load(&mut self, model_type: EmbeddingModelType) -> Result<CachedModel, EmbeddingError> {
|
||||
if let Some(model) = self.models.get(&model_type) {
|
||||
return Ok(Arc::clone(model));
|
||||
}
|
||||
|
||||
let text_embedding = TextEmbedding::try_new(
|
||||
InitOptions::new(model_type.to_fastembed_model())
|
||||
).map_err(|e| EmbeddingError::LoadError(e.to_string()))?;
|
||||
|
||||
let model = Arc::new(Mutex::new(text_embedding));
|
||||
self.models.insert(model_type, Arc::clone(&model));
|
||||
Ok(model)
|
||||
}
|
||||
}
|
||||
|
||||
static MODEL_CACHE: Lazy<Mutex<ModelCache>> = Lazy::new(|| Mutex::new(ModelCache::new()));
|
||||
|
||||
// ── EmbeddingService ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingService {
|
||||
model: CachedModel,
|
||||
model_type: EmbeddingModelType,
|
||||
}
|
||||
|
||||
impl EmbeddingService {
|
||||
/// Obtain a service backed by the globally cached model. Loading from disk
|
||||
/// only happens the first time a given model type is requested.
|
||||
pub async fn new(model_type: EmbeddingModelType) -> Result<Self, EmbeddingError> {
|
||||
let model = MODEL_CACHE.lock().unwrap().get_or_load(model_type)?;
|
||||
Ok(Self { model, model_type })
|
||||
}
|
||||
|
||||
/// Generate embeddings using the cached model.
|
||||
pub async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
self.model
|
||||
.lock()
|
||||
.unwrap()
|
||||
.embed(texts, None)
|
||||
.map_err(|e| EmbeddingError::GenerationError(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn current_model(&self) -> EmbeddingModelType {
|
||||
self.model_type
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.model_type.dimensions()
|
||||
}
|
||||
}
|
||||
44
src/error.rs
Normal file
44
src/error.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use thiserror::Error;
|
||||
use crate::embedding::service::EmbeddingError;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ServerError {
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("Memory operation failed: {0}")]
|
||||
MemoryError(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
}
|
||||
|
||||
impl From<EmbeddingError> for ServerError {
|
||||
fn from(err: EmbeddingError) -> Self {
|
||||
match err {
|
||||
EmbeddingError::UnsupportedModel(m) => ServerError::ConfigError(m),
|
||||
EmbeddingError::LoadError(e) => ServerError::MemoryError(e),
|
||||
EmbeddingError::GenerationError(e) => ServerError::MemoryError(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for ServerError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
ServerError::MemoryError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rusqlite::Error> for ServerError {
|
||||
fn from(err: rusqlite::Error) -> Self {
|
||||
ServerError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ServerError>;
|
||||
9
src/lib.rs
Normal file
9
src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
pub mod embedding;
|
||||
pub mod semantic;
|
||||
pub mod logging;
|
||||
pub mod mcp;
|
||||
pub mod urn;
|
||||
pub mod api;
|
||||
33
src/logging.rs
Normal file
33
src/logging.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
// Logging module for MCP Server
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use chrono::Local;
|
||||
|
||||
/// Simple file logger
|
||||
#[derive(Clone)]
|
||||
pub struct FileLogger {
|
||||
log_path: String,
|
||||
}
|
||||
|
||||
impl FileLogger {
|
||||
pub fn new(log_path: String) -> Self {
|
||||
// Create directory if it doesn't exist
|
||||
if let Some(parent) = Path::new(&log_path).parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
Self { log_path }
|
||||
}
|
||||
|
||||
pub fn log(&self, method: &str, path: &str, status: &str) {
|
||||
let timestamp = Local::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
let log_entry = format!("{} {} {} {}\n", timestamp, method, path, status);
|
||||
|
||||
if let Ok(mut file) = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&self.log_path) {
|
||||
let _ = file.write_all(log_entry.as_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
148
src/main.rs
Normal file
148
src/main.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use mcp_server::{
|
||||
config::MemoryConfig,
|
||||
memory::service::MemoryService,
|
||||
mcp::protocol::{Request, Response, PARSE_ERROR},
|
||||
mcp::server::handle,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
|
||||
// --http [PORT] → run HTTP REST server
|
||||
if let Some(pos) = args.iter().position(|a| a == "--http") {
|
||||
let port: u16 = args.get(pos + 1)
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(3456);
|
||||
run_http(port).await;
|
||||
return;
|
||||
}
|
||||
|
||||
run_stdio().await;
|
||||
}
|
||||
|
||||
// ── stdio MCP transport ───────────────────────────────────────────────────────
|
||||
|
||||
async fn run_stdio() {
|
||||
let config = MemoryConfig::from_env();
|
||||
eprintln!(
|
||||
"sunbeam-memory {}: loading model and opening store at {}…",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
config.base_dir
|
||||
);
|
||||
|
||||
let memory = init_memory(config).await;
|
||||
eprintln!("sunbeam-memory ready (stdio transport)");
|
||||
|
||||
let stdin = BufReader::new(tokio::io::stdin());
|
||||
let mut stdout = tokio::io::stdout();
|
||||
let mut lines = stdin.lines();
|
||||
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let response: Option<Response> = match serde_json::from_str::<Request>(&line) {
|
||||
Ok(req) => handle(&req, &memory).await,
|
||||
Err(err) => Some(Response::err(Value::Null, PARSE_ERROR, err.to_string())),
|
||||
};
|
||||
|
||||
if let Some(resp) = response {
|
||||
match serde_json::to_string(&resp) {
|
||||
Ok(mut json) => {
|
||||
json.push('\n');
|
||||
let _ = stdout.write_all(json.as_bytes()).await;
|
||||
let _ = stdout.flush().await;
|
||||
}
|
||||
Err(e) => eprintln!("error: failed to serialize response: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── HTTP REST server ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn run_http(port: u16) {
|
||||
use actix_web::{web, App, HttpServer};
|
||||
use mcp_server::api::config::configure_api;
|
||||
use mcp_server::api::mcp_http::{AuthConfig, SessionStore};
|
||||
use mcp_server::api::oidc::OidcVerifier;
|
||||
|
||||
let config = MemoryConfig::from_env();
|
||||
eprintln!(
|
||||
"sunbeam-memory {}: loading model and opening store at {}…",
|
||||
env!("CARGO_PKG_VERSION"),
|
||||
config.base_dir
|
||||
);
|
||||
|
||||
// Build auth config — OIDC takes priority over simple bearer token
|
||||
let auth_config = if let Some(issuer) = config.oidc_issuer.clone() {
|
||||
eprintln!(" fetching OIDC JWKS from {issuer}…");
|
||||
match OidcVerifier::new(&issuer, config.oidc_audience.clone()).await {
|
||||
Ok(v) => {
|
||||
eprintln!(" OIDC ready (issuer: {issuer})");
|
||||
AuthConfig::Oidc(v)
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("fatal: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
} else if let Some(token) = config.auth_token.clone() {
|
||||
AuthConfig::Bearer(token.clone())
|
||||
} else {
|
||||
AuthConfig::LocalOnly
|
||||
};
|
||||
|
||||
let bind_addr = if auth_config.is_remote() { "0.0.0.0" } else { "127.0.0.1" };
|
||||
|
||||
match &auth_config {
|
||||
AuthConfig::LocalOnly => {
|
||||
eprintln!("sunbeam-memory ready (MCP HTTP on {bind_addr}:{port}, localhost only)");
|
||||
eprintln!(" MCP endpoint: http://127.0.0.1:{port}/mcp");
|
||||
}
|
||||
AuthConfig::Bearer(tok) => {
|
||||
eprintln!("sunbeam-memory ready (MCP HTTP on {bind_addr}:{port}, bearer auth)");
|
||||
eprintln!(" MCP endpoint: http://<your-host>:{port}/mcp");
|
||||
eprintln!(" Token: {tok}");
|
||||
}
|
||||
AuthConfig::Oidc(_) => {
|
||||
eprintln!("sunbeam-memory ready (MCP HTTP on {bind_addr}:{port}, OIDC auth)");
|
||||
eprintln!(" MCP endpoint: http://<your-host>:{port}/mcp");
|
||||
}
|
||||
}
|
||||
|
||||
let memory = init_memory(config).await;
|
||||
let memory_data = web::Data::new(memory);
|
||||
let sessions = web::Data::new(SessionStore::new());
|
||||
let auth = web::Data::new(auth_config);
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(memory_data.clone())
|
||||
.app_data(sessions.clone())
|
||||
.app_data(auth.clone())
|
||||
.configure(configure_api)
|
||||
})
|
||||
.bind((bind_addr, port))
|
||||
.unwrap_or_else(|e| { eprintln!("fatal: cannot bind {bind_addr}:{port}: {e}"); std::process::exit(1) })
|
||||
.run()
|
||||
.await
|
||||
.unwrap_or_else(|e| eprintln!("fatal: server error: {e}"));
|
||||
}
|
||||
|
||||
// ── shared init ───────────────────────────────────────────────────────────────
|
||||
|
||||
async fn init_memory(config: MemoryConfig) -> MemoryService {
|
||||
match MemoryService::new(&config).await {
|
||||
Ok(svc) => svc,
|
||||
Err(e) => {
|
||||
eprintln!("fatal: failed to initialise memory service: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
2
src/mcp/mod.rs
Normal file
2
src/mcp/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod protocol;
|
||||
pub mod server;
|
||||
97
src/mcp/protocol.rs
Normal file
97
src/mcp/protocol.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
// JSON-RPC 2.0 + MCP wire types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ── Incoming ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Request {
|
||||
pub jsonrpc: String,
|
||||
/// Absent on notifications; present on requests.
|
||||
pub id: Option<Value>,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
/// True when this is a notification (no id, no response expected).
|
||||
pub fn is_notification(&self) -> bool {
|
||||
self.id.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Outgoing ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct Response {
|
||||
pub jsonrpc: &'static str,
|
||||
pub id: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<RpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
// Standard JSON-RPC error codes
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
impl Response {
|
||||
pub fn ok(id: Value, result: Value) -> Self {
|
||||
Self { jsonrpc: "2.0", id, result: Some(result), error: None }
|
||||
}
|
||||
|
||||
pub fn err(id: Value, code: i32, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0",
|
||||
id,
|
||||
result: None,
|
||||
error: Some(RpcError { code, message: message.into(), data: None }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── MCP tool response content ─────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct TextContent {
|
||||
#[serde(rename = "type")]
|
||||
pub kind: &'static str, // always "text"
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ToolResult {
|
||||
pub content: Vec<TextContent>,
|
||||
#[serde(rename = "isError")]
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
pub fn text(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
content: vec![TextContent { kind: "text", text: text.into() }],
|
||||
is_error: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
content: vec![TextContent { kind: "text", text: message.into() }],
|
||||
is_error: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
435
src/mcp/server.rs
Normal file
435
src/mcp/server.rs
Normal file
@@ -0,0 +1,435 @@
|
||||
// MCP request dispatcher and tool implementations
|
||||
|
||||
use serde_json::{json, Value};
|
||||
use crate::memory::service::MemoryService;
|
||||
use crate::urn::{SourceUrn, schema_json, invalid_urn_response, SPEC};
|
||||
use chrono;
|
||||
use crate::mcp::protocol::{
|
||||
Request, Response, ToolResult,
|
||||
METHOD_NOT_FOUND, INVALID_PARAMS, INTERNAL_ERROR,
|
||||
};
|
||||
|
||||
const PROTOCOL_VERSION: &str = "2025-06-18";
|
||||
|
||||
/// Dispatch an incoming JSON-RPC request. Returns None for notifications
|
||||
/// (which expect no response).
|
||||
pub async fn handle(req: &Request, memory: &MemoryService) -> Option<Response> {
|
||||
if req.is_notification() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let id = req.id.clone().unwrap_or(Value::Null);
|
||||
|
||||
let outcome = match req.method.as_str() {
|
||||
"initialize" => Ok(initialize()),
|
||||
"ping" => Ok(json!({})),
|
||||
"tools/list" => Ok(tools_list()),
|
||||
"tools/call" => tools_call(req.params.as_ref(), memory).await,
|
||||
other => Err((METHOD_NOT_FOUND, format!("Unknown method: {other}"))),
|
||||
};
|
||||
|
||||
Some(match outcome {
|
||||
Ok(v) => Response::ok(id, v),
|
||||
Err((c, msg)) => Response::err(id, c, msg),
|
||||
})
|
||||
}
|
||||
|
||||
// ── initialize ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn initialize() -> Value {
|
||||
json!({
|
||||
"protocolVersion": PROTOCOL_VERSION,
|
||||
"capabilities": { "tools": {} },
|
||||
"serverInfo": {
|
||||
"name": "sunbeam-memory",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ── tools/list ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn tools_list() -> Value {
|
||||
json!({
|
||||
"tools": [
|
||||
{
|
||||
"name": "store_fact",
|
||||
"description": "Embed and store a piece of text in semantic memory. Returns the fact ID.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Text to store"
|
||||
},
|
||||
"namespace": {
|
||||
"type": "string",
|
||||
"description": "Logical grouping (e.g. 'code', 'docs', 'notes'). Defaults to 'default'."
|
||||
},
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Optional smem URN identifying where this content came from. Must be a valid urn:smem: URN if provided. Use build_source_urn to construct one. Example: urn:smem:code:fs:/path/to/file.rs#L10-L30"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "search_facts",
|
||||
"description": "Search semantic memory for content similar to a query. Returns ranked results with similarity scores.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 10)"
|
||||
},
|
||||
"namespace": {
|
||||
"type": "string",
|
||||
"description": "Restrict search to a specific namespace"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "update_fact",
|
||||
"description": "Update an existing fact in place, keeping the same ID. Re-embeds the new content and replaces the vector.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": { "type": "string", "description": "Fact ID to update" },
|
||||
"content": { "type": "string", "description": "New text content" },
|
||||
"source": { "type": "string", "description": "New smem URN (optional)" }
|
||||
},
|
||||
"required": ["id", "content"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "delete_fact",
|
||||
"description": "Delete a stored fact by its ID.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Fact ID (as returned by store_fact or search_facts)"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "list_facts",
|
||||
"description": "List facts stored in a namespace, most recent first. Supports date range filtering.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {
|
||||
"type": "string",
|
||||
"description": "Namespace to list (default: 'default')"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max facts to return (default: 50)"
|
||||
},
|
||||
"from": {
|
||||
"type": "string",
|
||||
"description": "Return only facts stored on or after this time. Accepts RFC 3339 (e.g. '2026-03-01T00:00:00Z') or Unix timestamp integer as string."
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Return only facts stored on or before this time. Same format as 'from'."
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "build_source_urn",
|
||||
"description": "Build a valid smem URN from its components. Use this before calling store_fact with a source.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "Content type: code | doc | web | data | note | conf"
|
||||
},
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "Origin: git | fs | https | http | db | api | manual"
|
||||
},
|
||||
"locator": {
|
||||
"type": "string",
|
||||
"description": "Origin-specific locator. See describe_urn_schema for shapes."
|
||||
},
|
||||
"fragment": {
|
||||
"type": "string",
|
||||
"description": "Optional fragment: L42 (line), L10-L30 (range), or a slug anchor."
|
||||
}
|
||||
},
|
||||
"required": ["content_type", "origin", "locator"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "parse_source_urn",
|
||||
"description": "Parse and validate a smem URN. Returns structured components or an error with the spec.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urn": {
|
||||
"type": "string",
|
||||
"description": "The URN to parse, e.g. urn:smem:code:fs:/path/to/file.rs#L10"
|
||||
}
|
||||
},
|
||||
"required": ["urn"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "describe_urn_schema",
|
||||
"description": "Return the full machine-readable smem URN taxonomy: content types, origins, locator shapes, and examples.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
// ── tools/call ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn tools_call(
|
||||
params: Option<&Value>,
|
||||
memory: &MemoryService,
|
||||
) -> Result<Value, (i32, String)> {
|
||||
let params = params.ok_or((INVALID_PARAMS, "Missing params".to_string()))?;
|
||||
|
||||
let name = params["name"]
|
||||
.as_str()
|
||||
.ok_or((INVALID_PARAMS, "Missing tool name".to_string()))?;
|
||||
|
||||
let args = ¶ms["arguments"];
|
||||
|
||||
let result = match name {
|
||||
"store_fact" => tool_store_fact(args, memory).await,
|
||||
"update_fact" => tool_update_fact(args, memory).await,
|
||||
"search_facts" => tool_search_facts(args, memory).await,
|
||||
"delete_fact" => tool_delete_fact(args, memory).await,
|
||||
"list_facts" => tool_list_facts(args, memory).await,
|
||||
"build_source_urn" => tool_build_source_urn(args),
|
||||
"parse_source_urn" => tool_parse_source_urn(args),
|
||||
"describe_urn_schema" => tool_describe_urn_schema(),
|
||||
other => ToolResult::error(format!("Unknown tool: {other}")),
|
||||
};
|
||||
|
||||
serde_json::to_value(result)
|
||||
.map_err(|e| (INTERNAL_ERROR, e.to_string()))
|
||||
}
|
||||
|
||||
// ── individual tools ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn tool_store_fact(args: &Value, memory: &MemoryService) -> ToolResult {
|
||||
let content = match args["content"].as_str() {
|
||||
Some(c) if !c.trim().is_empty() => c,
|
||||
_ => return ToolResult::error("`content` is required and must not be empty"),
|
||||
};
|
||||
let namespace = args["namespace"].as_str().unwrap_or("default");
|
||||
|
||||
// Validate source URN if provided
|
||||
let source = args["source"].as_str();
|
||||
if let Some(s) = source {
|
||||
if let Err(e) = SourceUrn::parse(s) {
|
||||
return ToolResult::text(
|
||||
serde_json::to_string(&invalid_urn_response(s, &e)).unwrap_or_default()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match memory.add_fact(namespace, content, source).await {
|
||||
Ok(fact) => {
|
||||
let source_line = match &fact.source {
|
||||
Some(s) => {
|
||||
let desc = SourceUrn::parse(s)
|
||||
.map(|u| u.human_readable())
|
||||
.unwrap_or_else(|_| s.clone());
|
||||
format!("\nSource: {s} ({desc})")
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
ToolResult::text(format!(
|
||||
"Stored.\nID: {}\nNamespace: {}\nCreated: {}{}",
|
||||
fact.id, fact.namespace, fact.created_at, source_line
|
||||
))
|
||||
}
|
||||
Err(e) => ToolResult::error(format!("Failed to store: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_search_facts(args: &Value, memory: &MemoryService) -> ToolResult {
|
||||
let query = match args["query"].as_str() {
|
||||
Some(q) if !q.trim().is_empty() => q,
|
||||
_ => return ToolResult::error("`query` is required and must not be empty"),
|
||||
};
|
||||
let limit = args["limit"].as_u64().unwrap_or(10) as usize;
|
||||
let namespace = args["namespace"].as_str();
|
||||
|
||||
match memory.search_facts(query, limit, namespace).await {
|
||||
Ok(results) if results.is_empty() => ToolResult::text("No results found."),
|
||||
Ok(results) => {
|
||||
let mut out = format!("Found {} result(s):\n\n", results.len());
|
||||
for (i, f) in results.iter().enumerate() {
|
||||
out.push_str(&format!(
|
||||
"{}. [{}] score={:.3} id={} created={}\n {}",
|
||||
i + 1, f.namespace, f.score, f.id, f.created_at, f.content
|
||||
));
|
||||
if let Some(s) = &f.source {
|
||||
let desc = SourceUrn::parse(s)
|
||||
.map(|u| u.human_readable())
|
||||
.unwrap_or_else(|_| s.clone());
|
||||
out.push_str(&format!("\n source: {s} ({desc})"));
|
||||
}
|
||||
out.push_str("\n\n");
|
||||
}
|
||||
ToolResult::text(out.trim_end())
|
||||
}
|
||||
Err(e) => ToolResult::error(format!("Search failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_update_fact(args: &Value, memory: &MemoryService) -> ToolResult {
|
||||
let id = match args["id"].as_str() {
|
||||
Some(s) if !s.trim().is_empty() => s,
|
||||
_ => return ToolResult::error("`id` is required"),
|
||||
};
|
||||
let content = match args["content"].as_str() {
|
||||
Some(s) if !s.trim().is_empty() => s,
|
||||
_ => return ToolResult::error("`content` is required and must not be empty"),
|
||||
};
|
||||
let source = args["source"].as_str();
|
||||
if let Some(s) = source {
|
||||
if let Err(e) = SourceUrn::parse(s) {
|
||||
return ToolResult::text(
|
||||
serde_json::to_string(&invalid_urn_response(s, &e)).unwrap_or_default()
|
||||
);
|
||||
}
|
||||
}
|
||||
match memory.update_fact(id, content, source).await {
|
||||
Ok(fact) => {
|
||||
let source_line = match &fact.source {
|
||||
Some(s) => {
|
||||
let desc = SourceUrn::parse(s).map(|u| u.human_readable()).unwrap_or_else(|_| s.clone());
|
||||
format!("\nSource: {s} ({desc})")
|
||||
}
|
||||
None => String::new(),
|
||||
};
|
||||
ToolResult::text(format!(
|
||||
"Updated.\nID: {}\nNamespace: {}\nCreated: {}{}",
|
||||
fact.id, fact.namespace, fact.created_at, source_line
|
||||
))
|
||||
}
|
||||
Err(e) => ToolResult::error(format!("Update failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_delete_fact(args: &Value, memory: &MemoryService) -> ToolResult {
|
||||
let id = match args["id"].as_str() {
|
||||
Some(id) if !id.trim().is_empty() => id,
|
||||
_ => return ToolResult::error("`id` is required"),
|
||||
};
|
||||
|
||||
match memory.delete_fact(id).await {
|
||||
Ok(true) => ToolResult::text(format!("Deleted {id}.")),
|
||||
Ok(false) => ToolResult::text(format!("Fact {id} not found.")),
|
||||
Err(e) => ToolResult::error(format!("Delete failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn tool_list_facts(args: &Value, memory: &MemoryService) -> ToolResult {
|
||||
let namespace = args["namespace"].as_str().unwrap_or("default");
|
||||
let limit = args["limit"].as_u64().unwrap_or(50) as usize;
|
||||
let from_ts = parse_date_arg(args["from"].as_str());
|
||||
let to_ts = parse_date_arg(args["to"].as_str());
|
||||
|
||||
match memory.list_facts(namespace, limit, from_ts, to_ts).await {
|
||||
Ok(facts) if facts.is_empty() => {
|
||||
ToolResult::text(format!("No facts in namespace '{namespace}'."))
|
||||
}
|
||||
Ok(facts) => {
|
||||
let mut out = format!("{} fact(s) in '{namespace}':\n\n", facts.len());
|
||||
for f in &facts {
|
||||
out.push_str(&format!("id={}\ncreated: {}\n{}", f.id, f.created_at, f.content));
|
||||
if let Some(s) = &f.source {
|
||||
let desc = SourceUrn::parse(s)
|
||||
.map(|u| u.human_readable())
|
||||
.unwrap_or_else(|_| s.clone());
|
||||
out.push_str(&format!("\nsource: {s} ({desc})"));
|
||||
}
|
||||
out.push_str("\n\n");
|
||||
}
|
||||
ToolResult::text(out.trim_end())
|
||||
}
|
||||
Err(e) => ToolResult::error(format!("List failed: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a date string: RFC 3339 or raw Unix timestamp integer.
|
||||
pub fn parse_ts(s: &str) -> Option<i64> {
|
||||
if let Ok(ts) = s.parse::<i64>() {
|
||||
return Some(ts);
|
||||
}
|
||||
chrono::DateTime::parse_from_rfc3339(s)
|
||||
.ok()
|
||||
.map(|dt| dt.timestamp())
|
||||
}
|
||||
|
||||
fn parse_date_arg(s: Option<&str>) -> Option<i64> {
|
||||
parse_ts(s?)
|
||||
}
|
||||
|
||||
fn tool_build_source_urn(args: &Value) -> ToolResult {
|
||||
let content_type = match args["content_type"].as_str() {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return ToolResult::error("`content_type` is required"),
|
||||
};
|
||||
let origin = match args["origin"].as_str() {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return ToolResult::error("`origin` is required"),
|
||||
};
|
||||
let locator = match args["locator"].as_str() {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return ToolResult::error("`locator` is required"),
|
||||
};
|
||||
let fragment = args["fragment"].as_str();
|
||||
|
||||
match SourceUrn::build(content_type, origin, locator, fragment) {
|
||||
Ok(urn) => ToolResult::text(urn),
|
||||
Err(e) => ToolResult::text(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"error": e.to_string(),
|
||||
"spec": SPEC,
|
||||
})).unwrap_or_default()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_parse_source_urn(args: &Value) -> ToolResult {
|
||||
let urn_str = match args["urn"].as_str() {
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return ToolResult::error("`urn` is required"),
|
||||
};
|
||||
|
||||
let result = match SourceUrn::parse(urn_str) {
|
||||
Ok(urn) => urn.describe(),
|
||||
Err(e) => invalid_urn_response(urn_str, &e),
|
||||
};
|
||||
|
||||
ToolResult::text(serde_json::to_string_pretty(&result).unwrap_or_default())
|
||||
}
|
||||
|
||||
fn tool_describe_urn_schema() -> ToolResult {
|
||||
ToolResult::text(serde_json::to_string_pretty(&schema_json()).unwrap_or_default())
|
||||
}
|
||||
1
src/memory/mod.rs
Normal file
1
src/memory/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod service;
|
||||
150
src/memory/service.rs
Normal file
150
src/memory/service.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use crate::config::MemoryConfig;
|
||||
use crate::error::Result;
|
||||
use crate::semantic::store::SemanticStore;
|
||||
use crate::embedding::service::{EmbeddingService, EmbeddingModelType};
|
||||
use std::sync::Arc;
|
||||
use chrono::TimeZone;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryService {
|
||||
store: Arc<SemanticStore>,
|
||||
embedding_service: Arc<std::sync::Mutex<EmbeddingService>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryFact {
|
||||
pub id: String,
|
||||
pub namespace: String,
|
||||
pub content: String,
|
||||
pub created_at: String, // RFC 3339
|
||||
pub score: f32, // cosine similarity; 0.0 when not from a search
|
||||
pub source: Option<String>, // smem URN identifying where this fact came from
|
||||
}
|
||||
|
||||
impl MemoryService {
|
||||
pub async fn new(config: &MemoryConfig) -> Result<Self> {
|
||||
let store = SemanticStore::new(&crate::semantic::SemanticConfig {
|
||||
base_dir: config.base_dir.clone(),
|
||||
dimension: 768,
|
||||
model_name: "bge-base-en-v1.5".to_string(),
|
||||
}).await?;
|
||||
let svc = EmbeddingService::new(EmbeddingModelType::BgeBaseEnglish).await?;
|
||||
Ok(Self {
|
||||
store: Arc::new(store),
|
||||
embedding_service: Arc::new(std::sync::Mutex::new(svc)),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn new_with_model(config: &MemoryConfig, model_type: EmbeddingModelType) -> Result<Self> {
|
||||
let store = SemanticStore::new(&crate::semantic::SemanticConfig {
|
||||
base_dir: config.base_dir.clone(),
|
||||
dimension: 768,
|
||||
model_name: "bge-base-en-v1.5".to_string(),
|
||||
}).await?;
|
||||
let svc = EmbeddingService::new(model_type).await?;
|
||||
Ok(Self {
|
||||
store: Arc::new(store),
|
||||
embedding_service: Arc::new(std::sync::Mutex::new(svc)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_store(&self) -> Arc<SemanticStore> {
|
||||
Arc::clone(&self.store)
|
||||
}
|
||||
|
||||
pub fn current_model(&self) -> EmbeddingModelType {
|
||||
self.embedding_service.lock().unwrap().current_model()
|
||||
}
|
||||
|
||||
/// Embed and store content. Returns the created fact.
|
||||
pub async fn add_fact(&self, namespace: &str, content: &str, source: Option<&str>) -> Result<MemoryFact> {
|
||||
let svc = self.embedding_service.lock().unwrap().clone();
|
||||
let embeddings = svc.embed(&[content]).await?;
|
||||
let (fact_id, created_at_ts) = self.store.add_fact(namespace, content, &embeddings[0], source).await?;
|
||||
Ok(MemoryFact {
|
||||
id: fact_id,
|
||||
namespace: namespace.to_string(),
|
||||
content: content.to_string(),
|
||||
created_at: ts_to_rfc3339(created_at_ts),
|
||||
score: 0.0,
|
||||
source: source.map(|s| s.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Semantic search. Optional `namespace` restricts results to one namespace.
|
||||
pub async fn search_facts(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
namespace: Option<&str>,
|
||||
) -> Result<Vec<MemoryFact>> {
|
||||
let svc = self.embedding_service.lock().unwrap().clone();
|
||||
let embeddings = svc.embed(&[query]).await?;
|
||||
let results = self.store.search(&embeddings[0], limit, namespace).await?;
|
||||
Ok(results.into_iter().map(|(fact, score)| MemoryFact {
|
||||
id: fact.id,
|
||||
namespace: fact.namespace,
|
||||
content: fact.content,
|
||||
created_at: ts_to_rfc3339(fact.created_at),
|
||||
score,
|
||||
source: fact.source,
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// Update an existing fact in place. Returns false if the ID doesn't exist.
|
||||
pub async fn update_fact(&self, fact_id: &str, content: &str, source: Option<&str>) -> Result<MemoryFact> {
|
||||
let existing = self.store.get_fact(fact_id).await?
|
||||
.ok_or_else(|| crate::error::ServerError::NotFound(fact_id.to_string()))?;
|
||||
let svc = self.embedding_service.lock().unwrap().clone();
|
||||
let embeddings = svc.embed(&[content]).await?;
|
||||
self.store.update_fact(fact_id, content, &embeddings[0], source).await?;
|
||||
Ok(MemoryFact {
|
||||
id: fact_id.to_string(),
|
||||
namespace: existing.namespace,
|
||||
content: content.to_string(),
|
||||
created_at: ts_to_rfc3339(existing.created_at),
|
||||
score: 0.0,
|
||||
source: source.map(|s| s.to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Delete a fact. Returns true if it existed.
|
||||
pub async fn delete_fact(&self, fact_id: &str) -> Result<bool> {
|
||||
self.store.delete_fact(fact_id).await
|
||||
}
|
||||
|
||||
/// List facts in a namespace, most recent first.
|
||||
/// Optional `from`/`to` are RFC 3339 or Unix timestamp strings for date filtering.
|
||||
pub async fn list_facts(
|
||||
&self,
|
||||
namespace: &str,
|
||||
limit: usize,
|
||||
from_ts: Option<i64>,
|
||||
to_ts: Option<i64>,
|
||||
) -> Result<Vec<MemoryFact>> {
|
||||
let facts = self.store.list_facts(namespace, limit, from_ts, to_ts).await?;
|
||||
Ok(facts.into_iter().map(|fact| MemoryFact {
|
||||
id: fact.id,
|
||||
namespace: fact.namespace,
|
||||
content: fact.content,
|
||||
created_at: ts_to_rfc3339(fact.created_at),
|
||||
score: 0.0,
|
||||
source: fact.source,
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// Switch to a different embedding model for future operations.
|
||||
pub async fn switch_model(&self, new_model: EmbeddingModelType) -> Result<()> {
|
||||
let new_svc = EmbeddingService::new(new_model).await?;
|
||||
*self.embedding_service.lock().unwrap() = new_svc;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn ts_to_rfc3339(ts: i64) -> String {
|
||||
chrono::Utc
|
||||
.timestamp_opt(ts, 0)
|
||||
.single()
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
2
src/memory/store.rs
Normal file
2
src/memory/store.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// Removed: MemoryStoreWrapper was a transition placeholder, now dead code.
|
||||
// Semantic storage lives in crate::semantic::store::SemanticStore.
|
||||
37
src/semantic.rs
Normal file
37
src/semantic.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
// Main semantic module
|
||||
// Re-exports semantic functionality
|
||||
|
||||
pub mod db;
|
||||
pub mod store;
|
||||
pub mod index;
|
||||
// search.rs removed — SemanticSearch was dead code, functionality lives in store.rs
|
||||
|
||||
/// Semantic fact storage
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SemanticFact {
|
||||
pub id: String,
|
||||
pub namespace: String,
|
||||
pub content: String,
|
||||
pub created_at: i64, // Unix timestamp
|
||||
pub embedding: Vec<f32>,
|
||||
pub source: Option<String>, // smem URN identifying where this fact came from
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SemanticConfig {
|
||||
pub base_dir: String,
|
||||
pub dimension: usize,
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
impl Default for SemanticConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base_dir: "./data/semantic".to_string(),
|
||||
dimension: 768,
|
||||
model_name: "bge-base-en-v1.5".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub use store::SemanticStore;
|
||||
231
src/semantic/db.rs
Normal file
231
src/semantic/db.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
// Semantic Database Layer
|
||||
|
||||
use crate::error::{Result, ServerError};
|
||||
use crate::semantic::SemanticFact;
|
||||
use rusqlite::{Connection, params, OptionalExtension};
|
||||
use std::path::Path;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// SQLite-based semantic fact storage
|
||||
pub struct SemanticDB {
|
||||
conn: Connection,
|
||||
}
|
||||
|
||||
impl SemanticDB {
|
||||
/// Create new database connection and initialize schema
|
||||
pub fn new(base_dir: &str) -> Result<Self> {
|
||||
let db_path = Path::new(base_dir).join("semantic.db");
|
||||
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(db_path)?;
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS facts (
|
||||
id TEXT PRIMARY KEY,
|
||||
namespace TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
source TEXT
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
// Migrate existing databases — no-op if column already present
|
||||
let has_source: bool = conn.query_row(
|
||||
"SELECT COUNT(*) FROM pragma_table_info('facts') WHERE name='source'",
|
||||
[],
|
||||
|row| row.get::<_, i64>(0),
|
||||
)? > 0;
|
||||
if !has_source {
|
||||
conn.execute("ALTER TABLE facts ADD COLUMN source TEXT", [])?;
|
||||
}
|
||||
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS embeddings (
|
||||
fact_id TEXT PRIMARY KEY,
|
||||
embedding BLOB NOT NULL,
|
||||
FOREIGN KEY (fact_id) REFERENCES facts(id)
|
||||
)",
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
/// Add a fact and return its generated (id, created_at timestamp).
|
||||
pub fn add_fact(&self, fact: &SemanticFact) -> Result<(String, i64)> {
|
||||
let fact_id = Uuid::new_v4().to_string();
|
||||
let timestamp = chrono::Utc::now().timestamp();
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO facts (id, namespace, content, created_at, source) VALUES (?, ?, ?, ?, ?)",
|
||||
params![&fact_id, &fact.namespace, &fact.content, timestamp, &fact.source],
|
||||
)?;
|
||||
|
||||
let embedding_bytes: Vec<u8> = fact.embedding.iter()
|
||||
.flat_map(|f| f.to_le_bytes().to_vec())
|
||||
.collect();
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO embeddings (fact_id, embedding) VALUES (?, ?)",
|
||||
params![&fact_id, &embedding_bytes],
|
||||
)?;
|
||||
|
||||
Ok((fact_id, timestamp))
|
||||
}
|
||||
|
||||
/// Get a fact by ID, including its embedding.
|
||||
pub fn get_fact(&self, fact_id: &str) -> Result<Option<SemanticFact>> {
|
||||
let fact = self.conn.query_row(
|
||||
"SELECT namespace, content, created_at, source FROM facts WHERE id = ?",
|
||||
params![fact_id],
|
||||
|row| {
|
||||
Ok(SemanticFact {
|
||||
id: fact_id.to_string(),
|
||||
namespace: row.get(0)?,
|
||||
content: row.get(1)?,
|
||||
created_at: row.get(2)?,
|
||||
embedding: vec![],
|
||||
source: row.get(3)?,
|
||||
})
|
||||
},
|
||||
).optional()?;
|
||||
|
||||
let fact = match fact {
|
||||
Some(mut f) => {
|
||||
let embedding_bytes: Vec<u8> = self.conn.query_row(
|
||||
"SELECT embedding FROM embeddings WHERE fact_id = ?",
|
||||
params![fact_id],
|
||||
|row| row.get(0),
|
||||
).optional()?.unwrap_or_default();
|
||||
|
||||
f.embedding = decode_embedding(&embedding_bytes)?;
|
||||
Some(f)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(fact)
|
||||
}
|
||||
|
||||
/// Get embedding by fact ID.
|
||||
pub fn get_embedding(&self, fact_id: &str) -> Result<Option<Vec<f32>>> {
|
||||
let embedding_bytes: Vec<u8> = self.conn.query_row(
|
||||
"SELECT embedding FROM embeddings WHERE fact_id = ?",
|
||||
params![fact_id],
|
||||
|row| row.get(0),
|
||||
).optional()?.unwrap_or_default();
|
||||
|
||||
if embedding_bytes.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(Some(decode_embedding(&embedding_bytes)?))
|
||||
}
|
||||
|
||||
/// Search facts by namespace with optional Unix timestamp bounds.
|
||||
pub fn search_by_namespace(
|
||||
&self,
|
||||
namespace: &str,
|
||||
limit: usize,
|
||||
from_ts: Option<i64>,
|
||||
to_ts: Option<i64>,
|
||||
) -> Result<Vec<SemanticFact>> {
|
||||
let from = from_ts.unwrap_or(i64::MIN);
|
||||
let to = to_ts.unwrap_or(i64::MAX);
|
||||
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id FROM facts WHERE namespace = ? AND created_at >= ? AND created_at <= ? ORDER BY created_at DESC LIMIT ?",
|
||||
)?;
|
||||
|
||||
let fact_ids: Vec<String> = stmt
|
||||
.query_map(params![namespace, from, to, limit as i64], |row| row.get(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
let mut results = vec![];
|
||||
for fact_id in fact_ids {
|
||||
if let Some(fact) = self.get_fact(&fact_id)? {
|
||||
results.push(fact);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Update content and/or source of an existing fact. Re-stores the embedding.
|
||||
/// Returns true if the fact existed.
|
||||
pub fn update_fact(&self, fact_id: &str, content: &str, source: Option<&str>, embedding: &[f32]) -> Result<bool> {
|
||||
let updated = self.conn.execute(
|
||||
"UPDATE facts SET content = ?, source = ? WHERE id = ?",
|
||||
params![content, source, fact_id],
|
||||
)?;
|
||||
if updated == 0 {
|
||||
return Ok(false);
|
||||
}
|
||||
let embedding_bytes: Vec<u8> = embedding.iter()
|
||||
.flat_map(|f| f.to_le_bytes())
|
||||
.collect();
|
||||
self.conn.execute(
|
||||
"UPDATE embeddings SET embedding = ? WHERE fact_id = ?",
|
||||
params![&embedding_bytes, fact_id],
|
||||
)?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Delete a fact and its embedding. Returns true if the fact existed.
|
||||
pub fn delete_fact(&mut self, fact_id: &str) -> Result<bool> {
|
||||
let tx = self.conn.transaction()?;
|
||||
|
||||
let count = tx.execute(
|
||||
"DELETE FROM facts WHERE id = ?",
|
||||
params![fact_id],
|
||||
)?;
|
||||
|
||||
tx.execute(
|
||||
"DELETE FROM embeddings WHERE fact_id = ?",
|
||||
params![fact_id],
|
||||
)?;
|
||||
|
||||
tx.commit()?;
|
||||
|
||||
Ok(count > 0)
|
||||
}
|
||||
|
||||
/// Get all facts (used by hybrid search).
|
||||
pub fn get_all_facts(&self) -> Result<Vec<SemanticFact>> {
|
||||
let mut stmt = self.conn.prepare(
|
||||
"SELECT id FROM facts ORDER BY created_at DESC",
|
||||
)?;
|
||||
|
||||
let fact_ids: Vec<String> = stmt
|
||||
.query_map([], |row| row.get(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
let mut results = vec![];
|
||||
for fact_id in fact_ids {
|
||||
if let Some(fact) = self.get_fact(&fact_id)? {
|
||||
results.push(fact);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a raw byte blob into an f32 vector, returning an error on malformed data
|
||||
/// instead of panicking.
|
||||
fn decode_embedding(bytes: &[u8]) -> Result<Vec<f32>> {
|
||||
bytes.chunks_exact(4)
|
||||
.map(|chunk| -> Result<f32> {
|
||||
let arr: [u8; 4] = chunk.try_into().map_err(|_| {
|
||||
ServerError::DatabaseError("malformed embedding blob: length not divisible by 4".to_string())
|
||||
})?;
|
||||
Ok(f32::from_le_bytes(arr))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
51
src/semantic/index.rs
Normal file
51
src/semantic/index.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
// Semantic Index — cosine similarity search over in-memory vectors.
|
||||
// Uses a HashMap so deletion is O(1) and the index stays consistent
|
||||
// with the database after deletes.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct SemanticIndex {
|
||||
vectors: HashMap<String, Vec<f32>>,
|
||||
}
|
||||
|
||||
impl SemanticIndex {
|
||||
pub fn new(_dimension: usize) -> Self {
|
||||
Self {
|
||||
vectors: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
0.0
|
||||
} else {
|
||||
dot / (norm_a * norm_b)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_vector(&mut self, vector: &[f32], id: &str) -> bool {
|
||||
self.vectors.insert(id.to_string(), vector.to_vec());
|
||||
true
|
||||
}
|
||||
|
||||
/// Remove a vector by id. Returns true if it existed.
|
||||
pub fn remove_vector(&mut self, id: &str) -> bool {
|
||||
self.vectors.remove(id).is_some()
|
||||
}
|
||||
|
||||
/// Return the top-k most similar ids with their scores, highest first.
|
||||
pub fn search(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
|
||||
let mut results: Vec<(String, f32)> = self
|
||||
.vectors
|
||||
.iter()
|
||||
.map(|(id, vec)| (id.clone(), Self::cosine_similarity(query, vec)))
|
||||
.collect();
|
||||
|
||||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
results.truncate(k);
|
||||
results
|
||||
}
|
||||
}
|
||||
2
src/semantic/search.rs
Normal file
2
src/semantic/search.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
// Removed: SemanticSearch was a stub with no callers.
|
||||
// Hybrid search is implemented in crate::semantic::store::SemanticStore::hybrid_search.
|
||||
174
src/semantic/store.rs
Normal file
174
src/semantic/store.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
// Semantic Store Implementation
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::semantic::{SemanticConfig, SemanticFact};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// Semantic store using SQLite for persistence and an in-memory cosine-similarity index.
|
||||
pub struct SemanticStore {
|
||||
config: SemanticConfig,
|
||||
db: Arc<Mutex<super::db::SemanticDB>>,
|
||||
index: Arc<Mutex<super::index::SemanticIndex>>,
|
||||
}
|
||||
|
||||
impl SemanticStore {
|
||||
pub async fn new(config: &SemanticConfig) -> Result<Self> {
|
||||
let db = Arc::new(Mutex::new(super::db::SemanticDB::new(&config.base_dir)?));
|
||||
let index = Arc::new(Mutex::new(super::index::SemanticIndex::new(config.dimension)));
|
||||
|
||||
Ok(Self {
|
||||
config: config.clone(),
|
||||
db,
|
||||
index,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a fact with its embedding. Returns (fact_id, created_at unix timestamp).
|
||||
pub async fn add_fact(
|
||||
&self,
|
||||
namespace: &str,
|
||||
content: &str,
|
||||
embedding: &[f32],
|
||||
source: Option<&str>,
|
||||
) -> Result<(String, i64)> {
|
||||
let fact = SemanticFact {
|
||||
id: String::new(),
|
||||
namespace: namespace.to_string(),
|
||||
content: content.to_string(),
|
||||
created_at: 0,
|
||||
embedding: embedding.to_vec(),
|
||||
source: source.map(|s| s.to_string()),
|
||||
};
|
||||
|
||||
let (fact_id, created_at) = self.db.lock().unwrap().add_fact(&fact)?;
|
||||
self.index.lock().unwrap().add_vector(embedding, &fact_id);
|
||||
|
||||
Ok((fact_id, created_at))
|
||||
}
|
||||
|
||||
/// Vector similarity search. Returns facts paired with their cosine similarity
|
||||
/// score (0–1, higher is better). When a namespace filter is supplied,
|
||||
/// oversamples from the index to ensure `limit` results after filtering.
|
||||
pub async fn search(
|
||||
&self,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
namespace_filter: Option<&str>,
|
||||
) -> Result<Vec<(SemanticFact, f32)>> {
|
||||
let search_limit = if namespace_filter.is_some() {
|
||||
(limit * 10).max(50)
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
|
||||
let similar_ids = self.index.lock().unwrap().search(query_embedding, search_limit);
|
||||
|
||||
let mut results = vec![];
|
||||
for (fact_id, score) in similar_ids {
|
||||
if results.len() >= limit {
|
||||
break;
|
||||
}
|
||||
if let Some(fact) = self.db.lock().unwrap().get_fact(&fact_id)? {
|
||||
if let Some(namespace) = namespace_filter {
|
||||
if fact.namespace == namespace {
|
||||
results.push((fact, score));
|
||||
}
|
||||
} else {
|
||||
results.push((fact, score));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Return facts in a namespace ordered by creation time, without scoring.
|
||||
/// Optional `from_ts`/`to_ts` are Unix timestamps (seconds) for date filtering.
|
||||
pub async fn list_facts(
|
||||
&self,
|
||||
namespace: &str,
|
||||
limit: usize,
|
||||
from_ts: Option<i64>,
|
||||
to_ts: Option<i64>,
|
||||
) -> Result<Vec<SemanticFact>> {
|
||||
self.db.lock().unwrap().search_by_namespace(namespace, limit, from_ts, to_ts)
|
||||
}
|
||||
|
||||
/// Hybrid search: keyword filter + vector ranking.
|
||||
///
|
||||
/// If any facts contain `keyword`, those are ranked by vector similarity and
|
||||
/// the top `limit` are returned. If no facts contain the keyword the search
|
||||
/// falls back to a pure vector search over all facts so callers always get
|
||||
/// useful results.
|
||||
pub async fn hybrid_search(
|
||||
&self,
|
||||
keyword: &str,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
) -> Result<Vec<SemanticFact>> {
|
||||
let all_facts = self.db.lock().unwrap().get_all_facts()?;
|
||||
let keyword_lower = keyword.to_lowercase();
|
||||
|
||||
let keyword_matches: Vec<SemanticFact> = all_facts
|
||||
.iter()
|
||||
.filter(|f| f.content.to_lowercase().contains(&keyword_lower))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Fall back to all facts when the keyword matches nothing, rather than
|
||||
// returning an empty result that is worse than a plain vector search.
|
||||
let candidates = if keyword_matches.is_empty() {
|
||||
all_facts
|
||||
} else {
|
||||
keyword_matches
|
||||
};
|
||||
|
||||
let mut scored: Vec<(SemanticFact, f32)> = candidates
|
||||
.into_iter()
|
||||
.map(|fact| {
|
||||
let sim = super::index::SemanticIndex::cosine_similarity(
|
||||
query_embedding,
|
||||
&fact.embedding,
|
||||
);
|
||||
(fact, sim)
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
Ok(scored.into_iter().take(limit).map(|(f, _)| f).collect())
|
||||
}
|
||||
|
||||
/// Update an existing fact in place. Re-embeds with the new content and
|
||||
/// replaces the vector in the index. Returns false if the ID doesn't exist.
|
||||
pub async fn update_fact(
|
||||
&self,
|
||||
fact_id: &str,
|
||||
content: &str,
|
||||
embedding: &[f32],
|
||||
source: Option<&str>,
|
||||
) -> Result<bool> {
|
||||
let updated = self.db.lock().unwrap().update_fact(fact_id, content, source, embedding)?;
|
||||
if updated {
|
||||
self.index.lock().unwrap().add_vector(embedding, fact_id);
|
||||
}
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Delete a fact from both the database and the vector index.
|
||||
pub async fn delete_fact(&self, fact_id: &str) -> Result<bool> {
|
||||
let deleted = self.db.lock().unwrap().delete_fact(fact_id)?;
|
||||
if deleted {
|
||||
self.index.lock().unwrap().remove_vector(fact_id);
|
||||
}
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Get a fact by ID.
|
||||
pub async fn get_fact(&self, fact_id: &str) -> Result<Option<SemanticFact>> {
|
||||
self.db.lock().unwrap().get_fact(fact_id)
|
||||
}
|
||||
|
||||
pub fn get_base_dir(&self) -> String {
|
||||
self.config.base_dir.clone()
|
||||
}
|
||||
}
|
||||
821
src/urn.rs
Normal file
821
src/urn.rs
Normal file
@@ -0,0 +1,821 @@
|
||||
/// smem URN — a stable, machine-readable source identifier for any fact stored
|
||||
/// in semantic memory, regardless of where it originated.
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// ```text
|
||||
/// urn:smem:<type>:<origin>:<locator>[#<fragment>]
|
||||
/// ```
|
||||
///
|
||||
/// Every component is lowercase ASCII except `<locator>` and `<fragment>`,
|
||||
/// which preserve the case of the underlying system (paths, URLs, etc.).
|
||||
///
|
||||
/// # Content types
|
||||
///
|
||||
/// | Type | Meaning |
|
||||
/// |--------|--------------------------------------------------|
|
||||
/// | `code` | Source code file |
|
||||
/// | `doc` | Documentation, markdown, prose |
|
||||
/// | `web` | Live web content |
|
||||
/// | `data` | Structured data — DB rows, API payloads, CSV |
|
||||
/// | `note` | Manually authored or synthesized fact |
|
||||
/// | `conf` | Configuration file |
|
||||
///
|
||||
/// # Origins and locator shapes
|
||||
///
|
||||
/// | Origin | Locator shape | Example locator |
|
||||
/// |----------|---------------------------------------------------------|----------------------------------------------------------|
|
||||
/// | `git` | `<host>:<org>:<repo>:<branch>:<path>` | `github.com:acme:repo:main:src/lib.rs` |
|
||||
/// | `git` | `<host>::<repo>:<branch>:<path>` | `git.example.com::repo:main:src/lib.rs` (no org) |
|
||||
/// | `fs` | `[<hostname>:]<absolute-path>` | `/home/user/project/src/main.rs` or `my-nas:/mnt/share/file.txt` |
|
||||
/// | `https` | `<host>/<path>` | `docs.example.com/api/overview` |
|
||||
/// | `http` | `<host>/<path>` | `intranet.corp/wiki/setup` |
|
||||
/// | `db` | `<driver>/<host>/<database>/<table>/<pk>` | `postgres/localhost/myapp/users/abc-123` |
|
||||
/// | `api` | `<host>/<path>` | `api.example.com/v2/facts/abc-123` |
|
||||
/// | `manual` | `<label>` | `2026-03-04/onboarding-session` |
|
||||
///
|
||||
/// # Fragment (`#`)
|
||||
///
|
||||
/// | Shape | Meaning |
|
||||
/// |------------|------------------------------------------------------|
|
||||
/// | `L42` | Single line 42 |
|
||||
/// | `L10-L30` | Lines 10 through 30 |
|
||||
/// | `<slug>` | Section anchor (HTML `id` or Markdown heading slug) |
|
||||
///
|
||||
/// # Full examples
|
||||
///
|
||||
/// ```text
|
||||
/// urn:smem:code:git:github.com/acme/repo/refs/heads/main/src/lib.rs#L1-L50
|
||||
/// urn:smem:code:fs:/Users/sienna/Development/sunbeam/mcp-server/src/main.rs#L10
|
||||
/// urn:smem:doc:https:docs.anthropic.com/mcp/protocol#tools
|
||||
/// urn:smem:doc:fs:/Users/sienna/Development/sunbeam/README.md
|
||||
/// urn:smem:data:db:postgres/localhost/sunbeam/facts/abc-123
|
||||
/// urn:smem:note:manual:2026-03-04/onboarding-session
|
||||
/// urn:smem:conf:fs:/etc/myapp/config.toml
|
||||
/// ```
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
// ── constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const PREFIX: &str = "urn:smem:";
|
||||
|
||||
/// Machine-readable spec string embedded into MCP tool descriptions so any
|
||||
/// client can understand the format without out-of-band documentation.
|
||||
pub const SPEC: &str = "\
|
||||
smem URN format: urn:smem:<type>:<origin>:<locator>[#<fragment>]
|
||||
|
||||
Content types: code doc web data note conf
|
||||
Origins: git fs https http db api manual
|
||||
|
||||
Origin locator shapes:
|
||||
git <host>:<org>:<repo>:<branch>:<path> (ARN-like format, org optional)
|
||||
fs [<hostname>:]<absolute-path> (hostname optional; identifies NAS, remote machine, cloud drive, etc.)
|
||||
https <host>/<path>
|
||||
http <host>/<path>
|
||||
db <driver>/<host>/<database>/<table>/<pk>
|
||||
api <host>/<path>
|
||||
manual <label>
|
||||
|
||||
Fragment (#):
|
||||
L42 single line
|
||||
L10-L30 line range
|
||||
<slug> section anchor / HTML id
|
||||
|
||||
Examples:
|
||||
urn:smem:code:git:github.com:acme:repo:main:src/lib.rs#L1-L50
|
||||
urn:smem:code:git:git.example.com::repo:main:src/lib.rs
|
||||
urn:smem:code:fs:/home/user/project/src/main.rs#L10
|
||||
urn:smem:doc:https:docs.example.com/api/overview#authentication
|
||||
urn:smem:data:db:postgres/localhost/mydb/users/abc-123
|
||||
urn:smem:note:manual:2026-03-04/session-notes";
|
||||
|
||||
// ── types ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ContentType {
|
||||
Code,
|
||||
Doc,
|
||||
Web,
|
||||
Data,
|
||||
Note,
|
||||
Conf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Origin {
|
||||
Git,
|
||||
Fs,
|
||||
Https,
|
||||
Http,
|
||||
Db,
|
||||
Api,
|
||||
Manual,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SourceUrn {
|
||||
pub content_type: ContentType,
|
||||
pub origin: Origin,
|
||||
/// Origin-specific locator string; see the per-origin shapes above.
|
||||
pub locator: String,
|
||||
/// Optional sub-location within the source (line range, anchor, etc.).
|
||||
pub fragment: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UrnError(pub String);
|
||||
|
||||
impl std::fmt::Display for UrnError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// ── parsing ───────────────────────────────────────────────────────────────────
|
||||
|
||||
impl ContentType {
|
||||
fn parse(s: &str) -> Result<Self, UrnError> {
|
||||
match s {
|
||||
"code" => Ok(Self::Code),
|
||||
"doc" => Ok(Self::Doc),
|
||||
"web" => Ok(Self::Web),
|
||||
"data" => Ok(Self::Data),
|
||||
"note" => Ok(Self::Note),
|
||||
"conf" => Ok(Self::Conf),
|
||||
other => Err(UrnError(format!(
|
||||
"unknown content type {other:?}; valid: code, doc, web, data, note, conf"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Code => "code",
|
||||
Self::Doc => "doc",
|
||||
Self::Web => "web",
|
||||
Self::Data => "data",
|
||||
Self::Note => "note",
|
||||
Self::Conf => "conf",
|
||||
}
|
||||
}
|
||||
|
||||
fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Code => "source code",
|
||||
Self::Doc => "documentation",
|
||||
Self::Web => "web page",
|
||||
Self::Data => "data record",
|
||||
Self::Note => "note",
|
||||
Self::Conf => "configuration",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Origin {
|
||||
fn parse(s: &str) -> Result<Self, UrnError> {
|
||||
match s {
|
||||
"git" => Ok(Self::Git),
|
||||
"fs" => Ok(Self::Fs),
|
||||
"https" => Ok(Self::Https),
|
||||
"http" => Ok(Self::Http),
|
||||
"db" => Ok(Self::Db),
|
||||
"api" => Ok(Self::Api),
|
||||
"manual" => Ok(Self::Manual),
|
||||
other => Err(UrnError(format!(
|
||||
"unknown origin {other:?}; valid: git, fs, https, http, db, api, manual"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Git => "git",
|
||||
Self::Fs => "fs",
|
||||
Self::Https => "https",
|
||||
Self::Http => "http",
|
||||
Self::Db => "db",
|
||||
Self::Api => "api",
|
||||
Self::Manual => "manual",
|
||||
}
|
||||
}
|
||||
|
||||
fn label(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Git => "git repository",
|
||||
Self::Fs => "local file",
|
||||
Self::Https | Self::Http => "web URL",
|
||||
Self::Db => "database",
|
||||
Self::Api => "API endpoint",
|
||||
Self::Manual => "manually authored",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── SourceUrn ─────────────────────────────────────────────────────────────────
|
||||
|
||||
impl SourceUrn {
|
||||
/// Build and validate a URN from string components.
|
||||
///
|
||||
/// Validates each component and returns the canonical URN string, or a
|
||||
/// `UrnError` if any component is invalid.
|
||||
pub fn build(
|
||||
content_type_str: &str,
|
||||
origin_str: &str,
|
||||
locator: &str,
|
||||
fragment: Option<&str>,
|
||||
) -> Result<String, UrnError> {
|
||||
if locator.is_empty() {
|
||||
return Err(UrnError("locator must not be empty".to_string()));
|
||||
}
|
||||
if let Some(f) = fragment {
|
||||
if f.is_empty() {
|
||||
return Err(UrnError("fragment must not be empty if provided".to_string()));
|
||||
}
|
||||
}
|
||||
let urn = Self {
|
||||
content_type: ContentType::parse(content_type_str)?,
|
||||
origin: Origin::parse(origin_str)?,
|
||||
locator: locator.to_string(),
|
||||
fragment: fragment.map(|f| f.to_string()),
|
||||
};
|
||||
Ok(urn.to_urn())
|
||||
}
|
||||
|
||||
/// Parse a `urn:smem:...` string into its components.
|
||||
///
|
||||
/// Returns `UrnError` with a human-readable message on any malformed input.
|
||||
pub fn parse(input: &str) -> Result<Self, UrnError> {
|
||||
let rest = input.strip_prefix(PREFIX).ok_or_else(|| {
|
||||
UrnError(format!("must start with '{PREFIX}'; got {input:?}"))
|
||||
})?;
|
||||
|
||||
// Fragment splits on the first '#'; the fragment may itself contain '#'.
|
||||
let (body, fragment) = match rest.split_once('#') {
|
||||
Some((b, f)) if f.is_empty() => {
|
||||
return Err(UrnError("fragment after '#' must not be empty".to_string()));
|
||||
}
|
||||
Some((b, f)) => (b, Some(f.to_string())),
|
||||
None => (rest, None),
|
||||
};
|
||||
|
||||
// body = <type>:<origin>:<locator>
|
||||
// splitn(3) so that the locator can itself contain colons (e.g. fs paths, db URIs).
|
||||
let mut parts = body.splitn(3, ':');
|
||||
let type_str = parts.next().filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| UrnError("missing <type> component".to_string()))?;
|
||||
let origin_str = parts.next().filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| UrnError("missing <origin> component".to_string()))?;
|
||||
let locator = parts.next().filter(|s| !s.is_empty())
|
||||
.ok_or_else(|| UrnError("missing <locator> component".to_string()))?;
|
||||
|
||||
Ok(Self {
|
||||
content_type: ContentType::parse(type_str)?,
|
||||
origin: Origin::parse(origin_str)?,
|
||||
locator: locator.to_string(),
|
||||
fragment,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reconstitute the canonical URN string.
|
||||
pub fn to_urn(&self) -> String {
|
||||
let frag = self.fragment.as_deref()
|
||||
.map(|f| format!("#{f}"))
|
||||
.unwrap_or_default();
|
||||
format!(
|
||||
"{}{}:{}:{}{}",
|
||||
PREFIX, self.content_type.as_str(), self.origin.as_str(), self.locator, frag
|
||||
)
|
||||
}
|
||||
|
||||
/// Return structured JSON describing every parsed component.
|
||||
/// Used by the `parse_source_urn` MCP tool.
|
||||
pub fn describe(&self) -> Value {
|
||||
json!({
|
||||
"valid": true,
|
||||
"content_type": self.content_type.as_str(),
|
||||
"origin": self.origin.as_str(),
|
||||
"locator": self.locator,
|
||||
"fragment": self.fragment,
|
||||
"human_readable": self.human_readable(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn human_readable(&self) -> String {
|
||||
let frag = match &self.fragment {
|
||||
None => String::new(),
|
||||
Some(f) if f.starts_with('L') => format!(" ({})", f.replace('-', "–")),
|
||||
Some(f) => format!(" §{f}"),
|
||||
};
|
||||
format!(
|
||||
"{} from {}: {}{}",
|
||||
self.content_type.label(), self.origin.label(), self.locator, frag
|
||||
)
|
||||
}
|
||||
|
||||
//─── Git-specific Methods (ARN-like format) ────────────────────────────────
|
||||
|
||||
/// Validate that a git URN has the correct ARN-like format
|
||||
pub fn is_valid_git_urn(&self) -> bool {
|
||||
if self.origin != Origin::Git {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split by colon to parse ARN-like format
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
|
||||
// Minimum valid format: host::repo:branch:path (5 parts)
|
||||
// Or: host:org:repo:branch:path (6+ parts)
|
||||
if parts.len() < 5 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate required components are not empty
|
||||
// parts[0] = host, parts[2] = repo, parts[3] = branch, parts[4..] = path
|
||||
parts[0].is_empty() == false &&
|
||||
parts[2].is_empty() == false &&
|
||||
parts[3].is_empty() == false &&
|
||||
parts[4..].join(":").is_empty() == false
|
||||
}
|
||||
|
||||
/// Extract host from git locator (ARN-like format: host:org:repo:branch:path)
|
||||
pub fn extract_git_host(&self) -> Option<&str> {
|
||||
if !self.is_valid_git_urn() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
Some(parts[0])
|
||||
}
|
||||
|
||||
/// Extract organization from git locator (returns None if no org)
|
||||
pub fn extract_git_org(&self) -> Option<&str> {
|
||||
if !self.is_valid_git_urn() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
|
||||
// parts[1] is org field - empty string means no org
|
||||
if parts[1].is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract repository name from git locator
|
||||
pub fn extract_git_repo(&self) -> Option<&str> {
|
||||
if !self.is_valid_git_urn() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
|
||||
// parts[2] is always the repo name
|
||||
Some(parts[2])
|
||||
}
|
||||
|
||||
/// Extract branch from git locator
|
||||
pub fn extract_git_branch(&self) -> Option<&str> {
|
||||
if !self.is_valid_git_urn() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
|
||||
// parts[3] is always the branch name
|
||||
Some(parts[3])
|
||||
}
|
||||
|
||||
/// Extract file path from git locator (preserves / separators)
|
||||
pub fn extract_git_path(&self) -> Option<String> {
|
||||
if !self.is_valid_git_urn() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = self.locator.split(':').collect();
|
||||
|
||||
// parts[4..] is the path - join with : to preserve any colons in the path
|
||||
// This handles paths like "src:lib.rs" (though rare)
|
||||
Some(parts[4..].join(":"))
|
||||
}
|
||||
|
||||
/// Build a git URN with ARN-like format: host:org:repo:branch:path
|
||||
pub fn build_git_urn(
|
||||
host: &str,
|
||||
org: Option<&str>,
|
||||
repo: &str,
|
||||
branch: &str,
|
||||
path: &str,
|
||||
fragment: Option<&str>
|
||||
) -> Result<String, UrnError> {
|
||||
// Validate required components
|
||||
if host.is_empty() {
|
||||
return Err(UrnError("host cannot be empty".to_string()));
|
||||
}
|
||||
if repo.is_empty() {
|
||||
return Err(UrnError("repo cannot be empty".to_string()));
|
||||
}
|
||||
if branch.is_empty() {
|
||||
return Err(UrnError("branch cannot be empty".to_string()));
|
||||
}
|
||||
if path.is_empty() {
|
||||
return Err(UrnError("path cannot be empty".to_string()));
|
||||
}
|
||||
|
||||
// Build locator: host:org:repo:branch:path
|
||||
let org_part = org.unwrap_or("");
|
||||
let locator = format!("{}:{}:{}:{}:{}", host, org_part, repo, branch, path);
|
||||
|
||||
// Build full URN
|
||||
let fragment_part = fragment.map(|f| format!("#{}", f)).unwrap_or_default();
|
||||
Ok(format!("urn:smem:code:git:{}{}", locator, fragment_part))
|
||||
}
|
||||
}
|
||||
|
||||
// ── schema ────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Return machine-readable taxonomy of all valid URN components.
|
||||
/// Used by the `describe_urn_schema` MCP tool.
|
||||
pub fn schema_json() -> Value {
|
||||
json!({
|
||||
"format": "urn:smem:<type>:<origin>:<locator>[#<fragment>]",
|
||||
"content_types": [
|
||||
{ "value": "code", "label": "source code" },
|
||||
{ "value": "doc", "label": "documentation" },
|
||||
{ "value": "web", "label": "web page" },
|
||||
{ "value": "data", "label": "data record" },
|
||||
{ "value": "note", "label": "note" },
|
||||
{ "value": "conf", "label": "configuration" }
|
||||
],
|
||||
"origins": [
|
||||
{
|
||||
"value": "git",
|
||||
"label": "git repository",
|
||||
"locator_shape": "<host>:<org>:<repo>:<branch>:<path>",
|
||||
"note": "ARN-like format with colons. Org is optional (use :: for no org). Path preserves / separators.",
|
||||
"examples": [
|
||||
"github.com:acme:repo:main:src/lib.rs",
|
||||
"git.example.com::repo:main:src/lib.rs",
|
||||
"github.com:acme:repo:feature/new-auth:src/auth.rs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "fs",
|
||||
"label": "local or remote file",
|
||||
"locator_shape": "[<hostname>:]<absolute-path>",
|
||||
"note": "Hostname is optional. Omit for local files; include to identify a NAS, remote machine, cloud drive mount, etc.",
|
||||
"examples": [
|
||||
"/home/user/project/src/main.rs",
|
||||
"my-nas:/mnt/share/docs/readme.txt",
|
||||
"macbook-pro:/Users/sienna/notes.md",
|
||||
"google-drive:/My Drive/design.gdoc"
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "https",
|
||||
"label": "web URL (https)",
|
||||
"locator_shape": "<host>/<path>",
|
||||
"example": "docs.example.com/api/overview"
|
||||
},
|
||||
{
|
||||
"value": "http",
|
||||
"label": "web URL (http)",
|
||||
"locator_shape": "<host>/<path>",
|
||||
"example": "intranet.corp/wiki/setup"
|
||||
},
|
||||
{
|
||||
"value": "db",
|
||||
"label": "database record",
|
||||
"locator_shape": "<driver>/<host>/<database>/<table>/<pk>",
|
||||
"example": "postgres/localhost/myapp/users/abc-123"
|
||||
},
|
||||
{
|
||||
"value": "api",
|
||||
"label": "API endpoint",
|
||||
"locator_shape": "<host>/<path>",
|
||||
"example": "api.example.com/v2/facts/abc-123"
|
||||
},
|
||||
{
|
||||
"value": "manual",
|
||||
"label": "manually authored",
|
||||
"locator_shape": "<label>",
|
||||
"example": "2026-03-04/onboarding-session"
|
||||
}
|
||||
],
|
||||
"fragment_shapes": [
|
||||
{ "pattern": "L42", "meaning": "single line 42" },
|
||||
{ "pattern": "L10-L30", "meaning": "lines 10 through 30" },
|
||||
{ "pattern": "<slug>", "meaning": "section anchor / HTML id" }
|
||||
],
|
||||
"examples": [
|
||||
"urn:smem:code:git:github.com:acme:repo:main:src/lib.rs#L1-L50",
|
||||
"urn:smem:code:git:git.example.com::repo:main:src/lib.rs",
|
||||
"urn:smem:code:fs:/home/user/project/src/main.rs#L10",
|
||||
"urn:smem:code:fs:my-nas:/mnt/share/src/main.rs",
|
||||
"urn:smem:doc:https:docs.example.com/api/overview#authentication",
|
||||
"urn:smem:data:db:postgres/localhost/mydb/users/abc-123",
|
||||
"urn:smem:note:manual:2026-03-04/session-notes"
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
// ── error response helper ─────────────────────────────────────────────────────
|
||||
|
||||
/// Build the JSON response for an invalid URN (used by the MCP tool).
|
||||
pub fn invalid_urn_response(input: &str, err: &UrnError) -> Value {
|
||||
json!({
|
||||
"valid": false,
|
||||
"input": input,
|
||||
"error": err.to_string(),
|
||||
"spec": SPEC,
|
||||
})
|
||||
}
|
||||
|
||||
// ── tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn roundtrip_git() {
|
||||
let s = "urn:smem:code:git:github.com/acme/repo/refs/heads/main/src/lib.rs#L1-L50";
|
||||
let urn = SourceUrn::parse(s).unwrap();
|
||||
assert_eq!(urn.content_type, ContentType::Code);
|
||||
assert_eq!(urn.origin, Origin::Git);
|
||||
assert_eq!(urn.locator, "github.com/acme/repo/refs/heads/main/src/lib.rs");
|
||||
assert_eq!(urn.fragment.as_deref(), Some("L1-L50"));
|
||||
assert_eq!(urn.to_urn(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_fs_no_fragment() {
|
||||
let s = "urn:smem:doc:fs:/Users/sienna/README.md";
|
||||
let urn = SourceUrn::parse(s).unwrap();
|
||||
assert_eq!(urn.origin, Origin::Fs);
|
||||
assert!(urn.fragment.is_none());
|
||||
assert_eq!(urn.to_urn(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_db_with_colons_in_locator() {
|
||||
// locator contains slashes but the split is on ':' — fs paths have ':' on Windows
|
||||
// and db locators may too; splitn(3) ensures the locator is taken verbatim.
|
||||
let s = "urn:smem:data:db:postgres/localhost/mydb/users/abc-123";
|
||||
let urn = SourceUrn::parse(s).unwrap();
|
||||
assert_eq!(urn.locator, "postgres/localhost/mydb/users/abc-123");
|
||||
assert_eq!(urn.to_urn(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roundtrip_https_with_anchor() {
|
||||
let s = "urn:smem:doc:https:docs.example.com/api/overview#authentication";
|
||||
let urn = SourceUrn::parse(s).unwrap();
|
||||
assert_eq!(urn.origin, Origin::Https);
|
||||
assert_eq!(urn.fragment.as_deref(), Some("authentication"));
|
||||
assert_eq!(urn.to_urn(), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_missing_prefix() {
|
||||
assert!(SourceUrn::parse("smem:code:fs:/foo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_unknown_type() {
|
||||
assert!(SourceUrn::parse("urn:smem:blob:fs:/foo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_unknown_origin() {
|
||||
assert!(SourceUrn::parse("urn:smem:code:ftp:/foo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn err_empty_fragment() {
|
||||
assert!(SourceUrn::parse("urn:smem:code:fs:/foo#").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_valid_urn() {
|
||||
let urn = SourceUrn::build("code", "fs", "/Users/sienna/file.rs", Some("L10-L30")).unwrap();
|
||||
assert_eq!(urn, "urn:smem:code:fs:/Users/sienna/file.rs#L10-L30");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_valid_fs_with_hostname() {
|
||||
let urn = SourceUrn::build("doc", "fs", "my-nas:/mnt/share/readme.txt", None).unwrap();
|
||||
assert_eq!(urn, "urn:smem:doc:fs:my-nas:/mnt/share/readme.txt");
|
||||
// Verify it round-trips through parse
|
||||
let parsed = SourceUrn::parse(&urn).unwrap();
|
||||
assert_eq!(parsed.locator, "my-nas:/mnt/share/readme.txt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_err_empty_locator() {
|
||||
assert!(SourceUrn::build("code", "fs", "", None).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_err_empty_fragment() {
|
||||
assert!(SourceUrn::build("code", "fs", "/foo", Some("")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_json_has_required_fields() {
|
||||
let s = schema_json();
|
||||
assert!(s["content_types"].is_array());
|
||||
assert!(s["origins"].is_array());
|
||||
assert!(s["fragment_shapes"].is_array());
|
||||
assert!(s["examples"].is_array());
|
||||
// fs origin should describe hostname support
|
||||
let fs = s["origins"].as_array().unwrap()
|
||||
.iter()
|
||||
.find(|o| o["value"] == "fs")
|
||||
.expect("fs origin");
|
||||
assert!(fs["note"].as_str().unwrap().contains("NAS") ||
|
||||
fs["locator_shape"].as_str().unwrap().contains("hostname"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn human_readable_line_range() {
|
||||
let urn = SourceUrn::parse(
|
||||
"urn:smem:code:git:github.com/acme/repo/refs/heads/main/src/lib.rs#L10-L30"
|
||||
).unwrap();
|
||||
let desc = urn.human_readable();
|
||||
assert!(desc.contains("source code"));
|
||||
assert!(desc.contains("git repository"));
|
||||
assert!(desc.contains("L10–L30"));
|
||||
}
|
||||
|
||||
//─── New ARN-like Git URN Tests ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_basic_parsing() {
|
||||
// Test that the new format can be parsed at all
|
||||
let urn_str = "urn:smem:code:git:github.com:acme:repo:main:src/lib.rs#L1-L50";
|
||||
let urn = SourceUrn::parse(urn_str).unwrap();
|
||||
|
||||
// Verify basic parsing works
|
||||
assert_eq!(urn.content_type, ContentType::Code);
|
||||
assert_eq!(urn.origin, Origin::Git);
|
||||
assert_eq!(urn.locator, "github.com:acme:repo:main:src/lib.rs");
|
||||
assert_eq!(urn.fragment.as_deref(), Some("L1-L50"));
|
||||
|
||||
// Verify round-trip works
|
||||
assert_eq!(urn.to_urn(), urn_str);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_no_org_parsing() {
|
||||
// Test parsing without organization (double colon)
|
||||
let urn_str = "urn:smem:code:git:git.example.com::repo:main:src/lib.rs";
|
||||
let urn = SourceUrn::parse(urn_str).unwrap();
|
||||
|
||||
assert_eq!(urn.content_type, ContentType::Code);
|
||||
assert_eq!(urn.origin, Origin::Git);
|
||||
assert_eq!(urn.locator, "git.example.com::repo:main:src/lib.rs");
|
||||
assert_eq!(urn.to_urn(), urn_str);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_basic_format() {
|
||||
// Standard format with all components
|
||||
let urn_str = "urn:smem:code:git:github.com:acme:repo:main:src/lib.rs#L1-L50";
|
||||
let urn = SourceUrn::parse(urn_str).unwrap();
|
||||
|
||||
// Verify basic parsing
|
||||
assert_eq!(urn.content_type, ContentType::Code);
|
||||
assert_eq!(urn.origin, Origin::Git);
|
||||
assert_eq!(urn.locator, "github.com:acme:repo:main:src/lib.rs");
|
||||
assert_eq!(urn.fragment.as_deref(), Some("L1-L50"));
|
||||
|
||||
// Verify round-trip
|
||||
assert_eq!(urn.to_urn(), urn_str);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_with_organization() {
|
||||
let urn = SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:main:src/lib.rs").unwrap();
|
||||
assert_eq!(urn.extract_git_org(), Some("acme"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_without_organization() {
|
||||
// Empty org field (double colon)
|
||||
let urn = SourceUrn::parse("urn:smem:code:git:git.example.com::repo:main:src/lib.rs").unwrap();
|
||||
assert_eq!(urn.extract_git_org(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_component_extraction() {
|
||||
let urn = SourceUrn::parse("urn:smem:code:git:gitlab.com:company:project:dev:src/main.rs").unwrap();
|
||||
|
||||
assert_eq!(urn.extract_git_host(), Some("gitlab.com"));
|
||||
assert_eq!(urn.extract_git_org(), Some("company"));
|
||||
assert_eq!(urn.extract_git_repo(), Some("project"));
|
||||
assert_eq!(urn.extract_git_branch(), Some("dev"));
|
||||
assert_eq!(urn.extract_git_path(), Some("src/main.rs".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_path_preservation() {
|
||||
// Paths should preserve their natural / separators
|
||||
let urn = SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:main:src/components/Button.tsx").unwrap();
|
||||
assert_eq!(urn.extract_git_path(), Some("src/components/Button.tsx".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_branch_with_slashes() {
|
||||
// Feature branches often have slashes
|
||||
let urn = SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:feature/new-auth:src/auth.rs").unwrap();
|
||||
assert_eq!(urn.extract_git_branch(), Some("feature/new-auth"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_validation_valid_cases() {
|
||||
// All valid formats
|
||||
assert!(SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:main:src/lib.rs").is_ok());
|
||||
assert!(SourceUrn::parse("urn:smem:code:git:git.example.com::repo:main:src/lib.rs").is_ok());
|
||||
assert!(SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:feature/branch:src/file.rs").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_validation_invalid_cases() {
|
||||
// Missing components - these should parse but be invalid
|
||||
let urn1 = SourceUrn::parse("urn:smem:code:git:github.com:acme:repo:main").unwrap(); // Missing path
|
||||
assert!(!urn1.is_valid_git_urn());
|
||||
|
||||
let urn2 = SourceUrn::parse("urn:smem:code:git:github.com:acme:repo").unwrap(); // Missing branch and path
|
||||
assert!(!urn2.is_valid_git_urn());
|
||||
|
||||
let urn3 = SourceUrn::parse("urn:smem:code:git:github.com:acme").unwrap(); // Missing repo, branch, path
|
||||
assert!(!urn3.is_valid_git_urn());
|
||||
|
||||
// Empty required components
|
||||
let urn4 = SourceUrn::parse("urn:smem:code:git::acme:repo:main:src/lib.rs").unwrap(); // Empty host
|
||||
assert!(!urn4.is_valid_git_urn());
|
||||
|
||||
let urn5 = SourceUrn::parse("urn:smem:code:git:github.com::::src/lib.rs").unwrap(); // Empty repo and branch
|
||||
assert!(!urn5.is_valid_git_urn());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_urn_roundtrip_consistency() {
|
||||
// Parse -> Serialize -> Parse should be identical
|
||||
let original = "urn:smem:code:git:gitlab.com:company:project:feature/new-ui:src/components/Button.tsx#L42";
|
||||
|
||||
let parsed1 = SourceUrn::parse(original).unwrap();
|
||||
let serialized = parsed1.to_urn();
|
||||
let parsed2 = SourceUrn::parse(&serialized).unwrap();
|
||||
|
||||
assert_eq!(parsed1.locator, parsed2.locator);
|
||||
assert_eq!(parsed1.fragment, parsed2.fragment);
|
||||
assert_eq!(serialized, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_git_urn_with_organization() {
|
||||
let urn = SourceUrn::build_git_urn(
|
||||
"github.com",
|
||||
Some("acme"),
|
||||
"repo",
|
||||
"main",
|
||||
"src/lib.rs",
|
||||
Some("L10-L20")
|
||||
).unwrap();
|
||||
|
||||
assert_eq!(urn, "urn:smem:code:git:github.com:acme:repo:main:src/lib.rs#L10-L20");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_git_urn_without_organization() {
|
||||
let urn = SourceUrn::build_git_urn(
|
||||
"git.example.com",
|
||||
None,
|
||||
"repo",
|
||||
"main",
|
||||
"src/lib.rs",
|
||||
None
|
||||
).unwrap();
|
||||
|
||||
assert_eq!(urn, "urn:smem:code:git:git.example.com::repo:main:src/lib.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_git_urn_validation() {
|
||||
// Empty host
|
||||
assert!(SourceUrn::build_git_urn("", Some("acme"), "repo", "main", "src/lib.rs", None).is_err());
|
||||
|
||||
// Empty repo
|
||||
assert!(SourceUrn::build_git_urn("github.com", Some("acme"), "", "main", "src/lib.rs", None).is_err());
|
||||
|
||||
// Empty branch
|
||||
assert!(SourceUrn::build_git_urn("github.com", Some("acme"), "repo", "", "src/lib.rs", None).is_err());
|
||||
|
||||
// Empty path
|
||||
assert!(SourceUrn::build_git_urn("github.com", Some("acme"), "repo", "main", "", None).is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user