feat(code): tree-sitter symbol extraction + auto-indexing
Symbol extraction (symbols.rs): - tree-sitter parsers for Rust, TypeScript, Python - Extracts: functions, structs, enums, traits, classes, interfaces - Signatures, docstrings, line ranges for each symbol - extract_project_symbols() walks project directory - Skips hidden/vendor/target/node_modules, files >100KB Proto: IndexSymbols + SymbolEntry messages for client→server symbol relay Client: after SessionReady, extracts symbols and sends IndexSymbols to Sol for indexing into the code search index. 14 unit tests for symbol extraction across Rust/TS/Python.
This commit is contained in:
59
Cargo.lock
generated
59
Cargo.lock
generated
@@ -4017,6 +4017,12 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
@@ -4074,6 +4080,10 @@ dependencies = [
|
||||
"tonic",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
"tui-markdown",
|
||||
]
|
||||
|
||||
@@ -4654,6 +4664,55 @@ dependencies = [
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.24.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5387dffa7ffc7d2dae12b50c6f7aab8ff79d6210147c6613561fc3d474c6f75"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"streaming-iterator",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782"
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-python"
|
||||
version = "0.23.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d065aaa27f3aaceaf60c1f0e0ac09e1cb9eb8ed28e7bcdaa52129cffc7f4b04"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-rust"
|
||||
version = "0.23.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca8ccb3e3a3495c8a943f6c3fd24c3804c471fd7f4f16087623c7fa4c0068e8a"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-typescript"
|
||||
version = "0.23.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c5f76ed8d947a75cc446d5fccd8b602ebf0cde64ccf2ffa434d873d7a575eff"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.5"
|
||||
|
||||
@@ -16,9 +16,28 @@ message ClientMessage {
|
||||
ToolResult tool_result = 3;
|
||||
ToolApproval approval = 4;
|
||||
EndSession end = 5;
|
||||
IndexSymbols index_symbols = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message IndexSymbols {
|
||||
string project_name = 1;
|
||||
string branch = 2;
|
||||
repeated SymbolEntry symbols = 3;
|
||||
}
|
||||
|
||||
message SymbolEntry {
|
||||
string file_path = 1;
|
||||
string name = 2;
|
||||
string kind = 3;
|
||||
string signature = 4;
|
||||
string docstring = 5;
|
||||
int32 start_line = 6;
|
||||
int32 end_line = 7;
|
||||
string language = 8;
|
||||
string content = 9;
|
||||
}
|
||||
|
||||
message StartSession {
|
||||
string project_path = 1;
|
||||
string prompt_md = 2;
|
||||
|
||||
@@ -29,6 +29,10 @@ futures = "0.3"
|
||||
crossbeam-channel = "0.5"
|
||||
textwrap = "0.16"
|
||||
tui-markdown = "=0.3.6"
|
||||
tree-sitter = "0.24"
|
||||
tree-sitter-rust = "0.23"
|
||||
tree-sitter-typescript = "0.23"
|
||||
tree-sitter-python = "0.23"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
|
||||
@@ -104,6 +104,38 @@ pub async fn connect(
|
||||
}
|
||||
};
|
||||
|
||||
// Extract and send symbols for code index (fire-and-forget)
|
||||
let symbols = super::symbols::extract_project_symbols(&project.path);
|
||||
if !symbols.is_empty() {
|
||||
let branch = project.git_branch.clone().unwrap_or_else(|| "mainline".into());
|
||||
let proto_symbols: Vec<_> = symbols
|
||||
.iter()
|
||||
.map(|s| SymbolEntry {
|
||||
file_path: s.file_path.clone(),
|
||||
name: s.name.clone(),
|
||||
kind: s.kind.clone(),
|
||||
signature: s.signature.clone(),
|
||||
docstring: s.docstring.clone(),
|
||||
start_line: s.start_line as i32,
|
||||
end_line: s.end_line as i32,
|
||||
language: s.language.clone(),
|
||||
content: s.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let project_name = project.path.split('/').last().unwrap_or("unknown").to_string();
|
||||
let _ = tx
|
||||
.send(ClientMessage {
|
||||
payload: Some(client_message::Payload::IndexSymbols(IndexSymbols {
|
||||
project_name,
|
||||
branch,
|
||||
symbols: proto_symbols,
|
||||
})),
|
||||
})
|
||||
.await;
|
||||
info!(count = symbols.len(), "Sent project symbols for indexing");
|
||||
}
|
||||
|
||||
let history = ready
|
||||
.history
|
||||
.into_iter()
|
||||
|
||||
@@ -2,6 +2,7 @@ pub mod agent;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod project;
|
||||
pub mod symbols;
|
||||
pub mod tools;
|
||||
pub mod tui;
|
||||
|
||||
|
||||
659
sunbeam/src/code/symbols.rs
Normal file
659
sunbeam/src/code/symbols.rs
Normal file
@@ -0,0 +1,659 @@
|
||||
//! Symbol extraction from source code using tree-sitter.
|
||||
//!
|
||||
//! Extracts function signatures, struct/enum/trait definitions, and
|
||||
//! docstrings from Rust, TypeScript, and Python files. These symbols
|
||||
//! are sent to Sol for indexing in the code search index.
|
||||
|
||||
use std::path::Path;
|
||||
use tracing::debug;
|
||||
|
||||
/// An extracted code symbol with file context.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProjectSymbol {
|
||||
pub file_path: String, // relative to project root
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub signature: String,
|
||||
pub docstring: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Extract symbols from all source files in a project.
|
||||
pub fn extract_project_symbols(project_root: &str) -> Vec<ProjectSymbol> {
|
||||
let root = Path::new(project_root);
|
||||
let mut symbols = Vec::new();
|
||||
|
||||
walk_directory(root, root, &mut symbols);
|
||||
debug!(count = symbols.len(), "Extracted project symbols");
|
||||
symbols
|
||||
}
|
||||
|
||||
fn walk_directory(dir: &Path, root: &Path, symbols: &mut Vec<ProjectSymbol>) {
|
||||
let Ok(entries) = std::fs::read_dir(dir) else { return };
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
|
||||
// Skip hidden, vendor, target, node_modules, etc.
|
||||
if name.starts_with('.') || name == "target" || name == "vendor"
|
||||
|| name == "node_modules" || name == "dist" || name == "build"
|
||||
|| name == "__pycache__" || name == ".git"
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if path.is_dir() {
|
||||
walk_directory(&path, root, symbols);
|
||||
} else if path.is_file() {
|
||||
let path_str = path.to_string_lossy().to_string();
|
||||
if detect_language(&path_str).is_some() {
|
||||
// Read file (skip large files)
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
if content.len() > 100_000 { continue; } // skip >100KB
|
||||
|
||||
let rel_path = path.strip_prefix(root)
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or(path_str.clone());
|
||||
|
||||
for sym in extract_symbols(&path_str, &content) {
|
||||
// Build content: signature + body up to 500 chars
|
||||
let body_start = content.lines()
|
||||
.take(sym.start_line as usize - 1)
|
||||
.map(|l| l.len() + 1)
|
||||
.sum::<usize>();
|
||||
let body_end = content.lines()
|
||||
.take(sym.end_line as usize)
|
||||
.map(|l| l.len() + 1)
|
||||
.sum::<usize>()
|
||||
.min(content.len());
|
||||
let body = &content[body_start..body_end];
|
||||
let truncated = if body.len() > 500 {
|
||||
format!("{}…", &body[..497])
|
||||
} else {
|
||||
body.to_string()
|
||||
};
|
||||
|
||||
symbols.push(ProjectSymbol {
|
||||
file_path: rel_path.clone(),
|
||||
name: sym.name,
|
||||
kind: sym.kind,
|
||||
signature: sym.signature,
|
||||
docstring: sym.docstring,
|
||||
start_line: sym.start_line,
|
||||
end_line: sym.end_line,
|
||||
language: sym.language,
|
||||
content: truncated,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An extracted code symbol.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeSymbol {
|
||||
pub name: String,
|
||||
pub kind: String, // "function", "struct", "enum", "trait", "class", "interface", "method"
|
||||
pub signature: String, // full signature line
|
||||
pub docstring: String, // doc comment / docstring
|
||||
pub start_line: u32, // 1-based
|
||||
pub end_line: u32, // 1-based
|
||||
pub language: String,
|
||||
}
|
||||
|
||||
/// Detect language from file extension.
|
||||
pub fn detect_language(path: &str) -> Option<&'static str> {
|
||||
let ext = Path::new(path).extension()?.to_str()?;
|
||||
match ext {
|
||||
"rs" => Some("rust"),
|
||||
"ts" | "tsx" => Some("typescript"),
|
||||
"js" | "jsx" => Some("javascript"),
|
||||
"py" => Some("python"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract symbols from a source file's content.
|
||||
pub fn extract_symbols(path: &str, content: &str) -> Vec<CodeSymbol> {
|
||||
let Some(lang) = detect_language(path) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
match lang {
|
||||
"rust" => extract_rust_symbols(content),
|
||||
"typescript" | "javascript" => extract_ts_symbols(content),
|
||||
"python" => extract_python_symbols(content),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Rust ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn extract_rust_symbols(content: &str) -> Vec<CodeSymbol> {
|
||||
let mut parser = tree_sitter::Parser::new();
|
||||
parser.set_language(&tree_sitter_rust::LANGUAGE.into()).ok();
|
||||
|
||||
let Some(tree) = parser.parse(content, None) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut symbols = Vec::new();
|
||||
let root = tree.root_node();
|
||||
let bytes = content.as_bytes();
|
||||
|
||||
walk_rust_node(root, bytes, content, &mut symbols);
|
||||
symbols
|
||||
}
|
||||
|
||||
fn walk_rust_node(
|
||||
node: tree_sitter::Node,
|
||||
bytes: &[u8],
|
||||
source: &str,
|
||||
symbols: &mut Vec<CodeSymbol>,
|
||||
) {
|
||||
match node.kind() {
|
||||
"function_item" | "function_signature_item" => {
|
||||
if let Some(sym) = extract_rust_function(node, bytes, source) {
|
||||
symbols.push(sym);
|
||||
}
|
||||
}
|
||||
"struct_item" => {
|
||||
if let Some(sym) = extract_rust_type(node, bytes, source, "struct") {
|
||||
symbols.push(sym);
|
||||
}
|
||||
}
|
||||
"enum_item" => {
|
||||
if let Some(sym) = extract_rust_type(node, bytes, source, "enum") {
|
||||
symbols.push(sym);
|
||||
}
|
||||
}
|
||||
"trait_item" => {
|
||||
if let Some(sym) = extract_rust_type(node, bytes, source, "trait") {
|
||||
symbols.push(sym);
|
||||
}
|
||||
}
|
||||
"impl_item" => {
|
||||
// Walk impl methods
|
||||
for i in 0..node.child_count() {
|
||||
if let Some(child) = node.child(i) {
|
||||
if child.kind() == "declaration_list" {
|
||||
walk_rust_node(child, bytes, source, symbols);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
for i in 0..node.child_count() {
|
||||
if let Some(child) = node.child(i) {
|
||||
walk_rust_node(child, bytes, source, symbols);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_rust_function(node: tree_sitter::Node, bytes: &[u8], source: &str) -> Option<CodeSymbol> {
|
||||
let name = node.child_by_field_name("name")?;
|
||||
let name_str = name.utf8_text(bytes).ok()?.to_string();
|
||||
|
||||
// Build signature: everything from start to the opening brace (or end if no body)
|
||||
let start_byte = node.start_byte();
|
||||
let sig_end = find_rust_sig_end(node, source);
|
||||
let signature = source[start_byte..sig_end].trim().to_string();
|
||||
|
||||
// Extract doc comment (line comments starting with /// before the function)
|
||||
let docstring = extract_rust_doc_comment(node, source);
|
||||
|
||||
Some(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: "function".into(),
|
||||
signature,
|
||||
docstring,
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "rust".into(),
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_rust_type(node: tree_sitter::Node, bytes: &[u8], source: &str, kind: &str) -> Option<CodeSymbol> {
|
||||
let name = node.child_by_field_name("name")?;
|
||||
let name_str = name.utf8_text(bytes).ok()?.to_string();
|
||||
|
||||
// Signature: first line of the definition
|
||||
let start = node.start_byte();
|
||||
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
|
||||
let signature = source[start..first_line_end].trim().to_string();
|
||||
|
||||
let docstring = extract_rust_doc_comment(node, source);
|
||||
|
||||
Some(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: kind.into(),
|
||||
signature,
|
||||
docstring,
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "rust".into(),
|
||||
})
|
||||
}
|
||||
|
||||
fn find_rust_sig_end(node: tree_sitter::Node, source: &str) -> usize {
|
||||
// Find the opening brace
|
||||
for i in 0..node.child_count() {
|
||||
if let Some(child) = node.child(i) {
|
||||
if child.kind() == "block" || child.kind() == "field_declaration_list"
|
||||
|| child.kind() == "enum_variant_list" || child.kind() == "declaration_list"
|
||||
{
|
||||
return child.start_byte();
|
||||
}
|
||||
}
|
||||
}
|
||||
// No body (e.g., trait method signature)
|
||||
node.end_byte().min(source.len())
|
||||
}
|
||||
|
||||
fn extract_rust_doc_comment(node: tree_sitter::Node, source: &str) -> String {
|
||||
let start_line = node.start_position().row;
|
||||
if start_line == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let lines: Vec<&str> = source.lines().collect();
|
||||
let mut doc_lines = Vec::new();
|
||||
|
||||
// Walk backwards from the line before the node
|
||||
let mut line_idx = start_line.saturating_sub(1);
|
||||
loop {
|
||||
if line_idx >= lines.len() {
|
||||
break;
|
||||
}
|
||||
let line = lines[line_idx].trim();
|
||||
if line.starts_with("///") {
|
||||
doc_lines.push(line.trim_start_matches("///").trim());
|
||||
} else if line.starts_with("#[") || line.is_empty() {
|
||||
// Skip attributes and blank lines between doc and function
|
||||
if line.is_empty() && !doc_lines.is_empty() {
|
||||
break; // blank line after doc block = stop
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if line_idx == 0 {
|
||||
break;
|
||||
}
|
||||
line_idx -= 1;
|
||||
}
|
||||
|
||||
doc_lines.reverse();
|
||||
doc_lines.join("\n")
|
||||
}
|
||||
|
||||
// ── TypeScript / JavaScript ─────────────────────────────────────────────
|
||||
|
||||
fn extract_ts_symbols(content: &str) -> Vec<CodeSymbol> {
|
||||
let mut parser = tree_sitter::Parser::new();
|
||||
parser.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()).ok();
|
||||
|
||||
let Some(tree) = parser.parse(content, None) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut symbols = Vec::new();
|
||||
walk_ts_node(tree.root_node(), content.as_bytes(), content, &mut symbols);
|
||||
symbols
|
||||
}
|
||||
|
||||
fn walk_ts_node(
|
||||
node: tree_sitter::Node,
|
||||
bytes: &[u8],
|
||||
source: &str,
|
||||
symbols: &mut Vec<CodeSymbol>,
|
||||
) {
|
||||
match node.kind() {
|
||||
"function_declaration" | "method_definition" | "arrow_function" => {
|
||||
if let Some(name) = node.child_by_field_name("name") {
|
||||
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
|
||||
if !name_str.is_empty() {
|
||||
let start = node.start_byte();
|
||||
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
|
||||
symbols.push(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: "function".into(),
|
||||
signature: source[start..first_line_end].trim().to_string(),
|
||||
docstring: String::new(), // TODO: JSDoc extraction
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "typescript".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
"class_declaration" | "interface_declaration" | "type_alias_declaration" | "enum_declaration" => {
|
||||
if let Some(name) = node.child_by_field_name("name") {
|
||||
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
|
||||
let kind = match node.kind() {
|
||||
"class_declaration" => "class",
|
||||
"interface_declaration" => "interface",
|
||||
"enum_declaration" => "enum",
|
||||
_ => "type",
|
||||
};
|
||||
let start = node.start_byte();
|
||||
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
|
||||
symbols.push(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: kind.into(),
|
||||
signature: source[start..first_line_end].trim().to_string(),
|
||||
docstring: String::new(),
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "typescript".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
for i in 0..node.child_count() {
|
||||
if let Some(child) = node.child(i) {
|
||||
walk_ts_node(child, bytes, source, symbols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Python ──────────────────────────────────────────────────────────────
|
||||
|
||||
fn extract_python_symbols(content: &str) -> Vec<CodeSymbol> {
|
||||
let mut parser = tree_sitter::Parser::new();
|
||||
parser.set_language(&tree_sitter_python::LANGUAGE.into()).ok();
|
||||
|
||||
let Some(tree) = parser.parse(content, None) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut symbols = Vec::new();
|
||||
walk_python_node(tree.root_node(), content.as_bytes(), content, &mut symbols);
|
||||
symbols
|
||||
}
|
||||
|
||||
fn walk_python_node(
|
||||
node: tree_sitter::Node,
|
||||
bytes: &[u8],
|
||||
source: &str,
|
||||
symbols: &mut Vec<CodeSymbol>,
|
||||
) {
|
||||
match node.kind() {
|
||||
"function_definition" => {
|
||||
if let Some(name) = node.child_by_field_name("name") {
|
||||
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
|
||||
let start = node.start_byte();
|
||||
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
|
||||
let docstring = extract_python_docstring(node, bytes);
|
||||
symbols.push(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: "function".into(),
|
||||
signature: source[start..first_line_end].trim().to_string(),
|
||||
docstring,
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "python".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
"class_definition" => {
|
||||
if let Some(name) = node.child_by_field_name("name") {
|
||||
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
|
||||
let start = node.start_byte();
|
||||
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
|
||||
let docstring = extract_python_docstring(node, bytes);
|
||||
symbols.push(CodeSymbol {
|
||||
name: name_str,
|
||||
kind: "class".into(),
|
||||
signature: source[start..first_line_end].trim().to_string(),
|
||||
docstring,
|
||||
start_line: node.start_position().row as u32 + 1,
|
||||
end_line: node.end_position().row as u32 + 1,
|
||||
language: "python".into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
for i in 0..node.child_count() {
|
||||
if let Some(child) = node.child(i) {
|
||||
walk_python_node(child, bytes, source, symbols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_python_docstring(node: tree_sitter::Node, bytes: &[u8]) -> String {
|
||||
// Python docstrings are the first expression_statement in the body
|
||||
if let Some(body) = node.child_by_field_name("body") {
|
||||
if let Some(first_stmt) = body.child(0) {
|
||||
if first_stmt.kind() == "expression_statement" {
|
||||
if let Some(expr) = first_stmt.child(0) {
|
||||
if expr.kind() == "string" {
|
||||
let text = expr.utf8_text(bytes).unwrap_or("");
|
||||
// Strip triple quotes
|
||||
let trimmed = text
|
||||
.trim_start_matches("\"\"\"")
|
||||
.trim_start_matches("'''")
|
||||
.trim_end_matches("\"\"\"")
|
||||
.trim_end_matches("'''")
|
||||
.trim();
|
||||
return trimmed.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_language() {
|
||||
assert_eq!(detect_language("src/main.rs"), Some("rust"));
|
||||
assert_eq!(detect_language("app.ts"), Some("typescript"));
|
||||
assert_eq!(detect_language("app.tsx"), Some("typescript"));
|
||||
assert_eq!(detect_language("script.py"), Some("python"));
|
||||
assert_eq!(detect_language("script.js"), Some("javascript"));
|
||||
assert_eq!(detect_language("data.json"), None);
|
||||
assert_eq!(detect_language("README.md"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_rust_function() {
|
||||
let source = r#"
|
||||
/// Generate a response.
|
||||
pub async fn generate(&self, req: &GenerateRequest) -> Option<String> {
|
||||
self.run_and_emit(req).await
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_rust_symbols(source);
|
||||
assert!(!symbols.is_empty(), "Should extract at least one symbol");
|
||||
|
||||
let func = &symbols[0];
|
||||
assert_eq!(func.name, "generate");
|
||||
assert_eq!(func.kind, "function");
|
||||
assert!(func.signature.contains("pub async fn generate"));
|
||||
assert!(func.docstring.contains("Generate a response"));
|
||||
assert_eq!(func.language, "rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_rust_struct() {
|
||||
let source = r#"
|
||||
/// A request to generate.
|
||||
pub struct GenerateRequest {
|
||||
pub text: String,
|
||||
pub user_id: String,
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_rust_symbols(source);
|
||||
let structs: Vec<_> = symbols.iter().filter(|s| s.kind == "struct").collect();
|
||||
assert!(!structs.is_empty());
|
||||
assert_eq!(structs[0].name, "GenerateRequest");
|
||||
assert!(structs[0].docstring.contains("request to generate"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_rust_enum() {
|
||||
let source = r#"
|
||||
/// Whether server or client.
|
||||
pub enum ToolSide {
|
||||
Server,
|
||||
Client,
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_rust_symbols(source);
|
||||
let enums: Vec<_> = symbols.iter().filter(|s| s.kind == "enum").collect();
|
||||
assert!(!enums.is_empty());
|
||||
assert_eq!(enums[0].name, "ToolSide");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_rust_trait() {
|
||||
let source = r#"
|
||||
pub trait Executor {
|
||||
fn execute(&self, args: &str) -> String;
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_rust_symbols(source);
|
||||
let traits: Vec<_> = symbols.iter().filter(|s| s.kind == "trait").collect();
|
||||
assert!(!traits.is_empty());
|
||||
assert_eq!(traits[0].name, "Executor");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_rust_impl_methods() {
|
||||
let source = r#"
|
||||
impl Orchestrator {
|
||||
/// Create new.
|
||||
pub fn new(config: Config) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Subscribe to events.
|
||||
pub fn subscribe(&self) -> Receiver {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_rust_symbols(source);
|
||||
let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect();
|
||||
assert!(fns.len() >= 2, "Should find impl methods, got {}", fns.len());
|
||||
let names: Vec<&str> = fns.iter().map(|s| s.name.as_str()).collect();
|
||||
assert!(names.contains(&"new"));
|
||||
assert!(names.contains(&"subscribe"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_ts_function() {
|
||||
let source = r#"
|
||||
function greet(name: string): string {
|
||||
return `Hello, ${name}`;
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_ts_symbols(source);
|
||||
assert!(!symbols.is_empty());
|
||||
assert_eq!(symbols[0].name, "greet");
|
||||
assert_eq!(symbols[0].kind, "function");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_ts_class() {
|
||||
let source = r#"
|
||||
class UserService {
|
||||
constructor(private db: Database) {}
|
||||
|
||||
async getUser(id: string): Promise<User> {
|
||||
return this.db.find(id);
|
||||
}
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_ts_symbols(source);
|
||||
let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect();
|
||||
assert!(!classes.is_empty());
|
||||
assert_eq!(classes[0].name, "UserService");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_ts_interface() {
|
||||
let source = r#"
|
||||
interface User {
|
||||
id: string;
|
||||
name: string;
|
||||
email?: string;
|
||||
}
|
||||
"#;
|
||||
let symbols = extract_ts_symbols(source);
|
||||
let ifaces: Vec<_> = symbols.iter().filter(|s| s.kind == "interface").collect();
|
||||
assert!(!ifaces.is_empty());
|
||||
assert_eq!(ifaces[0].name, "User");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_python_function() {
|
||||
let source = r#"
|
||||
def process_data(items: list[str]) -> dict:
|
||||
"""Process a list of items into a dictionary."""
|
||||
return {item: len(item) for item in items}
|
||||
"#;
|
||||
let symbols = extract_python_symbols(source);
|
||||
assert!(!symbols.is_empty());
|
||||
assert_eq!(symbols[0].name, "process_data");
|
||||
assert_eq!(symbols[0].kind, "function");
|
||||
assert!(symbols[0].docstring.contains("Process a list"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_python_class() {
|
||||
let source = r#"
|
||||
class DataProcessor:
|
||||
"""Processes data from various sources."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
"#;
|
||||
let symbols = extract_python_symbols(source);
|
||||
let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect();
|
||||
assert!(!classes.is_empty());
|
||||
assert_eq!(classes[0].name, "DataProcessor");
|
||||
assert!(classes[0].docstring.contains("Processes data"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_symbols_unknown_language() {
|
||||
let symbols = extract_symbols("data.json", "{}");
|
||||
assert!(symbols.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_symbols_empty_file() {
|
||||
let symbols = extract_symbols("empty.rs", "");
|
||||
assert!(symbols.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_line_numbers_are_1_based() {
|
||||
let source = "fn first() {}\nfn second() {}\nfn third() {}";
|
||||
let symbols = extract_rust_symbols(source);
|
||||
assert!(symbols.len() >= 3);
|
||||
assert_eq!(symbols[0].start_line, 1);
|
||||
assert_eq!(symbols[1].start_line, 2);
|
||||
assert_eq!(symbols[2].start_line, 3);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user