diff --git a/Cargo.toml b/Cargo.toml index 6b1ca4a..c6d67d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ readme = "README.md" repository = "https://github.com/ivangabriele/mistralai-client-rs" [dependencies] -reqwest = { version = "0.11.24", features = ["json", "blocking"] } +futures = "0.3.30" +reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] } serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.114" thiserror = "1.0.57" diff --git a/README.md b/README.md index 931e1a4..a828533 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Rust client for the Mistral AI API. - [Usage](#usage) - [Chat without streaming](#chat-without-streaming) - [Chat without streaming (async)](#chat-without-streaming-async) - - [Chat with streaming](#chat-with-streaming) + - [Chat with streaming (async)](#chat-with-streaming-async) - [Embeddings](#embeddings) - [Embeddings (async)](#embeddings-async) - [List models](#list-models) @@ -29,7 +29,7 @@ Rust client for the Mistral AI API. - [x] Chat without streaming - [x] Chat without streaming (async) -- [ ] Chat with streaming +- [x] Chat with streaming - [x] Embedding - [x] Embedding (async) - [x] List models @@ -71,7 +71,7 @@ fn main() { ```rs use mistralai_client::v1::{ - chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequestOptions}, + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, client::Client, constants::Model, }; @@ -81,8 +81,8 @@ fn main() { let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatCompletionMessage { - role: ChatCompletionMessageRole::user, + let messages = vec![ChatMessage { + role: ChatMessageRole::user, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), }]; let options = ChatCompletionRequestOptions { @@ -101,7 +101,7 @@ fn main() { ```rs use mistralai_client::v1::{ - chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequestOptions}, + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, client::Client, constants::Model, }; @@ -112,8 +112,8 @@ async fn main() { let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatCompletionMessage { - role: ChatCompletionMessageRole::user, + let messages = vec![ChatMessage { + role: ChatMessageRole::user, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), }]; let options = ChatCompletionRequestOptions { @@ -128,9 +128,44 @@ async fn main() { } ``` -### Chat with streaming +### Chat with streaming (async) -_In progress._ +```rs +use futures::stream::StreamExt; +use mistralai_client::v1::{ + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + client::Client, + constants::Model, +}; + +[#tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatMessage { + role: ChatMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let stream_result = client.chat_stream(model, messages, Some(options)).await; + let mut stream = stream_result.expect("Failed to create stream."); + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + println!("Assistant (message chunk): {}", chunk.choices[0].delta.content); + } + Err(e) => eprintln!("Error processing chunk: {:?}", e), + } + } +} +``` ### Embeddings diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 6349219..52b3cba 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -2,6 +2,25 @@ use serde::{Deserialize, Serialize}; use crate::v1::{common, constants}; +// ----------------------------------------------------------------------------- +// Definitions + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ChatMessage { + pub role: ChatMessageRole, + pub content: String, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[allow(non_camel_case_types)] +pub enum ChatMessageRole { + assistant, + user, +} + +// ----------------------------------------------------------------------------- +// Request + #[derive(Debug)] pub struct ChatCompletionParams { pub tools: Option, @@ -9,7 +28,6 @@ pub struct ChatCompletionParams { pub max_tokens: Option, pub top_p: Option, pub random_seed: Option, - pub stream: Option, pub safe_prompt: Option, } impl Default for ChatCompletionParams { @@ -20,7 +38,6 @@ impl Default for ChatCompletionParams { max_tokens: None, top_p: None, random_seed: None, - stream: None, safe_prompt: None, } } @@ -28,7 +45,7 @@ impl Default for ChatCompletionParams { #[derive(Debug, Serialize, Deserialize)] pub struct ChatCompletionRequest { - pub messages: Vec, + pub messages: Vec, pub model: constants::Model, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option, @@ -40,8 +57,7 @@ pub struct ChatCompletionRequest { pub top_p: Option, #[serde(skip_serializing_if = "Option::is_none")] pub random_seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, + pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub safe_prompt: Option, // TODO Check this prop (seen in official Python client but not in API doc). @@ -52,7 +68,8 @@ pub struct ChatCompletionRequest { impl ChatCompletionRequest { pub fn new( model: constants::Model, - messages: Vec, + messages: Vec, + stream: bool, options: Option, ) -> Self { let ChatCompletionParams { @@ -61,7 +78,6 @@ impl ChatCompletionRequest { max_tokens, top_p, random_seed, - stream, safe_prompt, } = options.unwrap_or_default(); @@ -79,6 +95,9 @@ impl ChatCompletionRequest { } } +// ----------------------------------------------------------------------------- +// Response + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ChatCompletionResponse { pub id: String, @@ -86,28 +105,45 @@ pub struct ChatCompletionResponse { /// Unix timestamp (in seconds). pub created: u32, pub model: constants::Model, - pub choices: Vec, + pub choices: Vec, pub usage: common::ResponseUsage, } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ChatCompletionChoice { +pub struct ChatCompletionResponseChoice { pub index: u32, - pub message: ChatCompletionMessage, + pub message: ChatMessage, pub finish_reason: String, // TODO Check this prop (seen in API responses but undocumented). // pub logprobs: ??? } -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ChatCompletionMessage { - pub role: ChatCompletionMessageRole, - pub content: String, +// ----------------------------------------------------------------------------- +// Stream + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionStreamChunk { + pub id: String, + pub object: String, + /// Unix timestamp (in seconds). + pub created: u32, + pub model: constants::Model, + pub choices: Vec, + // TODO Check this prop (seen in API responses but undocumented). + // pub usage: ???, } -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -#[allow(non_camel_case_types)] -pub enum ChatCompletionMessageRole { - assistant, - user, +#[derive(Debug, Deserialize)] +pub struct ChatCompletionStreamChunkChoice { + pub index: u32, + pub delta: ChatCompletionStreamChunkChoiceDelta, + pub finish_reason: Option, + // TODO Check this prop (seen in API responses but undocumented). + // pub logprobs: ???, +} + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionStreamChunkChoiceDelta { + pub role: Option, + pub content: String, } diff --git a/src/v1/client.rs b/src/v1/client.rs index 44ddfa0..99d5c64 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -1,9 +1,13 @@ -use crate::v1::error::ApiError; +use futures::stream::StreamExt; +use futures::Stream; use reqwest::Error as ReqwestError; +use serde_json::from_str; + +use crate::v1::error::ApiError; use crate::v1::{ chat_completion::{ - ChatCompletionMessage, ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionParams, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, }, constants::{EmbedModel, Model, API_URL_BASE}, embedding::{EmbeddingRequest, EmbeddingRequestOptions, EmbeddingResponse}, @@ -11,6 +15,8 @@ use crate::v1::{ model_list::ModelListResponse, }; +use super::chat_completion::ChatCompletionStreamChunk; + pub struct Client { pub api_key: String, pub endpoint: String, @@ -44,10 +50,10 @@ impl Client { pub fn chat( &self, model: Model, - messages: Vec, + messages: Vec, options: Option, ) -> Result { - let request = ChatCompletionRequest::new(model, messages, options); + let request = ChatCompletionRequest::new(model, messages, false, options); let response = self.post_sync("/chat/completions", &request)?; let result = response.json::(); @@ -60,10 +66,10 @@ impl Client { pub async fn chat_async( &self, model: Model, - messages: Vec, + messages: Vec, options: Option, ) -> Result { - let request = ChatCompletionRequest::new(model, messages, options); + let request = ChatCompletionRequest::new(model, messages, false, options); let response = self.post_async("/chat/completions", &request).await?; let result = response.json::().await; @@ -73,6 +79,50 @@ impl Client { } } + pub async fn chat_stream( + &self, + model: Model, + messages: Vec, + options: Option, + ) -> Result>, ApiError> { + let request = ChatCompletionRequest::new(model, messages, true, options); + let response = self + .post_stream("/chat/completions", &request) + .await + .map_err(|e| ApiError { + message: e.to_string(), + })?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(ApiError { + message: format!("{}: {}", status, text), + }); + } + + let deserialized_stream = + response + .bytes_stream() + .map(|item| -> Result { + match item { + Ok(bytes) => { + let text = String::from_utf8(bytes.to_vec()).map_err(|e| ApiError { + message: e.to_string(), + })?; + let text_trimmed = text.trim_start_matches("data: "); + from_str(&text_trimmed).map_err(|e| ApiError { + message: e.to_string(), + }) + } + Err(e) => Err(ApiError { + message: e.to_string(), + }), + } + }); + + Ok(deserialized_stream) + } + pub fn embeddings( &self, model: EmbedModel, @@ -156,10 +206,25 @@ impl Client { request_builder } + fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + let user_agent = format!( + "ivangabriele/mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ); + + let request_builder = request + .bearer_auth(&self.api_key) + .header("Accept", "text/event-stream") + .header("Content-Type", "application/json") + .header("User-Agent", user_agent); + + request_builder + } + fn get_sync(&self, path: &str) -> Result { - let client_sync = reqwest::blocking::Client::new(); + let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); - let request = self.build_request_sync(client_sync.get(url)); + let request = self.build_request_sync(reqwest_client.get(url)); let result = request.send(); match result { @@ -263,6 +328,35 @@ impl Client { } } + async fn post_stream( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + let request_builder = reqwest_client.post(url).json(params); + let request = self.build_request_stream(request_builder); + + let result = request.send().await; + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + Err(ApiError { + message: format!("{}: {}", status, text), + }) + } + } + Err(error) => Err(ApiError { + message: error.to_string(), + }), + } + } + fn to_api_error(&self, err: ReqwestError) -> ApiError { ApiError { message: err.to_string(), diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index 2c8eb72..299d4aa 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -1,6 +1,6 @@ use jrest::expect; use mistralai_client::v1::{ - chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionParams}, + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, client::Client, constants::Model, }; @@ -10,8 +10,8 @@ async fn test_client_chat_async() { let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatCompletionMessage { - role: ChatCompletionMessageRole::user, + let messages = vec![ChatMessage { + role: ChatMessageRole::user, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), }]; let options = ChatCompletionParams { @@ -29,7 +29,7 @@ async fn test_client_chat_async() { expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); - expect!(response.choices[0].message.role.clone()).to_be(ChatCompletionMessageRole::assistant); + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant); expect!(response.choices[0].message.content.clone()) .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); expect!(response.usage.prompt_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_chat_stream_test.rs b/tests/v1_client_chat_stream_test.rs new file mode 100644 index 0000000..28449f3 --- /dev/null +++ b/tests/v1_client_chat_stream_test.rs @@ -0,0 +1,40 @@ +use futures::stream::StreamExt; +use jrest::expect; +use mistralai_client::v1::{ + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, + client::Client, + constants::Model, +}; + +#[tokio::test] +async fn test_client_chat_stream() { + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::OpenMistral7b; + let messages = vec![ChatMessage { + role: ChatMessageRole::user, + content: "Just guess the next word: \"Eiffel ...\"?".to_string(), + }]; + let options = ChatCompletionParams { + temperature: Some(0.0), + random_seed: Some(42), + ..Default::default() + }; + + let stream_result = client.chat_stream(model, messages, Some(options)).await; + let mut stream = stream_result.expect("Failed to create stream."); + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if chunk.choices[0].delta.role == Some(ChatMessageRole::assistant) + || chunk.choices[0].finish_reason == Some("stop".to_string()) + { + expect!(chunk.choices[0].delta.content.len()).to_be(0); + } else { + expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); + } + } + Err(e) => eprintln!("Error processing chunk: {:?}", e), + } + } +} diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index 6ae3909..276029d 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -1,6 +1,6 @@ use jrest::expect; use mistralai_client::v1::{ - chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionParams}, + chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole}, client::Client, constants::Model, }; @@ -10,8 +10,8 @@ fn test_client_chat() { let client = Client::new(None, None, None, None).unwrap(); let model = Model::OpenMistral7b; - let messages = vec![ChatCompletionMessage { - role: ChatCompletionMessageRole::user, + let messages = vec![ChatMessage { + role: ChatMessageRole::user, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), }]; let options = ChatCompletionParams { @@ -26,7 +26,7 @@ fn test_client_chat() { expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); - expect!(response.choices[0].message.role.clone()).to_be(ChatCompletionMessageRole::assistant); + expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant); expect!(response.choices[0].message.content.clone()) .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); expect!(response.usage.prompt_tokens).to_be_greater_than(0);