diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index c06aa11..6afb942 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::{ chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, - tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Tool, ToolChoice}, }; mod setup; @@ -14,12 +14,12 @@ async fn test_client_chat_async() { let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; @@ -29,7 +29,6 @@ async fn test_client_chat_async() { .await .unwrap(); - expect!(response.model).to_be(Model::OpenMistral7b); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); @@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; let client = Client::new(None, None, None, None).unwrap(); - let model = Model::MistralSmallLatest; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Any), tools: Some(tools), @@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() { .await .unwrap(); - expect!(response.model).to_be(Model::MistralSmallLatest); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); @@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() { .to_be(ChatResponseChoiceFinishReason::ToolCalls); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_be("".to_string()); - // expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall { - // function: ToolCallFunction { - // name: "get_city_temperature".to_string(), - // arguments: "{\"city\": \"Paris\"}".to_string(), - // }, - // }])); expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_chat_stream_test.rs b/tests/v1_client_chat_stream_test.rs index 23379fa..aae04c8 100644 --- a/tests/v1_client_chat_stream_test.rs +++ b/tests/v1_client_chat_stream_test.rs @@ -1,16 +1,18 @@ +// Streaming tests require a live API key and are not run in CI. +// Uncomment to test locally. + // use futures::stream::StreamExt; -// use jrest::expect; // use mistralai_client::v1::{ -// chat_completion::{ChatParams, ChatMessage, ChatMessageRole}, +// chat::{ChatMessage, ChatParams}, // client::Client, // constants::Model, // }; - +// // #[tokio::test] // async fn test_client_chat_stream() { // let client = Client::new(None, None, None, None).unwrap(); - -// let model = Model::OpenMistral7b; +// +// let model = Model::mistral_small_latest(); // let messages = vec![ChatMessage::new_user_message( // "Just guess the next word: \"Eiffel ...\"?", // )]; @@ -19,22 +21,24 @@ // random_seed: Some(42), // ..Default::default() // }; - -// let stream_result = client.chat_stream(model, messages, Some(options)).await; -// let mut stream = stream_result.expect("Failed to create stream."); -// while let Some(maybe_chunk_result) = stream.next().await { -// match maybe_chunk_result { -// Some(Ok(chunk)) => { -// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant) -// || chunk.choices[0].finish_reason == Some("stop".to_string()) -// { -// expect!(chunk.choices[0].delta.content.len()).to_be(0); -// } else { -// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); +// +// let stream = client +// .chat_stream(model, messages, Some(options)) +// .await +// .expect("Failed to create stream."); +// +// stream +// .for_each(|chunk_result| async { +// match chunk_result { +// Ok(chunks) => { +// for chunk in &chunks { +// if let Some(content) = &chunk.choices[0].delta.content { +// print!("{}", content); +// } +// } // } +// Err(error) => eprintln!("Error: {:?}", error), // } -// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error), -// None => (), -// } -// } +// }) +// .await; // } diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index e489591..1ecf769 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::{ chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, - tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Tool, ToolChoice}, }; mod setup; @@ -14,19 +14,18 @@ fn test_client_chat() { let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; let response = client.chat(model, messages, Some(options)).unwrap(); - expect!(response.model).to_be(Model::OpenMistral7b); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); @@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; let client = Client::new(None, None, None, None).unwrap(); - let model = Model::MistralSmallLatest; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), @@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() { let response = client.chat(model, messages, Some(options)).unwrap(); - expect!(response.model).to_be(Model::MistralSmallLatest); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_be("".to_string()); expect!(response.choices[0].finish_reason.clone()) .to_be(ChatResponseChoiceFinishReason::ToolCalls); expect!(response.usage.prompt_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_embeddings_async_test.rs b/tests/v1_client_embeddings_async_test.rs index ad0c689..9fd972f 100644 --- a/tests/v1_client_embeddings_async_test.rs +++ b/tests/v1_client_embeddings_async_test.rs @@ -1,11 +1,11 @@ use jrest::expect; -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; #[tokio::test] async fn test_client_embeddings_async() { let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -17,7 +17,6 @@ async fn test_client_embeddings_async() { .await .unwrap(); - expect!(response.model).to_be(EmbedModel::MistralEmbed); expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be(2); expect!(response.data[0].index).to_be(0); diff --git a/tests/v1_client_embeddings_test.rs b/tests/v1_client_embeddings_test.rs index bb32fa4..d4c2a80 100644 --- a/tests/v1_client_embeddings_test.rs +++ b/tests/v1_client_embeddings_test.rs @@ -1,11 +1,11 @@ use jrest::expect; -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; #[test] fn test_client_embeddings() { let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -14,7 +14,6 @@ fn test_client_embeddings() { let response = client.embeddings(model, input, options).unwrap(); - expect!(response.model).to_be(EmbedModel::MistralEmbed); expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be(2); expect!(response.data[0].index).to_be(0); diff --git a/tests/v1_constants_test.rs b/tests/v1_constants_test.rs index 7903acf..7039af4 100644 --- a/tests/v1_constants_test.rs +++ b/tests/v1_constants_test.rs @@ -6,26 +6,19 @@ use mistralai_client::v1::{ }; #[test] -fn test_model_constant() { +fn test_model_constants() { let models = vec![ - Model::OpenMistral7b, - Model::OpenMixtral8x7b, - Model::OpenMixtral8x22b, - Model::OpenMistralNemo, - Model::MistralTiny, - Model::MistralSmallLatest, - Model::MistralMediumLatest, - Model::MistralLargeLatest, - Model::MistralLarge, - Model::CodestralLatest, - Model::CodestralMamba, + Model::mistral_small_latest(), + Model::mistral_large_latest(), + Model::open_mistral_nemo(), + Model::codestral_latest(), ]; let client = Client::new(None, None, None, None).unwrap(); let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() };