Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
4c7f1cde0a
|
|||
|
a29c3c0109
|
|||
|
d5eb16dffc
|
|||
|
4dee5cfe0f
|
|||
|
ab2ecec351
|
|||
|
5e6fac99a9
|
|||
|
63f0edf574
|
|||
|
ce96bcfeeb
|
|||
|
4d6eca62ef
|
|||
|
bbb6aaed1c
|
|||
|
83396773ce
|
|||
|
|
9ad6a1dc84 | ||
|
|
7c464830ee | ||
|
|
161b33c725 | ||
|
|
79a410b298 | ||
|
|
4a45ad337f | ||
|
|
2114916941 | ||
|
|
9bfbf2e9dc | ||
|
|
67aa5bbaef | ||
|
|
415fd98167 | ||
|
|
8e9f7a5386 | ||
|
|
3afeec1d58 | ||
|
|
0c097aa56d | ||
|
|
e6539c0ccf | ||
|
|
30156c5273 | ||
|
|
ecd0c3028f | ||
|
|
0df67b1b25 | ||
|
|
f7d012b280 | ||
|
|
5b5bd2d68e | ||
|
|
2fc0642a5e | ||
|
|
cf68a77320 | ||
|
|
e61ace9a18 | ||
|
|
64034402ca | ||
|
|
85c3611afb | ||
|
|
da5fe54115 | ||
|
|
7a5e0679c1 | ||
|
|
99d9d099e2 | ||
|
|
91fb775132 | ||
|
|
7474aa6730 | ||
|
|
6a99eca49c | ||
|
|
fccd59c0cc | ||
|
|
a463cb3106 | ||
|
|
8bee874bd4 | ||
|
|
16464a4c3d | ||
|
|
a4c2d4623d | ||
|
|
ab91154d35 | ||
|
|
74bf8a96ee |
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
[registries.sunbeam]
|
||||||
|
index = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This key is only used for development purposes.
|
|
||||||
# You'll only need one if you want to contribute to this library.
|
|
||||||
export MISTRAL_API_KEY=
|
|
||||||
35
.github/ISSUE_TEMPLATE/bug_report.md
vendored
35
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -1,35 +0,0 @@
|
|||||||
---
|
|
||||||
name: Bug report
|
|
||||||
about: Create a report to help us improve
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Describe the bug**
|
|
||||||
|
|
||||||
A clear and concise description of what the bug is.
|
|
||||||
|
|
||||||
**To Reproduce**
|
|
||||||
|
|
||||||
Steps to reproduce the behavior:
|
|
||||||
|
|
||||||
1. ...
|
|
||||||
2. ...
|
|
||||||
|
|
||||||
**Expected behavior**
|
|
||||||
|
|
||||||
A clear and concise description of what you expected to happen.
|
|
||||||
|
|
||||||
**Screenshots**
|
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
|
||||||
|
|
||||||
**Version**
|
|
||||||
|
|
||||||
If applicable, what version did you use?
|
|
||||||
|
|
||||||
**Environment**
|
|
||||||
|
|
||||||
Add useful information about your configuration and environment here.
|
|
||||||
24
.github/ISSUE_TEMPLATE/feature_request.md
vendored
24
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,24 +0,0 @@
|
|||||||
---
|
|
||||||
name: Feature request
|
|
||||||
about: Suggest an idea for this project
|
|
||||||
title: ''
|
|
||||||
labels: ''
|
|
||||||
assignees: ''
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
|
||||||
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
|
||||||
|
|
||||||
A clear and concise description of what you want to happen.
|
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
|
||||||
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@@ -1,8 +0,0 @@
|
|||||||
## Description
|
|
||||||
|
|
||||||
A clear and concise description of what your pull request is about.
|
|
||||||
|
|
||||||
## Checklist
|
|
||||||
|
|
||||||
- [ ] I updated the documentation accordingly. Or I don't need to.
|
|
||||||
- [ ] I updated the tests accordingly. Or I don't need to.
|
|
||||||
3
.github/renovate.json
vendored
3
.github/renovate.json
vendored
@@ -1,3 +0,0 @@
|
|||||||
{
|
|
||||||
"extends": ["github>ivangabriele/renovate-config"]
|
|
||||||
}
|
|
||||||
42
.github/workflows/test.yml
vendored
42
.github/workflows/test.yml
vendored
@@ -1,42 +0,0 @@
|
|||||||
name: Test
|
|
||||||
|
|
||||||
on: push
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
name: Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Setup Rust
|
|
||||||
uses: actions-rs/toolchain@v1
|
|
||||||
with:
|
|
||||||
toolchain: 1.76.0
|
|
||||||
- name: Install cargo-llvm-cov
|
|
||||||
uses: taiki-e/install-action@cargo-llvm-cov
|
|
||||||
- name: Run tests (with coverage)
|
|
||||||
run: cargo llvm-cov --lcov --output-path ./lcov.info
|
|
||||||
env:
|
|
||||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
|
||||||
- name: Upload tests coverage
|
|
||||||
uses: codecov/codecov-action@v4
|
|
||||||
with:
|
|
||||||
fail_ci_if_error: true
|
|
||||||
files: ./lcov.info
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
|
|
||||||
test_documentation:
|
|
||||||
name: Test Documentation
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Setup Rust
|
|
||||||
uses: actions-rs/toolchain@v1
|
|
||||||
with:
|
|
||||||
toolchain: 1.76.0
|
|
||||||
- name: Run documentation tests
|
|
||||||
run: make test-doc
|
|
||||||
env:
|
|
||||||
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -23,3 +23,4 @@ Cargo.lock
|
|||||||
/cobertura.xml
|
/cobertura.xml
|
||||||
|
|
||||||
.env
|
.env
|
||||||
|
.envrc
|
||||||
|
|||||||
104
CHANGELOG.md
104
CHANGELOG.md
@@ -1,3 +1,107 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
## [1.0.0](https://src.sunbeam.pt/studio/mistralai-client-rs) (2026-03-20)
|
||||||
|
|
||||||
|
Forked from [ivangabriele/mistralai-client-rs](https://github.com/ivangabriele/mistralai-client-rs) v0.14.0 and updated to the latest Mistral AI API.
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* `Model` is now a string-based struct with constructor methods instead of a closed enum
|
||||||
|
* `EmbedModel` is removed — use `Model::mistral_embed()` instead
|
||||||
|
* `Tool::new()` parameters now accept `serde_json::Value` (JSON Schema) instead of limited enum types
|
||||||
|
* `ChatParams.temperature` is now `Option<f32>` instead of `f32`
|
||||||
|
* Stream delta `content` is now `Option<String>`
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* Add all current Mistral models (Large 3, Small 4, Medium 3.1, Magistral, Codestral, Devstral, Pixtral, Voxtral, Ministral)
|
||||||
|
* Add FIM (fill-in-the-middle) completions endpoint
|
||||||
|
* Add Files API (upload, list, get, delete, download URL)
|
||||||
|
* Add Fine-tuning jobs API (create, list, get, cancel, start)
|
||||||
|
* Add Batch jobs API (create, list, get, cancel)
|
||||||
|
* Add OCR endpoint (document text extraction)
|
||||||
|
* Add Audio transcription endpoint
|
||||||
|
* Add Moderations and Classifications endpoints
|
||||||
|
* Add Agent completions endpoint
|
||||||
|
* Add new chat fields: frequency_penalty, presence_penalty, stop, n, min_tokens, parallel_tool_calls, reasoning_effort, json_schema response format
|
||||||
|
* Add embedding fields: output_dimension, output_dtype
|
||||||
|
* Add tool call IDs and Required tool choice variant
|
||||||
|
* Add model get and delete endpoints
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Upstream Changelog (pre-fork)
|
||||||
|
|
||||||
|
## [0.14.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.13.0...v) (2024-08-27)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **constants:** update model constants ([#17](https://github.com/ivangabriele/mistralai-client-rs/issues/17)) ([161b33c](https://github.com/ivangabriele/mistralai-client-rs/commit/161b33c72539a6e982207349942a436df95399b7))
|
||||||
|
## [0.13.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.12.0...v) (2024-08-21)
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* **client:** `v1::model_list::ModelListData` struct has been updated.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **client:** remove the `Content-Type` from the headers of the reqwest builders. ([#14](https://github.com/ivangabriele/mistralai-client-rs/issues/14)) ([9bfbf2e](https://github.com/ivangabriele/mistralai-client-rs/commit/9bfbf2e9dc7b48103ac56923fb8b3ac9a5e2d9cf)), closes [#13](https://github.com/ivangabriele/mistralai-client-rs/issues/13)
|
||||||
|
* **client:** update ModelListData struct following API changes ([2114916](https://github.com/ivangabriele/mistralai-client-rs/commit/2114916941e1ff5aa242290df5f092c0d4954afc))
|
||||||
|
## [0.12.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.11.0...v) (2024-07-24)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* implement the Debug trait for Client ([#11](https://github.com/ivangabriele/mistralai-client-rs/issues/11)) ([3afeec1](https://github.com/ivangabriele/mistralai-client-rs/commit/3afeec1d586022e43c7b10906acec5e65927ba7d))
|
||||||
|
* mark Function trait as Send ([#12](https://github.com/ivangabriele/mistralai-client-rs/issues/12)) ([8e9f7a5](https://github.com/ivangabriele/mistralai-client-rs/commit/8e9f7a53863879b2ad618e9e5707b198e4f3b135))
|
||||||
|
## [0.11.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.10.0...v) (2024-06-22)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **constants:** add OpenMixtral8x22b, MistralTiny & CodestralLatest to Model enum ([ecd0c30](https://github.com/ivangabriele/mistralai-client-rs/commit/ecd0c3028fdcfab32b867eb1eed86182f5f4ab81))
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **chat:** implement Clone trait for ChatParams & ResponseFormat ([0df67b1](https://github.com/ivangabriele/mistralai-client-rs/commit/0df67b1b2571fb04b636ce015a2daabe629ff352))
|
||||||
|
## [0.10.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.9.0...v) (2024-06-07)
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* **chat:** - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`.
|
||||||
|
- `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`.
|
||||||
|
- `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **chat:** add response_format for JSON return values ([85c3611](https://github.com/ivangabriele/mistralai-client-rs/commit/85c3611afbbe8df30dfc7512cc381ed304ce4024))
|
||||||
|
* **chat:** add the 'system' and 'tool' message roles ([#10](https://github.com/ivangabriele/mistralai-client-rs/issues/10)) ([2fc0642](https://github.com/ivangabriele/mistralai-client-rs/commit/2fc0642a5e4c024b15710acaab7735480e8dfe6a))
|
||||||
|
* **chat:** change safe_prompt, temperature & top_p to non-Option types ([cf68a77](https://github.com/ivangabriele/mistralai-client-rs/commit/cf68a773201ebe0e802face52af388711acf0c27))
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **chat:** skip serializing tool_calls if null, to avoid 422 error ([da5fe54](https://github.com/ivangabriele/mistralai-client-rs/commit/da5fe54115ce622379776661a440e2708b24810c))
|
||||||
|
## [0.9.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.8.0...v) (2024-04-13)
|
||||||
|
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* `Model.OpenMistral8x7b` has been renamed to `Model.OpenMixtral8x7b`.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **deps:** update rust crate reqwest to 0.12.0 ([#6](https://github.com/ivangabriele/mistralai-client-rs/issues/6)) ([fccd59c](https://github.com/ivangabriele/mistralai-client-rs/commit/fccd59c0cc783edddec1b404363faabb009eecd6))
|
||||||
|
* fix typo in OpenMixtral8x7b model name ([#8](https://github.com/ivangabriele/mistralai-client-rs/issues/8)) ([6a99eca](https://github.com/ivangabriele/mistralai-client-rs/commit/6a99eca49c0cc8e3764a56f6dfd7762ec44a4c3b))
|
||||||
|
|
||||||
|
## [0.8.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.7.0...v) (2024-03-09)
|
||||||
|
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* Too many to count in this version. Check the README examples.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* add function calling support to client.chat() & client.chat_async() ([74bf8a9](https://github.com/ivangabriele/mistralai-client-rs/commit/74bf8a96ee31f9d54ee3d7404619e803a182918b))
|
||||||
|
|
||||||
## [0.7.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.6.0...v) (2024-03-05)
|
## [0.7.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.6.0...v) (2024-03-05)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
325
CONTRIBUTING.md
325
CONTRIBUTING.md
@@ -1,62 +1,277 @@
|
|||||||
# Contribute
|
# Contributing to mistralai-client-rs
|
||||||
|
|
||||||
|
Thank you for your interest in contributing! We're excited to work with you.
|
||||||
|
|
||||||
|
This document provides guidelines for contributing to the project. Following these guidelines helps maintain code quality and makes the review process smoother for everyone.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
- [Getting Started](#getting-started)
|
|
||||||
- [Requirements](#requirements)
|
|
||||||
- [First setup](#first-setup)
|
|
||||||
- [Optional requirements](#optional-requirements)
|
|
||||||
- [Test](#test)
|
|
||||||
- [Code of Conduct](#code-of-conduct)
|
- [Code of Conduct](#code-of-conduct)
|
||||||
- [Commit Message Format](#commit-message-format)
|
- [Getting Started](#getting-started)
|
||||||
|
- [Development Environment Setup](#development-environment-setup)
|
||||||
---
|
- [How to Contribute](#how-to-contribute)
|
||||||
|
- [Coding Standards](#coding-standards)
|
||||||
## Getting Started
|
- [Testing](#testing)
|
||||||
|
- [Pull Request Process](#pull-request-process)
|
||||||
### Requirements
|
- [Reporting Bugs](#reporting-bugs)
|
||||||
|
- [Suggesting Features](#suggesting-features)
|
||||||
- [Rust](https://www.rust-lang.org/tools/install): v1
|
- [AI Usage Policy](#ai-usage-policy)
|
||||||
|
- [Questions?](#questions)
|
||||||
### First setup
|
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> If you're under **Windows**, you nust run all CLI commands under a Linux shell-like terminal (i.e.: WSL or Git Bash).
|
|
||||||
|
|
||||||
Then run:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
git clone https://github.com/ivangabriele/mistralai-client-rs.git # or your fork
|
|
||||||
cd ./mistralai-client-rs
|
|
||||||
cargo build
|
|
||||||
cp .env.example .env
|
|
||||||
```
|
|
||||||
|
|
||||||
Then edit the `.env` file to set your `MISTRAL_API_KEY`.
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> All tests use either the `open-mistral-7b` or `mistral-embed` models and only consume a few dozen tokens.
|
|
||||||
> So you would have to run them thousands of times to even reach a single dollar of usage.
|
|
||||||
|
|
||||||
### Optional requirements
|
|
||||||
|
|
||||||
- [cargo-llvm-cov](https://github.com/taiki-e/cargo-llvm-cov?tab=readme-ov-file#installation) for `make test-cover`
|
|
||||||
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-watch`.
|
|
||||||
|
|
||||||
### Test
|
|
||||||
|
|
||||||
```sh
|
|
||||||
make test
|
|
||||||
```
|
|
||||||
|
|
||||||
or
|
|
||||||
|
|
||||||
```sh
|
|
||||||
make test-watch
|
|
||||||
```
|
|
||||||
|
|
||||||
## Code of Conduct
|
## Code of Conduct
|
||||||
|
|
||||||
Help us keep this project open and inclusive. Please read and follow our [Code of Conduct](./CODE_OF_CONDUCT.md).
|
This project adheres to the [Contributor Covenant Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to the project maintainers.
|
||||||
|
|
||||||
## Commit Message Format
|
## Getting Started
|
||||||
|
|
||||||
This repository follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
1. **Fork the repository** on Gitea
|
||||||
|
2. **Clone your fork** locally
|
||||||
|
3. **Set up your development environment** (see below)
|
||||||
|
4. **Create a branch** for your changes
|
||||||
|
5. **Make your changes** with clear commit messages
|
||||||
|
6. **Test your changes** thoroughly
|
||||||
|
7. **Submit a pull request**
|
||||||
|
|
||||||
|
## Development Environment Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- **Rust** 2024 edition or later (install via [rustup](https://rustup.rs/))
|
||||||
|
- **Git** for version control
|
||||||
|
|
||||||
|
### Initial Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone your fork
|
||||||
|
git clone git@src.sunbeam.pt:studio/mistralai-client-rs.git
|
||||||
|
cd mistralai-client-rs
|
||||||
|
|
||||||
|
# Build the project
|
||||||
|
cargo build
|
||||||
|
|
||||||
|
# Run tests (requires MISTRAL_API_KEY)
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
### Useful Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check code without building
|
||||||
|
cargo check
|
||||||
|
|
||||||
|
# Run clippy for linting
|
||||||
|
cargo clippy
|
||||||
|
|
||||||
|
# Format code
|
||||||
|
cargo fmt
|
||||||
|
|
||||||
|
# Run tests with output
|
||||||
|
cargo test -- --nocapture
|
||||||
|
|
||||||
|
# Build documentation
|
||||||
|
cargo doc --open
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables with `.envrc`
|
||||||
|
|
||||||
|
This project uses [direnv](https://direnv.net/) for managing environment variables.
|
||||||
|
|
||||||
|
#### Setup
|
||||||
|
|
||||||
|
1. **Install direnv** (if not already installed):
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
brew install direnv
|
||||||
|
|
||||||
|
# Add to your shell profile (~/.zshrc or ~/.bashrc)
|
||||||
|
eval "$(direnv hook zsh)" # or bash
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Create `.envrc` file** in the project root:
|
||||||
|
```bash
|
||||||
|
# The .envrc file is already gitignored for security
|
||||||
|
export MISTRAL_API_KEY=your_api_key_here
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Allow direnv** to load the file:
|
||||||
|
```bash
|
||||||
|
direnv allow .
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to Contribute
|
||||||
|
|
||||||
|
### Types of Contributions
|
||||||
|
|
||||||
|
We welcome many types of contributions:
|
||||||
|
|
||||||
|
- **Bug fixes** - Fix issues and improve stability
|
||||||
|
- **Features** - Implement new functionality (discuss first in an issue)
|
||||||
|
- **Documentation** - Improve or add documentation
|
||||||
|
- **Examples** - Create new examples or demos
|
||||||
|
- **Tests** - Add test coverage
|
||||||
|
- **Performance** - Optimize existing code
|
||||||
|
- **Refactoring** - Improve code quality
|
||||||
|
|
||||||
|
### Before You Start
|
||||||
|
|
||||||
|
For **bug fixes and small improvements**, feel free to open a PR directly.
|
||||||
|
|
||||||
|
For **new features or significant changes**:
|
||||||
|
1. **Open an issue first** to discuss the proposal
|
||||||
|
2. Wait for maintainer feedback before investing significant time
|
||||||
|
3. Reference the issue in your PR
|
||||||
|
|
||||||
|
This helps ensure your work aligns with project direction and avoids duplicate effort.
|
||||||
|
|
||||||
|
## Coding Standards
|
||||||
|
|
||||||
|
### Rust Style
|
||||||
|
|
||||||
|
- Follow the [Rust API Guidelines](https://rust-lang.github.io/api-guidelines/)
|
||||||
|
- Use `cargo fmt` to format code (run before committing)
|
||||||
|
- Address all `cargo clippy` warnings
|
||||||
|
- Use meaningful variable and function names
|
||||||
|
- Add doc comments (`///`) for public APIs
|
||||||
|
|
||||||
|
### Code Organization
|
||||||
|
|
||||||
|
- Keep modules focused and cohesive
|
||||||
|
- Prefer composition over inheritance
|
||||||
|
- Use Rust's type system to enforce invariants
|
||||||
|
- Avoid unnecessary `unsafe` code
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
- Add doc comments for all public types, traits, and functions
|
||||||
|
- Include examples in doc comments when helpful
|
||||||
|
- Keep README.md in sync with current capabilities
|
||||||
|
|
||||||
|
### Commit Messages
|
||||||
|
|
||||||
|
This repository follows the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
||||||
|
|
||||||
|
Write clear, descriptive commit messages:
|
||||||
|
|
||||||
|
```
|
||||||
|
feat: add batch jobs API
|
||||||
|
|
||||||
|
fix: handle empty tool_calls in streaming response
|
||||||
|
|
||||||
|
refactor!: replace Model enum with string-based type
|
||||||
|
|
||||||
|
BREAKING CHANGE: Model is now a struct, not an enum.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests (requires MISTRAL_API_KEY)
|
||||||
|
cargo test
|
||||||
|
|
||||||
|
# Run specific test
|
||||||
|
cargo test test_client_chat
|
||||||
|
|
||||||
|
# Run tests with output
|
||||||
|
cargo test -- --nocapture
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Add unit tests in the same file as the code (in a `mod tests` block)
|
||||||
|
- Add integration tests in `tests/` directory
|
||||||
|
- Test edge cases and error conditions
|
||||||
|
- Keep tests focused and readable
|
||||||
|
- Use descriptive test names
|
||||||
|
|
||||||
|
### Test Coverage
|
||||||
|
|
||||||
|
Tests hit the live Mistral API, so a valid `MISTRAL_API_KEY` is required. The tests use small, cheap models and consume minimal tokens.
|
||||||
|
|
||||||
|
## Pull Request Process
|
||||||
|
|
||||||
|
### Before Submitting
|
||||||
|
|
||||||
|
1. **Update your branch** with latest upstream changes
|
||||||
|
```bash
|
||||||
|
git fetch origin
|
||||||
|
git rebase origin/main
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Run the test suite** and ensure all tests pass
|
||||||
|
```bash
|
||||||
|
cargo test
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run clippy** and fix any warnings
|
||||||
|
```bash
|
||||||
|
cargo clippy
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Format your code**
|
||||||
|
```bash
|
||||||
|
cargo fmt
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Update documentation** if you changed APIs or behavior
|
||||||
|
|
||||||
|
### Submitting Your PR
|
||||||
|
|
||||||
|
1. **Push to your fork**
|
||||||
|
2. **Open a pull request** on Gitea
|
||||||
|
3. **Fill out the PR description** with what changed and why
|
||||||
|
4. **Request review** from maintainers
|
||||||
|
|
||||||
|
### During Review
|
||||||
|
|
||||||
|
- Be responsive to feedback
|
||||||
|
- Make requested changes promptly
|
||||||
|
- Push updates to the same branch
|
||||||
|
- Be patient - maintainers are volunteers with limited time
|
||||||
|
|
||||||
|
## Reporting Bugs
|
||||||
|
|
||||||
|
### Before Reporting
|
||||||
|
|
||||||
|
1. **Check existing issues** to avoid duplicates
|
||||||
|
2. **Verify it's a bug** and not expected behavior
|
||||||
|
3. **Test on the latest version** from main branch
|
||||||
|
|
||||||
|
### Bug Report Template
|
||||||
|
|
||||||
|
When opening a bug report, please include:
|
||||||
|
|
||||||
|
- **Description** - What went wrong?
|
||||||
|
- **Expected behavior** - What should have happened?
|
||||||
|
- **Actual behavior** - What actually happened?
|
||||||
|
- **Steps to reproduce** - Minimal steps to reproduce the issue
|
||||||
|
- **Environment**:
|
||||||
|
- OS version
|
||||||
|
- Rust version (`rustc --version`)
|
||||||
|
- Crate version or commit hash
|
||||||
|
- **Logs/Stack traces** - Error messages or relevant log output
|
||||||
|
|
||||||
|
## Suggesting Features
|
||||||
|
|
||||||
|
We welcome feature suggestions! When suggesting a feature, please include:
|
||||||
|
|
||||||
|
- **Problem statement** - What problem does this solve?
|
||||||
|
- **Proposed solution** - How would this feature work?
|
||||||
|
- **Alternatives considered** - What other approaches did you think about?
|
||||||
|
- **Use cases** - Real-world scenarios where this helps
|
||||||
|
|
||||||
|
## AI Usage Policy
|
||||||
|
|
||||||
|
- AI tools (Copilot, ChatGPT, etc.) are allowed for productivity
|
||||||
|
- You must understand and be accountable for all code you submit
|
||||||
|
- Humans make all architectural decisions, not AI
|
||||||
|
- When in doubt, ask yourself: "Can I maintain and debug this?"
|
||||||
|
|
||||||
|
## Questions?
|
||||||
|
|
||||||
|
Open an issue on the repository for any questions about contributing.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing! Your effort helps make this library better for everyone.
|
||||||
|
|||||||
30
Cargo.toml
30
Cargo.toml
@@ -2,27 +2,31 @@
|
|||||||
name = "mistralai-client"
|
name = "mistralai-client"
|
||||||
description = "Mistral AI API client library for Rust (unofficial)."
|
description = "Mistral AI API client library for Rust (unofficial)."
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
version = "0.7.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.76.0"
|
rust-version = "1.76.0"
|
||||||
|
|
||||||
authors = ["Ivan Gabriele <ivan.gabriele@protonmail.com>"]
|
authors = ["Sunbeam Studios <hello@sunbeam.pt>"]
|
||||||
categories = ["api-bindings"]
|
categories = ["api-bindings"]
|
||||||
homepage = "https://github.com/ivangabriele/mistralai-client-rs#readme"
|
homepage = "https://sunbeam.pt"
|
||||||
keywords = ["mistral", "mistralai", "client", "api", "llm"]
|
keywords = ["mistral", "mistralai", "client", "api", "llm"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
repository = "https://github.com/ivangabriele/mistralai-client-rs"
|
repository = "https://src.sunbeam.pt/studio/mistralai-client-rs"
|
||||||
|
publish = ["sunbeam"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "0.3.30"
|
async-stream = "0.3"
|
||||||
reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] }
|
async-trait = "0.1"
|
||||||
serde = { version = "1.0.197", features = ["derive"] }
|
env_logger = "0.11"
|
||||||
serde_json = "1.0.114"
|
futures = "0.3"
|
||||||
strum = "0.26.1"
|
log = "0.4"
|
||||||
strum_macros = "0.26.1"
|
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] }
|
||||||
thiserror = "1.0.57"
|
serde = { version = "1", features = ["derive"] }
|
||||||
tokio = { version = "1.36.0", features = ["full"] }
|
serde_json = "1"
|
||||||
|
thiserror = "2"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-stream = "0.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
jrest = "0.2.3"
|
jrest = "0.2"
|
||||||
|
|||||||
45
Makefile
45
Makefile
@@ -1,45 +0,0 @@
|
|||||||
SHELL := /bin/bash
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
|
|
||||||
define source_env_if_not_ci
|
|
||||||
@if [ -z "$${CI}" ]; then \
|
|
||||||
if [ -f ./.env ]; then \
|
|
||||||
source ./.env; \
|
|
||||||
else \
|
|
||||||
echo "No .env file found"; \
|
|
||||||
exit 1; \
|
|
||||||
fi \
|
|
||||||
fi
|
|
||||||
endef
|
|
||||||
|
|
||||||
define RELEASE_TEMPLATE
|
|
||||||
conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s
|
|
||||||
git add .
|
|
||||||
git commit -m "docs(changelog): update"
|
|
||||||
git push origin HEAD
|
|
||||||
cargo release $(1) --execute
|
|
||||||
git push origin HEAD --tags
|
|
||||||
endef
|
|
||||||
|
|
||||||
doc:
|
|
||||||
cargo doc
|
|
||||||
open ./target/doc/mistralai_client/index.html
|
|
||||||
|
|
||||||
release-patch:
|
|
||||||
$(call RELEASE_TEMPLATE,patch)
|
|
||||||
|
|
||||||
release-minor:
|
|
||||||
$(call RELEASE_TEMPLATE,minor)
|
|
||||||
|
|
||||||
release-major:
|
|
||||||
$(call RELEASE_TEMPLATE,major)
|
|
||||||
|
|
||||||
test:
|
|
||||||
@$(source_env_if_not_ci) && cargo test --no-fail-fast
|
|
||||||
test-cover:
|
|
||||||
@$(source_env_if_not_ci) && cargo llvm-cov
|
|
||||||
test-doc:
|
|
||||||
@$(source_env_if_not_ci) && cargo test --doc --no-fail-fast
|
|
||||||
test-watch:
|
|
||||||
@source ./.env && cargo watch -x "test -- --nocapture"
|
|
||||||
324
README.md
324
README.md
@@ -1,214 +1,207 @@
|
|||||||
# Mistral AI Rust Client
|
# Mistral AI Rust Client
|
||||||
|
|
||||||
[](https://crates.io/crates/mistralai-client)
|
Rust client for the [Mistral AI API](https://docs.mistral.ai/api/).
|
||||||
[](https://docs.rs/mistralai-client/latest/mistralai-client)
|
|
||||||
[](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++)
|
|
||||||
[](https://app.codecov.io/github/ivangabriele/mistralai-client-rs)
|
|
||||||
|
|
||||||
Rust client for the Mistral AI API.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
- [Supported APIs](#supported-apis)
|
|
||||||
- [Installation](#installation)
|
|
||||||
- [Mistral API Key](#mistral-api-key)
|
|
||||||
- [As an environment variable](#as-an-environment-variable)
|
|
||||||
- [As a client argument](#as-a-client-argument)
|
|
||||||
- [Usage](#usage)
|
|
||||||
- [Chat without streaming](#chat-without-streaming)
|
|
||||||
- [Chat without streaming (async)](#chat-without-streaming-async)
|
|
||||||
- [Chat with streaming (async)](#chat-with-streaming-async)
|
|
||||||
- [Embeddings](#embeddings)
|
|
||||||
- [Embeddings (async)](#embeddings-async)
|
|
||||||
- [List models](#list-models)
|
|
||||||
- [List models (async)](#list-models-async)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Supported APIs
|
## Supported APIs
|
||||||
|
|
||||||
- [x] Chat without streaming
|
- [x] Chat completions (sync, async, streaming)
|
||||||
- [x] Chat without streaming (async)
|
- [x] Function calling / tool use
|
||||||
- [x] Chat with streaming
|
- [x] FIM (fill-in-the-middle) code completions
|
||||||
- [x] Embedding
|
- [x] Embeddings (sync, async)
|
||||||
- [x] Embedding (async)
|
- [x] Models (list, get, delete)
|
||||||
- [x] List models
|
- [x] Files (upload, list, get, delete, download URL)
|
||||||
- [x] List models (async)
|
- [x] Fine-tuning jobs (create, list, get, cancel, start)
|
||||||
- [ ] Function Calling
|
- [x] Batch jobs (create, list, get, cancel)
|
||||||
- [ ] Function Calling (async)
|
- [x] OCR (document text extraction)
|
||||||
|
- [x] Audio transcription
|
||||||
|
- [x] Moderations & classifications
|
||||||
|
- [x] Agent completions
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install the library in your project using:
|
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
cargo add mistralai-client
|
cargo add mistralai-client
|
||||||
```
|
```
|
||||||
|
|
||||||
### Mistral API Key
|
### API Key
|
||||||
|
|
||||||
You can get your Mistral API Key there: <https://docs.mistral.ai/#api-access>.
|
Get your key at <https://console.mistral.ai/api-keys>.
|
||||||
|
|
||||||
#### As an environment variable
|
|
||||||
|
|
||||||
Just set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
|
|
||||||
#### As a client argument
|
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use mistralai_client::v1::client::Client;
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
fn main() {
|
// From MISTRAL_API_KEY environment variable:
|
||||||
let api_key = "your_api_key";
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let client = Client::new(Some(api_key), None, None, None).unwrap();
|
// Or pass directly:
|
||||||
}
|
let client = Client::new(Some("your_api_key".to_string()), None, None, None).unwrap();
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Chat without streaming
|
### Chat
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message("What is the Eiffel Tower?")];
|
||||||
role: ChatMessageRole::user,
|
let options = ChatParams {
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
temperature: Some(0.7),
|
||||||
}];
|
|
||||||
let options = ChatCompletionRequestOptions {
|
|
||||||
temperature: Some(0.0),
|
|
||||||
random_seed: Some(42),
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = client.chat(model, messages, Some(options)).unwrap();
|
let result = client.chat(model, messages, Some(options)).unwrap();
|
||||||
println!("Assistant: {}", result.choices[0].message.content);
|
println!("{}", result.choices[0].message.content);
|
||||||
// => "Assistant: Tower. [...]"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat without streaming (async)
|
### Chat (async)
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message("What is the Eiffel Tower?")];
|
||||||
role: ChatMessageRole::user,
|
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
|
||||||
}];
|
|
||||||
let options = ChatCompletionRequestOptions {
|
|
||||||
temperature: Some(0.0),
|
|
||||||
random_seed: Some(42),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = client.chat_async(model, messages, Some(options)).await.unwrap();
|
let result = client.chat_async(model, messages, None).await.unwrap();
|
||||||
println!("Assistant: {}", result.choices[0].message.content);
|
println!("{}", result.choices[0].message.content);
|
||||||
// => "Assistant: Tower. [...]"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Chat with streaming (async)
|
### Chat with streaming
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
|
use std::io::{self, Write};
|
||||||
|
|
||||||
[#tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message("Tell me a short story.")];
|
||||||
role: ChatMessageRole::user,
|
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
let stream = client.chat_stream(model, messages, None).await.unwrap();
|
||||||
}];
|
stream
|
||||||
let options = ChatCompletionParams {
|
.for_each(|chunk_result| async {
|
||||||
temperature: Some(0.0),
|
match chunk_result {
|
||||||
random_seed: Some(42),
|
Ok(chunks) => chunks.iter().for_each(|chunk| {
|
||||||
|
if let Some(content) = &chunk.choices[0].delta.content {
|
||||||
|
print!("{}", content);
|
||||||
|
io::stdout().flush().unwrap();
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
Err(error) => eprintln!("Error: {:?}", error),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Function calling
|
||||||
|
|
||||||
|
```rs
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
tool::{Function, Tool, ToolChoice},
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::any::Any;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct GetWeatherArgs { city: String }
|
||||||
|
|
||||||
|
struct GetWeatherFunction;
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Function for GetWeatherFunction {
|
||||||
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||||
|
let args: GetWeatherArgs = serde_json::from_str(&arguments).unwrap();
|
||||||
|
Box::new(format!("20°C in {}", args.city))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let tools = vec![Tool::new(
|
||||||
|
"get_weather".to_string(),
|
||||||
|
"Get the weather in a city.".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": { "type": "string", "description": "City name" }
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)];
|
||||||
|
|
||||||
|
let mut client = Client::new(None, None, None, None).unwrap();
|
||||||
|
client.register_function("get_weather".to_string(), Box::new(GetWeatherFunction));
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::new_user_message("What's the weather in Paris?")];
|
||||||
|
let options = ChatParams {
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
tools: Some(tools),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let stream_result = client.chat_stream(model, messages, Some(options)).await;
|
client.chat(Model::mistral_small_latest(), messages, Some(options)).unwrap();
|
||||||
let mut stream = stream_result.expect("Failed to create stream.");
|
let result = client.get_last_function_call_result().unwrap().downcast::<String>().unwrap();
|
||||||
while let Some(chunk_result) = stream.next().await {
|
println!("{}", result);
|
||||||
match chunk_result {
|
}
|
||||||
Ok(chunk) => {
|
```
|
||||||
println!("Assistant (message chunk): {}", chunk.choices[0].delta.content);
|
|
||||||
}
|
### FIM (code completion)
|
||||||
Err(e) => eprintln!("Error processing chunk: {:?}", e),
|
|
||||||
}
|
```rs
|
||||||
}
|
use mistralai_client::v1::{client::Client, constants::Model, fim::FimParams};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let options = FimParams {
|
||||||
|
suffix: Some("\n return result".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = client.fim(Model::codestral_latest(), "def fibonacci(".to_string(), Some(options)).unwrap();
|
||||||
|
println!("{}", result.choices[0].message.content);
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Embeddings
|
### Embeddings
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let input = vec!["Hello world".to_string(), "Goodbye world".to_string()];
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let response = client.embeddings(Model::mistral_embed(), input, None).unwrap();
|
||||||
.iter()
|
println!("Dimensions: {}", response.data[0].embedding.len());
|
||||||
.map(|s| s.to_string())
|
|
||||||
.collect();
|
|
||||||
let options = None;
|
|
||||||
|
|
||||||
let response = client.embeddings(model, input, options).unwrap();
|
|
||||||
println!("Embeddings: {:?}", response.data);
|
|
||||||
// => "Embeddings: [{...}, {...}]"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Embeddings (async)
|
|
||||||
|
|
||||||
```rs
|
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
|
||||||
.iter()
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
.collect();
|
|
||||||
let options = None;
|
|
||||||
|
|
||||||
let response = client.embeddings_async(model, input, options).await.unwrap();
|
|
||||||
println!("Embeddings: {:?}", response.data);
|
|
||||||
// => "Embeddings: [{...}, {...}]"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -218,27 +211,62 @@ async fn main() {
|
|||||||
use mistralai_client::v1::client::Client;
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let result = client.list_models().unwrap();
|
let models = client.list_models().unwrap();
|
||||||
println!("First Model ID: {:?}", result.data[0].id);
|
for model in &models.data {
|
||||||
// => "First Model ID: open-mistral-7b"
|
println!("{}", model.id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### List models (async)
|
### OCR
|
||||||
|
|
||||||
```rs
|
```rs
|
||||||
use mistralai_client::v1::client::Client;
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
ocr::{OcrDocument, OcrRequest},
|
||||||
|
};
|
||||||
|
|
||||||
#[tokio::main]
|
fn main() {
|
||||||
async fn main() {
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
|
||||||
let client = Client::new(None, None, None, None).await.unwrap();
|
|
||||||
|
|
||||||
let result = client.list_models_async().unwrap();
|
let request = OcrRequest {
|
||||||
println!("First Model ID: {:?}", result.data[0].id);
|
model: Model::mistral_ocr_latest(),
|
||||||
// => "First Model ID: open-mistral-7b"
|
document: OcrDocument::from_url("https://example.com/document.pdf"),
|
||||||
|
pages: Some(vec![0]),
|
||||||
|
table_format: None,
|
||||||
|
include_image_base64: None,
|
||||||
|
image_limit: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.ocr(&request).unwrap();
|
||||||
|
println!("{}", response.pages[0].markdown);
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
Use `Model::new("any-model-id")` for any model, or use the built-in constructors:
|
||||||
|
|
||||||
|
| Constructor | Model ID |
|
||||||
|
|---|---|
|
||||||
|
| `Model::mistral_large_latest()` | `mistral-large-latest` |
|
||||||
|
| `Model::mistral_medium_latest()` | `mistral-medium-latest` |
|
||||||
|
| `Model::mistral_small_latest()` | `mistral-small-latest` |
|
||||||
|
| `Model::mistral_small_4()` | `mistral-small-4-0-26-03` |
|
||||||
|
| `Model::codestral_latest()` | `codestral-latest` |
|
||||||
|
| `Model::magistral_medium_latest()` | `magistral-medium-latest` |
|
||||||
|
| `Model::magistral_small_latest()` | `magistral-small-latest` |
|
||||||
|
| `Model::mistral_embed()` | `mistral-embed` |
|
||||||
|
| `Model::mistral_ocr_latest()` | `mistral-ocr-latest` |
|
||||||
|
| `Model::mistral_moderation_latest()` | `mistral-moderation-26-03` |
|
||||||
|
| `Model::pixtral_large()` | `pixtral-large-2411` |
|
||||||
|
| `Model::voxtral_mini_transcribe()` | `voxtral-mini-transcribe-2-26-02` |
|
||||||
|
|
||||||
|
See `constants.rs` for the full list.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Apache-2.0
|
||||||
|
|||||||
24
examples/chat.rs
Normal file
24
examples/chat.rs
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"Just guess the next word: \"Eiffel ...\"?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
println!("Assistant: {}", result.choices[0].message.content);
|
||||||
|
// => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
|
||||||
|
}
|
||||||
30
examples/chat_async.rs
Normal file
30
examples/chat_async.rs
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"Just guess the next word: \"Eiffel ...\"?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = client
|
||||||
|
.chat_async(model, messages, Some(options))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
println!(
|
||||||
|
"{:?}: {}",
|
||||||
|
result.choices[0].message.role, result.choices[0].message.content
|
||||||
|
);
|
||||||
|
}
|
||||||
73
examples/chat_with_function_calling.rs
Normal file
73
examples/chat_with_function_calling.rs
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
tool::{Function, Tool, ToolChoice},
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::any::Any;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct GetCityTemperatureArguments {
|
||||||
|
city: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GetCityTemperatureFunction;
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Function for GetCityTemperatureFunction {
|
||||||
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||||
|
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||||
|
|
||||||
|
let temperature = match city.as_str() {
|
||||||
|
"Paris" => "20°C",
|
||||||
|
_ => "Unknown city",
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::new(temperature.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let tools = vec![Tool::new(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
"Get the current temperature in a city.".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)];
|
||||||
|
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let mut client = Client::new(None, None, None, None).unwrap();
|
||||||
|
client.register_function(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
Box::new(GetCityTemperatureFunction),
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"What's the temperature in Paris?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
tools: Some(tools),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
let temperature = client
|
||||||
|
.get_last_function_call_result()
|
||||||
|
.unwrap()
|
||||||
|
.downcast::<String>()
|
||||||
|
.unwrap();
|
||||||
|
println!("The temperature in Paris is: {}.", temperature);
|
||||||
|
// => "The temperature in Paris is: 20°C."
|
||||||
|
}
|
||||||
77
examples/chat_with_function_calling_async.rs
Normal file
77
examples/chat_with_function_calling_async.rs
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
tool::{Function, Tool, ToolChoice},
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::any::Any;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct GetCityTemperatureArguments {
|
||||||
|
city: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GetCityTemperatureFunction;
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Function for GetCityTemperatureFunction {
|
||||||
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||||
|
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||||
|
|
||||||
|
let temperature = match city.as_str() {
|
||||||
|
"Paris" => "20°C",
|
||||||
|
_ => "Unknown city",
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::new(temperature.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let tools = vec![Tool::new(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
"Get the current temperature in a city.".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)];
|
||||||
|
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let mut client = Client::new(None, None, None, None).unwrap();
|
||||||
|
client.register_function(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
Box::new(GetCityTemperatureFunction),
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"What's the temperature in Paris?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
tools: Some(tools),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
client
|
||||||
|
.chat_async(model, messages, Some(options))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let temperature = client
|
||||||
|
.get_last_function_call_result()
|
||||||
|
.unwrap()
|
||||||
|
.downcast::<String>()
|
||||||
|
.unwrap();
|
||||||
|
println!("The temperature in Paris is: {}.", temperature);
|
||||||
|
// => "The temperature in Paris is: 20°C."
|
||||||
|
}
|
||||||
42
examples/chat_with_streaming.rs
Normal file
42
examples/chat_with_streaming.rs
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
use futures::stream::StreamExt;
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
use std::io::{self, Write};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream_result = client
|
||||||
|
.chat_stream(model, messages, Some(options))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
stream_result
|
||||||
|
.for_each(|chunk_result| async {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunks) => chunks.iter().for_each(|chunk| {
|
||||||
|
if let Some(content) = &chunk.choices[0].delta.content {
|
||||||
|
print!("{}", content);
|
||||||
|
io::stdout().flush().unwrap();
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
Err(error) => {
|
||||||
|
eprintln!("Error processing chunk: {:?}", error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
println!();
|
||||||
|
}
|
||||||
16
examples/embeddings.rs
Normal file
16
examples/embeddings.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_embed();
|
||||||
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
let options = None;
|
||||||
|
|
||||||
|
let response = client.embeddings(model, input, options).unwrap();
|
||||||
|
println!("First Embedding: {:?}", response.data[0]);
|
||||||
|
}
|
||||||
20
examples/embeddings_async.rs
Normal file
20
examples/embeddings_async.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_embed();
|
||||||
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
let options = None;
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.embeddings_async(model, input, options)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
println!("First Embedding: {:?}", response.data[0]);
|
||||||
|
}
|
||||||
21
examples/fim.rs
Normal file
21
examples/fim.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
fim::FimParams,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::codestral_latest();
|
||||||
|
let prompt = "def fibonacci(n):".to_string();
|
||||||
|
let options = FimParams {
|
||||||
|
suffix: Some("\n return result".to_string()),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.fim(model, prompt, Some(options)).unwrap();
|
||||||
|
println!("Completion: {}", response.choices[0].message.content);
|
||||||
|
}
|
||||||
10
examples/list_models.rs
Normal file
10
examples/list_models.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let result = client.list_models().unwrap();
|
||||||
|
println!("First Model ID: {:?}", result.data[0].id);
|
||||||
|
// => "First Model ID: open-mistral-7b"
|
||||||
|
}
|
||||||
11
examples/list_models_async.rs
Normal file
11
examples/list_models_async.rs
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let result = client.list_models_async().await.unwrap();
|
||||||
|
println!("First Model ID: {:?}", result.data[0].id);
|
||||||
|
// => "First Model ID: open-mistral-7b"
|
||||||
|
}
|
||||||
25
examples/ocr.rs
Normal file
25
examples/ocr.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
ocr::{OcrDocument, OcrRequest},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let request = OcrRequest {
|
||||||
|
model: Model::mistral_ocr_latest(),
|
||||||
|
document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"),
|
||||||
|
pages: Some(vec![0]),
|
||||||
|
table_format: None,
|
||||||
|
include_image_base64: None,
|
||||||
|
image_limit: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.ocr(&request).unwrap();
|
||||||
|
for page in &response.pages {
|
||||||
|
println!("--- Page {} ---", page.index);
|
||||||
|
println!("{}", &page.markdown[..200.min(page.markdown.len())]);
|
||||||
|
}
|
||||||
|
}
|
||||||
27
justfile
Normal file
27
justfile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
check:
|
||||||
|
cargo check --all-targets
|
||||||
|
|
||||||
|
doc:
|
||||||
|
cargo doc --open
|
||||||
|
|
||||||
|
fmt:
|
||||||
|
cargo fmt
|
||||||
|
|
||||||
|
lint:
|
||||||
|
cargo clippy --all-targets
|
||||||
|
|
||||||
|
publish:
|
||||||
|
cargo publish --registry sunbeam
|
||||||
|
|
||||||
|
test:
|
||||||
|
cargo test --no-fail-fast
|
||||||
|
|
||||||
|
test-cover:
|
||||||
|
cargo llvm-cov
|
||||||
|
|
||||||
|
test-examples:
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
for example in $(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do
|
||||||
|
echo "Running $example"
|
||||||
|
cargo run --example "$example"
|
||||||
|
done
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
//! This crate provides a easy bindings and types for MistralAI's API.
|
//! Rust client for the [Mistral AI API](https://docs.mistral.ai/api/).
|
||||||
|
//!
|
||||||
|
//! Supports chat completions, embeddings, FIM, files, fine-tuning, batch jobs,
|
||||||
|
//! OCR, audio transcription, moderations, classifications, and agent completions.
|
||||||
|
|
||||||
/// The v1 module contains the types and methods for the v1 API endpoints.
|
/// The v1 module contains types and methods for all `/v1/` API endpoints.
|
||||||
pub mod v1;
|
pub mod v1;
|
||||||
|
|||||||
283
src/v1/agents.rs
Normal file
283
src/v1/agents.rs
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{chat, common, constants, tool};
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Agent Completions (existing — POST /v1/agents/completions)
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AgentCompletionParams {
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
pub response_format: Option<chat::ResponseFormat>,
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
}
|
||||||
|
impl Default for AgentCompletionParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
random_seed: None,
|
||||||
|
stop: None,
|
||||||
|
response_format: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct AgentCompletionRequest {
|
||||||
|
pub agent_id: String,
|
||||||
|
pub messages: Vec<chat::ChatMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<chat::ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
}
|
||||||
|
impl AgentCompletionRequest {
|
||||||
|
pub fn new(
|
||||||
|
agent_id: String,
|
||||||
|
messages: Vec<chat::ChatMessage>,
|
||||||
|
stream: bool,
|
||||||
|
options: Option<AgentCompletionParams>,
|
||||||
|
) -> Self {
|
||||||
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
agent_id,
|
||||||
|
messages,
|
||||||
|
stream,
|
||||||
|
max_tokens: opts.max_tokens,
|
||||||
|
min_tokens: opts.min_tokens,
|
||||||
|
temperature: opts.temperature,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
random_seed: opts.random_seed,
|
||||||
|
stop: opts.stop,
|
||||||
|
response_format: opts.response_format,
|
||||||
|
tools: opts.tools,
|
||||||
|
tool_choice: opts.tool_choice,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Agent completion response (same shape as chat completions)
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AgentCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<chat::ChatResponseChoice>,
|
||||||
|
pub usage: common::ResponseUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Agents API — CRUD (Beta)
|
||||||
|
// POST/GET/PATCH/DELETE /v1/agents
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Tool types for agents
|
||||||
|
|
||||||
|
/// A function tool definition for an agent.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionTool {
|
||||||
|
pub function: tool::ToolFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool types available to agents.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum AgentTool {
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
Function(FunctionTool),
|
||||||
|
#[serde(rename = "web_search")]
|
||||||
|
WebSearch {},
|
||||||
|
#[serde(rename = "web_search_premium")]
|
||||||
|
WebSearchPremium {},
|
||||||
|
#[serde(rename = "code_interpreter")]
|
||||||
|
CodeInterpreter {},
|
||||||
|
#[serde(rename = "image_generation")]
|
||||||
|
ImageGeneration {},
|
||||||
|
#[serde(rename = "document_library")]
|
||||||
|
DocumentLibrary {},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentTool {
|
||||||
|
/// Create a function tool from name, description, and JSON Schema parameters.
|
||||||
|
pub fn function(name: String, description: String, parameters: serde_json::Value) -> Self {
|
||||||
|
Self::Function(FunctionTool {
|
||||||
|
function: tool::ToolFunction {
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
parameters,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn web_search() -> Self {
|
||||||
|
Self::WebSearch {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn code_interpreter() -> Self {
|
||||||
|
Self::CodeInterpreter {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn image_generation() -> Self {
|
||||||
|
Self::ImageGeneration {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn document_library() -> Self {
|
||||||
|
Self::DocumentLibrary {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Completion args (subset of chat params allowed for agents)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||||
|
pub struct CompletionArgs {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<chat::ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prediction: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Create agent request
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct CreateAgentRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<AgentTool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoffs: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Update agent request
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||||
|
pub struct UpdateAgentRequest {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<AgentTool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoffs: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Agent response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct Agent {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub name: String,
|
||||||
|
pub model: String,
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub version: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub versions: Vec<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tools: Vec<AgentTool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoffs: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub deployment_chat: bool,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub source: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub version_message: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub guardrails: Vec<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List agents response. The API returns a raw JSON array of agents.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub struct AgentListResponse {
|
||||||
|
pub data: Vec<Agent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete agent response. The API returns 204 No Content on success.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct AgentDeleteResponse {
|
||||||
|
pub deleted: bool,
|
||||||
|
}
|
||||||
78
src/v1/audio.rs
Normal file
78
src/v1/audio.rs
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request (multipart form, but we define the params struct)
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AudioTranscriptionParams {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub diarize: Option<bool>,
|
||||||
|
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
|
||||||
|
}
|
||||||
|
impl Default for AudioTranscriptionParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
model: constants::Model::voxtral_mini_transcribe(),
|
||||||
|
language: None,
|
||||||
|
temperature: None,
|
||||||
|
diarize: None,
|
||||||
|
timestamp_granularities: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum TimestampGranularity {
|
||||||
|
Segment,
|
||||||
|
Word,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AudioTranscriptionResponse {
|
||||||
|
pub text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub language: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub segments: Option<Vec<TranscriptionSegment>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub words: Option<Vec<TranscriptionWord>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<AudioUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TranscriptionSegment {
|
||||||
|
pub id: u32,
|
||||||
|
pub start: f32,
|
||||||
|
pub end: f32,
|
||||||
|
pub text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub speaker: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TranscriptionWord {
|
||||||
|
pub word: String,
|
||||||
|
pub start: f32,
|
||||||
|
pub end: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AudioUsage {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_audio_seconds: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub total_tokens: Option<u32>,
|
||||||
|
}
|
||||||
53
src/v1/batch.rs
Normal file
53
src/v1/batch.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct BatchJobRequest {
|
||||||
|
pub input_files: Vec<String>,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub endpoint: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct BatchJobResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub endpoint: String,
|
||||||
|
pub input_files: Vec<String>,
|
||||||
|
pub status: String,
|
||||||
|
pub created_at: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_file: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error_file: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub total_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub succeeded_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub failed_requests: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct BatchJobListResponse {
|
||||||
|
pub data: Vec<BatchJobResponse>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
354
src/v1/chat.rs
Normal file
354
src/v1/chat.rs
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{common, constants, tool};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Content parts (multimodal)
|
||||||
|
|
||||||
|
/// A single part of a multimodal message.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ContentPart {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "image_url")]
|
||||||
|
ImageUrl {
|
||||||
|
image_url: ImageUrl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ImageUrl {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message content: either a plain text string or multimodal content parts.
|
||||||
|
///
|
||||||
|
/// Serializes as a JSON string for text, or a JSON array for parts.
|
||||||
|
/// All existing `new_*_message()` constructors produce `Text` variants,
|
||||||
|
/// so existing code continues to work unchanged.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ChatMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Parts(Vec<ContentPart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatMessageContent {
|
||||||
|
/// Extract the text content. For multimodal messages, concatenates all text parts.
|
||||||
|
pub fn text(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Text(s) => s.clone(),
|
||||||
|
Self::Parts(parts) => parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|p| match p {
|
||||||
|
ContentPart::Text { text } => Some(text.as_str()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the content as a string slice if it is a plain text message.
|
||||||
|
pub fn as_text(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
Self::Text(s) => Some(s),
|
||||||
|
Self::Parts(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this is a multimodal message with image parts.
|
||||||
|
pub fn has_images(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::Text(_) => false,
|
||||||
|
Self::Parts(parts) => parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for ChatMessageContent {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}", self.text())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for ChatMessageContent {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self::Text(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ChatMessageContent {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self::Text(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Definitions
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: ChatMessageRole,
|
||||||
|
pub content: ChatMessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
|
/// Tool call ID, required when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
/// Function name, used when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
}
|
||||||
|
impl ChatMessage {
|
||||||
|
pub fn new_system_message(content: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::System,
|
||||||
|
content: ChatMessageContent::Text(content.to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::Assistant,
|
||||||
|
content: ChatMessageContent::Text(content.to_string()),
|
||||||
|
tool_calls,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_user_message(content: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::User,
|
||||||
|
content: ChatMessageContent::Text(content.to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a user message with mixed text and image content.
|
||||||
|
pub fn new_user_message_with_images(parts: Vec<ContentPart>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::User,
|
||||||
|
content: ChatMessageContent::Parts(parts),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::Tool,
|
||||||
|
content: ChatMessageContent::Text(content.to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(tool_call_id.to_string()),
|
||||||
|
name: name.map(|n| n.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See the [chat completions API](https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post).
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub enum ChatMessageRole {
|
||||||
|
#[serde(rename = "system")]
|
||||||
|
System,
|
||||||
|
#[serde(rename = "assistant")]
|
||||||
|
Assistant,
|
||||||
|
#[serde(rename = "user")]
|
||||||
|
User,
|
||||||
|
#[serde(rename = "tool")]
|
||||||
|
Tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The format that the model must output.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ResponseFormat {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
impl ResponseFormat {
|
||||||
|
pub fn text() -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "text".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn json_object() -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "json_object".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn json_schema(schema: serde_json::Value) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "json_schema".to_string(),
|
||||||
|
json_schema: Some(schema),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
/// The parameters for the chat request.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ChatParams {
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
pub safe_prompt: bool,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
pub n: Option<u32>,
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
/// For reasoning models (Magistral). "high" or "none".
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
|
}
|
||||||
|
impl Default for ChatParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
|
random_seed: None,
|
||||||
|
safe_prompt: false,
|
||||||
|
response_format: None,
|
||||||
|
temperature: None,
|
||||||
|
tool_choice: None,
|
||||||
|
tools: None,
|
||||||
|
top_p: None,
|
||||||
|
stop: None,
|
||||||
|
n: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
parallel_tool_calls: None,
|
||||||
|
reasoning_effort: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub safe_prompt: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
|
}
|
||||||
|
impl ChatRequest {
|
||||||
|
pub fn new(
|
||||||
|
model: constants::Model,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
stream: bool,
|
||||||
|
options: Option<ChatParams>,
|
||||||
|
) -> Self {
|
||||||
|
let opts = options.unwrap_or_default();
|
||||||
|
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
||||||
|
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
stream,
|
||||||
|
max_tokens: opts.max_tokens,
|
||||||
|
min_tokens: opts.min_tokens,
|
||||||
|
random_seed: opts.random_seed,
|
||||||
|
safe_prompt,
|
||||||
|
temperature: opts.temperature,
|
||||||
|
tool_choice: opts.tool_choice,
|
||||||
|
tools: opts.tools,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
response_format: opts.response_format,
|
||||||
|
stop: opts.stop,
|
||||||
|
n: opts.n,
|
||||||
|
frequency_penalty: opts.frequency_penalty,
|
||||||
|
presence_penalty: opts.presence_penalty,
|
||||||
|
parallel_tool_calls: opts.parallel_tool_calls,
|
||||||
|
reasoning_effort: opts.reasoning_effort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
/// Unix timestamp (in seconds).
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<ChatResponseChoice>,
|
||||||
|
pub usage: common::ResponseUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatResponseChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: ChatMessage,
|
||||||
|
pub finish_reason: ChatResponseChoiceFinishReason,
|
||||||
|
/// Reasoning content returned by Magistral models.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub enum ChatResponseChoiceFinishReason {
|
||||||
|
#[serde(rename = "stop")]
|
||||||
|
Stop,
|
||||||
|
#[serde(rename = "length")]
|
||||||
|
Length,
|
||||||
|
#[serde(rename = "tool_calls")]
|
||||||
|
ToolCalls,
|
||||||
|
#[serde(rename = "model_length")]
|
||||||
|
ModelLength,
|
||||||
|
#[serde(rename = "error")]
|
||||||
|
Error,
|
||||||
|
}
|
||||||
@@ -1,149 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::v1::{common, constants};
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
// Definitions
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ChatMessage {
|
|
||||||
pub role: ChatMessageRole,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, strum_macros::Display, Eq, PartialEq, Deserialize, Serialize)]
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
pub enum ChatMessageRole {
|
|
||||||
assistant,
|
|
||||||
user,
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
// Request
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ChatCompletionParams {
|
|
||||||
pub tools: Option<String>,
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
pub random_seed: Option<u32>,
|
|
||||||
pub safe_prompt: Option<bool>,
|
|
||||||
}
|
|
||||||
impl Default for ChatCompletionParams {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
tools: None,
|
|
||||||
temperature: None,
|
|
||||||
max_tokens: None,
|
|
||||||
top_p: None,
|
|
||||||
random_seed: None,
|
|
||||||
safe_prompt: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
pub model: constants::Model,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tools: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub random_seed: Option<u32>,
|
|
||||||
pub stream: bool,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub safe_prompt: Option<bool>,
|
|
||||||
// TODO Check this prop (seen in official Python client but not in API doc).
|
|
||||||
// pub tool_choice: Option<String>,
|
|
||||||
// TODO Check this prop (seen in official Python client but not in API doc).
|
|
||||||
// pub response_format: Option<String>,
|
|
||||||
}
|
|
||||||
impl ChatCompletionRequest {
|
|
||||||
pub fn new(
|
|
||||||
model: constants::Model,
|
|
||||||
messages: Vec<ChatMessage>,
|
|
||||||
stream: bool,
|
|
||||||
options: Option<ChatCompletionParams>,
|
|
||||||
) -> Self {
|
|
||||||
let ChatCompletionParams {
|
|
||||||
tools,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
top_p,
|
|
||||||
random_seed,
|
|
||||||
safe_prompt,
|
|
||||||
} = options.unwrap_or_default();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
messages,
|
|
||||||
model,
|
|
||||||
tools,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
top_p,
|
|
||||||
random_seed,
|
|
||||||
stream,
|
|
||||||
safe_prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
// Response
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
/// Unix timestamp (in seconds).
|
|
||||||
pub created: u32,
|
|
||||||
pub model: constants::Model,
|
|
||||||
pub choices: Vec<ChatCompletionResponseChoice>,
|
|
||||||
pub usage: common::ResponseUsage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionResponseChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub message: ChatMessage,
|
|
||||||
pub finish_reason: String,
|
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
// Stream
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct ChatCompletionStreamChunk {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
/// Unix timestamp (in seconds).
|
|
||||||
pub created: u32,
|
|
||||||
pub model: constants::Model,
|
|
||||||
pub choices: Vec<ChatCompletionStreamChunkChoice>,
|
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub usage: ???,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct ChatCompletionStreamChunkChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ChatCompletionStreamChunkChoiceDelta,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct ChatCompletionStreamChunkChoiceDelta {
|
|
||||||
pub role: Option<ChatMessageRole>,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
56
src/v1/chat_stream.rs
Normal file
56
src/v1/chat_stream.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::from_str;
|
||||||
|
|
||||||
|
use crate::v1::{chat, common, constants, error, tool};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatStreamChunk {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
/// Unix timestamp (in seconds).
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<ChatStreamChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<common::ResponseUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatStreamChunkChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ChatStreamChunkChoiceDelta,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatStreamChunkChoiceDelta {
|
||||||
|
pub role: Option<chat::ChatMessageRole>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts serialized chunks from a stream message.
|
||||||
|
pub fn get_chunk_from_stream_message_line(
|
||||||
|
line: &str,
|
||||||
|
) -> Result<Option<Vec<ChatStreamChunk>>, error::ApiError> {
|
||||||
|
if line.trim() == "data: [DONE]" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk_as_json = line.trim_start_matches("data: ").trim();
|
||||||
|
if chunk_as_json.is_empty() {
|
||||||
|
return Ok(Some(vec![]));
|
||||||
|
}
|
||||||
|
|
||||||
|
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
||||||
|
Ok(chunk) => Ok(Some(vec![chunk])),
|
||||||
|
Err(e) => Err(error::ApiError {
|
||||||
|
message: e.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
1736
src/v1/client.rs
1736
src/v1/client.rs
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ResponseUsage {
|
pub struct ResponseUsage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,131 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
/// A Mistral AI model identifier.
|
||||||
pub enum Model {
|
///
|
||||||
#[serde(rename = "open-mistral-7b")]
|
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
|
||||||
OpenMistral7b,
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
#[serde(rename = "open-mistral-8x7b")]
|
#[serde(transparent)]
|
||||||
OpenMistral8x7b,
|
pub struct Model(pub String);
|
||||||
#[serde(rename = "mistral-small-latest")]
|
|
||||||
MistralSmallLatest,
|
impl Model {
|
||||||
#[serde(rename = "mistral-medium-latest")]
|
pub fn new(id: impl Into<String>) -> Self {
|
||||||
MistralMediumLatest,
|
Self(id.into())
|
||||||
#[serde(rename = "mistral-large-latest")]
|
}
|
||||||
MistralLargeLatest,
|
|
||||||
|
// Flagship / Premier
|
||||||
|
pub fn mistral_large_latest() -> Self {
|
||||||
|
Self::new("mistral-large-latest")
|
||||||
|
}
|
||||||
|
pub fn mistral_large_3() -> Self {
|
||||||
|
Self::new("mistral-large-3-25-12")
|
||||||
|
}
|
||||||
|
pub fn mistral_medium_latest() -> Self {
|
||||||
|
Self::new("mistral-medium-latest")
|
||||||
|
}
|
||||||
|
pub fn mistral_medium_3_1() -> Self {
|
||||||
|
Self::new("mistral-medium-3-1-25-08")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_latest() -> Self {
|
||||||
|
Self::new("mistral-small-latest")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_4() -> Self {
|
||||||
|
Self::new("mistral-small-4-0-26-03")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_3_2() -> Self {
|
||||||
|
Self::new("mistral-small-3-2-25-06")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ministral
|
||||||
|
pub fn ministral_3_14b() -> Self {
|
||||||
|
Self::new("ministral-3-14b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_8b() -> Self {
|
||||||
|
Self::new("ministral-3-8b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_3b() -> Self {
|
||||||
|
Self::new("ministral-3-3b-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning
|
||||||
|
pub fn magistral_medium_latest() -> Self {
|
||||||
|
Self::new("magistral-medium-latest")
|
||||||
|
}
|
||||||
|
pub fn magistral_small_latest() -> Self {
|
||||||
|
Self::new("magistral-small-latest")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code
|
||||||
|
pub fn codestral_latest() -> Self {
|
||||||
|
Self::new("codestral-latest")
|
||||||
|
}
|
||||||
|
pub fn codestral_2508() -> Self {
|
||||||
|
Self::new("codestral-2508")
|
||||||
|
}
|
||||||
|
pub fn codestral_embed() -> Self {
|
||||||
|
Self::new("codestral-embed-25-05")
|
||||||
|
}
|
||||||
|
pub fn devstral_2() -> Self {
|
||||||
|
Self::new("devstral-2-25-12")
|
||||||
|
}
|
||||||
|
pub fn devstral_small_2() -> Self {
|
||||||
|
Self::new("devstral-small-2-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multimodal / Vision
|
||||||
|
pub fn pixtral_large() -> Self {
|
||||||
|
Self::new("pixtral-large-2411")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio
|
||||||
|
pub fn voxtral_mini_transcribe() -> Self {
|
||||||
|
Self::new("voxtral-mini-transcribe-2-26-02")
|
||||||
|
}
|
||||||
|
pub fn voxtral_small() -> Self {
|
||||||
|
Self::new("voxtral-small-25-07")
|
||||||
|
}
|
||||||
|
pub fn voxtral_mini() -> Self {
|
||||||
|
Self::new("voxtral-mini-25-07")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy (kept for backward compatibility)
|
||||||
|
pub fn open_mistral_nemo() -> Self {
|
||||||
|
Self::new("open-mistral-nemo")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding
|
||||||
|
pub fn mistral_embed() -> Self {
|
||||||
|
Self::new("mistral-embed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moderation
|
||||||
|
pub fn mistral_moderation_latest() -> Self {
|
||||||
|
Self::new("mistral-moderation-26-03")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OCR
|
||||||
|
pub fn mistral_ocr_latest() -> Self {
|
||||||
|
Self::new("mistral-ocr-latest")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
impl fmt::Display for Model {
|
||||||
pub enum EmbedModel {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
#[serde(rename = "mistral-embed")]
|
write!(f, "{}", self.0)
|
||||||
MistralEmbed,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Model {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for Model {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
395
src/v1/conversation_stream.rs
Normal file
395
src/v1/conversation_stream.rs
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
//! Streaming support for the Conversations API.
|
||||||
|
//!
|
||||||
|
//! When `stream: true` is set on a conversation request, the API returns
|
||||||
|
//! Server-Sent Events (SSE). Each event has an `event:` type line and a
|
||||||
|
//! `data:` JSON payload, discriminated by the `type` field.
|
||||||
|
//!
|
||||||
|
//! Event types:
|
||||||
|
//! - `conversation.response.started` — generation began
|
||||||
|
//! - `message.output.delta` — partial assistant text
|
||||||
|
//! - `function.call.delta` — a function call (tool call)
|
||||||
|
//! - `conversation.response.done` — generation complete (has usage)
|
||||||
|
//! - `conversation.response.error` — error during generation
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{conversations, error};
|
||||||
|
|
||||||
|
// ── SSE event types ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// A streaming event from the Conversations API.
|
||||||
|
/// The `type` field discriminates the variant.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ConversationEvent {
|
||||||
|
/// Generation started.
|
||||||
|
#[serde(rename = "conversation.response.started")]
|
||||||
|
ResponseStarted {
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Partial assistant text output.
|
||||||
|
#[serde(rename = "message.output.delta")]
|
||||||
|
MessageOutput {
|
||||||
|
id: String,
|
||||||
|
content: serde_json::Value, // string or array of chunks
|
||||||
|
#[serde(default)]
|
||||||
|
role: String,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
content_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
model: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// A function call from the model.
|
||||||
|
#[serde(rename = "function.call.delta")]
|
||||||
|
FunctionCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
tool_call_id: String,
|
||||||
|
arguments: String,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
model: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
confirmation_status: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Generation complete — includes token usage.
|
||||||
|
#[serde(rename = "conversation.response.done")]
|
||||||
|
ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo,
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Error during generation.
|
||||||
|
#[serde(rename = "conversation.response.error")]
|
||||||
|
ResponseError {
|
||||||
|
message: String,
|
||||||
|
code: i32,
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution started (server-side).
|
||||||
|
#[serde(rename = "tool.execution.started")]
|
||||||
|
ToolExecutionStarted {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution delta (server-side).
|
||||||
|
#[serde(rename = "tool.execution.delta")]
|
||||||
|
ToolExecutionDelta {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution done (server-side).
|
||||||
|
#[serde(rename = "tool.execution.done")]
|
||||||
|
ToolExecutionDone {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Agent handoff started.
|
||||||
|
#[serde(rename = "agent.handoff.started")]
|
||||||
|
AgentHandoffStarted {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Agent handoff done.
|
||||||
|
#[serde(rename = "agent.handoff.done")]
|
||||||
|
AgentHandoffDone {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationEvent {
|
||||||
|
/// Extract text content from a MessageOutput event.
|
||||||
|
pub fn text_delta(&self) -> Option<String> {
|
||||||
|
match self {
|
||||||
|
ConversationEvent::MessageOutput { content, .. } => {
|
||||||
|
// content can be a string or an array of chunks
|
||||||
|
if let Some(s) = content.as_str() {
|
||||||
|
Some(s.to_string())
|
||||||
|
} else if let Some(arr) = content.as_array() {
|
||||||
|
// Array of chunks — extract text from TextChunk items
|
||||||
|
let mut text = String::new();
|
||||||
|
for chunk in arr {
|
||||||
|
if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) {
|
||||||
|
text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if text.is_empty() { None } else { Some(text) }
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── SSE parsing ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Parse a single SSE `data:` line into a conversation event.
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `Ok(Some(event))` — parsed event
|
||||||
|
/// - `Ok(None)` — `[DONE]` signal or empty/comment line
|
||||||
|
/// - `Err(e)` — parse error
|
||||||
|
pub fn parse_sse_line(line: &str) -> Result<Option<ConversationEvent>, error::ApiError> {
|
||||||
|
let line = line.trim();
|
||||||
|
|
||||||
|
if line.is_empty() || line.starts_with(':') || line.starts_with("event:") {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "data: [DONE]" || line == "[DONE]" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE data lines start with "data: "
|
||||||
|
let json = match line.strip_prefix("data: ") {
|
||||||
|
Some(j) => j.trim(),
|
||||||
|
None => return Ok(None), // not a data line
|
||||||
|
};
|
||||||
|
if json.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_json::from_str::<ConversationEvent>(json).map(Some).map_err(|e| {
|
||||||
|
error::ApiError {
|
||||||
|
message: format!("Failed to parse conversation stream event: {e}\nRaw: {json}"),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Accumulate streaming events into a final `ConversationResponse`.
|
||||||
|
pub fn accumulate(
|
||||||
|
conversation_id: &str,
|
||||||
|
events: &[ConversationEvent],
|
||||||
|
) -> conversations::ConversationResponse {
|
||||||
|
let mut full_text = String::new();
|
||||||
|
let mut function_calls = Vec::new();
|
||||||
|
let mut usage = conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
ConversationEvent::MessageOutput { content, .. } => {
|
||||||
|
if let Some(s) = content.as_str() {
|
||||||
|
full_text.push_str(s);
|
||||||
|
} else if let Some(arr) = content.as_array() {
|
||||||
|
for chunk in arr {
|
||||||
|
if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) {
|
||||||
|
full_text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConversationEvent::FunctionCall {
|
||||||
|
id, name, tool_call_id, arguments, ..
|
||||||
|
} => {
|
||||||
|
function_calls.push(conversations::ConversationEntry::FunctionCall(
|
||||||
|
conversations::FunctionCallEntry {
|
||||||
|
name: name.clone(),
|
||||||
|
arguments: arguments.clone(),
|
||||||
|
id: Some(id.clone()),
|
||||||
|
object: None,
|
||||||
|
tool_call_id: Some(tool_call_id.clone()),
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
ConversationEvent::ResponseDone { usage: u, .. } => {
|
||||||
|
usage = u.clone();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut outputs = Vec::new();
|
||||||
|
if !full_text.is_empty() {
|
||||||
|
outputs.push(conversations::ConversationEntry::MessageOutput(
|
||||||
|
conversations::MessageOutputEntry {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: crate::v1::chat::ChatMessageContent::Text(full_text),
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
model: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
outputs.extend(function_calls);
|
||||||
|
|
||||||
|
conversations::ConversationResponse {
|
||||||
|
conversation_id: conversation_id.to_string(),
|
||||||
|
outputs,
|
||||||
|
usage,
|
||||||
|
object: "conversation.response".into(),
|
||||||
|
guardrails: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_done() {
|
||||||
|
assert!(parse_sse_line("data: [DONE]").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_empty() {
|
||||||
|
assert!(parse_sse_line("").unwrap().is_none());
|
||||||
|
assert!(parse_sse_line(" ").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_comment() {
|
||||||
|
assert!(parse_sse_line(": keep-alive").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_started() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.started","created_at":"2026-03-24T12:00:00Z"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert!(matches!(event, ConversationEvent::ResponseStarted { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message_output_string() {
|
||||||
|
let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":"hello ","role":"assistant"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert_eq!(event.text_delta(), Some("hello ".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message_output_chunks() {
|
||||||
|
let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":[{"type":"text","text":"world"}],"role":"assistant"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert_eq!(event.text_delta(), Some("world".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_function_call() {
|
||||||
|
let line = r#"data: {"type":"function.call.delta","id":"fc-1","name":"search_web","tool_call_id":"tc-1","arguments":"{\"query\":\"test\"}"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::FunctionCall { name, arguments, tool_call_id, .. } => {
|
||||||
|
assert_eq!(name, "search_web");
|
||||||
|
assert_eq!(tool_call_id, "tc-1");
|
||||||
|
assert!(arguments.contains("test"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected FunctionCall"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_done() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.done","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::ResponseDone { usage, .. } => {
|
||||||
|
assert_eq!(usage.prompt_tokens, 100);
|
||||||
|
assert_eq!(usage.completion_tokens, 50);
|
||||||
|
assert_eq!(usage.total_tokens, 150);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ResponseDone"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_error() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.error","message":"rate limited","code":429}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::ResponseError { message, code, .. } => {
|
||||||
|
assert_eq!(message, "rate limited");
|
||||||
|
assert_eq!(code, 429);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ResponseError"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::ResponseStarted { created_at: None },
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("hello "),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("world"),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = accumulate("conv-1", &events);
|
||||||
|
assert_eq!(resp.conversation_id, "conv-1");
|
||||||
|
assert_eq!(resp.assistant_text(), Some("hello world".into()));
|
||||||
|
assert_eq!(resp.usage.total_tokens, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate_with_function_calls() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::FunctionCall {
|
||||||
|
id: "fc-1".into(),
|
||||||
|
name: "search".into(),
|
||||||
|
tool_call_id: "tc-1".into(),
|
||||||
|
arguments: r#"{"q":"test"}"#.into(),
|
||||||
|
output_index: 0,
|
||||||
|
model: None,
|
||||||
|
confirmation_status: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 20,
|
||||||
|
completion_tokens: 10,
|
||||||
|
total_tokens: 30,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = accumulate("conv-2", &events);
|
||||||
|
assert!(resp.assistant_text().is_none());
|
||||||
|
let calls = resp.function_calls();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "search");
|
||||||
|
}
|
||||||
|
}
|
||||||
377
src/v1/conversations.rs
Normal file
377
src/v1/conversations.rs
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{agents, chat};
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Conversations API (Beta)
|
||||||
|
// POST/GET/DELETE /v1/conversations
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation entries (inputs and outputs)
|
||||||
|
// All entries share common fields: id, object, type, created_at, completed_at
|
||||||
|
|
||||||
|
/// Input entry — a message sent to the conversation.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct MessageInputEntry {
|
||||||
|
pub role: String,
|
||||||
|
pub content: chat::ChatMessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prefix: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Output entry — an assistant message produced by the model.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct MessageOutputEntry {
|
||||||
|
pub role: String,
|
||||||
|
pub content: chat::ChatMessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A function call requested by the model.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionCallEntry {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of a function call, sent back to the conversation.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionResultEntry {
|
||||||
|
pub tool_call_id: String,
|
||||||
|
pub result: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A built-in tool execution (web_search, code_interpreter, etc.).
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ToolExecutionEntry {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub extra: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent handoff entry — transfer between agents.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct AgentHandoffEntry {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub previous_agent_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub next_agent_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub object: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Union of all conversation entry types.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ConversationEntry {
|
||||||
|
#[serde(rename = "message.input")]
|
||||||
|
MessageInput(MessageInputEntry),
|
||||||
|
#[serde(rename = "message.output")]
|
||||||
|
MessageOutput(MessageOutputEntry),
|
||||||
|
#[serde(rename = "function.call")]
|
||||||
|
FunctionCall(FunctionCallEntry),
|
||||||
|
#[serde(rename = "function.result")]
|
||||||
|
FunctionResult(FunctionResultEntry),
|
||||||
|
#[serde(rename = "tool.execution")]
|
||||||
|
ToolExecution(ToolExecutionEntry),
|
||||||
|
#[serde(rename = "agent.handoff")]
|
||||||
|
AgentHandoff(AgentHandoffEntry),
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation inputs (flexible: string or array of entries)
|
||||||
|
|
||||||
|
/// Conversation input: either a plain string or structured entry array.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ConversationInput {
|
||||||
|
Text(String),
|
||||||
|
Entries(Vec<ConversationEntry>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for ConversationInput {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self::Text(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for ConversationInput {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self::Text(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<ConversationEntry>> for ConversationInput {
|
||||||
|
fn from(entries: Vec<ConversationEntry>) -> Self {
|
||||||
|
Self::Entries(entries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Handoff execution mode
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum HandoffExecution {
|
||||||
|
#[serde(rename = "server")]
|
||||||
|
Server,
|
||||||
|
#[serde(rename = "client")]
|
||||||
|
Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for HandoffExecution {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Server
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Create conversation request (POST /v1/conversations)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct CreateConversationRequest {
|
||||||
|
pub inputs: ConversationInput,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub agent_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub agent_version: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<agents::CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<agents::AgentTool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoff_execution: Option<HandoffExecution>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub store: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Append to conversation request (POST /v1/conversations/{id})
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct AppendConversationRequest {
|
||||||
|
pub inputs: ConversationInput,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<agents::CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoff_execution: Option<HandoffExecution>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub store: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_confirmations: Option<Vec<ToolCallConfirmation>>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCallConfirmation {
|
||||||
|
pub tool_call_id: String,
|
||||||
|
pub result: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Restart conversation request (POST /v1/conversations/{id}/restart)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct RestartConversationRequest {
|
||||||
|
pub from_entry_id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub inputs: Option<ConversationInput>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<agents::CompletionArgs>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub agent_version: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub handoff_execution: Option<HandoffExecution>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub store: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation response (returned by create, append, restart)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ConversationUsageInfo {
|
||||||
|
#[serde(default)]
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ConversationResponse {
|
||||||
|
pub conversation_id: String,
|
||||||
|
pub outputs: Vec<ConversationEntry>,
|
||||||
|
pub usage: ConversationUsageInfo,
|
||||||
|
#[serde(default)]
|
||||||
|
pub object: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub guardrails: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationResponse {
|
||||||
|
/// Extract the assistant's text response from the outputs, if any.
|
||||||
|
pub fn assistant_text(&self) -> Option<String> {
|
||||||
|
for entry in &self.outputs {
|
||||||
|
if let ConversationEntry::MessageOutput(msg) = entry {
|
||||||
|
return Some(msg.content.text());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract all function call entries from the outputs.
|
||||||
|
pub fn function_calls(&self) -> Vec<&FunctionCallEntry> {
|
||||||
|
self.outputs
|
||||||
|
.iter()
|
||||||
|
.filter_map(|e| match e {
|
||||||
|
ConversationEntry::FunctionCall(fc) => Some(fc),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if any outputs are agent handoff entries.
|
||||||
|
pub fn has_handoff(&self) -> bool {
|
||||||
|
self.outputs
|
||||||
|
.iter()
|
||||||
|
.any(|e| matches!(e, ConversationEntry::AgentHandoff(_)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation history response (GET /v1/conversations/{id}/history)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ConversationHistoryResponse {
|
||||||
|
pub conversation_id: String,
|
||||||
|
pub entries: Vec<ConversationEntry>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation messages response (GET /v1/conversations/{id}/messages)
|
||||||
|
// Note: may have same shape as history; keeping separate for API clarity
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ConversationMessagesResponse {
|
||||||
|
pub conversation_id: String,
|
||||||
|
#[serde(alias = "messages", alias = "entries")]
|
||||||
|
pub messages: Vec<ConversationEntry>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub object: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Conversation info (GET /v1/conversations/{id})
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct Conversation {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub object: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub agent_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub created_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub updated_at: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_args: Option<agents::CompletionArgs>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tools: Vec<agents::AgentTool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub guardrails: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List conversations response. API returns a raw JSON array.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub struct ConversationListResponse {
|
||||||
|
pub data: Vec<Conversation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete conversation response. API returns 204 No Content.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ConversationDeleteResponse {
|
||||||
|
pub deleted: bool,
|
||||||
|
}
|
||||||
@@ -2,52 +2,78 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::v1::{common, constants};
|
use crate::v1::{common, constants};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingRequestOptions {
|
pub struct EmbeddingRequestOptions {
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl Default for EmbeddingRequestOptions {
|
impl Default for EmbeddingRequestOptions {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
|
output_dimension: None,
|
||||||
|
output_dtype: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub input: Vec<String>,
|
pub input: Vec<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl EmbeddingRequest {
|
impl EmbeddingRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model: constants::EmbedModel,
|
model: constants::Model,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<EmbeddingRequestOptions>,
|
options: Option<EmbeddingRequestOptions>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
input,
|
input,
|
||||||
encoding_format,
|
encoding_format: opts.encoding_format,
|
||||||
|
output_dimension: opts.output_dimension,
|
||||||
|
output_dtype: opts.output_dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
#[allow(non_camel_case_types)]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum EmbeddingRequestEncodingFormat {
|
pub enum EmbeddingRequestEncodingFormat {
|
||||||
float,
|
Float,
|
||||||
|
Base64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EmbeddingOutputDtype {
|
||||||
|
Float,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Binary,
|
||||||
|
Ubinary,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
pub object: String,
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub data: Vec<EmbeddingResponseDataItem>,
|
pub data: Vec<EmbeddingResponseDataItem>,
|
||||||
pub usage: common::ResponseUsage,
|
pub usage: common::ResponseUsage,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ impl Error for ApiError {}
|
|||||||
|
|
||||||
#[derive(Debug, PartialEq, thiserror::Error)]
|
#[derive(Debug, PartialEq, thiserror::Error)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...).")]
|
#[error(
|
||||||
|
"You must either set the `MISTRAL_API_KEY` environment variable or specify it in `Client::new(api_key, ...)."
|
||||||
|
)]
|
||||||
MissingApiKey,
|
MissingApiKey,
|
||||||
#[error("Failed to read the response text.")]
|
#[error("Failed to read the response text.")]
|
||||||
UnreadableResponseText,
|
UnreadableResponseText,
|
||||||
|
|||||||
55
src/v1/files.rs
Normal file
55
src/v1/files.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
|
pub enum FilePurpose {
|
||||||
|
FineTune,
|
||||||
|
Batch,
|
||||||
|
Ocr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileListResponse {
|
||||||
|
pub data: Vec<FileObject>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileObject {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub bytes: u64,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub filename: String,
|
||||||
|
pub purpose: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sample_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub source: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub num_lines: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mimetype: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileDeleteResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub deleted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileUrlResponse {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
}
|
||||||
101
src/v1/fim.rs
Normal file
101
src/v1/fim.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{common, constants};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FimParams {
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
}
|
||||||
|
impl Default for FimParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stop: None,
|
||||||
|
random_seed: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FimRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub prompt: String,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
}
|
||||||
|
impl FimRequest {
|
||||||
|
pub fn new(
|
||||||
|
model: constants::Model,
|
||||||
|
prompt: String,
|
||||||
|
stream: bool,
|
||||||
|
options: Option<FimParams>,
|
||||||
|
) -> Self {
|
||||||
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
stream,
|
||||||
|
suffix: opts.suffix,
|
||||||
|
max_tokens: opts.max_tokens,
|
||||||
|
min_tokens: opts.min_tokens,
|
||||||
|
temperature: opts.temperature,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
stop: opts.stop,
|
||||||
|
random_seed: opts.random_seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<FimResponseChoice>,
|
||||||
|
pub usage: common::ResponseUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponseChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: FimResponseMessage,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponseMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
101
src/v1/fine_tuning.rs
Normal file
101
src/v1/fine_tuning.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FineTuningJobRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub training_files: Vec<TrainingFile>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub validation_files: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hyperparameters: Option<Hyperparameters>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub auto_start: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub job_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub integrations: Option<Vec<Integration>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct TrainingFile {
|
||||||
|
pub file_id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub weight: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Hyperparameters {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub learning_rate: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub training_steps: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub warmup_fraction: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub epochs: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Integration {
|
||||||
|
pub r#type: String,
|
||||||
|
pub project: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FineTuningJobResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub status: FineTuningJobStatus,
|
||||||
|
pub created_at: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub modified_at: Option<u64>,
|
||||||
|
pub training_files: Vec<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub validation_files: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hyperparameters: Option<Hyperparameters>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub fine_tuned_model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub integrations: Option<Vec<Integration>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub trained_tokens: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum FineTuningJobStatus {
|
||||||
|
Queued,
|
||||||
|
Running,
|
||||||
|
Success,
|
||||||
|
Failed,
|
||||||
|
TimeoutExceeded,
|
||||||
|
CancellationRequested,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FineTuningJobListResponse {
|
||||||
|
pub data: Vec<FineTuningJobResponse>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
@@ -1,7 +1,20 @@
|
|||||||
pub mod chat_completion;
|
pub mod agents;
|
||||||
|
pub mod audio;
|
||||||
|
pub mod batch;
|
||||||
|
pub mod chat;
|
||||||
|
pub mod chat_stream;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
|
pub mod conversation_stream;
|
||||||
|
pub mod conversations;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod files;
|
||||||
|
pub mod fim;
|
||||||
|
pub mod fine_tuning;
|
||||||
pub mod model_list;
|
pub mod model_list;
|
||||||
|
pub mod moderation;
|
||||||
|
pub mod ocr;
|
||||||
|
pub mod tool;
|
||||||
|
pub mod utils;
|
||||||
|
|||||||
@@ -1,39 +1,58 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ModelListResponse {
|
pub struct ModelListResponse {
|
||||||
pub object: String,
|
pub object: String,
|
||||||
pub data: Vec<ModelListData>,
|
pub data: Vec<ModelListData>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// See the [models API](https://docs.mistral.ai/api/#tag/models/operation/list_models_v1_models_get).
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ModelListData {
|
pub struct ModelListData {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub owned_by: String,
|
pub owned_by: String,
|
||||||
pub permission: Vec<ModelListDataPermission>,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
pub root: Option<String>,
|
||||||
// pub root: ???,
|
#[serde(default)]
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
pub archived: bool,
|
||||||
// pub parent: ???,
|
#[serde(default)]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub capabilities: Option<ModelListDataCapabilities>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_context_length: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub aliases: Vec<String>,
|
||||||
|
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub deprecation: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ModelListDataPermission {
|
pub struct ModelListDataCapabilities {
|
||||||
|
#[serde(default)]
|
||||||
|
pub completion_chat: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub completion_fim: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub function_calling: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub fine_tuning: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModelDeleteResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
pub deleted: bool,
|
||||||
pub created: u32,
|
|
||||||
pub allow_create_engine: bool,
|
|
||||||
pub allow_sampling: bool,
|
|
||||||
pub allow_logprobs: bool,
|
|
||||||
pub allow_search_indices: bool,
|
|
||||||
pub allow_view: bool,
|
|
||||||
pub allow_fine_tuning: bool,
|
|
||||||
pub organization: String,
|
|
||||||
pub is_blocking: bool,
|
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub group: ???,
|
|
||||||
}
|
}
|
||||||
|
|||||||
70
src/v1/moderation.rs
Normal file
70
src/v1/moderation.rs
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ModerationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<String>,
|
||||||
|
}
|
||||||
|
impl ModerationRequest {
|
||||||
|
pub fn new(model: constants::Model, input: Vec<String>) -> Self {
|
||||||
|
Self { model, input }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatModerationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<ChatModerationInput>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatModerationInput {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ClassificationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatClassificationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<ChatModerationInput>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModerationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub results: Vec<ModerationResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModerationResult {
|
||||||
|
pub categories: serde_json::Value,
|
||||||
|
pub category_scores: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ClassificationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub results: Vec<ClassificationResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ClassificationResult {
|
||||||
|
pub categories: serde_json::Value,
|
||||||
|
pub category_scores: serde_json::Value,
|
||||||
|
}
|
||||||
96
src/v1/ocr.rs
Normal file
96
src/v1/ocr.rs
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OcrRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub document: OcrDocument,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub pages: Option<Vec<u32>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub table_format: Option<OcrTableFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_image_base64: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_limit: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OcrDocument {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub document_url: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_id: Option<String>,
|
||||||
|
}
|
||||||
|
impl OcrDocument {
|
||||||
|
pub fn from_url(url: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "document_url".to_string(),
|
||||||
|
document_url: Some(url.to_string()),
|
||||||
|
file_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_file_id(file_id: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "file_id".to_string(),
|
||||||
|
document_url: None,
|
||||||
|
file_id: Some(file_id.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum OcrTableFormat {
|
||||||
|
Markdown,
|
||||||
|
Html,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrResponse {
|
||||||
|
pub pages: Vec<OcrPage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage_info: Option<OcrUsageInfo>,
|
||||||
|
pub model: constants::Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrPage {
|
||||||
|
pub index: u32,
|
||||||
|
pub markdown: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub images: Vec<OcrImage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub dimensions: Option<OcrPageDimensions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrImage {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_base64: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrPageDimensions {
|
||||||
|
pub width: f32,
|
||||||
|
pub height: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrUsageInfo {
|
||||||
|
pub pages_processed: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub doc_size_bytes: Option<u64>,
|
||||||
|
}
|
||||||
91
src/v1/tool.rs
Normal file
91
src/v1/tool.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{any::Any, fmt::Debug};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Definitions
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub r#type: Option<String>,
|
||||||
|
pub function: ToolCallFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCallFunction {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
pub r#type: ToolType,
|
||||||
|
pub function: ToolFunction,
|
||||||
|
}
|
||||||
|
impl Tool {
|
||||||
|
/// Create a tool with a JSON Schema parameters object.
|
||||||
|
pub fn new(
|
||||||
|
function_name: String,
|
||||||
|
function_description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
r#type: ToolType::Function,
|
||||||
|
function: ToolFunction {
|
||||||
|
name: function_name,
|
||||||
|
description: function_description,
|
||||||
|
parameters,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ToolFunction {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub enum ToolType {
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An enum representing how functions should be called.
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
/// The model is forced to call a function.
|
||||||
|
#[serde(rename = "any")]
|
||||||
|
Any,
|
||||||
|
/// The model can choose to either generate a message or call a function.
|
||||||
|
#[serde(rename = "auto")]
|
||||||
|
Auto,
|
||||||
|
/// The model won't call a function and will generate a message instead.
|
||||||
|
#[serde(rename = "none")]
|
||||||
|
None,
|
||||||
|
/// The model must call at least one tool.
|
||||||
|
#[serde(rename = "required")]
|
||||||
|
Required,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Custom
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Function: Send {
|
||||||
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Debug for dyn Function {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Function()")
|
||||||
|
}
|
||||||
|
}
|
||||||
32
src/v1/utils.rs
Normal file
32
src/v1/utils.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use log::debug;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
pub fn prettify_json_string(json: &String) -> String {
|
||||||
|
match serde_json::from_str::<serde_json::Value>(&json) {
|
||||||
|
Ok(json_value) => {
|
||||||
|
serde_json::to_string_pretty(&json_value).unwrap_or_else(|_| json.to_owned())
|
||||||
|
}
|
||||||
|
Err(_) => json.to_owned(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prettify_json_struct<T: Debug + Serialize>(value: T) -> String {
|
||||||
|
match serde_json::to_string_pretty(&value) {
|
||||||
|
Ok(pretty_json) => pretty_json,
|
||||||
|
Err(_) => format!("{:?}", value),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn debug_pretty_json_from_string(label: &str, json: &String) -> () {
|
||||||
|
let pretty_json = prettify_json_string(json);
|
||||||
|
|
||||||
|
debug!("{label}: {}", pretty_json);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn debug_pretty_json_from_struct<T: Debug + Serialize>(label: &str, value: &T) -> () {
|
||||||
|
let pretty_json = prettify_json_struct(value);
|
||||||
|
|
||||||
|
debug!("{label}: {}", pretty_json);
|
||||||
|
}
|
||||||
3
tests/setup.rs
Normal file
3
tests/setup.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub fn setup() {
|
||||||
|
let _ = env_logger::builder().is_test(true).try_init();
|
||||||
|
}
|
||||||
372
tests/v1_agents_api_test.rs
Normal file
372
tests/v1_agents_api_test.rs
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
agents::*,
|
||||||
|
client::Client,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
|
fn make_client() -> Client {
|
||||||
|
Client::new(None, None, None, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Sync tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_and_delete_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-create-delete".to_string(),
|
||||||
|
description: Some("Integration test agent".to_string()),
|
||||||
|
instructions: Some("You are a test agent. Respond briefly.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = client.create_agent(&req).unwrap();
|
||||||
|
assert!(!agent.id.is_empty());
|
||||||
|
assert_eq!(agent.name, "test-create-delete");
|
||||||
|
assert_eq!(agent.model, "mistral-medium-latest");
|
||||||
|
assert_eq!(agent.object, "agent");
|
||||||
|
// Version starts at 0 in the API
|
||||||
|
assert!(agent.description.as_deref() == Some("Integration test agent"));
|
||||||
|
assert!(agent.instructions.as_deref() == Some("You are a test agent. Respond briefly."));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let del = client.delete_agent(&agent.id).unwrap();
|
||||||
|
assert!(del.deleted);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_agent_with_tools() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-agent-tools".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: Some("You can search.".to_string()),
|
||||||
|
tools: Some(vec![
|
||||||
|
AgentTool::function(
|
||||||
|
"search".to_string(),
|
||||||
|
"Search for things".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Search query"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
AgentTool::web_search(),
|
||||||
|
]),
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.3),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = client.create_agent(&req).unwrap();
|
||||||
|
assert_eq!(agent.tools.len(), 2);
|
||||||
|
assert!(matches!(&agent.tools[0], AgentTool::Function(_)));
|
||||||
|
assert!(matches!(&agent.tools[1], AgentTool::WebSearch {}));
|
||||||
|
|
||||||
|
// Verify completion_args round-tripped
|
||||||
|
let args = agent.completion_args.as_ref().unwrap();
|
||||||
|
assert!((args.temperature.unwrap() - 0.3).abs() < 0.01);
|
||||||
|
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-get-agent".to_string(),
|
||||||
|
description: Some("Get test".to_string()),
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let created = client.create_agent(&req).unwrap();
|
||||||
|
let fetched = client.get_agent(&created.id).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(fetched.id, created.id);
|
||||||
|
assert_eq!(fetched.name, "test-get-agent");
|
||||||
|
assert_eq!(fetched.model, "mistral-medium-latest");
|
||||||
|
assert_eq!(fetched.description.as_deref(), Some("Get test"));
|
||||||
|
|
||||||
|
client.delete_agent(&created.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_update_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-update-agent".to_string(),
|
||||||
|
description: Some("Before update".to_string()),
|
||||||
|
instructions: Some("Original instructions".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let created = client.create_agent(&req).unwrap();
|
||||||
|
|
||||||
|
let update = UpdateAgentRequest {
|
||||||
|
name: Some("test-update-agent-renamed".to_string()),
|
||||||
|
description: Some("After update".to_string()),
|
||||||
|
instructions: Some("Updated instructions".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let updated = client.update_agent(&created.id, &update).unwrap();
|
||||||
|
assert_eq!(updated.id, created.id);
|
||||||
|
assert_eq!(updated.name, "test-update-agent-renamed");
|
||||||
|
assert_eq!(updated.description.as_deref(), Some("After update"));
|
||||||
|
assert_eq!(updated.instructions.as_deref(), Some("Updated instructions"));
|
||||||
|
// Version should have incremented
|
||||||
|
assert!(updated.version >= created.version);
|
||||||
|
|
||||||
|
client.delete_agent(&created.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_list_agents() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Create two agents
|
||||||
|
let req1 = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-list-agent-1".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
let req2 = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-list-agent-2".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let a1 = client.create_agent(&req1).unwrap();
|
||||||
|
let a2 = client.create_agent(&req2).unwrap();
|
||||||
|
|
||||||
|
let list = client.list_agents().unwrap();
|
||||||
|
assert!(list.data.len() >= 2);
|
||||||
|
|
||||||
|
// Our two agents should be in the list
|
||||||
|
let ids: Vec<&str> = list.data.iter().map(|a| a.id.as_str()).collect();
|
||||||
|
assert!(ids.contains(&a1.id.as_str()));
|
||||||
|
assert!(ids.contains(&a2.id.as_str()));
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
client.delete_agent(&a1.id).unwrap();
|
||||||
|
client.delete_agent(&a2.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Async tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_and_delete_agent_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-async-create-delete".to_string(),
|
||||||
|
description: Some("Async integration test".to_string()),
|
||||||
|
instructions: Some("Respond briefly.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = client.create_agent_async(&req).await.unwrap();
|
||||||
|
assert!(!agent.id.is_empty());
|
||||||
|
assert_eq!(agent.name, "test-async-create-delete");
|
||||||
|
assert_eq!(agent.object, "agent");
|
||||||
|
|
||||||
|
let del = client.delete_agent_async(&agent.id).await.unwrap();
|
||||||
|
assert!(del.deleted);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_agent_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-async-get".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let created = client.create_agent_async(&req).await.unwrap();
|
||||||
|
let fetched = client.get_agent_async(&created.id).await.unwrap();
|
||||||
|
assert_eq!(fetched.id, created.id);
|
||||||
|
assert_eq!(fetched.name, "test-async-get");
|
||||||
|
|
||||||
|
client.delete_agent_async(&created.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_agent_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-async-update".to_string(),
|
||||||
|
description: Some("Before".to_string()),
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let created = client.create_agent_async(&req).await.unwrap();
|
||||||
|
|
||||||
|
let update = UpdateAgentRequest {
|
||||||
|
description: Some("After".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let updated = client.update_agent_async(&created.id, &update).await.unwrap();
|
||||||
|
assert_eq!(updated.description.as_deref(), Some("After"));
|
||||||
|
|
||||||
|
client.delete_agent_async(&created.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_agents_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-async-list".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = client.create_agent_async(&req).await.unwrap();
|
||||||
|
let list = client.list_agents_async().await.unwrap();
|
||||||
|
assert!(list.data.iter().any(|a| a.id == agent.id));
|
||||||
|
|
||||||
|
client.delete_agent_async(&agent.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_agent_with_handoffs() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Create a target agent first
|
||||||
|
let target_req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-handoff-target".to_string(),
|
||||||
|
description: Some("Target agent for handoff".to_string()),
|
||||||
|
instructions: Some("You handle math questions.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
let target = client.create_agent(&target_req).unwrap();
|
||||||
|
|
||||||
|
// Create orchestrator with handoff to target
|
||||||
|
let orch_req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-handoff-orchestrator".to_string(),
|
||||||
|
description: Some("Orchestrator with handoffs".to_string()),
|
||||||
|
instructions: Some("Delegate math questions.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: Some(vec![target.id.clone()]),
|
||||||
|
completion_args: None,
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
let orch = client.create_agent(&orch_req).unwrap();
|
||||||
|
assert_eq!(orch.handoffs.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(orch.handoffs.as_ref().unwrap()[0], target.id);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
client.delete_agent(&orch.id).unwrap();
|
||||||
|
client.delete_agent(&target.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_agent_completion_with_created_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "test-completion-agent".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Always respond with exactly the word 'pong'.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let agent = client.create_agent(&req).unwrap();
|
||||||
|
|
||||||
|
// Use the existing agent_completion method with the created agent
|
||||||
|
use mistralai_client::v1::chat::ChatMessage;
|
||||||
|
let messages = vec![ChatMessage::new_user_message("ping")];
|
||||||
|
let response = client
|
||||||
|
.agent_completion(agent.id.clone(), messages, None)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!response.choices.is_empty());
|
||||||
|
let text = response.choices[0].message.content.text().to_lowercase();
|
||||||
|
assert!(text.contains("pong"), "Expected 'pong', got: {text}");
|
||||||
|
assert!(response.usage.total_tokens > 0);
|
||||||
|
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
119
tests/v1_agents_types_test.rs
Normal file
119
tests/v1_agents_types_test.rs
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
use mistralai_client::v1::agents::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_agent_request_serialization() {
|
||||||
|
let req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "sol-orchestrator".to_string(),
|
||||||
|
description: Some("Virtual librarian".to_string()),
|
||||||
|
instructions: Some("You are Sol.".to_string()),
|
||||||
|
tools: Some(vec![AgentTool::web_search()]),
|
||||||
|
handoffs: Some(vec!["agent_abc123".to_string()]),
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.3),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["model"], "mistral-medium-latest");
|
||||||
|
assert_eq!(json["name"], "sol-orchestrator");
|
||||||
|
assert_eq!(json["tools"][0]["type"], "web_search");
|
||||||
|
assert_eq!(json["handoffs"][0], "agent_abc123");
|
||||||
|
assert!(json["completion_args"]["temperature"].as_f64().unwrap() - 0.3 < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_agent_response_deserialization() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"id": "ag_abc123",
|
||||||
|
"object": "agent",
|
||||||
|
"name": "sol-orchestrator",
|
||||||
|
"model": "mistral-medium-latest",
|
||||||
|
"created_at": "2026-03-21T10:00:00Z",
|
||||||
|
"updated_at": "2026-03-21T10:00:00Z",
|
||||||
|
"version": 1,
|
||||||
|
"versions": [1],
|
||||||
|
"description": "Virtual librarian",
|
||||||
|
"instructions": "You are Sol.",
|
||||||
|
"tools": [
|
||||||
|
{"type": "function", "function": {"name": "search", "description": "Search", "parameters": {}}},
|
||||||
|
{"type": "web_search"},
|
||||||
|
{"type": "code_interpreter"}
|
||||||
|
],
|
||||||
|
"handoffs": ["ag_def456"],
|
||||||
|
"completion_args": {"temperature": 0.3, "response_format": {"type": "text"}}
|
||||||
|
});
|
||||||
|
|
||||||
|
let agent: Agent = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(agent.id, "ag_abc123");
|
||||||
|
assert_eq!(agent.name, "sol-orchestrator");
|
||||||
|
assert_eq!(agent.version, 1);
|
||||||
|
assert_eq!(agent.tools.len(), 3);
|
||||||
|
assert!(matches!(&agent.tools[0], AgentTool::Function(_)));
|
||||||
|
assert!(matches!(&agent.tools[1], AgentTool::WebSearch {}));
|
||||||
|
assert!(matches!(&agent.tools[2], AgentTool::CodeInterpreter {}));
|
||||||
|
assert_eq!(agent.handoffs.as_ref().unwrap()[0], "ag_def456");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_agent_tool_function_constructor() {
|
||||||
|
let tool = AgentTool::function(
|
||||||
|
"search_archive".to_string(),
|
||||||
|
"Search messages".to_string(),
|
||||||
|
serde_json::json!({"type": "object", "properties": {"query": {"type": "string"}}}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&tool).unwrap();
|
||||||
|
assert_eq!(json["type"], "function");
|
||||||
|
assert_eq!(json["function"]["name"], "search_archive");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_args_default_skips_none() {
|
||||||
|
let args = CompletionArgs::default();
|
||||||
|
let json = serde_json::to_value(&args).unwrap();
|
||||||
|
// All fields are None, so the JSON object should be empty
|
||||||
|
assert_eq!(json, serde_json::json!({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_agent_delete_response() {
|
||||||
|
// AgentDeleteResponse is not deserialized from JSON — the API returns 204 No Content.
|
||||||
|
// The client constructs it directly.
|
||||||
|
let resp = AgentDeleteResponse { deleted: true };
|
||||||
|
assert!(resp.deleted);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_agent_list_response() {
|
||||||
|
// API returns a raw JSON array (no wrapper object)
|
||||||
|
let json = serde_json::json!([
|
||||||
|
{
|
||||||
|
"id": "ag_1",
|
||||||
|
"object": "agent",
|
||||||
|
"name": "agent-1",
|
||||||
|
"model": "mistral-medium-latest",
|
||||||
|
"created_at": "2026-03-21T10:00:00Z",
|
||||||
|
"updated_at": "2026-03-21T10:00:00Z",
|
||||||
|
"version": 0,
|
||||||
|
"tools": []
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
let resp: AgentListResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(resp.data.len(), 1);
|
||||||
|
assert_eq!(resp.data[0].name, "agent-1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_update_agent_partial() {
|
||||||
|
let req = UpdateAgentRequest {
|
||||||
|
instructions: Some("New instructions".to_string()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["instructions"], "New instructions");
|
||||||
|
assert!(json.get("model").is_none());
|
||||||
|
assert!(json.get("name").is_none());
|
||||||
|
}
|
||||||
156
tests/v1_chat_multimodal_api_test.rs
Normal file
156
tests/v1_chat_multimodal_api_test.rs
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{
|
||||||
|
ChatMessage, ChatParams, ChatResponseChoiceFinishReason, ContentPart, ImageUrl,
|
||||||
|
},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
|
fn make_client() -> Client {
|
||||||
|
Client::new(None, None, None, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multimodal_chat_with_image_url() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Use a small, publicly accessible image
|
||||||
|
let msg = ChatMessage::new_user_message_with_images(vec![
|
||||||
|
ContentPart::Text {
|
||||||
|
text: "Describe this image in one sentence.".to_string(),
|
||||||
|
},
|
||||||
|
ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "https://picsum.photos/id/237/200/300".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let model = Model::new("pixtral-large-latest".to_string());
|
||||||
|
let options = ChatParams {
|
||||||
|
max_tokens: Some(100),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.chat(model, vec![msg], Some(options)).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
response.choices[0].finish_reason,
|
||||||
|
ChatResponseChoiceFinishReason::Stop
|
||||||
|
);
|
||||||
|
let text = response.choices[0].message.content.text();
|
||||||
|
assert!(!text.is_empty(), "Expected non-empty description");
|
||||||
|
assert!(response.usage.total_tokens > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multimodal_chat_with_image_url_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let msg = ChatMessage::new_user_message_with_images(vec![
|
||||||
|
ContentPart::Text {
|
||||||
|
text: "What colors do you see in this image? Reply in one sentence.".to_string(),
|
||||||
|
},
|
||||||
|
ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "https://picsum.photos/id/237/200/300".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let model = Model::new("pixtral-large-latest".to_string());
|
||||||
|
let options = ChatParams {
|
||||||
|
max_tokens: Some(100),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.chat_async(model, vec![msg], Some(options))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let text = response.choices[0].message.content.text();
|
||||||
|
assert!(!text.is_empty(), "Expected non-empty description");
|
||||||
|
assert!(response.usage.total_tokens > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mixed_text_and_image_messages() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// First message: just text
|
||||||
|
let msg1 = ChatMessage::new_user_message("I'm going to show you an image next.");
|
||||||
|
|
||||||
|
// Second message: text + image
|
||||||
|
let msg2 = ChatMessage::new_user_message_with_images(vec![
|
||||||
|
ContentPart::Text {
|
||||||
|
text: "Here it is. What do you see?".to_string(),
|
||||||
|
},
|
||||||
|
ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "https://picsum.photos/id/237/200/300".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let model = Model::new("pixtral-large-latest".to_string());
|
||||||
|
let options = ChatParams {
|
||||||
|
max_tokens: Some(100),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.chat(model, vec![msg1, msg2], Some(options)).unwrap();
|
||||||
|
let text = response.choices[0].message.content.text();
|
||||||
|
assert!(!text.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_text_only_message_still_works() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Verify that text-only messages (the common case) still work fine
|
||||||
|
// with the new ChatMessageContent type
|
||||||
|
let msg = ChatMessage::new_user_message("What is 7 + 8?");
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(50),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.chat(model, vec![msg], Some(options)).unwrap();
|
||||||
|
let text = response.choices[0].message.content.text();
|
||||||
|
assert!(text.contains("15"), "Expected '15', got: {text}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reasoning_field_presence() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Normal model should not have reasoning
|
||||||
|
let msg = ChatMessage::new_user_message("What is 2 + 2?");
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
max_tokens: Some(50),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.chat(model, vec![msg], Some(options)).unwrap();
|
||||||
|
// reasoning is None for non-Magistral models (or it might just be absent)
|
||||||
|
// This test verifies the field deserializes correctly either way
|
||||||
|
let _ = response.choices[0].reasoning.as_ref();
|
||||||
|
}
|
||||||
204
tests/v1_chat_multimodal_test.rs
Normal file
204
tests/v1_chat_multimodal_test.rs
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
use mistralai_client::v1::chat::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_content_part_text_serialization() {
|
||||||
|
let part = ContentPart::Text {
|
||||||
|
text: "hello".to_string(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_value(&part).unwrap();
|
||||||
|
assert_eq!(json["type"], "text");
|
||||||
|
assert_eq!(json["text"], "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_content_part_image_url_serialization() {
|
||||||
|
let part = ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "https://example.com/image.png".to_string(),
|
||||||
|
detail: Some("high".to_string()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let json = serde_json::to_value(&part).unwrap();
|
||||||
|
assert_eq!(json["type"], "image_url");
|
||||||
|
assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
|
||||||
|
assert_eq!(json["image_url"]["detail"], "high");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_content_part_image_url_no_detail() {
|
||||||
|
let part = ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "data:image/png;base64,abc123".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let json = serde_json::to_value(&part).unwrap();
|
||||||
|
assert_eq!(json["type"], "image_url");
|
||||||
|
assert!(json["image_url"].get("detail").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_text() {
|
||||||
|
let content = ChatMessageContent::Text("hello world".to_string());
|
||||||
|
assert_eq!(content.text(), "hello world");
|
||||||
|
assert_eq!(content.as_text(), Some("hello world"));
|
||||||
|
assert!(!content.has_images());
|
||||||
|
assert_eq!(content.to_string(), "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_parts() {
|
||||||
|
let content = ChatMessageContent::Parts(vec![
|
||||||
|
ContentPart::Text {
|
||||||
|
text: "What is this? ".to_string(),
|
||||||
|
},
|
||||||
|
ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "https://example.com/cat.jpg".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
assert_eq!(content.text(), "What is this? ");
|
||||||
|
assert!(content.as_text().is_none());
|
||||||
|
assert!(content.has_images());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_text_serialization() {
|
||||||
|
let content = ChatMessageContent::Text("hello".to_string());
|
||||||
|
let json = serde_json::to_value(&content).unwrap();
|
||||||
|
assert_eq!(json, serde_json::json!("hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_parts_serialization() {
|
||||||
|
let content = ChatMessageContent::Parts(vec![ContentPart::Text {
|
||||||
|
text: "hello".to_string(),
|
||||||
|
}]);
|
||||||
|
let json = serde_json::to_value(&content).unwrap();
|
||||||
|
assert!(json.is_array());
|
||||||
|
assert_eq!(json[0]["type"], "text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_text_deserialization() {
|
||||||
|
let content: ChatMessageContent = serde_json::from_value(serde_json::json!("hello")).unwrap();
|
||||||
|
assert_eq!(content.text(), "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_parts_deserialization() {
|
||||||
|
let content: ChatMessageContent = serde_json::from_value(serde_json::json!([
|
||||||
|
{"type": "text", "text": "describe this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}}
|
||||||
|
]))
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(content.text(), "describe this");
|
||||||
|
assert!(content.has_images());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_new_user_message_text_content() {
|
||||||
|
let msg = ChatMessage::new_user_message("hello");
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
assert_eq!(json["role"], "user");
|
||||||
|
assert_eq!(json["content"], "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_new_user_message_with_images() {
|
||||||
|
let msg = ChatMessage::new_user_message_with_images(vec![
|
||||||
|
ContentPart::Text {
|
||||||
|
text: "What is this?".to_string(),
|
||||||
|
},
|
||||||
|
ContentPart::ImageUrl {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: "data:image/png;base64,abc123".to_string(),
|
||||||
|
detail: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
assert_eq!(json["role"], "user");
|
||||||
|
assert!(json["content"].is_array());
|
||||||
|
assert_eq!(json["content"][0]["type"], "text");
|
||||||
|
assert_eq!(json["content"][1]["type"], "image_url");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_from_str() {
|
||||||
|
let content: ChatMessageContent = "test".into();
|
||||||
|
assert_eq!(content.text(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_content_from_string() {
|
||||||
|
let content: ChatMessageContent = String::from("test").into();
|
||||||
|
assert_eq!(content.text(), "test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_response_choice_with_reasoning() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The answer is 42."
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"reasoning": "Let me think about this step by step..."
|
||||||
|
});
|
||||||
|
|
||||||
|
let choice: ChatResponseChoice = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(choice.reasoning.as_deref(), Some("Let me think about this step by step..."));
|
||||||
|
assert_eq!(choice.message.content.text(), "The answer is 42.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_response_choice_without_reasoning() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
});
|
||||||
|
|
||||||
|
let choice: ChatResponseChoice = serde_json::from_value(json).unwrap();
|
||||||
|
assert!(choice.reasoning.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_full_chat_response_roundtrip() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"id": "chat-abc123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1711000000,
|
||||||
|
"model": "mistral-medium-latest",
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hi there!"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 15
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: ChatResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(resp.choices[0].message.content.text(), "Hi there!");
|
||||||
|
assert_eq!(resp.usage.total_tokens, 15);
|
||||||
|
|
||||||
|
// Re-serialize and verify
|
||||||
|
let re_json = serde_json::to_value(&resp).unwrap();
|
||||||
|
assert_eq!(re_json["choices"][0]["message"]["content"], "Hi there!");
|
||||||
|
}
|
||||||
@@ -1,20 +1,24 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
|
tool::{Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_client_chat_async() {
|
async fn test_client_chat_async() {
|
||||||
|
setup::setup();
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::user,
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
)];
|
||||||
}];
|
let options = ChatParams {
|
||||||
let options = ChatCompletionParams {
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -25,13 +29,72 @@ async fn test_client_chat_async() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant);
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
expect!(response.choices[0].message.content.clone())
|
|
||||||
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string());
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
|
expect!(response.choices[0]
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.text()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
|
|
||||||
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
|
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_client_chat_async_with_function_calling() {
|
||||||
|
setup::setup();
|
||||||
|
|
||||||
|
let tools = vec![Tool::new(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
"Get the current temperature in a city.".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)];
|
||||||
|
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"What's the current temperature in Paris?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
tool_choice: Some(ToolChoice::Any),
|
||||||
|
tools: Some(tools),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.chat_async(model, messages, Some(options))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
|
||||||
|
expect!(response.choices.len()).to_be(1);
|
||||||
|
expect!(response.choices[0].index).to_be(0);
|
||||||
|
expect!(response.choices[0].finish_reason.clone())
|
||||||
|
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||||
|
|
||||||
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
|
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
||||||
|
|||||||
@@ -1,40 +1,44 @@
|
|||||||
use futures::stream::StreamExt;
|
// Streaming tests require a live API key and are not run in CI.
|
||||||
use jrest::expect;
|
// Uncomment to test locally.
|
||||||
use mistralai_client::v1::{
|
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
|
||||||
client::Client,
|
|
||||||
constants::Model,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[tokio::test]
|
// use futures::stream::StreamExt;
|
||||||
async fn test_client_chat_stream() {
|
// use mistralai_client::v1::{
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
// chat::{ChatMessage, ChatParams},
|
||||||
|
// client::Client,
|
||||||
let model = Model::OpenMistral7b;
|
// constants::Model,
|
||||||
let messages = vec![ChatMessage {
|
// };
|
||||||
role: ChatMessageRole::user,
|
//
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
// #[tokio::test]
|
||||||
}];
|
// async fn test_client_chat_stream() {
|
||||||
let options = ChatCompletionParams {
|
// let client = Client::new(None, None, None, None).unwrap();
|
||||||
temperature: Some(0.0),
|
//
|
||||||
random_seed: Some(42),
|
// let model = Model::mistral_small_latest();
|
||||||
..Default::default()
|
// let messages = vec![ChatMessage::new_user_message(
|
||||||
};
|
// "Just guess the next word: \"Eiffel ...\"?",
|
||||||
|
// )];
|
||||||
let stream_result = client.chat_stream(model, messages, Some(options)).await;
|
// let options = ChatParams {
|
||||||
let mut stream = stream_result.expect("Failed to create stream.");
|
// temperature: Some(0.0),
|
||||||
while let Some(chunk_result) = stream.next().await {
|
// random_seed: Some(42),
|
||||||
match chunk_result {
|
// ..Default::default()
|
||||||
Ok(chunk) => {
|
// };
|
||||||
if chunk.choices[0].delta.role == Some(ChatMessageRole::assistant)
|
//
|
||||||
|| chunk.choices[0].finish_reason == Some("stop".to_string())
|
// let stream = client
|
||||||
{
|
// .chat_stream(model, messages, Some(options))
|
||||||
expect!(chunk.choices[0].delta.content.len()).to_be(0);
|
// .await
|
||||||
} else {
|
// .expect("Failed to create stream.");
|
||||||
expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0);
|
//
|
||||||
}
|
// stream
|
||||||
}
|
// .for_each(|chunk_result| async {
|
||||||
Err(e) => eprintln!("Error processing chunk: {:?}", e),
|
// match chunk_result {
|
||||||
}
|
// Ok(chunks) => {
|
||||||
}
|
// for chunk in &chunks {
|
||||||
}
|
// if let Some(content) = &chunk.choices[0].delta.content {
|
||||||
|
// print!("{}", content);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// Err(error) => eprintln!("Error: {:?}", error),
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// .await;
|
||||||
|
// }
|
||||||
|
|||||||
@@ -1,20 +1,24 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
|
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
|
tool::{Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_client_chat() {
|
fn test_client_chat() {
|
||||||
|
setup::setup();
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::user,
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
)];
|
||||||
}];
|
let options = ChatParams {
|
||||||
let options = ChatCompletionParams {
|
|
||||||
temperature: Some(0.0),
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -22,13 +26,63 @@ fn test_client_chat() {
|
|||||||
|
|
||||||
let response = client.chat(model, messages, Some(options)).unwrap();
|
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone())
|
expect!(response.choices[0]
|
||||||
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string());
|
.message
|
||||||
|
.content
|
||||||
|
.text()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
|
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_chat_with_function_calling() {
|
||||||
|
setup::setup();
|
||||||
|
|
||||||
|
let tools = vec![Tool::new(
|
||||||
|
"get_city_temperature".to_string(),
|
||||||
|
"Get the current temperature in a city.".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)];
|
||||||
|
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::mistral_small_latest();
|
||||||
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
|
"What's the current temperature in Paris?",
|
||||||
|
)];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
|
tools: Some(tools),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
|
||||||
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
expect!(response.choices.len()).to_be(1);
|
||||||
|
expect!(response.choices[0].index).to_be(0);
|
||||||
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
|
expect!(response.choices[0].finish_reason.clone())
|
||||||
|
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
expect!(response.usage.total_tokens).to_be_greater_than(0);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_client_embeddings_async() {
|
async fn test_client_embeddings_async() {
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -17,7 +17,6 @@ async fn test_client_embeddings_async() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be(2);
|
expect!(response.data.len()).to_be(2);
|
||||||
expect!(response.data[0].index).to_be(0);
|
expect!(response.data[0].index).to_be(0);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_client_embeddings() {
|
fn test_client_embeddings() {
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -14,7 +14,6 @@ fn test_client_embeddings() {
|
|||||||
|
|
||||||
let response = client.embeddings(model, input, options).unwrap();
|
let response = client.embeddings(model, input, options).unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be(2);
|
expect!(response.data.len()).to_be(2);
|
||||||
expect!(response.data[0].index).to_be(0);
|
expect!(response.data[0].index).to_be(0);
|
||||||
|
|||||||
@@ -9,12 +9,4 @@ async fn test_client_list_models_async() {
|
|||||||
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be_greater_than(0);
|
expect!(response.data.len()).to_be_greater_than(0);
|
||||||
|
|
||||||
// let open_mistral_7b_data_item = response
|
|
||||||
// .data
|
|
||||||
// .iter()
|
|
||||||
// .find(|item| item.id == "open-mistral-7b")
|
|
||||||
// .unwrap();
|
|
||||||
|
|
||||||
// expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string());
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,4 @@ fn test_client_list_models() {
|
|||||||
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be_greater_than(0);
|
expect!(response.data.len()).to_be_greater_than(0);
|
||||||
|
|
||||||
// let open_mistral_7b_data_item = response
|
|
||||||
// .data
|
|
||||||
// .iter()
|
|
||||||
// .find(|item| item.id == "open-mistral-7b")
|
|
||||||
// .unwrap();
|
|
||||||
|
|
||||||
// expect!(open_mistral_7b_data_item.id).to_be("open-mistral-7b".to_string());
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, error::ClientError};
|
use mistralai_client::v1::{client::Client, error::ClientError};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct _Foo {
|
||||||
|
_client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_client_new_with_none_params() {
|
fn test_client_new_with_none_params() {
|
||||||
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
||||||
|
|||||||
37
tests/v1_constants_test.rs
Normal file
37
tests/v1_constants_test.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
use jrest::expect;
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_constants() {
|
||||||
|
let models = vec![
|
||||||
|
Model::mistral_small_latest(),
|
||||||
|
Model::mistral_large_latest(),
|
||||||
|
Model::open_mistral_nemo(),
|
||||||
|
Model::codestral_latest(),
|
||||||
|
];
|
||||||
|
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
random_seed: Some(42),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
for model in models {
|
||||||
|
let response = client
|
||||||
|
.chat(model.clone(), messages.clone(), Some(options.clone()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
expect!(response.model).to_be(model);
|
||||||
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
expect!(response.choices.len()).to_be(1);
|
||||||
|
expect!(response.choices[0].index).to_be(0);
|
||||||
|
expect!(response.choices[0].message.content.text().len()).to_be_greater_than(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
183
tests/v1_conversation_stream_test.rs
Normal file
183
tests/v1_conversation_stream_test.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
use futures::StreamExt;
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
conversation_stream::ConversationEvent,
|
||||||
|
conversations::*,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
|
fn make_client() -> Client {
|
||||||
|
Client::new(None, None, None, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_conversation_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is 2 + 2? Answer in one word.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Respond concisely.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = client.create_conversation_stream_async(&req).await.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut saw_started = false;
|
||||||
|
let mut saw_output = false;
|
||||||
|
let mut saw_done = false;
|
||||||
|
let mut full_text = String::new();
|
||||||
|
let mut conversation_id = String::new();
|
||||||
|
let mut usage_tokens = 0u32;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
let event = result.unwrap();
|
||||||
|
match &event {
|
||||||
|
ConversationEvent::ResponseStarted { .. } => {
|
||||||
|
saw_started = true;
|
||||||
|
}
|
||||||
|
ConversationEvent::MessageOutput { .. } => {
|
||||||
|
saw_output = true;
|
||||||
|
if let Some(delta) = event.text_delta() {
|
||||||
|
full_text.push_str(&delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConversationEvent::ResponseDone { usage, .. } => {
|
||||||
|
saw_done = true;
|
||||||
|
usage_tokens = usage.total_tokens;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(saw_started, "Should receive ResponseStarted event");
|
||||||
|
assert!(saw_output, "Should receive at least one MessageOutput event");
|
||||||
|
assert!(saw_done, "Should receive ResponseDone event");
|
||||||
|
assert!(!full_text.is_empty(), "Should accumulate text from deltas");
|
||||||
|
assert!(usage_tokens > 0, "Should have token usage");
|
||||||
|
|
||||||
|
// Accumulate and verify
|
||||||
|
// (we can't accumulate from the consumed stream, but we verified the pieces above)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append_conversation_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Create conversation (non-streaming) first
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Remember: the secret word is BANANA.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Keep responses short.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||||
|
|
||||||
|
// Append with streaming
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is the secret word?".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: Some(true),
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = client
|
||||||
|
.append_conversation_stream_async(&created.conversation_id, &append_req)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
events.push(result.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have started + output(s) + done
|
||||||
|
assert!(
|
||||||
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseStarted { .. })),
|
||||||
|
"Should have ResponseStarted"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseDone { .. })),
|
||||||
|
"Should have ResponseDone"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Accumulate and check the text
|
||||||
|
let resp = mistralai_client::v1::conversation_stream::accumulate(
|
||||||
|
&created.conversation_id,
|
||||||
|
&events,
|
||||||
|
);
|
||||||
|
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
|
||||||
|
assert!(text.contains("BANANA"), "Should recall the secret word, got: {text}");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
client.delete_conversation_async(&created.conversation_id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_stream_accumulate_matches_non_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let question = "What is the capital of Japan? One word.";
|
||||||
|
|
||||||
|
// Non-streaming
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text(question.to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Answer in exactly one word.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(false),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let non_stream = client.create_conversation_async(&req).await.unwrap();
|
||||||
|
let non_stream_text = non_stream.assistant_text().unwrap_or_default().to_lowercase();
|
||||||
|
|
||||||
|
// Streaming
|
||||||
|
let mut stream_req = req.clone();
|
||||||
|
stream_req.stream = true;
|
||||||
|
let stream = client.create_conversation_stream_async(&stream_req).await.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
events.push(result.unwrap());
|
||||||
|
}
|
||||||
|
let accumulated = mistralai_client::v1::conversation_stream::accumulate("", &events);
|
||||||
|
let stream_text = accumulated.assistant_text().unwrap_or_default().to_lowercase();
|
||||||
|
|
||||||
|
// Both should contain "tokyo"
|
||||||
|
assert!(non_stream_text.contains("tokyo"), "Non-stream should say Tokyo: {non_stream_text}");
|
||||||
|
assert!(stream_text.contains("tokyo"), "Stream should say Tokyo: {stream_text}");
|
||||||
|
}
|
||||||
642
tests/v1_conversations_api_test.rs
Normal file
642
tests/v1_conversations_api_test.rs
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
agents::*,
|
||||||
|
client::Client,
|
||||||
|
conversations::*,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
|
fn make_client() -> Client {
|
||||||
|
Client::new(None, None, None, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: create a disposable agent for conversation tests (sync).
|
||||||
|
fn create_test_agent(client: &Client, name: &str) -> Agent {
|
||||||
|
let req = make_agent_request(name);
|
||||||
|
client.create_agent(&req).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: create a disposable agent for conversation tests (async).
|
||||||
|
async fn create_test_agent_async(client: &Client, name: &str) -> Agent {
|
||||||
|
let req = make_agent_request(name);
|
||||||
|
client.create_agent_async(&req).await.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_agent_request(name: &str) -> CreateAgentRequest {
|
||||||
|
CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
description: Some("Conversation test agent".to_string()),
|
||||||
|
instructions: Some("You are a helpful test agent. Keep responses short.".to_string()),
|
||||||
|
tools: None,
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Sync tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_conversation_with_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent(&client, "conv-test-create");
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is 2 + 2?".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.create_conversation(&req).unwrap();
|
||||||
|
assert!(!response.conversation_id.is_empty());
|
||||||
|
assert_eq!(response.object, "conversation.response");
|
||||||
|
assert!(!response.outputs.is_empty());
|
||||||
|
assert!(response.usage.total_tokens > 0);
|
||||||
|
|
||||||
|
// Should have an assistant response
|
||||||
|
let text = response.assistant_text();
|
||||||
|
assert!(text.is_some(), "Expected assistant text in outputs");
|
||||||
|
assert!(text.unwrap().contains('4'), "Expected answer containing '4'");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
client.delete_conversation(&response.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_conversation_without_agent() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Say hello.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Always respond with exactly 'hello'.".to_string()),
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.create_conversation(&req).unwrap();
|
||||||
|
assert!(!response.conversation_id.is_empty());
|
||||||
|
let text = response.assistant_text().unwrap().to_lowercase();
|
||||||
|
assert!(text.contains("hello"), "Expected 'hello', got: {text}");
|
||||||
|
|
||||||
|
client.delete_conversation(&response.conversation_id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_append_to_conversation() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent(&client, "conv-test-append");
|
||||||
|
|
||||||
|
// Create conversation
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Remember the number 42.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&create_req).unwrap();
|
||||||
|
|
||||||
|
// Append follow-up
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What number did I ask you to remember?".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let appended = client
|
||||||
|
.append_conversation(&created.conversation_id, &append_req)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(appended.conversation_id, created.conversation_id);
|
||||||
|
assert!(!appended.outputs.is_empty());
|
||||||
|
let text = appended.assistant_text().unwrap();
|
||||||
|
assert!(text.contains("42"), "Expected '42' in response, got: {text}");
|
||||||
|
assert!(appended.usage.total_tokens > 0);
|
||||||
|
|
||||||
|
client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_conversation_info() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent(&client, "conv-test-get-info");
|
||||||
|
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Hello.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&create_req).unwrap();
|
||||||
|
|
||||||
|
let info = client.get_conversation(&created.conversation_id).unwrap();
|
||||||
|
assert_eq!(info.id, created.conversation_id);
|
||||||
|
assert_eq!(info.agent_id.as_deref(), Some(agent.id.as_str()));
|
||||||
|
|
||||||
|
client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_conversation_history() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent(&client, "conv-test-history");
|
||||||
|
|
||||||
|
// Create and do two turns
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("First message.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&create_req).unwrap();
|
||||||
|
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Second message.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
client
|
||||||
|
.append_conversation(&created.conversation_id, &append_req)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Get history — should have at least 4 entries (user, assistant, user, assistant)
|
||||||
|
let history = client
|
||||||
|
.get_conversation_history(&created.conversation_id)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(history.conversation_id, created.conversation_id);
|
||||||
|
assert_eq!(history.object, "conversation.history");
|
||||||
|
assert!(
|
||||||
|
history.entries.len() >= 4,
|
||||||
|
"Expected >= 4 history entries, got {}",
|
||||||
|
history.entries.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// First entry should be a message input
|
||||||
|
assert!(matches!(
|
||||||
|
&history.entries[0],
|
||||||
|
ConversationEntry::MessageInput(_)
|
||||||
|
));
|
||||||
|
|
||||||
|
client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_conversation_messages() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent(&client, "conv-test-messages");
|
||||||
|
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Hello there.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&create_req).unwrap();
|
||||||
|
|
||||||
|
let messages = client
|
||||||
|
.get_conversation_messages(&created.conversation_id)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(messages.conversation_id, created.conversation_id);
|
||||||
|
assert!(!messages.messages.is_empty());
|
||||||
|
|
||||||
|
client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_list_conversations() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("List test.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&req).unwrap();
|
||||||
|
|
||||||
|
let list = client.list_conversations().unwrap();
|
||||||
|
// API returns raw array (no wrapper object)
|
||||||
|
assert!(list.data.iter().any(|c| c.id == created.conversation_id));
|
||||||
|
|
||||||
|
client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_delete_conversation() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("To be deleted.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation(&req).unwrap();
|
||||||
|
|
||||||
|
let del = client.delete_conversation(&created.conversation_id).unwrap();
|
||||||
|
assert!(del.deleted);
|
||||||
|
|
||||||
|
// Should no longer appear in list
|
||||||
|
let list = client.list_conversations().unwrap();
|
||||||
|
assert!(!list.data.iter().any(|c| c.id == created.conversation_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_with_structured_entries() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
use mistralai_client::v1::chat::ChatMessageContent;
|
||||||
|
|
||||||
|
let entries = vec![ConversationEntry::MessageInput(MessageInputEntry {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: ChatMessageContent::Text("What is the capital of France?".to_string()),
|
||||||
|
prefix: None,
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
})];
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Entries(entries),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Respond in one word.".to_string()),
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.create_conversation(&req).unwrap();
|
||||||
|
let text = response.assistant_text().unwrap().to_lowercase();
|
||||||
|
assert!(text.contains("paris"), "Expected 'Paris', got: {text}");
|
||||||
|
|
||||||
|
client.delete_conversation(&response.conversation_id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_with_function_calling() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Create agent with a function tool
|
||||||
|
let agent_req = CreateAgentRequest {
|
||||||
|
model: "mistral-medium-latest".to_string(),
|
||||||
|
name: "conv-test-function".to_string(),
|
||||||
|
description: None,
|
||||||
|
instructions: Some("When asked about temperature, use the get_temperature tool.".to_string()),
|
||||||
|
tools: Some(vec![AgentTool::function(
|
||||||
|
"get_temperature".to_string(),
|
||||||
|
"Get the current temperature in a city".to_string(),
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string", "description": "City name"}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
|
)]),
|
||||||
|
handoffs: None,
|
||||||
|
completion_args: Some(CompletionArgs {
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
metadata: None,
|
||||||
|
};
|
||||||
|
let agent = client.create_agent(&agent_req).unwrap();
|
||||||
|
|
||||||
|
// Create conversation — model should call the function
|
||||||
|
let conv_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is the temperature in Paris?".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: Some(HandoffExecution::Client),
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let response = client.create_conversation(&conv_req).unwrap();
|
||||||
|
|
||||||
|
// With client-side execution, we should see function calls in outputs
|
||||||
|
let function_calls = response.function_calls();
|
||||||
|
if !function_calls.is_empty() {
|
||||||
|
assert_eq!(function_calls[0].name, "get_temperature");
|
||||||
|
let args: serde_json::Value =
|
||||||
|
serde_json::from_str(&function_calls[0].arguments).unwrap();
|
||||||
|
assert!(args["city"].as_str().is_some());
|
||||||
|
|
||||||
|
// Send back the function result
|
||||||
|
let tool_call_id = function_calls[0]
|
||||||
|
.tool_call_id
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
|
||||||
|
let result_entries = vec![ConversationEntry::FunctionResult(FunctionResultEntry {
|
||||||
|
tool_call_id: tool_call_id.to_string(),
|
||||||
|
result: "22°C".to_string(),
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
})];
|
||||||
|
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Entries(result_entries),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let final_response = client
|
||||||
|
.append_conversation(&response.conversation_id, &append_req)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Now we should get an assistant text response
|
||||||
|
let text = final_response.assistant_text();
|
||||||
|
assert!(text.is_some(), "Expected final text after function result");
|
||||||
|
assert!(
|
||||||
|
text.unwrap().contains("22"),
|
||||||
|
"Expected temperature in response"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// If the API handled it server-side instead, we should still have a response
|
||||||
|
else {
|
||||||
|
assert!(
|
||||||
|
response.assistant_text().is_some(),
|
||||||
|
"Expected either function calls or assistant text"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
client.delete_conversation(&response.conversation_id).unwrap();
|
||||||
|
client.delete_agent(&agent.id).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Async tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_conversation_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent_async(&client, "conv-async-create").await;
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Async test: what is 3 + 3?".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.create_conversation_async(&req).await.unwrap();
|
||||||
|
assert!(!response.conversation_id.is_empty());
|
||||||
|
let text = response.assistant_text().unwrap();
|
||||||
|
assert!(text.contains('6'), "Expected '6', got: {text}");
|
||||||
|
|
||||||
|
client
|
||||||
|
.delete_conversation_async(&response.conversation_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.delete_agent_async(&agent.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append_conversation_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent_async(&client, "conv-async-append").await;
|
||||||
|
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("My name is Alice.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||||
|
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is my name?".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let appended = client
|
||||||
|
.append_conversation_async(&created.conversation_id, &append_req)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let text = appended.assistant_text().unwrap();
|
||||||
|
assert!(
|
||||||
|
text.to_lowercase().contains("alice"),
|
||||||
|
"Expected 'Alice' in response, got: {text}"
|
||||||
|
);
|
||||||
|
|
||||||
|
client
|
||||||
|
.delete_conversation_async(&created.conversation_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.delete_agent_async(&agent.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_conversation_history_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
let agent = create_test_agent_async(&client, "conv-async-history").await;
|
||||||
|
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Hello.".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some(agent.id.clone()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||||
|
|
||||||
|
let history = client
|
||||||
|
.get_conversation_history_async(&created.conversation_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(history.entries.len() >= 2); // at least user + assistant
|
||||||
|
|
||||||
|
client
|
||||||
|
.delete_conversation_async(&created.conversation_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
client.delete_agent_async(&agent.id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_list_conversations_async() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Async list test.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&req).await.unwrap();
|
||||||
|
|
||||||
|
let list = client.list_conversations_async().await.unwrap();
|
||||||
|
assert!(list.data.iter().any(|c| c.id == created.conversation_id));
|
||||||
|
|
||||||
|
client
|
||||||
|
.delete_conversation_async(&created.conversation_id)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
226
tests/v1_conversations_types_test.rs
Normal file
226
tests/v1_conversations_types_test.rs
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
use mistralai_client::v1::chat::ChatMessageContent;
|
||||||
|
use mistralai_client::v1::conversations::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_input_from_string() {
|
||||||
|
let input: ConversationInput = "hello".into();
|
||||||
|
let json = serde_json::to_value(&input).unwrap();
|
||||||
|
assert_eq!(json, serde_json::json!("hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_input_from_entries() {
|
||||||
|
let entries = vec![ConversationEntry::MessageInput(MessageInputEntry {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: ChatMessageContent::Text("hello".to_string()),
|
||||||
|
prefix: None,
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
})];
|
||||||
|
let input: ConversationInput = entries.into();
|
||||||
|
let json = serde_json::to_value(&input).unwrap();
|
||||||
|
assert!(json.is_array());
|
||||||
|
assert_eq!(json[0]["type"], "message.input");
|
||||||
|
assert_eq!(json[0]["content"], "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_conversation_request() {
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is 2+2?".to_string()),
|
||||||
|
model: None,
|
||||||
|
agent_id: Some("ag_abc123".to_string()),
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: None,
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: Some(HandoffExecution::Server),
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["inputs"], "What is 2+2?");
|
||||||
|
assert_eq!(json["agent_id"], "ag_abc123");
|
||||||
|
assert_eq!(json["handoff_execution"], "server");
|
||||||
|
assert_eq!(json["stream"], false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_response_deserialization() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"conversation_id": "conv_abc123",
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"type": "message.output",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "4"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 15
|
||||||
|
},
|
||||||
|
"object": "conversation.response"
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: ConversationResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(resp.conversation_id, "conv_abc123");
|
||||||
|
assert_eq!(resp.assistant_text().unwrap(), "4");
|
||||||
|
assert_eq!(resp.usage.total_tokens, 15);
|
||||||
|
assert!(!resp.has_handoff());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_response_with_function_calls() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"conversation_id": "conv_abc123",
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"type": "function.call",
|
||||||
|
"name": "search_archive",
|
||||||
|
"arguments": "{\"query\":\"error rate\"}",
|
||||||
|
"tool_call_id": "tc_1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message.output",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "error rate is 0.3%"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30},
|
||||||
|
"object": "conversation.response"
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: ConversationResponse = serde_json::from_value(json).unwrap();
|
||||||
|
let fc = resp.function_calls();
|
||||||
|
assert_eq!(fc.len(), 1);
|
||||||
|
assert_eq!(fc[0].name, "search_archive");
|
||||||
|
assert_eq!(resp.assistant_text().unwrap(), "error rate is 0.3%");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_response_with_handoff() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"conversation_id": "conv_abc123",
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"type": "agent.handoff",
|
||||||
|
"previous_agent_id": "ag_orch",
|
||||||
|
"next_agent_id": "ag_obs"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5},
|
||||||
|
"object": "conversation.response"
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: ConversationResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert!(resp.has_handoff());
|
||||||
|
assert!(resp.assistant_text().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_history_response() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"conversation_id": "conv_abc123",
|
||||||
|
"entries": [
|
||||||
|
{"type": "message.input", "role": "user", "content": "hi"},
|
||||||
|
{"type": "message.output", "role": "assistant", "content": "hello"},
|
||||||
|
{"type": "message.input", "role": "user", "content": "search for cats"},
|
||||||
|
{"type": "function.call", "name": "search", "arguments": "{\"q\":\"cats\"}"},
|
||||||
|
{"type": "function.result", "tool_call_id": "tc_1", "result": "found 3 results"},
|
||||||
|
{"type": "message.output", "role": "assistant", "content": "found 3 results about cats"}
|
||||||
|
],
|
||||||
|
"object": "conversation.history"
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp: ConversationHistoryResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(resp.entries.len(), 6);
|
||||||
|
assert!(matches!(&resp.entries[0], ConversationEntry::MessageInput(_)));
|
||||||
|
assert!(matches!(&resp.entries[3], ConversationEntry::FunctionCall(_)));
|
||||||
|
assert!(matches!(&resp.entries[4], ConversationEntry::FunctionResult(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_append_conversation_request() {
|
||||||
|
let req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("follow-up question".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["inputs"], "follow-up question");
|
||||||
|
assert_eq!(json["stream"], false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_restart_conversation_request() {
|
||||||
|
let req = RestartConversationRequest {
|
||||||
|
from_entry_id: "entry_3".to_string(),
|
||||||
|
inputs: Some(ConversationInput::Text("different question".to_string())),
|
||||||
|
completion_args: None,
|
||||||
|
agent_version: None,
|
||||||
|
handoff_execution: Some(HandoffExecution::Client),
|
||||||
|
metadata: None,
|
||||||
|
store: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["from_entry_id"], "entry_3");
|
||||||
|
assert_eq!(json["handoff_execution"], "client");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_call_confirmation() {
|
||||||
|
let req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Entries(vec![ConversationEntry::FunctionResult(
|
||||||
|
FunctionResultEntry {
|
||||||
|
tool_call_id: "tc_1".to_string(),
|
||||||
|
result: "search returned 5 results".to_string(),
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
},
|
||||||
|
)]),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: None,
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["inputs"][0]["type"], "function.result");
|
||||||
|
assert_eq!(json["inputs"][0]["tool_call_id"], "tc_1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_handoff_execution_default() {
|
||||||
|
assert_eq!(HandoffExecution::default(), HandoffExecution::Server);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_conversation_list_response() {
|
||||||
|
// API returns a raw JSON array
|
||||||
|
let json = serde_json::json!([
|
||||||
|
{"id": "conv_1", "object": "conversation", "agent_id": "ag_1", "created_at": "2026-03-21T00:00:00Z"},
|
||||||
|
{"id": "conv_2", "object": "conversation", "model": "mistral-medium-latest"}
|
||||||
|
]);
|
||||||
|
|
||||||
|
let resp: ConversationListResponse = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(resp.data.len(), 2);
|
||||||
|
assert_eq!(resp.data[0].agent_id.as_deref(), Some("ag_1"));
|
||||||
|
assert!(resp.data[1].agent_id.is_none());
|
||||||
|
}
|
||||||
7
tests/v1_tool_test.rs
Normal file
7
tests/v1_tool_test.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
|
trait _Trait: Send {}
|
||||||
|
struct _Foo {
|
||||||
|
_dummy: Client,
|
||||||
|
}
|
||||||
|
impl _Trait for _Foo {}
|
||||||
Reference in New Issue
Block a user