Files
mistralai-client-rs/src/v1/tool.rs
Federico G. Schwindt 8e9f7a5386 feat: mark Function trait as Send (#12)
Allow to use Client in places that are also Send.
2024-07-24 20:16:11 +02:00

145 lines
3.9 KiB
Rust

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::{any::Any, collections::HashMap, fmt::Debug};
// -----------------------------------------------------------------------------
// Definitions
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCall {
pub function: ToolCallFunction,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Tool {
pub r#type: ToolType,
pub function: ToolFunction,
}
impl Tool {
pub fn new(
function_name: String,
function_description: String,
function_parameters: Vec<ToolFunctionParameter>,
) -> Self {
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
.into_iter()
.map(|param| {
(
param.name,
ToolFunctionParameterProperty {
r#type: param.r#type,
description: param.description,
},
)
})
.collect();
let property_names = properties.keys().cloned().collect();
let parameters = ToolFunctionParameters {
r#type: ToolFunctionParametersType::Object,
properties,
required: property_names,
};
Self {
r#type: ToolType::Function,
function: ToolFunction {
name: function_name,
description: function_description,
parameters,
},
}
}
}
// -----------------------------------------------------------------------------
// Request
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction {
name: String,
description: String,
parameters: ToolFunctionParameters,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameter {
name: String,
description: String,
r#type: ToolFunctionParameterType,
}
impl ToolFunctionParameter {
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
Self {
name,
r#type,
description,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameters {
r#type: ToolFunctionParametersType,
properties: HashMap<String, ToolFunctionParameterProperty>,
required: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameterProperty {
r#type: ToolFunctionParameterType,
description: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParametersType {
#[serde(rename = "object")]
Object,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParameterType {
#[serde(rename = "string")]
String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
/// An enum representing how functions should be called.
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolChoice {
/// The model is forced to call a function.
#[serde(rename = "any")]
Any,
/// The model can choose to either generate a message or call a function.
#[serde(rename = "auto")]
Auto,
/// The model won't call a function and will generate a message instead.
#[serde(rename = "none")]
None,
}
// -----------------------------------------------------------------------------
// Custom
#[async_trait]
pub trait Function: Send {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send>;
}
impl Debug for dyn Function {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Function()")
}
}