Files
mistralai-client-rs/examples/chat_with_function_calling_async.rs
Ivan Gabriele 74bf8a96ee feat!: add function calling support to client.chat() & client.chat_async()
BREAKING CHANGE: Too many to count in this version. Check the README examples.
2024-03-09 11:40:07 +01:00

76 lines
2.2 KiB
Rust

use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
client::Client,
constants::Model,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
};
use serde::Deserialize;
use std::any::Any;
#[derive(Debug, Deserialize)]
struct GetCityTemperatureArguments {
city: String,
}
struct GetCityTemperatureFunction;
#[async_trait::async_trait]
impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
// Deserialize arguments, perform the logic, and return the result
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() {
"Paris" => "20°C",
_ => "Unknown city",
};
Box::new(temperature.to_string())
}
}
#[tokio::main]
async fn main() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new(
"city".to_string(),
"The name of the city.".to_string(),
ToolFunctionParameterType::String,
)],
)];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let mut client = Client::new(None, None, None, None).unwrap();
client.register_function(
"get_city_temperature".to_string(),
Box::new(GetCityTemperatureFunction),
);
let model = Model::MistralSmallLatest;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "What's the temperature in Paris?".to_string(),
tool_calls: None,
}];
let options = ChatParams {
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
tools: Some(tools),
..Default::default()
};
client
.chat_async(model, messages, Some(options))
.await
.unwrap();
let temperature = client
.get_last_function_call_result()
.unwrap()
.downcast::<String>()
.unwrap();
println!("The temperature in Paris is: {}.", temperature);
// => "The temperature in Paris is: 20°C."
}