diff --git a/README.md b/README.md index 0a2a3c5..b13e9c2 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,7 @@ fn main() { ..Default::default() }; - let chat_completion_request = ChatCompletionRequest::new(model, messages, Some(options)); - let result = client.chat(chat_completion_request).unwrap(); + let result = client.chat(model, messages, Some(options)).unwrap(); println!("Assistant: {}", result.choices[0].message.content); // => "Assistant: Tower. [...]" } diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 25a2467..c527bd9 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::v1::common; #[derive(Debug)] -pub struct ChatCompletionRequestOptions { +pub struct ChatCompletionParams { pub tools: Option, pub temperature: Option, pub max_tokens: Option, @@ -12,7 +12,7 @@ pub struct ChatCompletionRequestOptions { pub stream: Option, pub safe_prompt: Option, } -impl Default for ChatCompletionRequestOptions { +impl Default for ChatCompletionParams { fn default() -> Self { Self { tools: None, @@ -53,9 +53,9 @@ impl ChatCompletionRequest { pub fn new( model: String, messages: Vec, - options: Option, + options: Option, ) -> Self { - let ChatCompletionRequestOptions { + let ChatCompletionParams { tools, temperature, max_tokens, diff --git a/src/v1/client.rs b/src/v1/client.rs index 293cef9..2e95266 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -6,6 +6,8 @@ use crate::v1::{ constants::API_URL_BASE, }; +use super::chat_completion::{ChatCompletionMessage, ChatCompletionParams}; + pub struct Client { pub api_key: String, pub endpoint: String, @@ -80,18 +82,22 @@ impl Client { let result = request.with_json(params).unwrap().send(); match result { - Ok(res) => { - print!("{:?}", res.as_str().unwrap()); + Ok(response) => { + print!("{:?}", response.as_str().unwrap()); - if (200..=299).contains(&res.status_code) { - Ok(res) + if (200..=299).contains(&response.status_code) { + Ok(response) } else { Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + message: format!( + "{}: {}", + response.status_code, + response.as_str().unwrap() + ), }) } } - Err(e) => Err(self.new_error(e)), + Err(error) => Err(self.new_error(error)), } } @@ -101,16 +107,20 @@ impl Client { let result = request.send(); match result { - Ok(res) => { - if (200..=299).contains(&res.status_code) { - Ok(res) + Ok(response) => { + if (200..=299).contains(&response.status_code) { + Ok(response) } else { Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), + message: format!( + "{}: {}", + response.status_code, + response.as_str().unwrap() + ), }) } } - Err(e) => Err(self.new_error(e)), + Err(error) => Err(self.new_error(error)), } } @@ -132,12 +142,19 @@ impl Client { // } // } - pub fn chat(&self, request: ChatCompletionRequest) -> Result { + pub fn chat( + &self, + model: String, + messages: Vec, + options: Option, + ) -> Result { + let request = ChatCompletionRequest::new(model, messages, options); + let response = self.post("/chat/completions", &request)?; let result = response.json::(); match result { - Ok(r) => Ok(r), - Err(e) => Err(self.new_error(e)), + Ok(response) => Ok(response), + Err(error) => Err(self.new_error(error)), } } diff --git a/tests/v1_chat_completion_test.rs b/tests/v1_chat_completion_test.rs index 302e5e8..e4eab28 100644 --- a/tests/v1_chat_completion_test.rs +++ b/tests/v1_chat_completion_test.rs @@ -1,9 +1,6 @@ use jrest::expect; use mistralai_client::v1::{ - chat_completion::{ - ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionRequest, - ChatCompletionRequestOptions, - }, + chat_completion::{ChatCompletionMessage, ChatCompletionMessageRole, ChatCompletionParams}, client::Client, constants::OPEN_MISTRAL_7B, }; @@ -22,32 +19,22 @@ fn test_chat_completion() { role: ChatCompletionMessageRole::user, content: "Just guess the next word: \"Eiffel ...\"?".to_string(), }]; - let options = ChatCompletionRequestOptions { + let options = ChatCompletionParams { temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; - let chat_completion_request = ChatCompletionRequest::new(model, messages, Some(options)); - let result = client.chat(chat_completion_request); + let response = client.chat(model, messages, Some(options)).unwrap(); - match result { - Ok(res) => { - expect!(res.model).to_be("open-mistral-7b".to_string()); - expect!(res.object).to_be("chat.completion".to_string()); - expect!(res.choices.len()).to_be(1); - expect!(res.choices[0].index).to_be(0); - expect!(res.choices[0].message.role.clone()) - .to_be(ChatCompletionMessageRole::assistant); - expect!(res.choices[0].message.content.clone()).to_be( - "Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string(), - ); - expect!(res.usage.prompt_tokens).to_be_greater_than(0); - expect!(res.usage.completion_tokens).to_be_greater_than(0); - expect!(res.usage.total_tokens).to_be_greater_than(21); - } - Err(err) => { - panic!("Error: {}", err); - } - } + expect!(response.model).to_be("open-mistral-7b".to_string()); + expect!(response.object).to_be("chat.completion".to_string()); + expect!(response.choices.len()).to_be(1); + expect!(response.choices[0].index).to_be(0); + expect!(response.choices[0].message.role.clone()).to_be(ChatCompletionMessageRole::assistant); + expect!(response.choices[0].message.content.clone()) + .to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); + expect!(response.usage.prompt_tokens).to_be_greater_than(0); + expect!(response.usage.completion_tokens).to_be_greater_than(0); + expect!(response.usage.total_tokens).to_be_greater_than(21); }