Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/aegis-core/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ pub mod imports;
pub mod languages;
pub mod parsed_file;
pub mod registry;
pub mod symbols;

pub use adapter::{default_max_chain_depth, LanguageAdapter};
pub use imports::{extract_imports, Import};
pub use parsed_file::{parse, ParsedFile};
pub use registry::LanguageRegistry;
pub use symbols::{extract_imported_symbols, extract_public_symbols};
21 changes: 21 additions & 0 deletions crates/aegis-core/src/ast/parsed_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
//! not a hard short-circuit baked into the parse layer.

use std::cell::OnceCell;
use std::collections::HashSet;

use crate::ast::adapter::LanguageAdapter;
use crate::ast::imports::{extract_imports, Import};
use crate::ast::registry::LanguageRegistry;
use crate::ast::symbols::{extract_imported_symbols, extract_public_symbols};

/// Output of a successful parse — tree + source + the adapter that
/// produced it. Cheap to pass by reference; nothing here is cloned.
Expand All @@ -30,6 +32,8 @@ pub struct ParsedFile<'src> {
source: &'src str,
language_name: &'static str,
imports_cache: OnceCell<Vec<Import>>,
public_symbols_cache: OnceCell<HashSet<String>>,
imported_symbols_cache: OnceCell<HashSet<String>>,
}

impl<'src> ParsedFile<'src> {
Expand Down Expand Up @@ -87,6 +91,21 @@ impl<'src> ParsedFile<'src> {
self.imports_cache.get_or_init(|| extract_imports(self))
}

/// Top-level public symbols this file exposes (functions, classes,
/// types, traits, exports). Lazily extracted; the same `HashSet`
/// is returned on subsequent calls.
pub fn public_symbols(&self) -> &HashSet<String> {
self.public_symbols_cache
.get_or_init(|| extract_public_symbols(self))
}

/// Names this file pulls in via its own imports (Python
/// `from X import Y`, TS `import { a, b } from 'x'`). Lazy.
pub fn imported_symbols(&self) -> &HashSet<String> {
self.imported_symbols_cache
.get_or_init(|| extract_imported_symbols(self))
}

/// Best-effort receiver-to-import lookup: given a call receiver
/// like `rand` or `myrand`, return the matching `Import` if one
/// of the file's imports plausibly produced that name.
Expand Down Expand Up @@ -140,6 +159,8 @@ pub fn parse<'src>(path: &str, source: &'src str) -> Option<ParsedFile<'src>> {
source,
language_name: adapter.name(),
imports_cache: OnceCell::new(),
public_symbols_cache: OnceCell::new(),
imported_symbols_cache: OnceCell::new(),
})
}

Expand Down
228 changes: 228 additions & 0 deletions crates/aegis-core/src/ast/symbols.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
//! Per-file symbol extraction — Layer 1 fact derivation.
//!
//! Mirrors the architecture of `ast::imports`: pulls the
//! cross-file public-symbol / imported-symbol extraction logic out
//! of `workspace.rs::summarize_file` into Layer 1 so any consumer
//! (security checks, future signal rules, the existing R2
//! workspace finding) can read the same canonical list without
//! re-walking the tree.
//!
//! Two collections per file:
//!
//! - **`public_symbols`** — top-level function / class / type /
//! trait / variable names this file exposes to its callers.
//! Powers the `public_symbol_removed` workspace finding.
//! - **`imported_symbols`** — names this file pulls in from its
//! own imports (Python `from X import Y, Z`, TS
//! `import { a, b } from 'x'`). Used by R2 to know what
//! downstream files depend on.
//!
//! Caller is responsible for caching; the supported entry point is
//! `ParsedFile::public_symbols()` / `ParsedFile::imported_symbols()`.

use std::collections::HashSet;

use crate::ast::parsed_file::ParsedFile;

pub fn extract_public_symbols(parsed: &ParsedFile<'_>) -> HashSet<String> {
let mut out = HashSet::new();
walk_public(parsed.root_node(), parsed.source_bytes(), &mut out);
out
}

pub fn extract_imported_symbols(parsed: &ParsedFile<'_>) -> HashSet<String> {
let mut out = HashSet::new();
walk_imported(parsed.root_node(), parsed.source_bytes(), &mut out);
out
}

fn walk_public(node: tree_sitter::Node<'_>, src: &[u8], out: &mut HashSet<String>) {
let kind = node.kind();
let is_decl = matches!(
kind,
"function_definition" | "function_declaration" | "function_item"
| "class_definition" | "class_declaration"
| "method_definition" | "method_declaration"
| "interface_declaration" | "enum_declaration"
| "struct_item" | "trait_item" | "type_alias"
| "lexical_declaration" | "variable_declaration"
| "export_statement"
);
if is_decl {
if let Some(name_node) = node.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(src) {
if is_public_name(name) && is_likely_public(node, src) {
out.insert(name.to_string());
}
}
} else if kind == "export_statement" {
walk_export(node, src, out);
}
}
// Recurse into children, but skip function bodies — nested
// local helpers are not part of the file's public API.
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if matches!(
kind,
"function_definition" | "function_item" | "method_definition"
) && matches!(
child.kind(),
"block" | "function_body" | "compound_statement"
) {
continue;
}
walk_public(child, src, out);
}
}

fn is_public_name(name: &str) -> bool {
// Python convention: _-prefixed = private.
!name.starts_with('_')
}

fn is_likely_public(node: tree_sitter::Node<'_>, src: &[u8]) -> bool {
// Rust: must have `pub` modifier.
if node.kind() == "function_item"
|| node.kind() == "struct_item"
|| node.kind() == "trait_item"
{
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier" {
if let Ok(text) = child.utf8_text(src) {
if text.starts_with("pub") {
return true;
}
}
}
}
return false;
}
true
}

fn walk_export(node: tree_sitter::Node<'_>, src: &[u8], out: &mut HashSet<String>) {
// Mark `export default` so callers can detect its loss.
// The synthetic name "default" is what TS module consumers
// import via `import Foo from './x'` — the local name on the
// right is the consumer's choice, but the export slot's
// identity is "default".
let raw = node.utf8_text(src).unwrap_or("");
if raw.trim_start().starts_with("export default") {
out.insert("default".to_string());
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(src) {
if is_public_name(name) {
out.insert(name.to_string());
}
}
}
// For `export { a, b as c } from 'x'` — collect identifiers
// inside export_clause / export_specifier nodes.
if matches!(child.kind(), "export_specifier") {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(src) {
out.insert(name.to_string());
}
} else if let Some(first) = child.named_child(0) {
if let Ok(name) = first.utf8_text(src) {
out.insert(name.to_string());
}
}
}
walk_export(child, src, out);
}
}

fn walk_imported(node: tree_sitter::Node<'_>, src: &[u8], out: &mut HashSet<String>) {
let kind = node.kind();
// Python: from X import Y, Z [as W]
if kind == "import_from_statement" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"dotted_name" => {
// Skip the module_name field — only collect the
// imported names that follow.
if node.child_by_field_name("module_name") == Some(child) {
continue;
}
if let Ok(text) = child.utf8_text(src) {
out.insert(text.to_string());
}
}
"aliased_import" => {
if let Some(name_node) = child.named_child(0) {
if let Ok(text) = name_node.utf8_text(src) {
out.insert(text.to_string());
}
}
}
_ => {}
}
}
}
// TS/JS: import { x, y as z } from 'X'
if kind == "import_specifier" {
if let Some(name) = node.child_by_field_name("name") {
if let Ok(text) = name.utf8_text(src) {
out.insert(text.to_string());
}
} else if let Some(first) = node.named_child(0) {
if let Ok(text) = first.utf8_text(src) {
out.insert(text.to_string());
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_imported(child, src, out);
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::ast::parsed_file::parse;

#[test]
fn python_public_symbols() {
let pf = parse(
"lib.py",
"def public_fn():\n pass\n\nclass Public:\n pass\n\ndef _private():\n pass\n",
).unwrap();
let syms = extract_public_symbols(&pf);
assert!(syms.contains("public_fn"));
assert!(syms.contains("Public"));
assert!(!syms.contains("_private"));
}

#[test]
fn rust_pub_only() {
let pf = parse(
"lib.rs",
"pub fn exposed() {}\nfn hidden() {}\npub struct Pubst;\nstruct Privst;\n",
).unwrap();
let syms = extract_public_symbols(&pf);
assert!(syms.contains("exposed"));
assert!(!syms.contains("hidden"));
assert!(syms.contains("Pubst"));
assert!(!syms.contains("Privst"));
}

#[test]
fn python_imported_names() {
let pf = parse(
"a.py",
"from collections import OrderedDict, defaultdict\nfrom os.path import join as joinpath\n",
).unwrap();
let syms = extract_imported_symbols(&pf);
assert!(syms.contains("OrderedDict"));
assert!(syms.contains("defaultdict"));
assert!(syms.contains("join"));
}
}
Loading
Loading