diff --git a/crates/flow/tests/integration_tests.rs b/crates/flow/tests/integration_tests.rs index 94d89b0..d3fa5fe 100644 --- a/crates/flow/tests/integration_tests.rs +++ b/crates/flow/tests/integration_tests.rs @@ -415,8 +415,6 @@ async fn test_parse_rust_code() { let output = result.unwrap(); let symbols = extract_symbols(&output); - // Note: Currently only extracts functions, not structs/classes - // TODO: Add struct/class extraction in future if !symbols.is_empty() { let symbol_names: Vec = symbols .iter() @@ -426,15 +424,17 @@ async fn test_parse_rust_code() { }) .collect(); - // Look for functions that should be extracted - let found_function = symbol_names.iter().any(|name| { + // Look for functions and structs/classes that should be extracted + let found_function_or_struct = symbol_names.iter().any(|name| { name.contains("main") || name.contains("process_user") || name.contains("calculate_total") + || name.contains("User") + || name.contains("Role") }); assert!( - found_function, - "Should find at least one function (main, process_user, or calculate_total). Found: {:?}", + found_function_or_struct, + "Should find at least one function or struct (main, process_user, calculate_total, User, or Role). Found: {:?}", symbol_names ); } else { diff --git a/crates/language/src/lib.rs b/crates/language/src/lib.rs index 721ddd6..7709c0e 100644 --- a/crates/language/src/lib.rs +++ b/crates/language/src/lib.rs @@ -1721,17 +1721,17 @@ pub fn from_extension(path: &Path) -> Option { } // Handle extensionless files or files with unknown extensions - if let Some(_file_name) = path.file_name().and_then(|n| n.to_str()) { + if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) { // 1. Check if the full filename matches a known extension (e.g. .bashrc) #[cfg(any(feature = "bash", feature = "all-parsers"))] - if constants::BASH_EXTS.contains(&_file_name) { + if constants::BASH_EXTS.contains(&file_name) { return Some(SupportLang::Bash); } // 2. Check known extensionless file names #[cfg(any(feature = "bash", feature = "all-parsers", feature = "ruby"))] for (name, lang) in constants::LANG_RELATIONSHIPS_WITH_NO_EXTENSION { - if *name == _file_name { + if *name == file_name { return Some(*lang); } } diff --git a/crates/services/src/conversion.rs b/crates/services/src/conversion.rs index 37d65e4..47e68c0 100644 --- a/crates/services/src/conversion.rs +++ b/crates/services/src/conversion.rs @@ -67,6 +67,13 @@ pub fn extract_basic_metadata( } } + // Extract class and struct definitions + if let Ok(class_matches) = extract_classes(&root_node) { + for (name, info) in class_matches { + metadata.defined_symbols.insert(name, info); + } + } + // Extract import statements if let Ok(imports) = extract_imports(&root_node, &document.language) { for (name, info) in imports { @@ -117,6 +124,43 @@ fn extract_functions(root_node: &Node) -> ServiceResult(root_node: &Node) -> ServiceResult> { + let mut classes = thread_utilities::get_map(); + + // Try different class/struct patterns based on common languages + let patterns = [ + "struct $NAME { $$$BODY }", // Rust, C++, C# + "class $NAME { $$$BODY }", // TypeScript, JavaScript, Java, C#, C++ + "class $NAME: $$$BODY", // Python + "class $NAME($$$PARAMS): $$$BODY", // Python + "type $NAME struct { $$$BODY }", // Go + "interface $NAME { $$$BODY }", // TypeScript, Java, C# + ]; + + for pattern in &patterns { + for node_match in root_node.find_all(pattern) { + if let Some(name_node) = node_match.get_env().get_match("NAME") { + let class_name = name_node.text().to_string(); + let position = name_node.start_pos(); + + let symbol_info = SymbolInfo { + name: class_name.clone(), + kind: SymbolKind::Class, + position, + scope: "global".to_string(), // Simplified for now + visibility: Visibility::Public, // Simplified for now + }; + + classes.insert(class_name, symbol_info); + } + } + } + + Ok(classes) +} + /// Extract import statements using language-specific patterns #[cfg(feature = "matching")] fn extract_imports(