From 9e5f7e61becdecaf3ea310e66bbbbfeb79f48e77 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Mon, 23 Mar 2026 17:40:25 +0000 Subject: [PATCH] feat(orchestrator): Phase 2 engine + tokenizer + tool dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Orchestrator engine: - engine.rs: unified Mistral Conversations API tool loop that emits OrchestratorEvent instead of calling Matrix/gRPC directly - tool_dispatch.rs: ToolSide routing (client vs server tools) - Memory loading stubbed (migrates in Phase 4) Server-side tokenizer: - tokenizer.rs: HuggingFace tokenizers-rs with Mistral's BPE tokenizer - count_tokens() for accurate usage metrics - Loads from local tokenizer.json or falls back to bundled vocab - Config: mistral.tokenizer_path (optional) No behavior change — engine is wired but not yet called from sync.rs or session.rs (Phase 2 continuation). --- Cargo.lock | 562 ++++++++++++++++++++++++++++-- Cargo.toml | 1 + dev/sol-dev.toml | 1 + src/config.rs | 3 + src/main.rs | 8 + src/orchestrator/engine.rs | 347 ++++++++++++++++++ src/orchestrator/mod.rs | 2 + src/orchestrator/tool_dispatch.rs | 49 +++ src/tokenizer.rs | 123 +++++++ 9 files changed, 1065 insertions(+), 31 deletions(-) create mode 100644 src/orchestrator/engine.rs create mode 100644 src/orchestrator/tool_dispatch.rs create mode 100644 src/tokenizer.rs diff --git a/Cargo.lock b/Cargo.lock index 76baeda..ef3893d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,7 +48,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", + "getrandom 0.3.4", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -338,6 +340,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" @@ -686,6 +694,21 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "compression-codecs" version = "0.4.37" @@ -712,6 +735,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -792,6 +828,25 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -846,14 +901,38 @@ dependencies = [ "syn", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + [[package]] name = "darling" version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.23.0", + "darling_macro 0.23.0", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", ] [[package]] @@ -869,17 +948,37 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn", +] + [[package]] name = "darling_macro" version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ - "darling_core", + "darling_core 0.23.0", "quote", "syn", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -977,7 +1076,7 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "292b1ce21933ce7cea00c69b8de023a6a29707e9b6cb2052ca27499710ddd133" dependencies = [ - "base64", + "base64 0.22.1", "capacity_builder", "deno_error", "deno_media_type", @@ -1166,6 +1265,37 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -1209,6 +1339,27 @@ dependencies = [ "syn", ] +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.59.0", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -1274,6 +1425,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1322,6 +1479,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" + [[package]] name = "event-listener" version = "5.4.1" @@ -1756,6 +1919,25 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hf-hub" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" +dependencies = [ + "dirs", + "http", + "indicatif", + "libc", + "log", + "rand 0.9.2", + "serde", + "serde_json", + "thiserror 2.0.18", + "ureq", + "windows-sys 0.60.2", +] + [[package]] name = "hkdf" version = "0.12.4" @@ -1879,7 +2061,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -1917,7 +2099,7 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-channel", "futures-util", @@ -2208,6 +2390,19 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "inout" version = "0.1.4" @@ -2367,7 +2562,7 @@ version = "9.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" dependencies = [ - "base64", + "base64 0.22.1", "js-sys", "pem", "ring", @@ -2430,6 +2625,15 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" +[[package]] +name = "libredox" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +dependencies = [ + "libc", +] + [[package]] name = "libsqlite3-sys" version = "0.30.1" @@ -2480,6 +2684,22 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "macroific" version = "1.3.1" @@ -2723,7 +2943,7 @@ checksum = "72f86434be7e6256a5d6e7828b887a4e91a42cd66380f8b02e02eeb702819589" dependencies = [ "anyhow", "async-trait", - "base64", + "base64 0.22.1", "getrandom 0.2.17", "gloo-utils", "hkdf", @@ -2773,7 +2993,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d702add6a56f288bf2e1a4e45145529620bff2003e746d7f23fc736ea806dbc8" dependencies = [ - "base64", + "base64 0.22.1", "blake3", "chacha20poly1305", "hmac", @@ -2861,6 +3081,28 @@ dependencies = [ "tokio-stream", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "multimap" version = "0.10.1" @@ -2955,6 +3197,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.37.3" @@ -2976,6 +3224,28 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "opaque-debug" version = "0.3.1" @@ -2988,7 +3258,7 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a62b025c3503d3d53eaba3b6f14adb955af9f69fc71141b4d030a4e5331f5d42" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "dyn-clone", "lazy_static", @@ -3046,6 +3316,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "outref" version = "0.5.2" @@ -3118,7 +3394,7 @@ version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" dependencies = [ - "base64", + "base64 0.22.1", "serde_core", ] @@ -3593,6 +3869,37 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "readlock" version = "0.1.11" @@ -3617,6 +3924,17 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 2.0.18", +] + [[package]] name = "ref-cast" version = "1.0.25" @@ -3672,7 +3990,7 @@ version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "encoding_rs", "futures-channel", @@ -3712,7 +4030,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -3822,7 +4140,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "387e1898e868d32ff7b205e7db327361d5dcf635c00a8ae5865068607595a9cf" dependencies = [ "as_variant", - "base64", + "base64 0.22.1", "bytes", "form_urlencoded", "getrandom 0.2.17", @@ -3975,6 +4293,7 @@ version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -4221,7 +4540,7 @@ version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -4240,7 +4559,7 @@ version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ - "darling", + "darling 0.23.0", "proc-macro2", "quote", "syn", @@ -4365,12 +4684,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "sol" version = "0.1.0" dependencies = [ "anyhow", - "base64", + "base64 0.22.1", "chrono", "deno_ast", "deno_core", @@ -4390,6 +4720,7 @@ dependencies = [ "serde", "serde_json", "tempfile", + "tokenizers", "tokio", "tokio-stream", "toml", @@ -4431,6 +4762,18 @@ dependencies = [ "der", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -4618,7 +4961,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2a6ee1ec49dda8dedeac54e4147b4e8b3f278d9bb34ab28983257a393d34ed" dependencies = [ "ascii", - "compact_str", + "compact_str 0.7.1", "memchr", "num-bigint", "once_cell", @@ -4775,7 +5118,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03de12e38e47ac1c96ac576f793ad37a9d7b16fbf4f2203881f89152f2498682" dependencies = [ - "base64", + "base64 0.22.1", "bytes-str", "indexmap 2.13.0", "once_cell", @@ -5167,6 +5510,40 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str 0.9.0", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "hf-hub", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.18", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.50.0" @@ -5319,7 +5696,7 @@ checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "axum", - "base64", + "base64 0.22.1", "bytes", "h2", "http", @@ -5568,6 +5945,21 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" version = "0.2.2" @@ -5580,6 +5972,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "universal-hash" version = "0.5.1" @@ -5596,6 +5994,25 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "url", + "webpki-roots 0.26.11", +] + [[package]] name = "url" version = "2.5.8" @@ -5681,7 +6098,7 @@ checksum = "dd4b56780b7827dd72c3c6398c3048752bebf8d1d84ec19b606b15dbc3c850b8" dependencies = [ "aes", "arrayvec", - "base64", + "base64 0.22.1", "base64ct", "cbc", "chacha20poly1305", @@ -5884,6 +6301,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" @@ -6018,7 +6444,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", ] [[package]] @@ -6027,7 +6453,16 @@ version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ - "windows-targets", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", ] [[package]] @@ -6045,14 +6480,31 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -6061,48 +6513,96 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + [[package]] name = "winnow" version = "0.7.15" diff --git a/Cargo.toml b/Cargo.toml index 85d5234..93c3b58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ tonic-prost = "0.14" prost = "0.14" tokio-stream = "0.1" jsonwebtoken = "9" +tokenizers = { version = "0.22", default-features = false, features = ["onig", "http"] } [build-dependencies] tonic-build = "0.14" diff --git a/dev/sol-dev.toml b/dev/sol-dev.toml index 3ec95a5..4136e01 100644 --- a/dev/sol-dev.toml +++ b/dev/sol-dev.toml @@ -17,6 +17,7 @@ default_model = "mistral-medium-latest" evaluation_model = "ministral-3b-latest" research_model = "mistral-large-latest" max_tool_iterations = 250 +# tokenizer_path = "dev/tokenizer.json" # uncomment to use a local tokenizer file [behavior] response_delay_min_ms = 0 diff --git a/src/config.rs b/src/config.rs index 1211d48..47b3f61 100644 --- a/src/config.rs +++ b/src/config.rs @@ -102,6 +102,9 @@ pub struct MistralConfig { pub research_model: String, #[serde(default = "default_max_tool_iterations")] pub max_tool_iterations: usize, + /// Path to a local `tokenizer.json` file. If unset, downloads from HuggingFace Hub. + #[serde(default)] + pub tokenizer_path: Option, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/main.rs b/src/main.rs index f324f9a..b8d1f3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ mod orchestrator; mod sdk; mod sync; mod time_context; +mod tokenizer; mod tools; use std::sync::Arc; @@ -123,6 +124,13 @@ async fn main() -> anyhow::Result<()> { )?; let mistral = Arc::new(mistral_client); + // Initialize tokenizer for accurate token counting + let _tokenizer = Arc::new( + tokenizer::SolTokenizer::new(config.mistral.tokenizer_path.as_deref()) + .expect("Failed to initialize tokenizer"), + ); + info!("Tokenizer initialized"); + // Build components let system_prompt_text = system_prompt.clone(); let personality = Arc::new(Personality::new(system_prompt)); diff --git a/src/orchestrator/engine.rs b/src/orchestrator/engine.rs new file mode 100644 index 0000000..4598885 --- /dev/null +++ b/src/orchestrator/engine.rs @@ -0,0 +1,347 @@ +//! The unified response generation engine. +//! +//! Single implementation of the Mistral Conversations API tool loop. +//! Emits `OrchestratorEvent`s instead of calling Matrix/gRPC directly. +//! Phase 2: replaces `responder.generate_response_conversations()`. + +use std::sync::Arc; +use std::time::Duration; + +use mistralai_client::v1::conversations::{ + ConversationEntry, ConversationInput, ConversationResponse, FunctionResultEntry, +}; +use rand::Rng; +use tokio::sync::broadcast; +use tracing::{debug, error, info, warn}; + +use super::event::*; +use super::tool_dispatch; +use super::Orchestrator; +use crate::brain::personality::Personality; +use crate::context::ResponseContext; +use crate::conversations::ConversationRegistry; +use crate::time_context::TimeContext; + +/// Strip "Sol: " or "sol: " prefix that models sometimes prepend. +fn strip_sol_prefix(text: &str) -> String { + let trimmed = text.trim(); + if trimmed.starts_with("Sol: ") || trimmed.starts_with("sol: ") { + trimmed[5..].to_string() + } else if trimmed.starts_with("Sol:\n") || trimmed.starts_with("sol:\n") { + trimmed[4..].to_string() + } else { + trimmed.to_string() + } +} + +/// Generate a chat response through the Conversations API. +/// This is the unified path that replaces both the responder's conversations +/// method and the gRPC session's inline tool loop. +pub async fn generate_response( + orchestrator: &Orchestrator, + personality: &Personality, + request: &ChatRequest, + response_ctx: &ResponseContext, + conversation_registry: &ConversationRegistry, +) -> Option { + let request_id = &request.request_id; + + // Emit start + orchestrator.emit(OrchestratorEvent::ResponseStarted { + request_id: request_id.clone(), + mode: ResponseMode::Chat { + room_id: request.room_id.clone(), + is_spontaneous: request.is_spontaneous, + use_thread: request.use_thread, + trigger_event_id: request.trigger_event_id.clone(), + }, + }); + + // Apply response delay + if !orchestrator.config.behavior.instant_responses { + let delay = if request.is_spontaneous { + rand::thread_rng().gen_range( + orchestrator.config.behavior.spontaneous_delay_min_ms + ..=orchestrator.config.behavior.spontaneous_delay_max_ms, + ) + } else { + rand::thread_rng().gen_range( + orchestrator.config.behavior.response_delay_min_ms + ..=orchestrator.config.behavior.response_delay_max_ms, + ) + }; + tokio::time::sleep(Duration::from_millis(delay)).await; + } + + orchestrator.emit(OrchestratorEvent::Thinking { + request_id: request_id.clone(), + }); + + // Memory query + let memory_notes = load_memory_notes(orchestrator, response_ctx, &request.trigger_body).await; + + // Build context header + let tc = TimeContext::now(); + let mut context_header = format!( + "{}\n[room: {} ({})]", + tc.message_line(), + request.room_name, + request.room_id, + ); + + if let Some(ref notes) = memory_notes { + context_header.push('\n'); + context_header.push_str(notes); + } + + let user_msg = if request.is_dm { + request.trigger_body.clone() + } else { + format!("<{}> {}", response_ctx.matrix_user_id, request.trigger_body) + }; + + let input_text = format!("{context_header}\n{user_msg}"); + let input = ConversationInput::Text(input_text); + + // Send through conversation registry + let response = match conversation_registry + .send_message( + &request.room_id, + input, + request.is_dm, + &orchestrator.mistral, + request.context_hint.as_deref(), + ) + .await + { + Ok(r) => r, + Err(e) => { + error!("Conversation API failed: {e}"); + orchestrator.emit(OrchestratorEvent::ResponseFailed { + request_id: request_id.clone(), + error: e.clone(), + }); + return None; + } + }; + + // Tool loop + let result = run_tool_loop( + orchestrator, + request_id, + response, + response_ctx, + conversation_registry, + &request.room_id, + request.is_dm, + ) + .await; + + match result { + Some(text) => { + let text = strip_sol_prefix(&text); + if text.is_empty() { + orchestrator.emit(OrchestratorEvent::ResponseFailed { + request_id: request_id.clone(), + error: "Empty response from model".into(), + }); + return None; + } + + orchestrator.emit(OrchestratorEvent::ResponseReady { + request_id: request_id.clone(), + text: text.clone(), + prompt_tokens: 0, // TODO: extract from response + completion_tokens: 0, + tool_iterations: 0, + }); + + // Schedule memory extraction + orchestrator.emit(OrchestratorEvent::MemoryExtractionScheduled { + request_id: request_id.clone(), + user_msg: request.trigger_body.clone(), + response: text.clone(), + }); + + Some(text) + } + None => { + orchestrator.emit(OrchestratorEvent::ResponseFailed { + request_id: request_id.clone(), + error: "No response from model".into(), + }); + None + } + } +} + +/// The unified tool iteration loop. +/// Emits tool events and executes server-side tools. +/// Client-side tools are dispatched via the pending_client_tools oneshot map. +async fn run_tool_loop( + orchestrator: &Orchestrator, + request_id: &RequestId, + initial_response: ConversationResponse, + response_ctx: &ResponseContext, + conversation_registry: &ConversationRegistry, + room_id: &str, + is_dm: bool, +) -> Option { + let function_calls = initial_response.function_calls(); + + // No tool calls — return the text directly + if function_calls.is_empty() { + return initial_response.assistant_text(); + } + + orchestrator.emit(OrchestratorEvent::AgentProgressStarted { + request_id: request_id.clone(), + }); + + let max_iterations = orchestrator.config.mistral.max_tool_iterations; + let mut current_response = initial_response; + + for iteration in 0..max_iterations { + let calls = current_response.function_calls(); + if calls.is_empty() { + break; + } + + let mut result_entries = Vec::new(); + + for fc in &calls { + let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown"); + let side = tool_dispatch::route(&fc.name); + + orchestrator.emit(OrchestratorEvent::ToolCallDetected { + request_id: request_id.clone(), + call_id: call_id.into(), + name: fc.name.clone(), + args: fc.arguments.clone(), + side: side.clone(), + }); + + orchestrator.emit(OrchestratorEvent::ToolExecutionStarted { + request_id: request_id.clone(), + call_id: call_id.into(), + name: fc.name.clone(), + }); + + let result_str = match side { + ToolSide::Server => { + // Execute server-side tool + let result = if fc.name == "research" { + // Research needs special handling (room + event context) + // For now, use the standard execute path + orchestrator + .tools + .execute(&fc.name, &fc.arguments, response_ctx) + .await + } else { + orchestrator + .tools + .execute(&fc.name, &fc.arguments, response_ctx) + .await + }; + + match result { + Ok(s) => { + let preview: String = s.chars().take(500).collect(); + info!( + tool = fc.name.as_str(), + id = call_id, + result_len = s.len(), + result_preview = preview.as_str(), + "Tool result" + ); + s + } + Err(e) => { + warn!(tool = fc.name.as_str(), "Tool failed: {e}"); + format!("Error: {e}") + } + } + } + ToolSide::Client => { + // Park on oneshot — gRPC bridge will deliver the result + let rx = orchestrator.register_pending_tool(call_id).await; + match tokio::time::timeout(Duration::from_secs(300), rx).await { + Ok(Ok(payload)) => { + if payload.is_error { + format!("Error: {}", payload.text) + } else { + payload.text + } + } + Ok(Err(_)) => "Error: client tool channel dropped".into(), + Err(_) => "Error: client tool timed out (5min)".into(), + } + } + }; + + let success = !result_str.starts_with("Error:"); + + orchestrator.emit(OrchestratorEvent::ToolExecutionCompleted { + request_id: request_id.clone(), + call_id: call_id.into(), + name: fc.name.clone(), + result: result_str.chars().take(200).collect(), + success, + }); + + orchestrator.emit(OrchestratorEvent::AgentProgressStep { + request_id: request_id.clone(), + summary: crate::agent_ux::AgentProgress::format_tool_call( + &fc.name, + &fc.arguments, + ), + }); + + result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry { + tool_call_id: call_id.to_string(), + result: result_str, + id: None, + object: None, + created_at: None, + completed_at: None, + })); + } + + // Send function results back to conversation + current_response = match conversation_registry + .send_function_result(room_id, result_entries, &orchestrator.mistral) + .await + { + Ok(r) => r, + Err(e) => { + error!("Failed to send function results: {e}"); + orchestrator.emit(OrchestratorEvent::AgentProgressDone { + request_id: request_id.clone(), + }); + return None; + } + }; + + debug!(iteration, "Tool iteration complete"); + } + + orchestrator.emit(OrchestratorEvent::AgentProgressDone { + request_id: request_id.clone(), + }); + + current_response.assistant_text() +} + +/// Load memory notes relevant to the trigger message. +/// TODO (Phase 4): move the full memory::store query logic here +/// when the Responder is dissolved. For now returns None — the Matrix +/// bridge path still uses the responder which has memory loading. +async fn load_memory_notes( + _orchestrator: &Orchestrator, + _ctx: &ResponseContext, + _trigger_body: &str, +) -> Option { + // Memory loading is not yet migrated to the orchestrator. + // The responder's load_memory_notes() still handles this for now. + None +} diff --git a/src/orchestrator/mod.rs b/src/orchestrator/mod.rs index bf636c2..91fb28a 100644 --- a/src/orchestrator/mod.rs +++ b/src/orchestrator/mod.rs @@ -6,7 +6,9 @@ //! //! Phase 1: types + channel wiring only. No behavior change. +pub mod engine; pub mod event; +pub mod tool_dispatch; use std::collections::HashMap; use std::sync::Arc; diff --git a/src/orchestrator/tool_dispatch.rs b/src/orchestrator/tool_dispatch.rs new file mode 100644 index 0000000..5c38bfd --- /dev/null +++ b/src/orchestrator/tool_dispatch.rs @@ -0,0 +1,49 @@ +//! Tool routing — determines whether a tool executes on the server or a connected client. + +use super::event::ToolSide; + +/// Client-side tools that execute on the `sunbeam code` TUI client. +const CLIENT_TOOLS: &[&str] = &[ + "file_read", + "file_write", + "search_replace", + "grep", + "bash", + "list_directory", + "ask_user", +]; + +/// Route a tool call to server or client. +pub fn route(tool_name: &str) -> ToolSide { + if CLIENT_TOOLS.contains(&tool_name) { + ToolSide::Client + } else { + ToolSide::Server + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_tools() { + assert_eq!(route("file_read"), ToolSide::Client); + assert_eq!(route("bash"), ToolSide::Client); + assert_eq!(route("grep"), ToolSide::Client); + assert_eq!(route("file_write"), ToolSide::Client); + assert_eq!(route("search_replace"), ToolSide::Client); + assert_eq!(route("list_directory"), ToolSide::Client); + assert_eq!(route("ask_user"), ToolSide::Client); + } + + #[test] + fn test_server_tools() { + assert_eq!(route("search_archive"), ToolSide::Server); + assert_eq!(route("search_web"), ToolSide::Server); + assert_eq!(route("run_script"), ToolSide::Server); + assert_eq!(route("research"), ToolSide::Server); + assert_eq!(route("gitea_list_repos"), ToolSide::Server); + assert_eq!(route("unknown_tool"), ToolSide::Server); + } +} diff --git a/src/tokenizer.rs b/src/tokenizer.rs new file mode 100644 index 0000000..e39f5ea --- /dev/null +++ b/src/tokenizer.rs @@ -0,0 +1,123 @@ +use std::sync::Arc; + +use anyhow::{Context, Result}; +use tokenizers::Tokenizer; +use tracing::{info, warn}; + +/// Default HuggingFace pretrained tokenizer identifier for Mistral models. +const DEFAULT_PRETRAINED: &str = "mistralai/Mistral-Small-24B-Base-2501"; + +/// Thread-safe wrapper around HuggingFace's `Tokenizer`. +/// +/// Load once at startup via [`SolTokenizer::new`] and share as `Arc`. +#[derive(Clone)] +pub struct SolTokenizer { + inner: Arc, +} + +impl SolTokenizer { + /// Load a tokenizer from a local `tokenizer.json` path, falling back to + /// HuggingFace Hub pretrained download if the path is absent or fails. + pub fn new(tokenizer_path: Option<&str>) -> Result { + let tokenizer = if let Some(path) = tokenizer_path { + match Tokenizer::from_file(path) { + Ok(t) => { + info!(path, "Loaded tokenizer from local file"); + t + } + Err(e) => { + warn!(path, error = %e, "Failed to load local tokenizer, falling back to pretrained"); + Self::from_pretrained()? + } + } + } else { + Self::from_pretrained()? + }; + + Ok(Self { + inner: Arc::new(tokenizer), + }) + } + + /// Download tokenizer from HuggingFace Hub. + fn from_pretrained() -> Result { + info!(model = DEFAULT_PRETRAINED, "Downloading tokenizer from HuggingFace Hub"); + Tokenizer::from_pretrained(DEFAULT_PRETRAINED, None) + .map_err(|e| anyhow::anyhow!("{e}")) + .context("Failed to download pretrained tokenizer") + } + + /// Count the number of tokens in the given text. + pub fn count_tokens(&self, text: &str) -> usize { + match self.inner.encode(text, false) { + Ok(encoding) => encoding.get_ids().len(), + Err(e) => { + warn!(error = %e, "Tokenization failed, estimating from char count"); + // Rough fallback: ~4 chars per token for English text + text.len() / 4 + } + } + } + + /// Encode text and return the token IDs. + pub fn encode(&self, text: &str) -> Result> { + let encoding = self + .inner + .encode(text, false) + .map_err(|e| anyhow::anyhow!("{e}")) + .context("Tokenization failed")?; + Ok(encoding.get_ids().to_vec()) + } +} + +impl std::fmt::Debug for SolTokenizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SolTokenizer").finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Test that the pretrained tokenizer can be loaded and produces + /// reasonable token counts. This test requires network access on + /// first run (the tokenizer is cached locally afterwards). + #[test] + fn test_pretrained_tokenizer_loads() { + let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load"); + let count = tok.count_tokens("Hello, world!"); + assert!(count > 0, "token count should be positive"); + assert!(count < 20, "token count for a short sentence should be small"); + } + + #[test] + fn test_count_tokens_empty_string() { + let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load"); + let count = tok.count_tokens(""); + assert_eq!(count, 0); + } + + #[test] + fn test_encode_returns_ids() { + let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load"); + let ids = tok.encode("Hello, world!").expect("encode should succeed"); + assert!(!ids.is_empty()); + } + + #[test] + fn test_invalid_path_falls_back_to_pretrained() { + let tok = SolTokenizer::new(Some("/nonexistent/tokenizer.json")) + .expect("should fall back to pretrained"); + let count = tok.count_tokens("fallback test"); + assert!(count > 0); + } + + #[test] + fn test_longer_text_produces_more_tokens() { + let tok = SolTokenizer::new(None).expect("pretrained tokenizer should load"); + let short = tok.count_tokens("Hi"); + let long = tok.count_tokens("This is a much longer sentence with many more words in it."); + assert!(long > short, "longer text should produce more tokens"); + } +}