//! Symbol extraction from source code using tree-sitter. //! //! Extracts function signatures, struct/enum/trait definitions, and //! docstrings from Rust, TypeScript, and Python files. These symbols //! are sent to Sol for indexing in the code search index. use std::path::Path; use tracing::debug; /// An extracted code symbol with file context. #[derive(Debug, Clone)] pub struct ProjectSymbol { pub file_path: String, // relative to project root pub name: String, pub kind: String, pub signature: String, pub docstring: String, pub start_line: u32, pub end_line: u32, pub language: String, pub content: String, } /// Extract symbols from all source files in a project. pub fn extract_project_symbols(project_root: &str) -> Vec { let root = Path::new(project_root); let mut symbols = Vec::new(); walk_directory(root, root, &mut symbols); debug!(count = symbols.len(), "Extracted project symbols"); symbols } fn walk_directory(dir: &Path, root: &Path, symbols: &mut Vec) { let Ok(entries) = std::fs::read_dir(dir) else { return }; for entry in entries.flatten() { let path = entry.path(); let name = entry.file_name().to_string_lossy().to_string(); // Skip hidden, vendor, target, node_modules, etc. if name.starts_with('.') || name == "target" || name == "vendor" || name == "node_modules" || name == "dist" || name == "build" || name == "__pycache__" || name == ".git" { continue; } if path.is_dir() { walk_directory(&path, root, symbols); } else if path.is_file() { let path_str = path.to_string_lossy().to_string(); if detect_language(&path_str).is_some() { // Read file (skip large files) if let Ok(content) = std::fs::read_to_string(&path) { if content.len() > 100_000 { continue; } // skip >100KB let rel_path = path.strip_prefix(root) .map(|p| p.to_string_lossy().to_string()) .unwrap_or(path_str.clone()); for sym in extract_symbols(&path_str, &content) { // Build content: signature + body up to 500 chars let body_start = content.lines() .take(sym.start_line as usize - 1) .map(|l| l.len() + 1) .sum::(); let body_end = content.lines() .take(sym.end_line as usize) .map(|l| l.len() + 1) .sum::() .min(content.len()); let body = &content[body_start..body_end]; let truncated = if body.len() > 500 { format!("{}…", &body[..497]) } else { body.to_string() }; symbols.push(ProjectSymbol { file_path: rel_path.clone(), name: sym.name, kind: sym.kind, signature: sym.signature, docstring: sym.docstring, start_line: sym.start_line, end_line: sym.end_line, language: sym.language, content: truncated, }); } } } } } } /// An extracted code symbol. #[derive(Debug, Clone)] pub struct CodeSymbol { pub name: String, pub kind: String, // "function", "struct", "enum", "trait", "class", "interface", "method" pub signature: String, // full signature line pub docstring: String, // doc comment / docstring pub start_line: u32, // 1-based pub end_line: u32, // 1-based pub language: String, } /// Detect language from file extension. pub fn detect_language(path: &str) -> Option<&'static str> { let ext = Path::new(path).extension()?.to_str()?; match ext { "rs" => Some("rust"), "ts" | "tsx" => Some("typescript"), "js" | "jsx" => Some("javascript"), "py" => Some("python"), _ => None, } } /// Extract symbols from a source file's content. pub fn extract_symbols(path: &str, content: &str) -> Vec { let Some(lang) = detect_language(path) else { return Vec::new(); }; match lang { "rust" => extract_rust_symbols(content), "typescript" | "javascript" => extract_ts_symbols(content), "python" => extract_python_symbols(content), _ => Vec::new(), } } // ── Rust ──────────────────────────────────────────────────────────────── fn extract_rust_symbols(content: &str) -> Vec { let mut parser = tree_sitter::Parser::new(); parser.set_language(&tree_sitter_rust::LANGUAGE.into()).ok(); let Some(tree) = parser.parse(content, None) else { return Vec::new(); }; let mut symbols = Vec::new(); let root = tree.root_node(); let bytes = content.as_bytes(); walk_rust_node(root, bytes, content, &mut symbols); symbols } fn walk_rust_node( node: tree_sitter::Node, bytes: &[u8], source: &str, symbols: &mut Vec, ) { match node.kind() { "function_item" | "function_signature_item" => { if let Some(sym) = extract_rust_function(node, bytes, source) { symbols.push(sym); } } "struct_item" => { if let Some(sym) = extract_rust_type(node, bytes, source, "struct") { symbols.push(sym); } } "enum_item" => { if let Some(sym) = extract_rust_type(node, bytes, source, "enum") { symbols.push(sym); } } "trait_item" => { if let Some(sym) = extract_rust_type(node, bytes, source, "trait") { symbols.push(sym); } } "impl_item" => { // Walk impl methods for i in 0..node.child_count() { if let Some(child) = node.child(i) { if child.kind() == "declaration_list" { walk_rust_node(child, bytes, source, symbols); } } } } _ => { for i in 0..node.child_count() { if let Some(child) = node.child(i) { walk_rust_node(child, bytes, source, symbols); } } } } } fn extract_rust_function(node: tree_sitter::Node, bytes: &[u8], source: &str) -> Option { let name = node.child_by_field_name("name")?; let name_str = name.utf8_text(bytes).ok()?.to_string(); // Build signature: everything from start to the opening brace (or end if no body) let start_byte = node.start_byte(); let sig_end = find_rust_sig_end(node, source); let signature = source[start_byte..sig_end].trim().to_string(); // Extract doc comment (line comments starting with /// before the function) let docstring = extract_rust_doc_comment(node, source); Some(CodeSymbol { name: name_str, kind: "function".into(), signature, docstring, start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "rust".into(), }) } fn extract_rust_type(node: tree_sitter::Node, bytes: &[u8], source: &str, kind: &str) -> Option { let name = node.child_by_field_name("name")?; let name_str = name.utf8_text(bytes).ok()?.to_string(); // Signature: first line of the definition let start = node.start_byte(); let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte()); let signature = source[start..first_line_end].trim().to_string(); let docstring = extract_rust_doc_comment(node, source); Some(CodeSymbol { name: name_str, kind: kind.into(), signature, docstring, start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "rust".into(), }) } fn find_rust_sig_end(node: tree_sitter::Node, source: &str) -> usize { // Find the opening brace for i in 0..node.child_count() { if let Some(child) = node.child(i) { if child.kind() == "block" || child.kind() == "field_declaration_list" || child.kind() == "enum_variant_list" || child.kind() == "declaration_list" { return child.start_byte(); } } } // No body (e.g., trait method signature) node.end_byte().min(source.len()) } fn extract_rust_doc_comment(node: tree_sitter::Node, source: &str) -> String { let start_line = node.start_position().row; if start_line == 0 { return String::new(); } let lines: Vec<&str> = source.lines().collect(); let mut doc_lines = Vec::new(); // Walk backwards from the line before the node let mut line_idx = start_line.saturating_sub(1); loop { if line_idx >= lines.len() { break; } let line = lines[line_idx].trim(); if line.starts_with("///") { doc_lines.push(line.trim_start_matches("///").trim()); } else if line.starts_with("#[") || line.is_empty() { // Skip attributes and blank lines between doc and function if line.is_empty() && !doc_lines.is_empty() { break; // blank line after doc block = stop } } else { break; } if line_idx == 0 { break; } line_idx -= 1; } doc_lines.reverse(); doc_lines.join("\n") } // ── TypeScript / JavaScript ───────────────────────────────────────────── fn extract_ts_symbols(content: &str) -> Vec { let mut parser = tree_sitter::Parser::new(); parser.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()).ok(); let Some(tree) = parser.parse(content, None) else { return Vec::new(); }; let mut symbols = Vec::new(); walk_ts_node(tree.root_node(), content.as_bytes(), content, &mut symbols); symbols } fn walk_ts_node( node: tree_sitter::Node, bytes: &[u8], source: &str, symbols: &mut Vec, ) { match node.kind() { "function_declaration" | "method_definition" | "arrow_function" => { if let Some(name) = node.child_by_field_name("name") { let name_str = name.utf8_text(bytes).unwrap_or("").to_string(); if !name_str.is_empty() { let start = node.start_byte(); let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte()); symbols.push(CodeSymbol { name: name_str, kind: "function".into(), signature: source[start..first_line_end].trim().to_string(), docstring: String::new(), // TODO: JSDoc extraction start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "typescript".into(), }); } } } "class_declaration" | "interface_declaration" | "type_alias_declaration" | "enum_declaration" => { if let Some(name) = node.child_by_field_name("name") { let name_str = name.utf8_text(bytes).unwrap_or("").to_string(); let kind = match node.kind() { "class_declaration" => "class", "interface_declaration" => "interface", "enum_declaration" => "enum", _ => "type", }; let start = node.start_byte(); let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte()); symbols.push(CodeSymbol { name: name_str, kind: kind.into(), signature: source[start..first_line_end].trim().to_string(), docstring: String::new(), start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "typescript".into(), }); } } _ => {} } for i in 0..node.child_count() { if let Some(child) = node.child(i) { walk_ts_node(child, bytes, source, symbols); } } } // ── Python ────────────────────────────────────────────────────────────── fn extract_python_symbols(content: &str) -> Vec { let mut parser = tree_sitter::Parser::new(); parser.set_language(&tree_sitter_python::LANGUAGE.into()).ok(); let Some(tree) = parser.parse(content, None) else { return Vec::new(); }; let mut symbols = Vec::new(); walk_python_node(tree.root_node(), content.as_bytes(), content, &mut symbols); symbols } fn walk_python_node( node: tree_sitter::Node, bytes: &[u8], source: &str, symbols: &mut Vec, ) { match node.kind() { "function_definition" => { if let Some(name) = node.child_by_field_name("name") { let name_str = name.utf8_text(bytes).unwrap_or("").to_string(); let start = node.start_byte(); let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte()); let docstring = extract_python_docstring(node, bytes); symbols.push(CodeSymbol { name: name_str, kind: "function".into(), signature: source[start..first_line_end].trim().to_string(), docstring, start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "python".into(), }); } } "class_definition" => { if let Some(name) = node.child_by_field_name("name") { let name_str = name.utf8_text(bytes).unwrap_or("").to_string(); let start = node.start_byte(); let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte()); let docstring = extract_python_docstring(node, bytes); symbols.push(CodeSymbol { name: name_str, kind: "class".into(), signature: source[start..first_line_end].trim().to_string(), docstring, start_line: node.start_position().row as u32 + 1, end_line: node.end_position().row as u32 + 1, language: "python".into(), }); } } _ => {} } for i in 0..node.child_count() { if let Some(child) = node.child(i) { walk_python_node(child, bytes, source, symbols); } } } fn extract_python_docstring(node: tree_sitter::Node, bytes: &[u8]) -> String { // Python docstrings are the first expression_statement in the body if let Some(body) = node.child_by_field_name("body") { if let Some(first_stmt) = body.child(0) { if first_stmt.kind() == "expression_statement" { if let Some(expr) = first_stmt.child(0) { if expr.kind() == "string" { let text = expr.utf8_text(bytes).unwrap_or(""); // Strip triple quotes let trimmed = text .trim_start_matches("\"\"\"") .trim_start_matches("'''") .trim_end_matches("\"\"\"") .trim_end_matches("'''") .trim(); return trimmed.to_string(); } } } } } String::new() } #[cfg(test)] mod tests { use super::*; #[test] fn test_detect_language() { assert_eq!(detect_language("src/main.rs"), Some("rust")); assert_eq!(detect_language("app.ts"), Some("typescript")); assert_eq!(detect_language("app.tsx"), Some("typescript")); assert_eq!(detect_language("script.py"), Some("python")); assert_eq!(detect_language("script.js"), Some("javascript")); assert_eq!(detect_language("data.json"), None); assert_eq!(detect_language("README.md"), None); } #[test] fn test_extract_rust_function() { let source = r#" /// Generate a response. pub async fn generate(&self, req: &GenerateRequest) -> Option { self.run_and_emit(req).await } "#; let symbols = extract_rust_symbols(source); assert!(!symbols.is_empty(), "Should extract at least one symbol"); let func = &symbols[0]; assert_eq!(func.name, "generate"); assert_eq!(func.kind, "function"); assert!(func.signature.contains("pub async fn generate")); assert!(func.docstring.contains("Generate a response")); assert_eq!(func.language, "rust"); } #[test] fn test_extract_rust_struct() { let source = r#" /// A request to generate. pub struct GenerateRequest { pub text: String, pub user_id: String, } "#; let symbols = extract_rust_symbols(source); let structs: Vec<_> = symbols.iter().filter(|s| s.kind == "struct").collect(); assert!(!structs.is_empty()); assert_eq!(structs[0].name, "GenerateRequest"); assert!(structs[0].docstring.contains("request to generate")); } #[test] fn test_extract_rust_enum() { let source = r#" /// Whether server or client. pub enum ToolSide { Server, Client, } "#; let symbols = extract_rust_symbols(source); let enums: Vec<_> = symbols.iter().filter(|s| s.kind == "enum").collect(); assert!(!enums.is_empty()); assert_eq!(enums[0].name, "ToolSide"); } #[test] fn test_extract_rust_trait() { let source = r#" pub trait Executor { fn execute(&self, args: &str) -> String; } "#; let symbols = extract_rust_symbols(source); let traits: Vec<_> = symbols.iter().filter(|s| s.kind == "trait").collect(); assert!(!traits.is_empty()); assert_eq!(traits[0].name, "Executor"); } #[test] fn test_extract_rust_impl_methods() { let source = r#" impl Orchestrator { /// Create new. pub fn new(config: Config) -> Self { Self { config } } /// Subscribe to events. pub fn subscribe(&self) -> Receiver { self.tx.subscribe() } } "#; let symbols = extract_rust_symbols(source); let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); assert!(fns.len() >= 2, "Should find impl methods, got {}", fns.len()); let names: Vec<&str> = fns.iter().map(|s| s.name.as_str()).collect(); assert!(names.contains(&"new")); assert!(names.contains(&"subscribe")); } #[test] fn test_extract_ts_function() { let source = r#" function greet(name: string): string { return `Hello, ${name}`; } "#; let symbols = extract_ts_symbols(source); assert!(!symbols.is_empty()); assert_eq!(symbols[0].name, "greet"); assert_eq!(symbols[0].kind, "function"); } #[test] fn test_extract_ts_class() { let source = r#" class UserService { constructor(private db: Database) {} async getUser(id: string): Promise { return this.db.find(id); } } "#; let symbols = extract_ts_symbols(source); let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect(); assert!(!classes.is_empty()); assert_eq!(classes[0].name, "UserService"); } #[test] fn test_extract_ts_interface() { let source = r#" interface User { id: string; name: string; email?: string; } "#; let symbols = extract_ts_symbols(source); let ifaces: Vec<_> = symbols.iter().filter(|s| s.kind == "interface").collect(); assert!(!ifaces.is_empty()); assert_eq!(ifaces[0].name, "User"); } #[test] fn test_extract_python_function() { let source = r#" def process_data(items: list[str]) -> dict: """Process a list of items into a dictionary.""" return {item: len(item) for item in items} "#; let symbols = extract_python_symbols(source); assert!(!symbols.is_empty()); assert_eq!(symbols[0].name, "process_data"); assert_eq!(symbols[0].kind, "function"); assert!(symbols[0].docstring.contains("Process a list")); } #[test] fn test_extract_python_class() { let source = r#" class DataProcessor: """Processes data from various sources.""" def __init__(self, config): self.config = config def run(self): pass "#; let symbols = extract_python_symbols(source); let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect(); assert!(!classes.is_empty()); assert_eq!(classes[0].name, "DataProcessor"); assert!(classes[0].docstring.contains("Processes data")); } #[test] fn test_extract_symbols_unknown_language() { let symbols = extract_symbols("data.json", "{}"); assert!(symbols.is_empty()); } #[test] fn test_extract_symbols_empty_file() { let symbols = extract_symbols("empty.rs", ""); assert!(symbols.is_empty()); } #[test] fn test_line_numbers_are_1_based() { let source = "fn first() {}\nfn second() {}\nfn third() {}"; let symbols = extract_rust_symbols(source); assert!(symbols.len() >= 3); assert_eq!(symbols[0].start_line, 1); assert_eq!(symbols[1].start_line, 2); assert_eq!(symbols[2].start_line, 3); } }