feat(grpc): proper tool result relay via tokio::select
session_chat_via_orchestrator now: - Spawns generation on a background task - Reads in_stream for client tool results in foreground - Forwards results to orchestrator.submit_tool_result() - Uses tokio::select! to handle both concurrently - Uses GenerateRequest + Metadata (no transport types in orchestrator) - Calls grpc::bridge (not orchestrator::grpc_bridge)
This commit is contained in:
@@ -5,7 +5,7 @@ use futures::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::{Request, Response, Status, Streaming};
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::auth::Claims;
|
||||
use super::proto::code_agent_server::CodeAgent;
|
||||
@@ -122,19 +122,41 @@ async fn run_session(
|
||||
}))
|
||||
.await?;
|
||||
|
||||
// Check if orchestrator is available
|
||||
let has_orch = state.orchestrator.is_some();
|
||||
info!(has_orchestrator = has_orch, "Checking orchestrator availability");
|
||||
let orchestrator = state.orchestrator.as_ref().cloned();
|
||||
|
||||
// Main message loop
|
||||
while let Some(msg) = in_stream.message().await? {
|
||||
match msg.payload {
|
||||
Some(client_message::Payload::Input(input)) => {
|
||||
if let Err(e) = session.chat(&input.text, tx, in_stream).await {
|
||||
error!("Chat error: {e}");
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Error(Error {
|
||||
message: e.to_string(),
|
||||
fatal: false,
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
if let Some(ref orch) = orchestrator {
|
||||
// Orchestrator path: delegate tool loop, bridge forwards events
|
||||
if let Err(e) = session_chat_via_orchestrator(
|
||||
&mut session, &input.text, orch, tx, in_stream,
|
||||
).await {
|
||||
error!("Chat error: {e}");
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Error(Error {
|
||||
message: e.to_string(),
|
||||
fatal: false,
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
}
|
||||
} else {
|
||||
// Fallback: inline tool loop (legacy)
|
||||
if let Err(e) = session.chat(&input.text, tx, in_stream).await {
|
||||
error!("Chat error: {e}");
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Error(Error {
|
||||
message: e.to_string(),
|
||||
fatal: false,
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(client_message::Payload::End(_)) => {
|
||||
@@ -150,10 +172,95 @@ async fn run_session(
|
||||
Some(client_message::Payload::Start(_)) => {
|
||||
warn!("Received duplicate StartSession — ignoring");
|
||||
}
|
||||
// ToolResult and Approval are handled inside session.chat()
|
||||
// ToolResult and Approval are handled by the orchestrator bridge
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Chat via the orchestrator: session handles conversation creation,
|
||||
/// orchestrator handles the tool loop, gRPC bridge forwards events.
|
||||
/// Client-side tool results are read from in_stream and forwarded to the orchestrator.
|
||||
async fn session_chat_via_orchestrator(
|
||||
session: &mut super::session::CodeSession,
|
||||
text: &str,
|
||||
orchestrator: &Arc<crate::orchestrator::Orchestrator>,
|
||||
tx: &mpsc::Sender<Result<ServerMessage, Status>>,
|
||||
in_stream: &mut Streaming<ClientMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::orchestrator::event::*;
|
||||
|
||||
let conversation_response = session.create_or_append_conversation(text).await?;
|
||||
session.post_to_matrix(text).await;
|
||||
|
||||
let request_id = RequestId::new();
|
||||
let request = GenerateRequest {
|
||||
request_id: request_id.clone(),
|
||||
text: text.into(),
|
||||
user_id: "dev".into(),
|
||||
display_name: None,
|
||||
conversation_key: session.session_id.clone(),
|
||||
is_direct: true,
|
||||
image: None,
|
||||
metadata: Metadata::new()
|
||||
.with("session_id", session.session_id.as_str())
|
||||
.with("room_id", session.room_id.as_str()),
|
||||
};
|
||||
|
||||
// Subscribe BEFORE starting generation
|
||||
let event_rx = orchestrator.subscribe();
|
||||
|
||||
// Spawn gRPC bridge (lives in grpc module, not orchestrator)
|
||||
let tx_clone = tx.clone();
|
||||
let rid_for_bridge = request_id.clone();
|
||||
let bridge_handle = tokio::spawn(async move {
|
||||
super::bridge::bridge_events_to_grpc(rid_for_bridge, event_rx, tx_clone).await;
|
||||
});
|
||||
|
||||
// Spawn orchestrator generation
|
||||
let orch_for_gen = orchestrator.clone();
|
||||
let mut gen_handle = tokio::spawn(async move {
|
||||
orch_for_gen.generate_from_response(&request, conversation_response).await
|
||||
});
|
||||
|
||||
// Read client tool results while generation runs
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = &mut gen_handle => {
|
||||
let gen_result = result.unwrap_or(None);
|
||||
if let Some(ref response_text) = gen_result {
|
||||
session.post_response_to_matrix(response_text).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
msg = in_stream.message() => {
|
||||
match msg {
|
||||
Ok(Some(msg)) => match msg.payload {
|
||||
Some(client_message::Payload::ToolResult(result)) => {
|
||||
debug!(call_id = result.call_id.as_str(), "Forwarding tool result");
|
||||
let _ = orchestrator.submit_tool_result(
|
||||
&result.call_id,
|
||||
ToolResultPayload { text: result.result, is_error: result.is_error },
|
||||
).await;
|
||||
}
|
||||
Some(client_message::Payload::Approval(a)) if !a.approved => {
|
||||
let _ = orchestrator.submit_tool_result(
|
||||
&a.call_id,
|
||||
ToolResultPayload { text: "Denied by user.".into(), is_error: true },
|
||||
).await;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Ok(None) => break,
|
||||
Err(e) => { warn!("Client stream error: {e}"); break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = bridge_handle.await;
|
||||
session.touch();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -480,6 +480,87 @@ you also have access to server-side tools: search_archive, search_web, research,
|
||||
tools
|
||||
}
|
||||
|
||||
/// Create or append to the Mistral conversation. Returns the response
|
||||
/// for the orchestrator to run through its tool loop.
|
||||
pub async fn create_or_append_conversation(
|
||||
&mut self,
|
||||
text: &str,
|
||||
) -> anyhow::Result<mistralai_client::v1::conversations::ConversationResponse> {
|
||||
let context_header = self.build_context_header();
|
||||
let input_text = format!("{context_header}\n{text}");
|
||||
|
||||
if let Some(ref conv_id) = self.conversation_id {
|
||||
let req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Text(input_text),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
};
|
||||
self.state
|
||||
.mistral
|
||||
.append_conversation_async(conv_id, &req)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("append_conversation failed: {}", e.message))
|
||||
} else {
|
||||
let instructions = self.build_instructions();
|
||||
let req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text(input_text),
|
||||
model: Some(self.model.clone()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: Some(format!("code-{}", self.project_name)),
|
||||
description: None,
|
||||
instructions: Some(instructions),
|
||||
completion_args: None,
|
||||
tools: Some(self.build_tool_definitions()),
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(true),
|
||||
stream: false,
|
||||
};
|
||||
let resp = self.state
|
||||
.mistral
|
||||
.create_conversation_async(&req)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("create_conversation failed: {}", e.message))?;
|
||||
|
||||
self.conversation_id = Some(resp.conversation_id.clone());
|
||||
self.state.store.set_code_session_conversation(
|
||||
&self.session_id,
|
||||
&resp.conversation_id,
|
||||
);
|
||||
|
||||
info!(
|
||||
conversation_id = resp.conversation_id.as_str(),
|
||||
"Created Mistral conversation for code session"
|
||||
);
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
/// Post user message to the Matrix room.
|
||||
pub async fn post_to_matrix(&self, text: &str) {
|
||||
if let Some(ref room) = self.room {
|
||||
let content = RoomMessageEventContent::notice_plain(text);
|
||||
let _ = room.send(content).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Post assistant response to the Matrix room.
|
||||
pub async fn post_response_to_matrix(&self, text: &str) {
|
||||
if let Some(ref room) = self.room {
|
||||
let content = RoomMessageEventContent::text_markdown(text);
|
||||
let _ = room.send(content).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Touch the session's last_active timestamp.
|
||||
pub fn touch(&self) {
|
||||
self.state.store.touch_code_session(&self.session_id);
|
||||
}
|
||||
|
||||
/// Disconnect from the session (keeps it active for future reconnection).
|
||||
pub fn end(&self) {
|
||||
self.state.store.touch_code_session(&self.session_id);
|
||||
|
||||
@@ -307,6 +307,13 @@ async fn main() -> anyhow::Result<()> {
|
||||
if config.grpc.is_some() {
|
||||
let orchestrator_id = state.conversation_registry.get_agent_id().await
|
||||
.unwrap_or_default();
|
||||
let orch = Arc::new(orchestrator::Orchestrator::new(
|
||||
config.clone(),
|
||||
state.responder.tools(),
|
||||
state.mistral.clone(),
|
||||
state.conversation_registry.clone(),
|
||||
system_prompt_text.clone(),
|
||||
));
|
||||
let grpc_state = std::sync::Arc::new(grpc::GrpcState {
|
||||
config: config.clone(),
|
||||
tools: state.responder.tools(),
|
||||
@@ -315,6 +322,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
matrix: matrix_client.clone(),
|
||||
system_prompt: system_prompt_text.clone(),
|
||||
orchestrator_agent_id: orchestrator_id,
|
||||
orchestrator: Some(orch),
|
||||
});
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = grpc::start_server(grpc_state).await {
|
||||
|
||||
Reference in New Issue
Block a user