Files
cli/sunbeam/src/code/symbols.rs
Sienna Meridian Satterwhite 73d7d6c15b feat(code): tree-sitter symbol extraction + auto-indexing
Symbol extraction (symbols.rs):
- tree-sitter parsers for Rust, TypeScript, Python
- Extracts: functions, structs, enums, traits, classes, interfaces
- Signatures, docstrings, line ranges for each symbol
- extract_project_symbols() walks project directory
- Skips hidden/vendor/target/node_modules, files >100KB

Proto: IndexSymbols + SymbolEntry messages for client→server symbol relay

Client: after SessionReady, extracts symbols and sends IndexSymbols
to Sol for indexing into the code search index.

14 unit tests for symbol extraction across Rust/TS/Python.
2026-03-24 00:42:03 +00:00

660 lines
22 KiB
Rust

//! Symbol extraction from source code using tree-sitter.
//!
//! Extracts function signatures, struct/enum/trait definitions, and
//! docstrings from Rust, TypeScript, and Python files. These symbols
//! are sent to Sol for indexing in the code search index.
use std::path::Path;
use tracing::debug;
/// An extracted code symbol with file context.
#[derive(Debug, Clone)]
pub struct ProjectSymbol {
pub file_path: String, // relative to project root
pub name: String,
pub kind: String,
pub signature: String,
pub docstring: String,
pub start_line: u32,
pub end_line: u32,
pub language: String,
pub content: String,
}
/// Extract symbols from all source files in a project.
pub fn extract_project_symbols(project_root: &str) -> Vec<ProjectSymbol> {
let root = Path::new(project_root);
let mut symbols = Vec::new();
walk_directory(root, root, &mut symbols);
debug!(count = symbols.len(), "Extracted project symbols");
symbols
}
fn walk_directory(dir: &Path, root: &Path, symbols: &mut Vec<ProjectSymbol>) {
let Ok(entries) = std::fs::read_dir(dir) else { return };
for entry in entries.flatten() {
let path = entry.path();
let name = entry.file_name().to_string_lossy().to_string();
// Skip hidden, vendor, target, node_modules, etc.
if name.starts_with('.') || name == "target" || name == "vendor"
|| name == "node_modules" || name == "dist" || name == "build"
|| name == "__pycache__" || name == ".git"
{
continue;
}
if path.is_dir() {
walk_directory(&path, root, symbols);
} else if path.is_file() {
let path_str = path.to_string_lossy().to_string();
if detect_language(&path_str).is_some() {
// Read file (skip large files)
if let Ok(content) = std::fs::read_to_string(&path) {
if content.len() > 100_000 { continue; } // skip >100KB
let rel_path = path.strip_prefix(root)
.map(|p| p.to_string_lossy().to_string())
.unwrap_or(path_str.clone());
for sym in extract_symbols(&path_str, &content) {
// Build content: signature + body up to 500 chars
let body_start = content.lines()
.take(sym.start_line as usize - 1)
.map(|l| l.len() + 1)
.sum::<usize>();
let body_end = content.lines()
.take(sym.end_line as usize)
.map(|l| l.len() + 1)
.sum::<usize>()
.min(content.len());
let body = &content[body_start..body_end];
let truncated = if body.len() > 500 {
format!("{}", &body[..497])
} else {
body.to_string()
};
symbols.push(ProjectSymbol {
file_path: rel_path.clone(),
name: sym.name,
kind: sym.kind,
signature: sym.signature,
docstring: sym.docstring,
start_line: sym.start_line,
end_line: sym.end_line,
language: sym.language,
content: truncated,
});
}
}
}
}
}
}
/// An extracted code symbol.
#[derive(Debug, Clone)]
pub struct CodeSymbol {
pub name: String,
pub kind: String, // "function", "struct", "enum", "trait", "class", "interface", "method"
pub signature: String, // full signature line
pub docstring: String, // doc comment / docstring
pub start_line: u32, // 1-based
pub end_line: u32, // 1-based
pub language: String,
}
/// Detect language from file extension.
pub fn detect_language(path: &str) -> Option<&'static str> {
let ext = Path::new(path).extension()?.to_str()?;
match ext {
"rs" => Some("rust"),
"ts" | "tsx" => Some("typescript"),
"js" | "jsx" => Some("javascript"),
"py" => Some("python"),
_ => None,
}
}
/// Extract symbols from a source file's content.
pub fn extract_symbols(path: &str, content: &str) -> Vec<CodeSymbol> {
let Some(lang) = detect_language(path) else {
return Vec::new();
};
match lang {
"rust" => extract_rust_symbols(content),
"typescript" | "javascript" => extract_ts_symbols(content),
"python" => extract_python_symbols(content),
_ => Vec::new(),
}
}
// ── Rust ────────────────────────────────────────────────────────────────
fn extract_rust_symbols(content: &str) -> Vec<CodeSymbol> {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_rust::LANGUAGE.into()).ok();
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
let mut symbols = Vec::new();
let root = tree.root_node();
let bytes = content.as_bytes();
walk_rust_node(root, bytes, content, &mut symbols);
symbols
}
fn walk_rust_node(
node: tree_sitter::Node,
bytes: &[u8],
source: &str,
symbols: &mut Vec<CodeSymbol>,
) {
match node.kind() {
"function_item" | "function_signature_item" => {
if let Some(sym) = extract_rust_function(node, bytes, source) {
symbols.push(sym);
}
}
"struct_item" => {
if let Some(sym) = extract_rust_type(node, bytes, source, "struct") {
symbols.push(sym);
}
}
"enum_item" => {
if let Some(sym) = extract_rust_type(node, bytes, source, "enum") {
symbols.push(sym);
}
}
"trait_item" => {
if let Some(sym) = extract_rust_type(node, bytes, source, "trait") {
symbols.push(sym);
}
}
"impl_item" => {
// Walk impl methods
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "declaration_list" {
walk_rust_node(child, bytes, source, symbols);
}
}
}
}
_ => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
walk_rust_node(child, bytes, source, symbols);
}
}
}
}
}
fn extract_rust_function(node: tree_sitter::Node, bytes: &[u8], source: &str) -> Option<CodeSymbol> {
let name = node.child_by_field_name("name")?;
let name_str = name.utf8_text(bytes).ok()?.to_string();
// Build signature: everything from start to the opening brace (or end if no body)
let start_byte = node.start_byte();
let sig_end = find_rust_sig_end(node, source);
let signature = source[start_byte..sig_end].trim().to_string();
// Extract doc comment (line comments starting with /// before the function)
let docstring = extract_rust_doc_comment(node, source);
Some(CodeSymbol {
name: name_str,
kind: "function".into(),
signature,
docstring,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "rust".into(),
})
}
fn extract_rust_type(node: tree_sitter::Node, bytes: &[u8], source: &str, kind: &str) -> Option<CodeSymbol> {
let name = node.child_by_field_name("name")?;
let name_str = name.utf8_text(bytes).ok()?.to_string();
// Signature: first line of the definition
let start = node.start_byte();
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
let signature = source[start..first_line_end].trim().to_string();
let docstring = extract_rust_doc_comment(node, source);
Some(CodeSymbol {
name: name_str,
kind: kind.into(),
signature,
docstring,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "rust".into(),
})
}
fn find_rust_sig_end(node: tree_sitter::Node, source: &str) -> usize {
// Find the opening brace
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "block" || child.kind() == "field_declaration_list"
|| child.kind() == "enum_variant_list" || child.kind() == "declaration_list"
{
return child.start_byte();
}
}
}
// No body (e.g., trait method signature)
node.end_byte().min(source.len())
}
fn extract_rust_doc_comment(node: tree_sitter::Node, source: &str) -> String {
let start_line = node.start_position().row;
if start_line == 0 {
return String::new();
}
let lines: Vec<&str> = source.lines().collect();
let mut doc_lines = Vec::new();
// Walk backwards from the line before the node
let mut line_idx = start_line.saturating_sub(1);
loop {
if line_idx >= lines.len() {
break;
}
let line = lines[line_idx].trim();
if line.starts_with("///") {
doc_lines.push(line.trim_start_matches("///").trim());
} else if line.starts_with("#[") || line.is_empty() {
// Skip attributes and blank lines between doc and function
if line.is_empty() && !doc_lines.is_empty() {
break; // blank line after doc block = stop
}
} else {
break;
}
if line_idx == 0 {
break;
}
line_idx -= 1;
}
doc_lines.reverse();
doc_lines.join("\n")
}
// ── TypeScript / JavaScript ─────────────────────────────────────────────
fn extract_ts_symbols(content: &str) -> Vec<CodeSymbol> {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()).ok();
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
let mut symbols = Vec::new();
walk_ts_node(tree.root_node(), content.as_bytes(), content, &mut symbols);
symbols
}
fn walk_ts_node(
node: tree_sitter::Node,
bytes: &[u8],
source: &str,
symbols: &mut Vec<CodeSymbol>,
) {
match node.kind() {
"function_declaration" | "method_definition" | "arrow_function" => {
if let Some(name) = node.child_by_field_name("name") {
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
if !name_str.is_empty() {
let start = node.start_byte();
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
symbols.push(CodeSymbol {
name: name_str,
kind: "function".into(),
signature: source[start..first_line_end].trim().to_string(),
docstring: String::new(), // TODO: JSDoc extraction
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "typescript".into(),
});
}
}
}
"class_declaration" | "interface_declaration" | "type_alias_declaration" | "enum_declaration" => {
if let Some(name) = node.child_by_field_name("name") {
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
let kind = match node.kind() {
"class_declaration" => "class",
"interface_declaration" => "interface",
"enum_declaration" => "enum",
_ => "type",
};
let start = node.start_byte();
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
symbols.push(CodeSymbol {
name: name_str,
kind: kind.into(),
signature: source[start..first_line_end].trim().to_string(),
docstring: String::new(),
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "typescript".into(),
});
}
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
walk_ts_node(child, bytes, source, symbols);
}
}
}
// ── Python ──────────────────────────────────────────────────────────────
fn extract_python_symbols(content: &str) -> Vec<CodeSymbol> {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_python::LANGUAGE.into()).ok();
let Some(tree) = parser.parse(content, None) else {
return Vec::new();
};
let mut symbols = Vec::new();
walk_python_node(tree.root_node(), content.as_bytes(), content, &mut symbols);
symbols
}
fn walk_python_node(
node: tree_sitter::Node,
bytes: &[u8],
source: &str,
symbols: &mut Vec<CodeSymbol>,
) {
match node.kind() {
"function_definition" => {
if let Some(name) = node.child_by_field_name("name") {
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
let start = node.start_byte();
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
let docstring = extract_python_docstring(node, bytes);
symbols.push(CodeSymbol {
name: name_str,
kind: "function".into(),
signature: source[start..first_line_end].trim().to_string(),
docstring,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "python".into(),
});
}
}
"class_definition" => {
if let Some(name) = node.child_by_field_name("name") {
let name_str = name.utf8_text(bytes).unwrap_or("").to_string();
let start = node.start_byte();
let first_line_end = source[start..].find('\n').map(|i| start + i).unwrap_or(node.end_byte());
let docstring = extract_python_docstring(node, bytes);
symbols.push(CodeSymbol {
name: name_str,
kind: "class".into(),
signature: source[start..first_line_end].trim().to_string(),
docstring,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
language: "python".into(),
});
}
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
walk_python_node(child, bytes, source, symbols);
}
}
}
fn extract_python_docstring(node: tree_sitter::Node, bytes: &[u8]) -> String {
// Python docstrings are the first expression_statement in the body
if let Some(body) = node.child_by_field_name("body") {
if let Some(first_stmt) = body.child(0) {
if first_stmt.kind() == "expression_statement" {
if let Some(expr) = first_stmt.child(0) {
if expr.kind() == "string" {
let text = expr.utf8_text(bytes).unwrap_or("");
// Strip triple quotes
let trimmed = text
.trim_start_matches("\"\"\"")
.trim_start_matches("'''")
.trim_end_matches("\"\"\"")
.trim_end_matches("'''")
.trim();
return trimmed.to_string();
}
}
}
}
}
String::new()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_language() {
assert_eq!(detect_language("src/main.rs"), Some("rust"));
assert_eq!(detect_language("app.ts"), Some("typescript"));
assert_eq!(detect_language("app.tsx"), Some("typescript"));
assert_eq!(detect_language("script.py"), Some("python"));
assert_eq!(detect_language("script.js"), Some("javascript"));
assert_eq!(detect_language("data.json"), None);
assert_eq!(detect_language("README.md"), None);
}
#[test]
fn test_extract_rust_function() {
let source = r#"
/// Generate a response.
pub async fn generate(&self, req: &GenerateRequest) -> Option<String> {
self.run_and_emit(req).await
}
"#;
let symbols = extract_rust_symbols(source);
assert!(!symbols.is_empty(), "Should extract at least one symbol");
let func = &symbols[0];
assert_eq!(func.name, "generate");
assert_eq!(func.kind, "function");
assert!(func.signature.contains("pub async fn generate"));
assert!(func.docstring.contains("Generate a response"));
assert_eq!(func.language, "rust");
}
#[test]
fn test_extract_rust_struct() {
let source = r#"
/// A request to generate.
pub struct GenerateRequest {
pub text: String,
pub user_id: String,
}
"#;
let symbols = extract_rust_symbols(source);
let structs: Vec<_> = symbols.iter().filter(|s| s.kind == "struct").collect();
assert!(!structs.is_empty());
assert_eq!(structs[0].name, "GenerateRequest");
assert!(structs[0].docstring.contains("request to generate"));
}
#[test]
fn test_extract_rust_enum() {
let source = r#"
/// Whether server or client.
pub enum ToolSide {
Server,
Client,
}
"#;
let symbols = extract_rust_symbols(source);
let enums: Vec<_> = symbols.iter().filter(|s| s.kind == "enum").collect();
assert!(!enums.is_empty());
assert_eq!(enums[0].name, "ToolSide");
}
#[test]
fn test_extract_rust_trait() {
let source = r#"
pub trait Executor {
fn execute(&self, args: &str) -> String;
}
"#;
let symbols = extract_rust_symbols(source);
let traits: Vec<_> = symbols.iter().filter(|s| s.kind == "trait").collect();
assert!(!traits.is_empty());
assert_eq!(traits[0].name, "Executor");
}
#[test]
fn test_extract_rust_impl_methods() {
let source = r#"
impl Orchestrator {
/// Create new.
pub fn new(config: Config) -> Self {
Self { config }
}
/// Subscribe to events.
pub fn subscribe(&self) -> Receiver {
self.tx.subscribe()
}
}
"#;
let symbols = extract_rust_symbols(source);
let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect();
assert!(fns.len() >= 2, "Should find impl methods, got {}", fns.len());
let names: Vec<&str> = fns.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"new"));
assert!(names.contains(&"subscribe"));
}
#[test]
fn test_extract_ts_function() {
let source = r#"
function greet(name: string): string {
return `Hello, ${name}`;
}
"#;
let symbols = extract_ts_symbols(source);
assert!(!symbols.is_empty());
assert_eq!(symbols[0].name, "greet");
assert_eq!(symbols[0].kind, "function");
}
#[test]
fn test_extract_ts_class() {
let source = r#"
class UserService {
constructor(private db: Database) {}
async getUser(id: string): Promise<User> {
return this.db.find(id);
}
}
"#;
let symbols = extract_ts_symbols(source);
let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect();
assert!(!classes.is_empty());
assert_eq!(classes[0].name, "UserService");
}
#[test]
fn test_extract_ts_interface() {
let source = r#"
interface User {
id: string;
name: string;
email?: string;
}
"#;
let symbols = extract_ts_symbols(source);
let ifaces: Vec<_> = symbols.iter().filter(|s| s.kind == "interface").collect();
assert!(!ifaces.is_empty());
assert_eq!(ifaces[0].name, "User");
}
#[test]
fn test_extract_python_function() {
let source = r#"
def process_data(items: list[str]) -> dict:
"""Process a list of items into a dictionary."""
return {item: len(item) for item in items}
"#;
let symbols = extract_python_symbols(source);
assert!(!symbols.is_empty());
assert_eq!(symbols[0].name, "process_data");
assert_eq!(symbols[0].kind, "function");
assert!(symbols[0].docstring.contains("Process a list"));
}
#[test]
fn test_extract_python_class() {
let source = r#"
class DataProcessor:
"""Processes data from various sources."""
def __init__(self, config):
self.config = config
def run(self):
pass
"#;
let symbols = extract_python_symbols(source);
let classes: Vec<_> = symbols.iter().filter(|s| s.kind == "class").collect();
assert!(!classes.is_empty());
assert_eq!(classes[0].name, "DataProcessor");
assert!(classes[0].docstring.contains("Processes data"));
}
#[test]
fn test_extract_symbols_unknown_language() {
let symbols = extract_symbols("data.json", "{}");
assert!(symbols.is_empty());
}
#[test]
fn test_extract_symbols_empty_file() {
let symbols = extract_symbols("empty.rs", "");
assert!(symbols.is_empty());
}
#[test]
fn test_line_numbers_are_1_based() {
let source = "fn first() {}\nfn second() {}\nfn third() {}";
let symbols = extract_rust_symbols(source);
assert!(symbols.len() >= 3);
assert_eq!(symbols[0].start_line, 1);
assert_eq!(symbols[1].start_line, 2);
assert_eq!(symbols[2].start_line, 3);
}
}