diff --git a/README.md b/README.md index aebbe59..931e1a4 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Rust client for the Mistral AI API. - [x] Chat without streaming (async) - [ ] Chat with streaming - [x] Embedding -- [ ] Embedding (async) +- [x] Embedding (async) - [x] List models - [x] List models (async) - [ ] Function Calling @@ -122,7 +122,7 @@ async fn main() { ..Default::default() }; - let result = client.chat(model, messages, Some(options)).await.unwrap(); + let result = client.chat_async(model, messages, Some(options)).await.unwrap(); println!("Assistant: {}", result.choices[0].message.content); // => "Assistant: Tower. [...]" } @@ -156,7 +156,26 @@ fn main() { ### Embeddings (async) -_In progress._ +```rs +use mistralai_client::v1::{client::Client, constants::EmbedModel}; + +#[tokio::main] +async fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client: Client = Client::new(None, None, None, None).unwrap(); + + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; + + let response = client.embeddings_async(model, input, options).await.unwrap(); + println!("Embeddings: {:?}", response.data); + // => "Embeddings: [{...}, {...}]" +} +``` ### List models @@ -183,7 +202,7 @@ async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client = Client::new(None, None, None, None).await.unwrap(); - let result = client.list_models().unwrap(); + let result = client.list_models_async().unwrap(); println!("First Model ID: {:?}", result.data[0].id); // => "First Model ID: open-mistral-7b" } diff --git a/src/v1/client.rs b/src/v1/client.rs index 4ec2ce4..5c59afb 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -89,6 +89,22 @@ impl Client { } } + pub async fn embeddings_async( + &self, + model: EmbedModel, + input: Vec, + options: Option, + ) -> Result { + let request = EmbeddingRequest::new(model, input, options); + + let response = self.post_async("/embeddings", &request).await?; + let result = response.json::().await; + match result { + Ok(response) => Ok(response), + Err(error) => Err(self.to_api_error(error)), + } + } + pub fn list_models(&self) -> Result { let response = self.get_sync("/models")?; let result = response.json::(); diff --git a/tests/v1_client_embeddings_async_test.rs b/tests/v1_client_embeddings_async_test.rs new file mode 100644 index 0000000..ad0c689 --- /dev/null +++ b/tests/v1_client_embeddings_async_test.rs @@ -0,0 +1,29 @@ +use jrest::expect; +use mistralai_client::v1::{client::Client, constants::EmbedModel}; + +#[tokio::test] +async fn test_client_embeddings_async() { + let client: Client = Client::new(None, None, None, None).unwrap(); + + let model = EmbedModel::MistralEmbed; + let input = vec!["Embed this sentence.", "As well as this one."] + .iter() + .map(|s| s.to_string()) + .collect(); + let options = None; + + let response = client + .embeddings_async(model, input, options) + .await + .unwrap(); + + expect!(response.model).to_be(EmbedModel::MistralEmbed); + expect!(response.object).to_be("list".to_string()); + expect!(response.data.len()).to_be(2); + expect!(response.data[0].index).to_be(0); + expect!(response.data[0].object.clone()).to_be("embedding".to_string()); + expect!(response.data[0].embedding.len()).to_be_greater_than(0); + expect!(response.usage.prompt_tokens).to_be_greater_than(0); + expect!(response.usage.completion_tokens).to_be(0); + expect!(response.usage.total_tokens).to_be_greater_than(0); +}