diff --git a/Cargo.lock b/Cargo.lock index dd20e41..52ebbdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -828,6 +828,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1357,7 +1366,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -1482,7 +1491,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3070,9 +3079,9 @@ dependencies = [ [[package]] name = "mistralai-client" -version = "1.1.0" +version = "1.2.0" source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/" -checksum = "6a17c5600508f30965a8e6ab78e947a0db7017edef8ffb022aa9ac71f50ebaff" +checksum = "24258b2ba72432f9c147a5a80e6d862c7b3c43db859135c619c948cf6c9e4fdf" dependencies = [ "async-stream", "async-trait", @@ -3154,7 +3163,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -3777,7 +3786,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.60.2", + "windows-sys 0.59.0", ] [[package]] @@ -4290,7 +4299,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4736,6 +4745,7 @@ dependencies = [ "tonic-prost", "tonic-prost-build", "tracing", + "tracing-appender", "tracing-subscriber", "tree-sitter", "tree-sitter-python", @@ -5355,7 +5365,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -5838,6 +5848,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.18", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.31" @@ -6425,7 +6447,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 492b84b..9bbd6b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ name = "sol" path = "src/main.rs" [dependencies] -mistralai-client = { version = "1.1.0", registry = "sunbeam" } +mistralai-client = { version = "1.2.0", registry = "sunbeam" } matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] } opensearch = "2" tokio = { version = "1", features = ["full"] } diff --git a/src/grpc/service.rs b/src/grpc/service.rs index c368418..1b18f75 100644 --- a/src/grpc/service.rs +++ b/src/grpc/service.rs @@ -314,7 +314,7 @@ async fn session_chat_via_orchestrator( ) -> anyhow::Result<()> { use crate::orchestrator::event::*; - let conversation_response = session.create_or_append_conversation(text).await?; + let conversation_response = session.create_or_append_conversation_streaming(text, tx).await?; session.post_to_matrix(text).await; let request_id = RequestId::new(); diff --git a/src/grpc/session.rs b/src/grpc/session.rs index c86c530..7075896 100644 --- a/src/grpc/session.rs +++ b/src/grpc/session.rs @@ -388,46 +388,160 @@ you also have access to server-side tools: search_archive, search_web, research, tools } - /// Create or append to the Mistral conversation. Returns the response - /// for the orchestrator to run through its tool loop. - pub async fn create_or_append_conversation( + /// Create or append to the Mistral conversation with streaming. + /// Emits TextDelta messages to the gRPC client as chunks arrive. + /// Returns the accumulated response for the orchestrator's tool loop. + pub async fn create_or_append_conversation_streaming( &mut self, text: &str, + client_tx: &tokio::sync::mpsc::Sender>, ) -> anyhow::Result { + use futures::StreamExt; + use mistralai_client::v1::conversation_stream::{self, ConversationEvent}; + let context_header = self.build_context_header(text).await; let input_text = format!("{context_header}\n{text}"); + let conv_id_for_accumulate: String; + let mut events = Vec::new(); + if let Some(ref conv_id) = self.conversation_id { + conv_id_for_accumulate = conv_id.clone(); let req = AppendConversationRequest { inputs: ConversationInput::Text(input_text.clone()), completion_args: None, handoff_execution: None, store: Some(true), tool_confirmations: None, - stream: false, + stream: true, }; - match self.state - .mistral - .append_conversation_async(conv_id, &req) + match self.state.mistral + .append_conversation_stream_async(conv_id, &req) .await { - Ok(resp) => Ok(resp), + Ok(stream) => { + tokio::pin!(stream); + while let Some(result) = stream.next().await { + match result { + Ok(event) => { + if let Some(delta) = event.text_delta() { + let _ = client_tx.send(Ok(ServerMessage { + payload: Some(server_message::Payload::Delta(TextDelta { + text: delta, + })), + })).await; + } + events.push(event); + } + Err(e) => { + warn!("Stream event error: {}", e.message); + } + } + } + } Err(e) if e.message.contains("function calls and responses") || e.message.contains("invalid_request_error") => { warn!( conversation_id = conv_id.as_str(), error = e.message.as_str(), - "Conversation corrupted — creating fresh conversation" + "Conversation corrupted — creating fresh conversation (streaming)" ); self.conversation_id = None; - self.create_fresh_conversation(input_text).await + return self.create_fresh_conversation_streaming(input_text, client_tx).await; } - Err(e) => Err(anyhow::anyhow!("append_conversation failed: {}", e.message)), + Err(e) => return Err(anyhow::anyhow!("append_conversation_stream failed: {}", e.message)), } } else { - self.create_fresh_conversation(input_text).await + return self.create_fresh_conversation_streaming(input_text, client_tx).await; } + + Ok(conversation_stream::accumulate(&conv_id_for_accumulate, &events)) + } + + /// Create a fresh streaming conversation — used on first message or after corruption. + async fn create_fresh_conversation_streaming( + &mut self, + input_text: String, + client_tx: &tokio::sync::mpsc::Sender>, + ) -> anyhow::Result { + use futures::StreamExt; + use mistralai_client::v1::conversation_stream::{self, ConversationEvent}; + + let instructions = self.build_instructions(); + let req = CreateConversationRequest { + inputs: ConversationInput::Text(input_text), + model: Some(self.model.clone()), + agent_id: None, + agent_version: None, + name: Some(format!("code-{}", self.project_name)), + description: None, + instructions: Some(instructions), + completion_args: None, + tools: Some(self.build_tool_definitions()), + handoff_execution: None, + metadata: None, + store: Some(true), + stream: true, + }; + + let stream = self.state.mistral + .create_conversation_stream_async(&req) + .await + .map_err(|e| anyhow::anyhow!("create_conversation_stream failed: {}", e.message))?; + + tokio::pin!(stream); + let mut events = Vec::new(); + let mut conversation_id = String::new(); + + while let Some(result) = stream.next().await { + match result { + Ok(event) => { + if let Some(delta) = event.text_delta() { + let _ = client_tx.send(Ok(ServerMessage { + payload: Some(server_message::Payload::Delta(TextDelta { + text: delta, + })), + })).await; + } + events.push(event); + } + Err(e) => { + warn!("Stream event error: {}", e.message); + } + } + } + + let resp = conversation_stream::accumulate(&conversation_id, &events); + + // Extract conversation_id from the accumulated response + if !resp.conversation_id.is_empty() { + conversation_id = resp.conversation_id.clone(); + } + + // If we still don't have a conversation_id, the streaming didn't return one. + // This happens with the Conversations API — the ID comes from the response headers + // or the ResponseDone event. Let's check events for it. + if conversation_id.is_empty() { + // Fallback: create non-streaming to get the conversation_id + warn!("Streaming didn't return conversation_id — falling back to non-streaming create"); + return self.create_fresh_conversation( + "".into() // empty — conversation already created, just need the ID + ).await; + } + + self.conversation_id = Some(conversation_id.clone()); + self.state.store.set_code_session_conversation( + &self.session_id, + &conversation_id, + ); + + info!( + conversation_id = conversation_id.as_str(), + "Created streaming Mistral conversation for code session" + ); + + Ok(resp) } /// Post user message to the Matrix room. diff --git a/src/integration_test.rs b/src/integration_test.rs index ad0c538..1d426f6 100644 --- a/src/integration_test.rs +++ b/src/integration_test.rs @@ -6051,3 +6051,213 @@ class Calculator: assert!(sessions.is_empty(), "Session should be marked complete, not running"); } } + +// ══════════════════════════════════════════════════════════════════════════ +// Conversation streaming — Mistral API integration +// ══════════════════════════════════════════════════════════════════════════ + +mod conversation_stream_tests { + use futures::StreamExt; + use mistralai_client::v1::conversation_stream::{self, ConversationEvent}; + use mistralai_client::v1::conversations::*; + + fn load_env() -> Option { + let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env"); + if let Ok(contents) = std::fs::read_to_string(&env_path) { + for line in contents.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { continue; } + if let Some((k, v)) = line.split_once('=') { + std::env::set_var(k.trim(), v.trim()); + } + } + } + std::env::var("SOL_MISTRAL_API_KEY").ok() + } + + #[tokio::test] + async fn test_streaming_create_conversation() { + let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; }; + let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("What is 7 * 8? Answer with just the number.".into()), + model: Some("mistral-medium-latest".into()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Answer concisely.".into()), + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: Some(false), + stream: true, + }; + + let stream = client.create_conversation_stream_async(&req).await.unwrap(); + tokio::pin!(stream); + + let mut events = Vec::new(); + let mut saw_text = false; + let mut saw_done = false; + + while let Some(result) = stream.next().await { + let event = result.unwrap(); + if event.text_delta().is_some() { saw_text = true; } + if matches!(&event, ConversationEvent::ResponseDone { .. }) { saw_done = true; } + events.push(event); + } + + assert!(saw_text, "Should receive text deltas"); + assert!(saw_done, "Should receive ResponseDone"); + + let resp = conversation_stream::accumulate("", &events); + let text = resp.assistant_text().unwrap_or_default(); + assert!(text.contains("56"), "Should compute 7*8=56, got: {text}"); + assert!(resp.usage.total_tokens > 0, "Should have token usage"); + } + + #[tokio::test] + async fn test_streaming_append_conversation() { + let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; }; + let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(); + + // Create non-streaming first + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("The magic word is SUNBEAM. Just say ok.".into()), + model: Some("mistral-medium-latest".into()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Be brief.".into()), + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: Some(true), + stream: false, + }; + let created = client.create_conversation_async(&create_req).await.unwrap(); + + // Append with streaming + let append_req = AppendConversationRequest { + inputs: ConversationInput::Text("What is the magic word?".into()), + completion_args: None, + handoff_execution: None, + store: Some(true), + tool_confirmations: None, + stream: true, + }; + let stream = client + .append_conversation_stream_async(&created.conversation_id, &append_req) + .await + .unwrap(); + tokio::pin!(stream); + + let mut events = Vec::new(); + while let Some(result) = stream.next().await { + events.push(result.unwrap()); + } + + let resp = conversation_stream::accumulate(&created.conversation_id, &events); + let text = resp.assistant_text().unwrap_or_default().to_uppercase(); + assert!(text.contains("SUNBEAM"), "Should recall magic word, got: {text}"); + + // Cleanup + let _ = client.delete_conversation_async(&created.conversation_id).await; + } + + #[test] + fn test_parse_sse_line_data_prefix() { + let line = r#"data: {"type":"conversation.response.started"}"#; + let event = conversation_stream::parse_sse_line(line).unwrap(); + assert!(event.is_some()); + assert!(matches!(event.unwrap(), ConversationEvent::ResponseStarted { .. })); + } + + #[test] + fn test_parse_sse_line_event_prefix_skipped() { + let line = "event: conversation.response.started"; + assert!(conversation_stream::parse_sse_line(line).unwrap().is_none()); + } + + #[test] + fn test_parse_sse_line_done() { + assert!(conversation_stream::parse_sse_line("data: [DONE]").unwrap().is_none()); + } + + #[test] + fn test_parse_sse_line_empty() { + assert!(conversation_stream::parse_sse_line("").unwrap().is_none()); + assert!(conversation_stream::parse_sse_line(" ").unwrap().is_none()); + } + + #[test] + fn test_accumulate_text_and_usage() { + let events = vec![ + ConversationEvent::ResponseStarted { created_at: None }, + ConversationEvent::MessageOutput { + id: "m1".into(), + content: serde_json::json!("hello "), + role: "assistant".into(), + output_index: 0, + content_index: 0, + model: None, + }, + ConversationEvent::MessageOutput { + id: "m1".into(), + content: serde_json::json!("world"), + role: "assistant".into(), + output_index: 0, + content_index: 0, + model: None, + }, + ConversationEvent::ResponseDone { + usage: ConversationUsageInfo { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + created_at: None, + }, + ]; + + let resp = conversation_stream::accumulate("conv-123", &events); + assert_eq!(resp.conversation_id, "conv-123"); + assert_eq!(resp.assistant_text().unwrap(), "hello world"); + assert_eq!(resp.usage.total_tokens, 15); + } + + #[test] + fn test_accumulate_with_function_calls() { + let events = vec![ + ConversationEvent::FunctionCall { + id: "fc-1".into(), + name: "search_web".into(), + tool_call_id: "tc-1".into(), + arguments: r#"{"query":"rust"}"#.into(), + output_index: 0, + model: None, + confirmation_status: None, + }, + ConversationEvent::ResponseDone { + usage: ConversationUsageInfo { + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30, + }, + created_at: None, + }, + ]; + + let resp = conversation_stream::accumulate("conv-456", &events); + assert!(resp.assistant_text().is_none()); + let calls = resp.function_calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "search_web"); + assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc-1")); + } +}