Files
mistralai-client-rs/examples/chat_with_function_calling.rs

74 lines
2.1 KiB
Rust
Raw Normal View History

use mistralai_client::v1::{
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
tool::{Function, Tool, ToolChoice},
};
use serde::Deserialize;
use std::any::Any;
#[derive(Debug, Deserialize)]
struct GetCityTemperatureArguments {
city: String,
}
struct GetCityTemperatureFunction;
#[async_trait::async_trait]
impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() {
"Paris" => "20°C",
_ => "Unknown city",
};
Box::new(temperature.to_string())
}
}
fn main() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let mut client = Client::new(None, None, None, None).unwrap();
client.register_function(
"get_city_temperature".to_string(),
Box::new(GetCityTemperatureFunction),
);
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"What's the temperature in Paris?",
)];
let options = ChatParams {
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
tools: Some(tools),
..Default::default()
};
client.chat(model, messages, Some(options)).unwrap();
let temperature = client
.get_last_function_call_result()
.unwrap()
.downcast::<String>()
.unwrap();
println!("The temperature in Paris is: {}.", temperature);
// => "The temperature in Paris is: 20°C."
}