From 593d930aa771836456b27552d07d260a9ccaf06e Mon Sep 17 00:00:00 2001 From: ar7casper Date: Tue, 3 Mar 2026 16:44:20 +0200 Subject: [PATCH] feat(openant-core): add Rust parser support - Add 4-stage Rust parser using tree-sitter-rust - repository_scanner.py: enumerate .rs files - function_extractor.py: extract functions/methods from AST - call_graph_builder.py: build bidirectional call graphs - unit_generator.py: generate dataset.json - test_pipeline.py: orchestrator with all 4 processing levels - Register Rust in parser_adapter.py (detection + dispatch) - Add 'rust' to CLI language whitelist (cli.py, parse.go) - Add tree-sitter-rust to dependencies (requirements.txt, pyproject.toml) - Update README with Rust in supported languages - Add adding-a-parser.md guide with CLI whitelist and venv dependency docs --- README.md | 5 + adding-a-parser.md | 595 ++++++++++ apps/openant-cli/cmd/parse.go | 4 +- libs/openant-core/core/parser_adapter.py | 66 +- libs/openant-core/openant/cli.py | 4 +- libs/openant-core/parsers/rust/__init__.py | 1 + .../parsers/rust/call_graph_builder.py | 498 ++++++++ .../parsers/rust/function_extractor.py | 532 +++++++++ .../parsers/rust/repository_scanner.py | 264 +++++ .../parsers/rust/test_pipeline.py | 1034 +++++++++++++++++ .../parsers/rust/unit_generator.py | 400 +++++++ libs/openant-core/pyproject.toml | 1 + libs/openant-core/requirements.txt | 1 + 13 files changed, 3400 insertions(+), 5 deletions(-) create mode 100644 adding-a-parser.md create mode 100644 libs/openant-core/parsers/rust/__init__.py create mode 100644 libs/openant-core/parsers/rust/call_graph_builder.py create mode 100644 libs/openant-core/parsers/rust/function_extractor.py create mode 100644 libs/openant-core/parsers/rust/repository_scanner.py create mode 100644 libs/openant-core/parsers/rust/test_pipeline.py create mode 100644 libs/openant-core/parsers/rust/unit_generator.py diff --git a/README.md b/README.md index f3f1b4a..1adc5ac 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,11 @@ To submit your repo for scanning: - C/C++ (beta) - PHP (beta) - Ruby (beta) +- Rust (beta) + +## Adding a new language + +Want to add support for another language? See [adding-a-parser.md](adding-a-parser.md) for the full guide. The Ruby parser is the cleanest reference implementation to copy. ## Credits Research and ideation: [Nahum Korda](https://github.com/NahumKorda/). diff --git a/adding-a-parser.md b/adding-a-parser.md new file mode 100644 index 0000000..91f5ea0 --- /dev/null +++ b/adding-a-parser.md @@ -0,0 +1,595 @@ +# Adding a Parser to OpenAnt + +This guide explains how to add support for a new programming language to OpenAnt's parsing pipeline. + +## Overview + +OpenAnt parsers transform source code repositories into **analysis units** — self-contained code snippets with dependency context that can be analyzed for vulnerabilities. Every parser follows the same 4-stage pipeline: + +``` +Repository → [1. Scanner] → [2. Extractor] → [3. Call Graph] → [4. Unit Generator] → Dataset +``` + +The output is a standardized `dataset.json` that downstream stages (analyzer, verifier, enhancer) consume. As long as your parser produces the correct output schema, you can implement it however you like. + +## Quick Start Checklist + +Adding a new language requires: + +1. **Create parser directory**: `libs/openant-core/parsers//` +2. **Implement 4 stage modules** (see [Pipeline Stages](#the-4-stage-pipeline)) +3. **Create pipeline orchestrator**: `test_pipeline.py` +4. **Register in adapter**: `libs/openant-core/core/parser_adapter.py` +5. **Update language detection**: Add file extensions to `detect_language()` +6. **Add to CLI whitelist**: `libs/openant-core/openant/cli.py` (required for `--language` flag) +7. **Add dependencies**: `requirements.txt` and `pyproject.toml` (for venv auto-install) +8. **Update CLI help**: `apps/openant-cli/cmd/parse.go` (optional, help text only) +9. **Update README**: Add language to "Supported languages" list + +## The 4-Stage Pipeline + +### Stage 1: Repository Scanner + +**Purpose**: Enumerate all source files in the repository. + +**Class**: `RepositoryScanner` + +**Input**: Repository path + options (skip_tests, exclude_patterns) + +**Output**: `scan_results.json` + +```json +{ + "repository": "/path/to/repo", + "scan_time": "2025-01-15T10:30:00", + "files": [ + { "path": "src/main.rs", "size": 1234 }, + { "path": "src/lib.rs", "size": 5678 } + ], + "statistics": { + "total_files": 150, + "total_size_bytes": 500000, + "directories_scanned": 25, + "directories_excluded": 10 + } +} +``` + +**Key responsibilities**: +- Walk directory tree, respecting exclude patterns (`.git`, `vendor`, `node_modules`, etc.) +- Filter by file extension for your language +- Optionally skip test files when `skip_tests=True` + +### Stage 2: Function Extractor + +**Purpose**: Extract all functions, methods, and classes from source files using AST parsing. + +**Class**: `FunctionExtractor` + +**Input**: Repository path + scan results + +**Output**: `functions.json` (intermediate, not written to disk in most implementations) + +```json +{ + "repository": "/path/to/repo", + "extraction_time": "2025-01-15T10:30:05", + "functions": { + "src/main.rs:main": { + "name": "main", + "qualified_name": "main", + "file_path": "src/main.rs", + "start_line": 10, + "end_line": 25, + "code": "fn main() {\n ...\n}", + "class_name": null, + "module_name": null, + "parameters": [], + "unit_type": "function" + }, + "src/lib.rs:Config.new": { + "name": "new", + "qualified_name": "Config.new", + "file_path": "src/lib.rs", + "start_line": 15, + "end_line": 20, + "code": "pub fn new() -> Self { ... }", + "class_name": "Config", + "module_name": null, + "parameters": [], + "unit_type": "constructor" + } + }, + "classes": { ... }, + "imports": { ... }, + "statistics": { + "total_functions": 150, + "total_classes": 20, + "files_processed": 50, + "files_with_errors": 2 + } +} +``` + +**Function ID format**: `:` + +Examples: +- `src/main.rs:main` +- `src/lib.rs:Config.new` +- `app/controllers/users_controller.rb:UsersController.create` + +**Required fields per function**: + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Simple function name | +| `qualified_name` | string | Class.method or just name if top-level | +| `file_path` | string | Relative path from repo root | +| `start_line` | int | 1-indexed line number | +| `end_line` | int | 1-indexed line number | +| `code` | string | Full source code of the function | +| `class_name` | string \| null | Containing class/struct name | +| `module_name` | string \| null | Containing module/namespace | +| `parameters` | string[] | Parameter names | +| `unit_type` | string | See unit types below | + +**Unit types** (used for classification and filtering): +- `function` — standalone function +- `method` — instance method in a class +- `constructor` — `__init__`, `new`, `initialize`, etc. +- `route_handler` — HTTP endpoint handler +- `callback` — lifecycle hooks, event handlers +- `test` — test functions (filtered out when `skip_tests=True`) +- `singleton_method` — class/static methods + +### Stage 3: Call Graph Builder + +**Purpose**: Build bidirectional call graphs showing function dependencies. + +**Class**: `CallGraphBuilder` + +**Input**: Function extractor output + +**Output**: `call_graph.json` + +```json +{ + "repository": "/path/to/repo", + "functions": { ... }, + "classes": { ... }, + "imports": { ... }, + "call_graph": { + "src/main.rs:main": ["src/lib.rs:Config.new", "src/lib.rs:run"], + "src/lib.rs:run": ["src/lib.rs:process"] + }, + "reverse_call_graph": { + "src/lib.rs:Config.new": ["src/main.rs:main"], + "src/lib.rs:run": ["src/main.rs:main"], + "src/lib.rs:process": ["src/lib.rs:run"] + }, + "statistics": { + "total_functions": 150, + "total_edges": 500, + "avg_out_degree": 3.33, + "max_out_degree": 15, + "isolated_functions": 20 + } +} +``` + +**Key responsibilities**: +- Parse function bodies to find call sites +- Resolve calls to function IDs (same file → imported files → unique name match) +- Filter out language builtins and standard library calls +- Build both forward (`call_graph`) and reverse (`reverse_call_graph`) mappings + +### Stage 4: Unit Generator + +**Purpose**: Create self-contained analysis units with dependency context. + +**Class**: `UnitGenerator` + +**Input**: Call graph output + +**Output**: `dataset.json` + `analyzer_output.json` + +This is the **critical output** — downstream stages depend on this exact schema. + +#### `dataset.json` Schema + +```json +{ + "name": "my-project", + "repository": "/path/to/repo", + "units": [ + { + "id": "src/lib.rs:process", + "unit_type": "function", + "code": { + "primary_code": "fn process() { ... }\n\n# ========== File Boundary ==========\n\nfn helper() { ... }", + "primary_origin": { + "file_path": "src/lib.rs", + "start_line": 30, + "end_line": 45, + "function_name": "process", + "class_name": null, + "enhanced": true, + "files_included": ["src/lib.rs", "src/helpers.rs"], + "original_length": 250, + "enhanced_length": 800 + }, + "dependencies": [], + "dependency_metadata": { + "depth": 3, + "total_upstream": 2, + "total_downstream": 1, + "direct_calls": 2, + "direct_callers": 1 + } + }, + "ground_truth": { + "status": "UNKNOWN", + "vulnerability_types": [], + "issues": [], + "annotation_source": null, + "annotation_key": null, + "notes": null + }, + "metadata": { + "parameters": ["input"], + "generator": "rust_unit_generator.py", + "direct_calls": ["src/helpers.rs:validate", "src/helpers.rs:transform"], + "direct_callers": ["src/lib.rs:run"] + } + } + ], + "statistics": { + "total_units": 150, + "by_type": { "function": 100, "method": 40, "constructor": 10 }, + "units_with_upstream": 120, + "units_with_downstream": 80, + "units_enhanced": 130, + "avg_upstream": 2.5, + "avg_downstream": 1.8 + }, + "metadata": { + "generator": "rust_unit_generator.py", + "generated_at": "2025-01-15T10:30:15", + "dependency_depth": 3 + } +} +``` + +**Enhanced code assembly**: The `primary_code` field contains the function's code plus its dependencies, separated by file boundary markers: + +``` +fn main() { + process(); +} + +# ========== File Boundary ========== + +fn process() { + // dependency code +} +``` + +Use your language's comment syntax for the boundary marker. + +#### `analyzer_output.json` Schema + +This file provides a function index used by the verifier for cross-referencing: + +```json +{ + "repository": "/path/to/repo", + "functions": { + "src/lib.rs:process": { + "name": "process", + "unitType": "function", + "code": "fn process() { ... }", + "filePath": "src/lib.rs", + "startLine": 30, + "endLine": 45, + "isExported": true, + "parameters": ["input"], + "className": null + } + }, + "call_graph": { ... }, + "reverse_call_graph": { ... } +} +``` + +Note: `analyzer_output.json` uses **camelCase** field names for historical reasons. + +## Pipeline Orchestrator (`test_pipeline.py`) + +Each parser has a `test_pipeline.py` that wires the 4 stages together and handles CLI arguments. This is the entry point called by `parser_adapter.py`. + +**Required CLI interface**: + +```bash +python test_pipeline.py \ + --output \ + --processing-level \ + --skip-tests \ + --name +``` + +**Processing levels** (filtering modes): +- `all` — Include all functions +- `reachable` — Filter to functions reachable from entry points +- `codeql` — Filter to reachable + CodeQL-flagged functions +- `exploitable` — Filter to reachable + CodeQL + LLM-classified exploitable + +The `reachable` filter uses `utilities/agentic_enhancer/entry_point_detector.py` and `reachability_analyzer.py`. See the Ruby parser's `apply_reachability_filter()` method for an example. + +**Exit code**: Return 0 on success, non-zero on failure. + +## Registering Your Parser + +### 1. Update `parser_adapter.py` + +Location: `libs/openant-core/core/parser_adapter.py` + +Add your language to three places: + +**a) `detect_language()` — file extension mapping**: + +```python +def detect_language(repo_path: str) -> str: + counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0, "rust": 0} # Add here + # ... + elif suffix == ".rs": # Add extension check + counts["rust"] += 1 +``` + +**b) `parse_repository()` — dispatch branch**: + +```python +def parse_repository(...) -> ParseResult: + # ... + elif language == "rust": + return _parse_rust(repo_path, output_dir, processing_level, skip_tests, name) +``` + +**c) Add `_parse_()` function**: + +```python +def _parse_rust(repo_path: str, output_dir: str, processing_level: str, + skip_tests: bool = True, name: str = None) -> ParseResult: + """Invoke the Rust parser.""" + print("[Parser] Running Rust parser...", file=sys.stderr) + + parser_script = _CORE_ROOT / "parsers" / "rust" / "test_pipeline.py" + + cmd = [ + sys.executable, str(parser_script), + repo_path, + "--output", output_dir, + "--processing-level", processing_level, + ] + + if name: + cmd.extend(["--name", name]) + if skip_tests: + cmd.append("--skip-tests") + + result = subprocess.run( + cmd, + stdout=sys.stderr, + stderr=sys.stderr, + cwd=str(_CORE_ROOT), + timeout=1800, # 30 min timeout + ) + + if result.returncode != 0: + raise RuntimeError(f"Rust parser failed with exit code {result.returncode}") + + dataset_path = os.path.join(output_dir, "dataset.json") + analyzer_output_path = os.path.join(output_dir, "analyzer_output.json") + + units_count = 0 + if os.path.exists(dataset_path): + with open(dataset_path) as f: + data = json.load(f) + units_count = len(data.get("units", [])) + + print(f" Rust parser complete: {units_count} units", file=sys.stderr) + + return ParseResult( + dataset_path=dataset_path, + analyzer_output_path=analyzer_output_path if os.path.exists(analyzer_output_path) else None, + units_count=units_count, + language="rust", + processing_level=processing_level, + ) +``` + +### 2. Add to CLI Whitelist (required) + +Location: `libs/openant-core/openant/cli.py` + +The CLI validates the `--language` flag against a whitelist. Without this, users get: +``` +error: argument --language/-l: invalid choice: 'rust' +``` + +Update the `choices` list in **two places** (for `scan` and `parse` commands): + +```python +# Around line 465 (scan command) +scan_p.add_argument( + "--language", "-l", + choices=["auto", "python", "javascript", "go", "c", "ruby", "php", "rust"], # Add here + ... +) + +# Around line 500 (parse command) +parse_p.add_argument( + "--language", "-l", + choices=["auto", "python", "javascript", "go", "c", "ruby", "php", "rust"], # Add here + ... +) +``` + +### 3. Add Dependencies (required for venv) + +When users run `openant init` or any command for the first time, OpenAnt creates a managed venv at `~/.openant/venv/` and installs dependencies from `pyproject.toml`. For your parser's dependencies to be included, add them to **both** files: + +**a) `libs/openant-core/requirements.txt`**: + +``` +tree-sitter-rust>=0.21.0 +``` + +**b) `libs/openant-core/pyproject.toml`**: + +```toml +dependencies = [ + # ... existing deps ... + "tree-sitter-rust>=0.21.0", +] +``` + +Without this, users will see `ModuleNotFoundError` when running the parser. + +### 4. Update CLI help (optional) + +Location: `apps/openant-cli/cmd/parse.go` + +Update the `--language` flag description: + +```go +parseCmd.Flags().StringVarP(&parseLanguage, "language", "l", "", + "Language: python, javascript, go, c, ruby, php, rust, auto") +``` + +### 5. Update README.md + +Add your language to the "Supported languages" list. + +## Recommended Approach: tree-sitter + +For most languages, [tree-sitter](https://tree-sitter.github.io/tree-sitter/) is the easiest path. It provides fast, incremental parsing with pre-built grammars for 100+ languages. + +**Why tree-sitter**: +- No external runtime needed (pure Python bindings) +- Consistent API across languages +- Handles syntax errors gracefully +- Pre-built grammars for most languages + +**Dependencies**: + +Add to both `libs/openant-core/requirements.txt` and `libs/openant-core/pyproject.toml`: + +``` +tree-sitter>=0.21.0 +tree-sitter->=0.21.0 # e.g., tree-sitter-rust +``` + +See [Add Dependencies](#3-add-dependencies-required-for-venv) for details. + +**Basic usage**: + +```python +import tree_sitter_rust as ts_rust +from tree_sitter import Language, Parser + +RUST_LANGUAGE = Language(ts_rust.language()) +parser = Parser(RUST_LANGUAGE) + +source = b"fn main() { println!(\"Hello\"); }" +tree = parser.parse(source) + +# Walk the AST +def walk(node, depth=0): + print(" " * depth + f"{node.type}: {source[node.start_byte:node.end_byte]}") + for child in node.children: + walk(child, depth + 1) + +walk(tree.root_node) +``` + +**Alternative approaches**: + +If tree-sitter doesn't have a grammar for your language, you can: +- Use the language's native AST parser (like Python's `ast` module) +- Use a subprocess to call an external parser (like the Go and JS parsers do) +- Write a regex-based fallback (less accurate, but works) + +## Reference Implementation + +The **Ruby parser** (`libs/openant-core/parsers/ruby/`) is the cleanest tree-sitter implementation to use as a template: + +| File | Purpose | +|------|---------| +| `repository_scanner.py` | Stage 1: File enumeration | +| `function_extractor.py` | Stage 2: AST parsing with tree-sitter | +| `call_graph_builder.py` | Stage 3: Call resolution | +| `unit_generator.py` | Stage 4: Dataset generation | +| `test_pipeline.py` | Pipeline orchestrator | +| `__init__.py` | Empty (required for Python imports) | + +Copy this directory, rename it, and adapt: +1. Change file extensions in `RepositoryScanner` +2. Update tree-sitter import and language in `FunctionExtractor` +3. Adjust call resolution patterns in `CallGraphBuilder` for your language's semantics +4. Update builtins list in `CallGraphBuilder.RUBY_BUILTINS` +5. Adjust unit type classification logic + +## Testing Your Parser + +### 1. Run on a test repository + +```bash +cd libs/openant-core +python parsers//test_pipeline.py /path/to/test/repo --output /tmp/test-output +``` + +### 2. Verify outputs + +Check that these files exist and have valid JSON: +- `/tmp/test-output/dataset.json` +- `/tmp/test-output/analyzer_output.json` +- `/tmp/test-output/scan_results.json` +- `/tmp/test-output/call_graph.json` (if your pipeline writes it) + +### 3. Validate dataset schema + +```python +import json + +with open("/tmp/test-output/dataset.json") as f: + dataset = json.load(f) + +# Check required fields +assert "name" in dataset +assert "units" in dataset +assert len(dataset["units"]) > 0 + +for unit in dataset["units"]: + assert "id" in unit + assert "unit_type" in unit + assert "code" in unit + assert "primary_code" in unit["code"] + assert "primary_origin" in unit["code"] +``` + +### 4. Test through the full pipeline + +```bash +# From repo root +openant init /path/to/test/repo -l --name test/repo +openant parse +openant analyze # Requires ANTHROPIC_API_KEY +``` + +### 5. Compare with existing parsers + +Parse the same polyglot repo with your parser and an existing one. The output structure should be identical — only the content differs. + +## Questions? + +Open an issue on GitHub or check existing parser implementations for examples. diff --git a/apps/openant-cli/cmd/parse.go b/apps/openant-cli/cmd/parse.go index 671df1a..148ba0c 100644 --- a/apps/openant-cli/cmd/parse.go +++ b/apps/openant-cli/cmd/parse.go @@ -15,7 +15,7 @@ var parseCmd = &cobra.Command{ Long: `Parse extracts analyzable code units from a repository. The output is a JSON dataset that can be fed into the analyze command. -Supports Python, JavaScript/TypeScript, Go, C/C++, Ruby, and PHP repositories. +Supports Python, JavaScript/TypeScript, Go, C/C++, Ruby, PHP, and Rust repositories. If no repository path is given, the active project is used (see: openant init).`, Args: cobra.MaximumNArgs(1), @@ -30,7 +30,7 @@ var ( func init() { parseCmd.Flags().StringVarP(&parseOutput, "output", "o", "", "Output directory (default: project scan dir)") - parseCmd.Flags().StringVarP(&parseLanguage, "language", "l", "", "Language: python, javascript, go, c, ruby, php, auto") + parseCmd.Flags().StringVarP(&parseLanguage, "language", "l", "", "Language: python, javascript, go, c, ruby, php, rust, auto") parseCmd.Flags().StringVar(&parseLevel, "level", "all", "Processing level: all, reachable, codeql, exploitable") } diff --git a/libs/openant-core/core/parser_adapter.py b/libs/openant-core/core/parser_adapter.py index 8e3ecc7..553b488 100644 --- a/libs/openant-core/core/parser_adapter.py +++ b/libs/openant-core/core/parser_adapter.py @@ -30,7 +30,7 @@ def detect_language(repo_path: str) -> str: "python", "javascript", or "go" """ repo = Path(repo_path) - counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0} + counts = {"python": 0, "javascript": 0, "go": 0, "c": 0, "ruby": 0, "php": 0, "rust": 0} for f in repo.rglob("*"): if not f.is_file(): @@ -56,6 +56,8 @@ def detect_language(repo_path: str) -> str: counts["ruby"] += 1 elif suffix == ".php": counts["php"] += 1 + elif suffix == ".rs": + counts["rust"] += 1 if not any(counts.values()): raise ValueError( @@ -116,6 +118,8 @@ def parse_repository( return _parse_ruby(repo_path, output_dir, processing_level, skip_tests, name) elif language == "php": return _parse_php(repo_path, output_dir, processing_level, skip_tests, name) + elif language == "rust": + return _parse_rust(repo_path, output_dir, processing_level, skip_tests, name) else: raise ValueError(f"Unsupported language: {language}") @@ -594,3 +598,63 @@ def _parse_php(repo_path: str, output_dir: str, processing_level: str, skip_test language="php", processing_level=processing_level, ) + + +# --------------------------------------------------------------------------- +# Rust parser +# --------------------------------------------------------------------------- + +def _parse_rust(repo_path: str, output_dir: str, processing_level: str, skip_tests: bool = True, name: str = None) -> ParseResult: + """Invoke the Rust parser. + + The Rust parser uses tree-sitter for function extraction and call graph + building. Invoked via subprocess (same pattern as other parsers). + + Requires: tree-sitter, tree-sitter-rust + """ + print("[Parser] Running Rust parser...", file=sys.stderr) + + parser_script = _CORE_ROOT / "parsers" / "rust" / "test_pipeline.py" + + cmd = [ + sys.executable, str(parser_script), + repo_path, + "--output", output_dir, + "--processing-level", processing_level, + ] + + if name: + cmd.extend(["--name", name]) + if skip_tests: + cmd.append("--skip-tests") + + result = subprocess.run( + cmd, + stdout=sys.stderr, + stderr=sys.stderr, + cwd=str(_CORE_ROOT), + timeout=1800, + ) + + if result.returncode != 0: + raise RuntimeError(f"Rust parser failed with exit code {result.returncode}") + + dataset_path = os.path.join(output_dir, "dataset.json") + analyzer_output_path = os.path.join(output_dir, "analyzer_output.json") + + # Count units + units_count = 0 + if os.path.exists(dataset_path): + with open(dataset_path) as f: + data = json.load(f) + units_count = len(data.get("units", [])) + + print(f" Rust parser complete: {units_count} units", file=sys.stderr) + + return ParseResult( + dataset_path=dataset_path, + analyzer_output_path=analyzer_output_path if os.path.exists(analyzer_output_path) else None, + units_count=units_count, + language="rust", + processing_level=processing_level, + ) diff --git a/libs/openant-core/openant/cli.py b/libs/openant-core/openant/cli.py index cdaf2bf..c6a6d37 100644 --- a/libs/openant-core/openant/cli.py +++ b/libs/openant-core/openant/cli.py @@ -462,7 +462,7 @@ def main(): scan_p.add_argument("--output", "-o", help="Output directory (default: temp dir)") scan_p.add_argument( "--language", "-l", - choices=["auto", "python", "javascript", "go", "c", "ruby", "php"], + choices=["auto", "python", "javascript", "go", "c", "ruby", "php", "rust"], default="auto", help="Language (default: auto-detect)", ) @@ -497,7 +497,7 @@ def main(): parse_p.add_argument("--output", "-o", help="Output directory (default: temp dir)") parse_p.add_argument( "--language", "-l", - choices=["auto", "python", "javascript", "go", "c", "ruby", "php"], + choices=["auto", "python", "javascript", "go", "c", "ruby", "php", "rust"], default="auto", help="Language (default: auto-detect)", ) diff --git a/libs/openant-core/parsers/rust/__init__.py b/libs/openant-core/parsers/rust/__init__.py new file mode 100644 index 0000000..bdd2f23 --- /dev/null +++ b/libs/openant-core/parsers/rust/__init__.py @@ -0,0 +1 @@ +# Rust parser for OpenAnt diff --git a/libs/openant-core/parsers/rust/call_graph_builder.py b/libs/openant-core/parsers/rust/call_graph_builder.py new file mode 100644 index 0000000..f35cf93 --- /dev/null +++ b/libs/openant-core/parsers/rust/call_graph_builder.py @@ -0,0 +1,498 @@ +#!/usr/bin/env python3 +""" +Call Graph Builder for Rust Codebases + +Builds bidirectional call graphs from extracted function data: +- Forward graph: function -> functions it calls +- Reverse graph: function -> functions that call it + +This is Phase 3 of the Rust parser - dependency resolution. + +Usage: + python call_graph_builder.py [--output ] [--depth ] + +Output (JSON): + { + "functions": {...}, + "call_graph": { + "src/main.rs:main": ["src/lib.rs:Config::new", "src/lib.rs:run"], + ... + }, + "reverse_call_graph": { + "src/lib.rs:Config::new": ["src/main.rs:main"], + ... + }, + "statistics": { + "total_edges": 500, + "avg_out_degree": 2.5, + "max_out_degree": 15, + "isolated_functions": 20 + } + } +""" + +import json +import re +import sys +from pathlib import Path +from typing import Dict, List, Optional, Set + +import tree_sitter_rust as ts_rust +from tree_sitter import Language, Parser + + +RUST_LANGUAGE = Language(ts_rust.language()) + +# Rust builtins, macros, and common methods to filter out +RUST_BUILTINS = { + # Macros (commonly used) + 'println', 'print', 'eprintln', 'eprint', 'format', 'write', 'writeln', + 'vec', 'panic', 'todo', 'unimplemented', 'unreachable', + 'assert', 'assert_eq', 'assert_ne', 'debug_assert', 'debug_assert_eq', + 'dbg', 'env', 'option_env', 'concat', 'stringify', 'include', 'include_str', + 'include_bytes', 'file', 'line', 'column', 'module_path', + 'cfg', 'cfg_attr', 'derive', 'test', 'bench', + # Standard library common methods + 'clone', 'to_string', 'to_owned', 'into', 'from', 'as_ref', 'as_mut', + 'borrow', 'borrow_mut', 'deref', 'deref_mut', + 'unwrap', 'expect', 'unwrap_or', 'unwrap_or_else', 'unwrap_or_default', + 'ok', 'err', 'ok_or', 'ok_or_else', + 'map', 'map_err', 'map_or', 'map_or_else', 'and_then', 'or_else', + 'filter', 'filter_map', 'find', 'find_map', 'position', + 'collect', 'iter', 'into_iter', 'iter_mut', + 'len', 'is_empty', 'capacity', + 'push', 'pop', 'insert', 'remove', 'get', 'get_mut', + 'contains', 'contains_key', 'entry', + 'first', 'last', 'nth', 'take', 'skip', 'step_by', + 'enumerate', 'zip', 'chain', 'flatten', 'flat_map', + 'fold', 'reduce', 'sum', 'product', 'count', + 'any', 'all', 'min', 'max', 'min_by', 'max_by', 'min_by_key', 'max_by_key', + 'sort', 'sort_by', 'sort_by_key', 'reverse', 'rev', + 'split', 'split_at', 'split_whitespace', 'lines', 'chars', 'bytes', + 'trim', 'trim_start', 'trim_end', 'strip_prefix', 'strip_suffix', + 'starts_with', 'ends_with', 'contains', 'replace', 'replacen', + 'parse', 'try_into', 'try_from', + # Type conversions + 'as_bytes', 'as_str', 'as_slice', 'as_ptr', 'as_mut_ptr', + 'to_vec', 'to_lowercase', 'to_uppercase', 'to_ascii_lowercase', 'to_ascii_uppercase', + # Memory/allocation + 'drop', 'forget', 'take', 'replace', 'swap', 'mem', + 'box', 'Box', 'Rc', 'Arc', 'Cell', 'RefCell', 'Mutex', 'RwLock', + # Common trait methods + 'default', 'Default', 'eq', 'ne', 'lt', 'le', 'gt', 'ge', 'cmp', 'partial_cmp', + 'hash', 'fmt', 'display', 'debug', + # Async + 'await', 'poll', 'spawn', 'block_on', + # Result/Option specific + 'is_some', 'is_none', 'is_ok', 'is_err', + 'transpose', 'flatten', + # Logging (common crates) + 'info', 'warn', 'error', 'debug', 'trace', 'log', + # Testing + 'assert', 'assert_eq', 'assert_ne', +} + + +class CallGraphBuilder: + """ + Build bidirectional call graphs from extracted Rust function data. + + This is Stage 3 of the Rust parser pipeline. + """ + + def __init__(self, extractor_output: Dict, options: Optional[Dict] = None): + options = options or {} + + self.functions = extractor_output.get('functions', {}) + self.classes = extractor_output.get('classes', {}) # impl blocks + self.imports = extractor_output.get('imports', {}) + self.repo_path = extractor_output.get('repository', '') + + self.max_depth = options.get('max_depth', 3) + + # Call graphs + self.call_graph: Dict[str, List[str]] = {} + self.reverse_call_graph: Dict[str, List[str]] = {} + + # Indexes for faster lookup + self.functions_by_name: Dict[str, List[str]] = {} + self.functions_by_file: Dict[str, List[str]] = {} + self.methods_by_impl: Dict[str, List[str]] = {} + + self._build_indexes() + + # Parser for re-parsing function bodies + self.rust_parser = Parser(RUST_LANGUAGE) + + def _build_indexes(self) -> None: + """Build lookup indexes for faster resolution.""" + for func_id, func_data in self.functions.items(): + name = func_data.get('name', '') + if name: + if name not in self.functions_by_name: + self.functions_by_name[name] = [] + self.functions_by_name[name].append(func_id) + + file_path = func_data.get('file_path', '') + if file_path: + if file_path not in self.functions_by_file: + self.functions_by_file[file_path] = [] + self.functions_by_file[file_path].append(func_id) + + class_name = func_data.get('class_name') # impl block name + if class_name: + impl_key = f"{file_path}:{class_name}" + if impl_key not in self.methods_by_impl: + self.methods_by_impl[impl_key] = [] + self.methods_by_impl[impl_key].append(func_id) + + def _is_builtin(self, name: str) -> bool: + """Check if name is a Rust builtin or common method.""" + return name in RUST_BUILTINS + + def _extract_calls_from_code(self, code: str, caller_id: str) -> Set[str]: + """Extract function call references from code using tree-sitter.""" + calls = set() + caller_file = caller_id.split(':')[0] + caller_func = self.functions.get(caller_id, {}) + caller_impl = caller_func.get('class_name') + + code_bytes = code.encode('utf-8', errors='replace') + try: + tree = self.rust_parser.parse(code_bytes) + except Exception: + return self._extract_calls_regex(code, caller_id) + + stack = [tree.root_node] + while stack: + node = stack.pop() + if node.type == 'call_expression': + resolved = self._resolve_call_node(node, code_bytes, caller_file, caller_impl) + if resolved: + calls.add(resolved) + elif node.type == 'method_call_expression': + resolved = self._resolve_method_call(node, code_bytes, caller_file, caller_impl) + if resolved: + calls.add(resolved) + stack.extend(reversed(node.children)) + + return calls + + def _resolve_call_node(self, node, source: bytes, caller_file: str, + caller_impl: Optional[str]) -> Optional[str]: + """Resolve a tree-sitter call_expression node to a function ID. + + Handles: + - foo() - simple function call + - Type::method() - associated function call + - self.method() - method call on self (handled by method_call_expression) + """ + function_node = node.child_by_field_name('function') + if function_node is None: + return None + + func_text = source[function_node.start_byte:function_node.end_byte].decode('utf-8', errors='replace') + + # Skip macros (end with !) + if func_text.rstrip().endswith('!'): + return None + + # Type::method() pattern + if '::' in func_text: + parts = func_text.split('::') + if len(parts) >= 2: + type_name = parts[-2] + method_name = parts[-1] + + if self._is_builtin(method_name): + return None + + return self._resolve_associated_call(type_name, method_name, caller_file) + + # Simple function call: foo() + func_name = func_text.strip() + if self._is_builtin(func_name): + return None + + return self._resolve_simple_call(func_name, caller_file, caller_impl) + + def _resolve_method_call(self, node, source: bytes, caller_file: str, + caller_impl: Optional[str]) -> Optional[str]: + """Resolve method_call_expression: receiver.method()""" + # Get method name + method_node = node.child_by_field_name('name') + if method_node is None: + return None + + method_name = source[method_node.start_byte:method_node.end_byte].decode('utf-8', errors='replace') + + if self._is_builtin(method_name): + return None + + # Get receiver + receiver_node = node.child_by_field_name('value') + if receiver_node: + receiver_text = source[receiver_node.start_byte:receiver_node.end_byte].decode('utf-8', errors='replace') + + # self.method() - same impl block + if receiver_text == 'self' and caller_impl: + return self._resolve_self_call(method_name, caller_file, caller_impl) + + # Self::method() handled by call_expression + + # Can't resolve receiver type statically, try unique name match + candidates = self.functions_by_name.get(method_name, []) + # Prefer methods over standalone functions + method_candidates = [c for c in candidates if self.functions.get(c, {}).get('class_name')] + if len(method_candidates) == 1: + return method_candidates[0] + + return None + + def _resolve_simple_call(self, func_name: str, caller_file: str, + caller_impl: Optional[str]) -> Optional[str]: + """Resolve a simple function call to a function ID.""" + # 1. Check same impl block first (implicit self for associated functions) + if caller_impl: + result = self._resolve_self_call(func_name, caller_file, caller_impl) + if result: + return result + + # 2. Check same file (module-level functions) + same_file_funcs = self.functions_by_file.get(caller_file, []) + for func_id in same_file_funcs: + func_data = self.functions.get(func_id, {}) + if func_data.get('name') == func_name and not func_data.get('class_name'): + return func_id + + # 3. Check use-imported items + file_imports = self.imports.get(caller_file, {}) + if func_name in file_imports: + # Try to find the actual function + for func_id in self.functions_by_name.get(func_name, []): + return func_id + + # 4. Unique name match across files (standalone functions only) + candidates = self.functions_by_name.get(func_name, []) + standalone = [c for c in candidates if not self.functions.get(c, {}).get('class_name')] + if len(standalone) == 1: + return standalone[0] + + return None + + def _resolve_self_call(self, method_name: str, caller_file: str, + caller_impl: str) -> Optional[str]: + """Resolve a self.method() or Self::method() call within an impl block.""" + impl_key = f"{caller_file}:{caller_impl}" + impl_methods = self.methods_by_impl.get(impl_key, []) + + for func_id in impl_methods: + func_data = self.functions.get(func_id, {}) + if func_data.get('name') == method_name: + return func_id + + return None + + def _resolve_associated_call(self, type_name: str, method_name: str, + caller_file: str) -> Optional[str]: + """Resolve a Type::method() associated function call.""" + # Check same file first + impl_key = f"{caller_file}:{type_name}" + if impl_key in self.methods_by_impl: + for func_id in self.methods_by_impl[impl_key]: + func_data = self.functions.get(func_id, {}) + if func_data.get('name') == method_name: + return func_id + + # Check all files for the type + for key, func_ids in self.methods_by_impl.items(): + if key.endswith(f":{type_name}"): + for func_id in func_ids: + func_data = self.functions.get(func_id, {}) + if func_data.get('name') == method_name: + return func_id + + return None + + def _extract_calls_regex(self, code: str, caller_id: str) -> Set[str]: + """Fallback regex-based call extraction for unparseable code.""" + calls = set() + caller_file = caller_id.split(':')[0] + + # Match function calls: name( + pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(]' + for match in re.finditer(pattern, code): + func_name = match.group(1) + # Skip Rust keywords + if func_name in ('if', 'else', 'match', 'while', 'for', 'loop', + 'fn', 'struct', 'enum', 'impl', 'trait', 'mod', + 'use', 'pub', 'let', 'mut', 'const', 'static', + 'return', 'break', 'continue', 'async', 'await', + 'where', 'type', 'unsafe', 'extern', 'crate', 'self', 'Self'): + continue + if not self._is_builtin(func_name): + resolved = self._resolve_simple_call(func_name, caller_file, None) + if resolved: + calls.add(resolved) + + return calls + + def build_call_graph(self) -> None: + """Build the complete call graph for all functions.""" + for func_id, func_data in self.functions.items(): + code = func_data.get('code', '') + if not code: + self.call_graph[func_id] = [] + continue + + calls = self._extract_calls_from_code(code, func_id) + + # Filter to valid function IDs (must exist, not self-calls) + valid_calls = [c for c in calls if c in self.functions and c != func_id] + self.call_graph[func_id] = valid_calls + + # Build reverse graph + for called_id in valid_calls: + if called_id not in self.reverse_call_graph: + self.reverse_call_graph[called_id] = [] + if func_id not in self.reverse_call_graph[called_id]: + self.reverse_call_graph[called_id].append(func_id) + + def get_dependencies(self, func_id: str, depth: Optional[int] = None) -> List[str]: + """Get all dependencies (callees) for a function up to max depth.""" + max_d = depth if depth is not None else self.max_depth + dependencies = [] + visited = {func_id} + queue = [(func_id, 0)] + + while queue: + current_id, current_depth = queue.pop(0) + + if current_depth >= max_d: + continue + + calls = self.call_graph.get(current_id, []) + for called_id in calls: + if called_id not in visited: + visited.add(called_id) + dependencies.append(called_id) + queue.append((called_id, current_depth + 1)) + + return dependencies + + def get_callers(self, func_id: str, depth: Optional[int] = None) -> List[str]: + """Get all callers for a function up to max depth.""" + max_d = depth if depth is not None else self.max_depth + callers = [] + visited = {func_id} + queue = [(func_id, 0)] + + while queue: + current_id, current_depth = queue.pop(0) + + if current_depth >= max_d: + continue + + caller_ids = self.reverse_call_graph.get(current_id, []) + for caller_id in caller_ids: + if caller_id not in visited: + visited.add(caller_id) + callers.append(caller_id) + queue.append((caller_id, current_depth + 1)) + + return callers + + def get_statistics(self) -> Dict: + """Calculate call graph statistics.""" + total_edges = sum(len(calls) for calls in self.call_graph.values()) + num_funcs = len(self.functions) + + out_degrees = [len(self.call_graph.get(f, [])) for f in self.functions] + in_degrees = [len(self.reverse_call_graph.get(f, [])) for f in self.functions] + + isolated = sum(1 for f in self.functions + if len(self.call_graph.get(f, [])) == 0 + and len(self.reverse_call_graph.get(f, [])) == 0) + + return { + 'total_functions': num_funcs, + 'total_edges': total_edges, + 'avg_out_degree': round(total_edges / num_funcs, 2) if num_funcs > 0 else 0, + 'avg_in_degree': round(total_edges / num_funcs, 2) if num_funcs > 0 else 0, + 'max_out_degree': max(out_degrees) if out_degrees else 0, + 'max_in_degree': max(in_degrees) if in_degrees else 0, + 'isolated_functions': isolated, + } + + def export(self) -> Dict: + """Export the call graph data.""" + return { + 'repository': self.repo_path, + 'functions': self.functions, + 'classes': self.classes, + 'imports': self.imports, + 'call_graph': self.call_graph, + 'reverse_call_graph': self.reverse_call_graph, + 'statistics': self.get_statistics(), + } + + +def main(): + """Command line interface.""" + import argparse + + parser = argparse.ArgumentParser( + description='Build call graphs from extracted Rust function data', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=''' +Examples: + python call_graph_builder.py functions.json + python call_graph_builder.py functions.json --output call_graph.json + python call_graph_builder.py functions.json --depth 5 + ''' + ) + + parser.add_argument('input_file', help='Function extractor output JSON file') + parser.add_argument('--output', '-o', help='Output file (default: stdout)') + parser.add_argument('--depth', '-d', type=int, default=3, + help='Max dependency resolution depth (default: 3)') + + args = parser.parse_args() + + try: + with open(args.input_file) as f: + extractor_output = json.load(f) + + print(f"Processing {len(extractor_output.get('functions', {}))} functions...", file=sys.stderr) + + builder = CallGraphBuilder(extractor_output, {'max_depth': args.depth}) + builder.build_call_graph() + + result = builder.export() + stats = result['statistics'] + + print(f"Call graph built:", file=sys.stderr) + print(f" Total functions: {stats['total_functions']}", file=sys.stderr) + print(f" Total edges: {stats['total_edges']}", file=sys.stderr) + print(f" Avg out-degree: {stats['avg_out_degree']}", file=sys.stderr) + print(f" Max out-degree: {stats['max_out_degree']}", file=sys.stderr) + print(f" Isolated functions: {stats['isolated_functions']}", file=sys.stderr) + + output = json.dumps(result, indent=2) + + if args.output: + with open(args.output, 'w') as f: + f.write(output) + print(f"Output written to: {args.output}", file=sys.stderr) + else: + print(output) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/libs/openant-core/parsers/rust/function_extractor.py b/libs/openant-core/parsers/rust/function_extractor.py new file mode 100644 index 0000000..c23d850 --- /dev/null +++ b/libs/openant-core/parsers/rust/function_extractor.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +""" +Function Extractor for Rust Codebases + +Extracts ALL functions, methods, and impl blocks from Rust source files using tree-sitter. +This is Phase 2 of the Rust parser - function inventory. + +Usage: + python function_extractor.py [--output ] [--scan-file ] + +Output (JSON): + { + "repository": "/path/to/repo", + "extraction_time": "2025-12-30T...", + "functions": { + "src/main.rs:main": { + "name": "main", + "qualified_name": "main", + "file_path": "src/main.rs", + "start_line": 10, + "end_line": 25, + "code": "fn main() {...}", + "class_name": null, + "module_name": null, + "parameters": [], + "is_public": true, + "is_async": false, + "unit_type": "function" + } + }, + "classes": { ... }, + "imports": { ... }, + "statistics": { ... } + } +""" + +import json +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +import tree_sitter_rust as ts_rust +from tree_sitter import Language, Parser + + +RUST_LANGUAGE = Language(ts_rust.language()) + + +class FunctionExtractor: + """ + Extract all functions and impl blocks from Rust source files using tree-sitter. + + This is Stage 2 of the Rust parser pipeline. + """ + + def __init__(self, repo_path: str): + self.repo_path = Path(repo_path).resolve() + self.functions: Dict[str, Dict] = {} + self.classes: Dict[str, Dict] = {} # impl blocks (struct/trait impls) + self.imports: Dict[str, Dict[str, str]] = {} # use statements + + self.parser = Parser(RUST_LANGUAGE) + + self.file_cache: Dict[str, bytes] = {} + + self.stats = { + 'total_functions': 0, + 'total_classes': 0, # impl blocks + 'total_methods': 0, + 'standalone_functions': 0, + 'async_functions': 0, + 'public_functions': 0, + 'files_processed': 0, + 'files_with_errors': 0, + 'by_type': {}, + } + + def read_file(self, file_path: Path) -> bytes: + """Read and cache file contents as bytes (tree-sitter needs bytes).""" + path_str = str(file_path) + if path_str not in self.file_cache: + try: + self.file_cache[path_str] = file_path.read_bytes() + except Exception as e: + print(f"Warning: Cannot read {file_path}: {e}", file=sys.stderr) + self.file_cache[path_str] = b"" + return self.file_cache[path_str] + + def _node_text(self, node, source: bytes) -> str: + """Extract text from a tree-sitter node.""" + return source[node.start_byte:node.end_byte].decode('utf-8', errors='replace') + + def _get_function_name(self, node, source: bytes) -> Optional[str]: + """Extract function name from a function_item node.""" + name_node = node.child_by_field_name('name') + if name_node: + return self._node_text(name_node, source) + # Fallback: search for identifier child + for child in node.children: + if child.type == 'identifier': + return self._node_text(child, source) + return None + + def _get_parameters(self, node, source: bytes) -> List[str]: + """Extract parameters from a function node.""" + params = [] + params_node = node.child_by_field_name('parameters') + if params_node is None: + return params + + for child in params_node.children: + if child.type == 'parameter': + # Get parameter pattern (name) + pattern_node = child.child_by_field_name('pattern') + if pattern_node: + param_name = self._node_text(pattern_node, source) + params.append(param_name) + elif child.type == 'self_parameter': + # &self, &mut self, self + params.append(self._node_text(child, source)) + + return params + + def _is_public(self, node, source: bytes) -> bool: + """Check if a function has pub visibility.""" + # Look for visibility_modifier as first child or sibling + for child in node.children: + if child.type == 'visibility_modifier': + return True + return False + + def _is_async(self, node, source: bytes) -> bool: + """Check if a function is async.""" + code = self._node_text(node, source) + return code.strip().startswith('async ') or 'async fn' in code[:50] + + def _has_test_attribute(self, node, source: bytes) -> bool: + """Check if a function has #[test] or #[cfg(test)] attribute.""" + # Look for attribute nodes before the function + parent = node.parent + if parent is None: + return False + + # Check siblings before this node for attributes + found_self = False + for sibling in parent.children: + if sibling.id == node.id: + found_self = True + break + if sibling.type == 'attribute_item': + attr_text = self._node_text(sibling, source) + if '#[test]' in attr_text or '#[cfg(test)]' in attr_text: + return True + + # Also check if inside a #[cfg(test)] module + current = parent + while current: + if current.type == 'mod_item': + # Check for #[cfg(test)] attribute on the module + for child in current.children: + if child.type == 'attribute_item': + attr_text = self._node_text(child, source) + if '#[cfg(test)]' in attr_text: + return True + current = current.parent + + return False + + def _has_route_attribute(self, node, source: bytes) -> bool: + """Check if a function has route handler attributes (actix, axum, rocket).""" + parent = node.parent + if parent is None: + return False + + route_patterns = ['#[get', '#[post', '#[put', '#[delete', '#[patch', + '#[route', '#[handler', '#[endpoint'] + + for sibling in parent.children: + if sibling.id == node.id: + break + if sibling.type == 'attribute_item': + attr_text = self._node_text(sibling, source).lower() + for pattern in route_patterns: + if pattern in attr_text: + return True + + return False + + def _has_main_attribute(self, node, source: bytes) -> bool: + """Check if function has async runtime main attributes.""" + parent = node.parent + if parent is None: + return False + + main_patterns = ['#[tokio::main', '#[async_std::main', '#[actix_web::main', + '#[actix_rt::main', '#[rocket::main', '#[rocket::launch'] + + for sibling in parent.children: + if sibling.id == node.id: + break + if sibling.type == 'attribute_item': + attr_text = self._node_text(sibling, source) + for pattern in main_patterns: + if pattern in attr_text: + return True + + return False + + def _classify_function(self, func_name: str, impl_name: Optional[str], + module_name: Optional[str], is_public: bool, + file_path: str, has_self: bool, + is_test: bool, is_route: bool, is_main: bool) -> str: + """Classify a function by its type/purpose.""" + path_lower = file_path.lower() + + # Test functions + if is_test or func_name.startswith('test_'): + return 'test' + + # Constructor patterns + if func_name in ('new', 'default', 'create', 'init', 'build'): + return 'constructor' + + # Route handlers + if is_route: + return 'route_handler' + + # Entry points + if func_name == 'main' or is_main: + return 'entry_point' + + # Method in impl block with self + if impl_name and has_self: + return 'method' + + # Associated function (no self, but in impl block) + if impl_name and not has_self: + return 'associated_function' + + # Standalone function + return 'function' + + def _extract_imports(self, tree, source: bytes) -> Dict[str, str]: + """Extract use statements from a file.""" + imports = {} + stack = [tree.root_node] + + while stack: + node = stack.pop() + + if node.type == 'use_declaration': + # Extract the use path + use_text = self._node_text(node, source) + # Parse use statement: use foo::bar::Baz; + # Store as import_name -> 'use' + if '::' in use_text: + # Get the last part as the imported name + parts = use_text.replace('use ', '').replace(';', '').strip() + if '{' in parts: + # use foo::{bar, baz}; + base = parts.split('{')[0].rstrip('::') + items = parts.split('{')[1].rstrip('}').split(',') + for item in items: + item = item.strip() + if item: + imports[item] = f"use {base}::{item}" + else: + # use foo::bar::Baz; + last_part = parts.split('::')[-1] + imports[last_part] = use_text.strip() + + stack.extend(reversed(node.children)) + + return imports + + def _extract_functions_from_tree(self, tree, source: bytes, file_path: Path, + relative_path: str) -> None: + """Extract all function definitions from a parsed tree.""" + # Stack-based traversal: (node, impl_name, module_name) + stack = [(tree.root_node, None, None)] + + while stack: + node, impl_name, module_name = stack.pop() + + if node.type == 'function_item': + self._process_function_node( + node, source, relative_path, impl_name, module_name + ) + + elif node.type == 'impl_item': + # Extract impl target (struct/trait name) + type_node = node.child_by_field_name('type') + new_impl_name = self._node_text(type_node, source) if type_node else None + + # Check for trait impl: impl Trait for Type + trait_node = node.child_by_field_name('trait') + if trait_node: + trait_name = self._node_text(trait_node, source) + new_impl_name = f"{trait_name} for {new_impl_name}" if new_impl_name else trait_name + + if new_impl_name: + impl_id = f"{relative_path}:{new_impl_name}" + body_node = node.child_by_field_name('body') + methods = [] + if body_node: + for child in body_node.children: + if child.type == 'function_item': + mname = self._get_function_name(child, source) + if mname: + methods.append(mname) + + self.classes[impl_id] = { + 'name': new_impl_name, + 'file_path': relative_path, + 'start_line': node.start_point[0] + 1, + 'end_line': node.end_point[0] + 1, + 'methods': methods, + 'module_name': module_name, + } + self.stats['total_classes'] += 1 + + # Recurse into impl body with updated impl_name + body_node = node.child_by_field_name('body') + if body_node: + for child in reversed(body_node.children): + stack.append((child, new_impl_name, module_name)) + continue # Don't walk children again + + elif node.type == 'mod_item': + # Extract module name + name_node = node.child_by_field_name('name') + new_module_name = self._node_text(name_node, source) if name_node else module_name + + # Recurse into module body + body_node = node.child_by_field_name('body') + if body_node: + for child in reversed(body_node.children): + stack.append((child, impl_name, new_module_name)) + continue # Don't walk children again + + else: + for child in reversed(node.children): + stack.append((child, impl_name, module_name)) + + def _process_function_node(self, node, source: bytes, relative_path: str, + impl_name: Optional[str], module_name: Optional[str]) -> None: + """Process a single function_item node.""" + name = self._get_function_name(node, source) + if not name: + return + + code = self._node_text(node, source) + start_line = node.start_point[0] + 1 # tree-sitter is 0-indexed + end_line = node.end_point[0] + 1 + parameters = self._get_parameters(node, source) + + is_public = self._is_public(node, source) + is_async = self._is_async(node, source) + is_test = self._has_test_attribute(node, source) + is_route = self._has_route_attribute(node, source) + is_main = self._has_main_attribute(node, source) + + # Check if method has self parameter + has_self = any('self' in p for p in parameters) + + unit_type = self._classify_function( + name, impl_name, module_name, is_public, relative_path, + has_self, is_test, is_route, is_main + ) + + # Build qualified name and function ID + if impl_name: + qualified_name = f"{impl_name}::{name}" + elif module_name: + qualified_name = f"{module_name}::{name}" + else: + qualified_name = name + + func_id = f"{relative_path}:{qualified_name}" + + func_data = { + 'name': name, + 'qualified_name': qualified_name, + 'file_path': relative_path, + 'start_line': start_line, + 'end_line': end_line, + 'code': code, + 'class_name': impl_name, # In Rust, this is the impl block target + 'module_name': module_name, + 'parameters': parameters, + 'is_public': is_public, + 'is_async': is_async, + 'unit_type': unit_type, + } + + self.functions[func_id] = func_data + self.stats['total_functions'] += 1 + + if impl_name: + self.stats['total_methods'] += 1 + else: + self.stats['standalone_functions'] += 1 + + if is_async: + self.stats['async_functions'] += 1 + + if is_public: + self.stats['public_functions'] += 1 + + self.stats['by_type'][unit_type] = self.stats['by_type'].get(unit_type, 0) + 1 + + def process_file(self, file_path: Path) -> None: + """Process a single Rust file.""" + source = self.read_file(file_path) + if not source: + self.stats['files_with_errors'] += 1 + return + + relative_path = str(file_path.relative_to(self.repo_path)) + + try: + tree = self.parser.parse(source) + except Exception as e: + print(f"Parse error in {file_path}: {e}", file=sys.stderr) + self.stats['files_with_errors'] += 1 + return + + self.stats['files_processed'] += 1 + + # Extract imports + self.imports[relative_path] = self._extract_imports(tree, source) + + # Extract functions + self._extract_functions_from_tree(tree, source, file_path, relative_path) + + def extract_from_scan(self, scan_result: Dict) -> Dict: + """Extract functions from files listed in a scan result.""" + for file_info in scan_result.get('files', []): + file_path = self.repo_path / file_info['path'] + self.process_file(file_path) + + return self.export() + + def extract_all(self, files: Optional[List[str]] = None) -> Dict: + """Extract functions from all Rust files or a specific list.""" + if files: + for file_rel_path in files: + file_path = self.repo_path / file_rel_path + if file_path.exists(): + self.process_file(file_path) + else: + for file_path in self.repo_path.rglob('*.rs'): + path_str = str(file_path) + if any(excl in path_str for excl in ['.git', 'target', '.cargo', 'vendor']): + continue + self.process_file(file_path) + + return self.export() + + def export(self) -> Dict: + """Export extraction results.""" + return { + 'repository': str(self.repo_path), + 'extraction_time': datetime.now().isoformat(), + 'functions': self.functions, + 'classes': self.classes, + 'imports': self.imports, + 'statistics': self.stats, + } + + +def main(): + """Command line interface.""" + import argparse + + parser = argparse.ArgumentParser( + description='Extract all functions and impl blocks from a Rust repository', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=''' +Examples: + python function_extractor.py /path/to/repo + python function_extractor.py /path/to/repo --output functions.json + python function_extractor.py /path/to/repo --scan-file scan_results.json + ''' + ) + + parser.add_argument('repo_path', help='Path to the repository') + parser.add_argument('--output', '-o', help='Output file (default: stdout)') + parser.add_argument('--scan-file', help='Use file list from repository scanner output') + + args = parser.parse_args() + + try: + extractor = FunctionExtractor(args.repo_path) + + if args.scan_file: + with open(args.scan_file) as f: + scan_result = json.load(f) + result = extractor.extract_from_scan(scan_result) + else: + result = extractor.extract_all() + + output = json.dumps(result, indent=2) + + if args.output: + with open(args.output, 'w') as f: + f.write(output) + print(f"Extraction complete. Results written to: {args.output}", file=sys.stderr) + print(f"Total functions: {result['statistics']['total_functions']}", file=sys.stderr) + print(f" Standalone: {result['statistics']['standalone_functions']}", file=sys.stderr) + print(f" Methods: {result['statistics']['total_methods']}", file=sys.stderr) + print(f" Async: {result['statistics']['async_functions']}", file=sys.stderr) + print(f" Public: {result['statistics']['public_functions']}", file=sys.stderr) + print(f"Total impl blocks: {result['statistics']['total_classes']}", file=sys.stderr) + print(f"Files processed: {result['statistics']['files_processed']}", file=sys.stderr) + if result['statistics']['files_with_errors'] > 0: + print(f"Files with errors: {result['statistics']['files_with_errors']}", file=sys.stderr) + print(f"By type:", file=sys.stderr) + for unit_type, count in sorted(result['statistics']['by_type'].items()): + print(f" {unit_type}: {count}", file=sys.stderr) + else: + print(output) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/libs/openant-core/parsers/rust/repository_scanner.py b/libs/openant-core/parsers/rust/repository_scanner.py new file mode 100644 index 0000000..f81abc2 --- /dev/null +++ b/libs/openant-core/parsers/rust/repository_scanner.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +""" +Repository Scanner for Rust Codebases + +Enumerates ALL Rust source files in a repository for complete coverage. +This is Phase 1 of the Rust parser - file discovery. + +Usage: + python repository_scanner.py [--output ] [--exclude ] + +Output (JSON): + { + "repository": "/path/to/repo", + "scan_time": "2025-12-30T...", + "files": [ + { "path": "relative/path/to/file.rs", "size": 1234 } + ], + "statistics": { + "total_files": 150, + "total_size_bytes": 500000, + "directories_scanned": 25, + "directories_excluded": 10 + } + } +""" + +import json +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Set + + +class RepositoryScanner: + """ + Scan a repository for all Rust source files. + + This is Stage 1 of the Rust parser pipeline. It walks the directory tree, + identifies Rust source files, and collects metadata about each file. + + Key features: + - Excludes common non-source directories (target, .cargo, .git, etc.) + - Optionally skips test files (tests/, *_test.rs, #[test] modules) + - Collects file size statistics for monitoring + + Usage: + scanner = RepositoryScanner('/path/to/repo') + result = scanner.scan() + # result['files'] contains list of {path, size} dicts + + Attributes: + repo_path: Absolute path to the repository root + exclude_patterns: Set of directory names to skip + source_extensions: Set of file extensions to include (default: {'.rs'}) + skip_tests: Whether to exclude test files + """ + + def __init__(self, repo_path: str, options: Optional[Dict] = None): + self.repo_path = Path(repo_path).resolve() + options = options or {} + + # Default exclude patterns for Rust projects + self.exclude_patterns: Set[str] = set(options.get('exclude_patterns', [ + '.git', + 'target', # Rust build output + '.cargo', # Cargo cache + 'vendor', # Vendored dependencies + 'node_modules', # If mixed with JS (e.g., wasm projects) + '.cache', + 'doc', + 'docs', + 'examples', # Usually not production code + 'benches', # Benchmarks + ])) + + # Source file extensions + self.source_extensions: Set[str] = set(options.get('source_extensions', [ + '.rs', + ])) + + # Skip test files by default (can be overridden) + self.skip_tests = options.get('skip_tests', False) + # Patterns that indicate test files/directories + self.test_patterns = { + 'tests/', # Standard test directory + 'test/', # Alternative test directory + '_test.rs', # Test file suffix + 'test_', # Test file prefix (less common in Rust) + '/tests.rs', # Module test file + } + + # Statistics + self.stats = { + 'total_files': 0, + 'total_size_bytes': 0, + 'directories_scanned': 0, + 'directories_excluded': 0, + 'test_files_skipped': 0, + } + + # Results + self.files: List[Dict] = [] + + def should_exclude_directory(self, dir_name: str) -> bool: + """Check if a directory should be excluded.""" + # Exact match + if dir_name in self.exclude_patterns: + return True + if dir_name.startswith('.'): + # Exclude hidden directories + return True + return False + + def is_source_file(self, file_name: str) -> bool: + """Check if a file is a Rust source file.""" + ext = os.path.splitext(file_name)[1].lower() + return ext in self.source_extensions + + def is_test_file(self, relative_path: str) -> bool: + """Check if a file is a test file.""" + path_lower = relative_path.lower() + for pattern in self.test_patterns: + if pattern in path_lower: + return True + return False + + def scan_directory(self, dir_path: Path, relative_path: str = '') -> None: + """Recursively scan a directory.""" + self.stats['directories_scanned'] += 1 + + try: + entries = list(dir_path.iterdir()) + except PermissionError: + print(f"Warning: Cannot read directory {dir_path}: Permission denied", file=sys.stderr) + return + except Exception as e: + print(f"Warning: Cannot read directory {dir_path}: {e}", file=sys.stderr) + return + + for entry in sorted(entries, key=lambda e: e.name): + entry_relative = os.path.join(relative_path, entry.name) if relative_path else entry.name + + if entry.is_dir(): + if self.should_exclude_directory(entry.name): + self.stats['directories_excluded'] += 1 + continue + self.scan_directory(entry, entry_relative) + + elif entry.is_file(): + if not self.is_source_file(entry.name): + continue + + # Skip test files if configured + if self.skip_tests and self.is_test_file(entry_relative): + self.stats['test_files_skipped'] += 1 + continue + + try: + file_size = entry.stat().st_size + except Exception: + file_size = 0 + + self.files.append({ + 'path': entry_relative, + 'size': file_size, + }) + + self.stats['total_files'] += 1 + self.stats['total_size_bytes'] += file_size + + def scan(self) -> Dict: + """Execute the repository scan and return results.""" + if not self.repo_path.exists(): + raise FileNotFoundError(f"Repository path does not exist: {self.repo_path}") + + if not self.repo_path.is_dir(): + raise NotADirectoryError(f"Repository path is not a directory: {self.repo_path}") + + # Reset state + self.files = [] + self.stats = { + 'total_files': 0, + 'total_size_bytes': 0, + 'directories_scanned': 0, + 'directories_excluded': 0, + 'test_files_skipped': 0, + } + + # Run scan + self.scan_directory(self.repo_path) + + # Sort files by path for consistent output + self.files.sort(key=lambda f: f['path']) + + return { + 'repository': str(self.repo_path), + 'scan_time': datetime.now().isoformat(), + 'files': self.files, + 'statistics': self.stats, + } + + +def main(): + """Command line interface.""" + import argparse + + parser = argparse.ArgumentParser( + description='Scan a Rust repository for source files', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=''' +Examples: + python repository_scanner.py /path/to/repo + python repository_scanner.py /path/to/repo --output scan_results.json + python repository_scanner.py /path/to/repo --exclude "custom_dir,another_dir" + python repository_scanner.py /path/to/repo --skip-tests + ''' + ) + + parser.add_argument('repo_path', help='Path to the repository to scan') + parser.add_argument('--output', '-o', help='Output file (default: stdout)') + parser.add_argument('--exclude', help='Comma-separated additional exclude patterns') + parser.add_argument('--skip-tests', action='store_true', help='Skip test files') + + args = parser.parse_args() + + # Build options + options = {} + if args.exclude: + additional_excludes = [p.strip() for p in args.exclude.split(',')] + default_excludes = [ + '.git', 'target', '.cargo', 'vendor', 'node_modules', + '.cache', 'doc', 'docs', 'examples', 'benches', + ] + options['exclude_patterns'] = default_excludes + additional_excludes + + options['skip_tests'] = args.skip_tests + + try: + scanner = RepositoryScanner(args.repo_path, options) + result = scanner.scan() + + output = json.dumps(result, indent=2) + + if args.output: + with open(args.output, 'w') as f: + f.write(output) + print(f"Scan complete. Results written to: {args.output}", file=sys.stderr) + print(f"Total files found: {result['statistics']['total_files']}", file=sys.stderr) + print(f"Total size: {result['statistics']['total_size_bytes']:,} bytes", file=sys.stderr) + print(f"Directories scanned: {result['statistics']['directories_scanned']}", file=sys.stderr) + print(f"Directories excluded: {result['statistics']['directories_excluded']}", file=sys.stderr) + if args.skip_tests: + print(f"Test files skipped: {result['statistics']['test_files_skipped']}", file=sys.stderr) + else: + print(output) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/libs/openant-core/parsers/rust/test_pipeline.py b/libs/openant-core/parsers/rust/test_pipeline.py new file mode 100644 index 0000000..8eb5949 --- /dev/null +++ b/libs/openant-core/parsers/rust/test_pipeline.py @@ -0,0 +1,1034 @@ +#!/usr/bin/env python3 +""" +Rust Parser Pipeline + +Tests the Rust parser pipeline components: +1. RepositoryScanner - Enumerates .rs files +2. FunctionExtractor - Extracts functions via tree-sitter +3. CallGraphBuilder - Builds bidirectional call graphs +4. UnitGenerator - Creates OpenAnt dataset format +5. CodeQL (optional) - Static analysis pre-filter +6. ContextEnhancer (optional) - LLM enhancement using Claude Sonnet + +Usage: + python test_pipeline.py [--output ] [--llm] [--agentic] [--processing-level LEVEL] + +Processing Levels (cumulative filtering): + Level 1: all - Process all units (no filtering) + Level 2: reachable - Process only units reachable from entry points + Level 3: codeql - Process only reachable + CodeQL-flagged units + Level 4: exploitable - Process only reachable + CodeQL-flagged + exploitable units + +Example: + # Static analysis only + python test_pipeline.py /path/to/repo --output /tmp/output + + # With agentic LLM enhancement + python test_pipeline.py /path/to/repo --output /tmp/output --llm --agentic + + # CodeQL pre-filter + agentic classification + python test_pipeline.py /path/to/repo --output /tmp/output --llm --agentic --processing-level codeql + + # Maximum cost savings: only exploitable units + python test_pipeline.py /path/to/repo --output /tmp/output --llm --agentic --processing-level exploitable +""" + +import argparse +import json +import os +import subprocess +import sys +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Set + +# Add parent directory to path for utilities import +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +from utilities.context_enhancer import ContextEnhancer +from utilities.agentic_enhancer import EntryPointDetector, ReachabilityAnalyzer + +# Local imports +from repository_scanner import RepositoryScanner +from function_extractor import FunctionExtractor +from call_graph_builder import CallGraphBuilder +from unit_generator import UnitGenerator + + +class ProcessingLevel(Enum): + """ + Processing level determines which units are processed. + Levels are cumulative - each level includes filters from previous levels. + """ + ALL = "all" + REACHABLE = "reachable" + CODEQL = "codeql" + EXPLOITABLE = "exploitable" + + +class RustPipelineTest: + def __init__( + self, + repo_path: str, + output_dir: str = None, + enable_llm: bool = False, + agentic: bool = False, + processing_level: ProcessingLevel = ProcessingLevel.ALL, + skip_tests: bool = False, + depth: int = 3, + name: str = None + ): + self.repo_path = os.path.abspath(repo_path) + self.output_dir = output_dir or os.path.join(os.path.dirname(__file__), 'test_output') + self.parser_dir = os.path.dirname(os.path.abspath(__file__)) + self.enable_llm = enable_llm + self.agentic = agentic + self.processing_level = processing_level + self.skip_tests = skip_tests + self.depth = depth + self.dataset_name = name + + # Pipeline artifacts + self.scan_results_file = None + self.analyzer_output_file = None + self.dataset_file = None + + # Reachability data + self.entry_points: Set[str] = set() + self.reachable_units: Set[str] = set() + + # CodeQL data + self.codeql_flagged_units: Set[str] = set() + self.codeql_findings: list = [] + + # Results + self.results = { + 'repository': self.repo_path, + 'test_time': datetime.now().isoformat(), + 'processing_level': processing_level.value, + 'stages': {} + } + + def setup(self): + """Create output directory.""" + os.makedirs(self.output_dir, exist_ok=True) + print(f"Output directory: {self.output_dir}") + print() + return True + + def run_parser_pipeline(self) -> bool: + """Run the full Rust parser pipeline (scan, extract, call graph, generate).""" + self.dataset_file = os.path.join(self.output_dir, 'dataset.json') + self.analyzer_output_file = os.path.join(self.output_dir, 'analyzer_output.json') + + print("=" * 60) + print("STAGE: rust_parser_pipeline") + print("=" * 60) + print() + + start_time = datetime.now() + + try: + # Stage 1: Scan + print(" [1/4] Scanning repository for Rust files...") + scanner_options = {'skip_tests': self.skip_tests} + scanner = RepositoryScanner(self.repo_path, scanner_options) + scan_result = scanner.scan() + file_count = scan_result['statistics']['total_files'] + print(f" Found {file_count} files ({scan_result['statistics']['total_size_bytes']:,} bytes)") + + # Save scan results + self.scan_results_file = os.path.join(self.output_dir, 'scan_results.json') + with open(self.scan_results_file, 'w') as f: + json.dump(scan_result, f, indent=2) + + # Stage 2: Extract functions + print(" [2/4] Extracting functions via tree-sitter...") + extractor = FunctionExtractor(self.repo_path) + extract_result = extractor.extract_from_scan(scan_result) + func_count = extract_result['statistics']['total_functions'] + print(f" Extracted {func_count} functions from {extract_result['statistics']['files_processed']} files") + if extract_result['statistics']['files_with_errors'] > 0: + print(f" ({extract_result['statistics']['files_with_errors']} files with errors)") + + # Print type breakdown + by_type = extract_result['statistics'].get('by_type', {}) + if by_type: + print(f" Types: {', '.join(f'{t}={c}' for t, c in sorted(by_type.items()))}") + + # Stage 3: Build call graph + print(" [3/4] Building call graph...") + builder = CallGraphBuilder(extract_result, {'max_depth': self.depth}) + builder.build_call_graph() + graph_result = builder.export() + graph_stats = graph_result['statistics'] + print(f" {graph_stats['total_edges']} edges, avg out-degree: {graph_stats['avg_out_degree']}") + print(f" {graph_stats['isolated_functions']} isolated functions") + + # Save call graph for reachability filter + call_graph_file = os.path.join(self.output_dir, 'call_graph.json') + with open(call_graph_file, 'w') as f: + json.dump(graph_result, f, indent=2) + + # Stage 4: Generate units + print(" [4/4] Generating dataset units...") + gen_options = {'max_depth': self.depth} + if self.dataset_name: + gen_options['dataset_name'] = self.dataset_name + generator = UnitGenerator(graph_result, gen_options) + dataset = generator.generate_units() + unit_count = dataset['statistics']['total_units'] + print(f" Generated {unit_count} units") + print(f" Enhanced: {dataset['statistics']['units_enhanced']}") + print(f" Avg upstream deps: {dataset['statistics']['avg_upstream']}") + + # Write dataset + with open(self.dataset_file, 'w') as f: + json.dump(dataset, f, indent=2) + + # Write analyzer output + analyzer_output = generator.generate_analyzer_output() + with open(self.analyzer_output_file, 'w') as f: + json.dump(analyzer_output, f, indent=2) + + elapsed = (datetime.now() - start_time).total_seconds() + + summary = { + 'total_files': file_count, + 'total_functions': func_count, + 'total_units': unit_count, + 'by_type': by_type, + 'call_graph_edges': graph_stats['total_edges'], + 'avg_out_degree': graph_stats['avg_out_degree'], + } + + result = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': self.dataset_file, + 'summary': summary + } + + print() + print(f" Success ({elapsed:.2f}s)") + print() + + self.results['stages']['rust_parser'] = result + return True + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + result = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + self.results['stages']['rust_parser'] = result + return False + + def apply_reachability_filter(self) -> bool: + """Filter dataset to only include units reachable from entry points.""" + if not self.analyzer_output_file or not os.path.exists(self.analyzer_output_file): + print("No analyzer output for reachability filtering") + return False + + if not self.dataset_file or not os.path.exists(self.dataset_file): + print("No dataset to filter") + return False + + print("=" * 60) + print("STAGE: reachability_filter (static analysis)") + print("=" * 60) + print() + + start_time = datetime.now() + + try: + with open(self.analyzer_output_file, 'r') as f: + analyzer = json.load(f) + + functions = analyzer.get("functions", {}) + + # Normalize for EntryPointDetector (expects camelCase) + normalized_functions = {} + for func_id, func_data in functions.items(): + normalized_functions[func_id] = { + 'name': func_data.get('name', ''), + 'unitType': func_data.get('unitType', func_data.get('unit_type', 'function')), + 'code': func_data.get('code', ''), + 'filePath': func_data.get('filePath', func_data.get('file_path', '')), + 'startLine': func_data.get('startLine', func_data.get('start_line', 0)), + 'endLine': func_data.get('endLine', func_data.get('end_line', 0)), + 'isExported': func_data.get('isExported', func_data.get('is_public', True)), + 'isPublic': func_data.get('isPublic', func_data.get('is_public', False)), + 'isAsync': func_data.get('isAsync', func_data.get('is_async', False)), + } + + # Build call graph from dataset unit metadata + with open(self.dataset_file, 'r') as f: + dataset = json.load(f) + + call_graph = {} + reverse_call_graph = {} + for unit in dataset.get('units', []): + unit_id = unit.get('id') + metadata = unit.get('metadata', {}) + direct_calls = metadata.get('direct_calls', metadata.get('directCalls', [])) + direct_callers = metadata.get('direct_callers', metadata.get('directCallers', [])) + + if direct_calls: + call_graph[unit_id] = direct_calls + if direct_callers: + reverse_call_graph[unit_id] = direct_callers + + # Detect entry points + detector = EntryPointDetector(normalized_functions, call_graph) + self.entry_points = detector.detect_entry_points() + + # Build reachability + reachability = ReachabilityAnalyzer( + functions=normalized_functions, + reverse_call_graph=reverse_call_graph, + entry_points=self.entry_points + ) + self.reachable_units = reachability.get_all_reachable() + + units = dataset.get("units", []) + original_count = len(units) + + filtered_units = [] + for u in units: + unit_id = u.get("id", "") + if unit_id in self.reachable_units: + u["reachable"] = True + u["is_entry_point"] = unit_id in self.entry_points + if unit_id in self.entry_points: + u["entry_point_reason"] = detector.get_entry_point_reason(unit_id) + filtered_units.append(u) + + dataset["units"] = filtered_units + dataset["metadata"] = dataset.get("metadata", {}) + dataset["metadata"]["reachability_filter"] = { + "original_units": original_count, + "entry_points": len(self.entry_points), + "reachable_units": len(filtered_units), + "filtered_out": original_count - len(filtered_units), + "reduction_percentage": round((1 - len(filtered_units) / original_count) * 100, 1) if original_count > 0 else 0 + } + + with open(self.dataset_file, 'w') as f: + json.dump(dataset, f, indent=2) + + elapsed = (datetime.now() - start_time).total_seconds() + + summary = { + 'original_units': original_count, + 'entry_points': len(self.entry_points), + 'reachable_units': len(filtered_units), + 'reduction_percentage': dataset["metadata"]["reachability_filter"]["reduction_percentage"] + } + + result = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': self.dataset_file, + 'summary': summary + } + + print(f" Success ({elapsed:.2f}s)") + print(f" Entry points detected: {len(self.entry_points)}") + print(f" Units: {original_count} -> {len(filtered_units)} ({summary['reduction_percentage']}% reduction)") + print() + + self.results['stages']['reachability_filter'] = result + return True + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + result = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + self.results['stages']['reachability_filter'] = result + return False + + def run_codeql_analysis(self) -> bool: + """Run CodeQL analysis on the repository.""" + print("=" * 60) + print("STAGE: codeql_analysis") + print("=" * 60) + print() + + start_time = datetime.now() + + language = "rust" + print(f"Language: {language}") + + codeql_db_path = os.path.join(self.output_dir, 'codeql-db') + sarif_output = os.path.join(self.output_dir, 'codeql-results.sarif') + + try: + # Step 1: Create CodeQL database + print("Creating CodeQL database...") + create_db_cmd = [ + 'codeql', 'database', 'create', + codeql_db_path, + f'--language={language}', + f'--source-root={self.repo_path}', + '--overwrite' + ] + + result = subprocess.run( + create_db_cmd, + capture_output=True, + text=True, + timeout=600 + ) + + if result.returncode != 0: + print(f" CodeQL database creation failed") + print(f" stderr: {result.stderr[:500] if result.stderr else 'none'}") + elapsed = (datetime.now() - start_time).total_seconds() + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': 'Database creation failed', + 'stderr': result.stderr + } + return False + + print(" Database created successfully") + + # Step 2: Run security queries + print("Running security queries...") + analyze_cmd = [ + 'codeql', 'database', 'analyze', + codeql_db_path, + '--format=sarif-latest', + f'--output={sarif_output}', + f'codeql/{language}-queries:codeql-suites/{language}-security-extended.qls' + ] + + result = subprocess.run( + analyze_cmd, + capture_output=True, + text=True, + timeout=1800 + ) + + if result.returncode != 0: + print(f" CodeQL analysis failed") + print(f" stderr: {result.stderr[:500] if result.stderr else 'none'}") + elapsed = (datetime.now() - start_time).total_seconds() + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': 'Analysis failed', + 'stderr': result.stderr + } + return False + + print(" Analysis completed") + + # Step 3: Parse SARIF output + print("Parsing results...") + if not os.path.exists(sarif_output): + print(" SARIF output not found") + elapsed = (datetime.now() - start_time).total_seconds() + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': 'SARIF output not found' + } + return False + + with open(sarif_output, 'r') as f: + sarif_data = json.load(f) + + self.codeql_findings = [] + + for run in sarif_data.get('runs', []): + for result_item in run.get('results', []): + rule_id = result_item.get('ruleId', 'unknown') + message = result_item.get('message', {}).get('text', '') + level = result_item.get('level', 'warning') + + for location in result_item.get('locations', []): + physical = location.get('physicalLocation', {}) + artifact = physical.get('artifactLocation', {}) + uri = artifact.get('uri', '') + region = physical.get('region', {}) + finding_start = region.get('startLine', 0) + finding_end = region.get('endLine', finding_start) + + finding = { + 'rule_id': rule_id, + 'message': message, + 'level': level, + 'file': uri, + 'start_line': finding_start, + 'end_line': finding_end + } + self.codeql_findings.append(finding) + + elapsed = (datetime.now() - start_time).total_seconds() + + summary = { + 'total_findings': len(self.codeql_findings), + 'unique_files': len(set(f['file'] for f in self.codeql_findings)), + 'by_level': {}, + 'by_rule': {} + } + + for finding in self.codeql_findings: + level = finding['level'] + rule = finding['rule_id'] + summary['by_level'][level] = summary['by_level'].get(level, 0) + 1 + summary['by_rule'][rule] = summary['by_rule'].get(rule, 0) + 1 + + result_data = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': sarif_output, + 'summary': summary + } + + print(f" Success ({elapsed:.2f}s)") + print(f" Total findings: {len(self.codeql_findings)}") + print(f" Unique files: {summary['unique_files']}") + if summary['by_level']: + print(f" By level: {summary['by_level']}") + print() + + self.results['stages']['codeql_analysis'] = result_data + return True + + except FileNotFoundError: + elapsed = (datetime.now() - start_time).total_seconds() + print(" CodeQL not found. Please install CodeQL CLI.") + print(" See: https://docs.github.com/en/code-security/codeql-cli") + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': 'CodeQL CLI not installed' + } + return False + + except subprocess.TimeoutExpired: + elapsed = (datetime.now() - start_time).total_seconds() + print(" CodeQL analysis timed out") + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': 'Timeout' + } + return False + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + self.results['stages']['codeql_analysis'] = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + return False + + def apply_codeql_filter(self) -> bool: + """Filter dataset to only include units flagged by CodeQL.""" + if not self.dataset_file or not os.path.exists(self.dataset_file): + print("No dataset to filter") + return False + + if not self.codeql_findings: + print("No CodeQL findings to filter by") + return False + + print("=" * 60) + print("STAGE: codeql_filter") + print("=" * 60) + print() + + start_time = datetime.now() + + try: + with open(self.dataset_file, 'r') as f: + dataset = json.load(f) + + # Build mapping of file -> [(start_line, end_line, func_id)] + file_functions = {} + for unit in dataset.get('units', []): + unit_id = unit.get('id', '') + origin = unit.get('code', {}).get('primary_origin', {}) + file_path = origin.get('file_path', '') + unit_start = origin.get('start_line', 0) + unit_end = origin.get('end_line', unit_start) + + if file_path: + if file_path not in file_functions: + file_functions[file_path] = [] + file_functions[file_path].append((unit_start, unit_end, unit_id)) + + # Map CodeQL findings to function units + for finding in self.codeql_findings: + file_uri = finding['file'] + finding_start = finding['start_line'] + finding_end = finding['end_line'] + + matched_file = None + for file_path in file_functions.keys(): + if file_path.endswith(file_uri) or file_uri.endswith(file_path) or file_path == file_uri: + matched_file = file_path + break + + if matched_file: + for start, end, func_id in file_functions[matched_file]: + if start <= finding_start <= end or start <= finding_end <= end: + self.codeql_flagged_units.add(func_id) + + units = dataset.get("units", []) + original_count = len(units) + + filtered_units = [u for u in units if u.get("id") in self.codeql_flagged_units] + + dataset["units"] = filtered_units + dataset["metadata"] = dataset.get("metadata", {}) + dataset["metadata"]["codeql_filter"] = { + "original_units": original_count, + "codeql_findings": len(self.codeql_findings), + "flagged_units": len(self.codeql_flagged_units), + "filtered_units": len(filtered_units), + "filtered_out": original_count - len(filtered_units), + "reduction_percentage": round((1 - len(filtered_units) / original_count) * 100, 1) if original_count > 0 else 0 + } + + with open(self.dataset_file, 'w') as f: + json.dump(dataset, f, indent=2) + + elapsed = (datetime.now() - start_time).total_seconds() + + summary = { + 'original_units': original_count, + 'codeql_findings': len(self.codeql_findings), + 'flagged_units': len(self.codeql_flagged_units), + 'filtered_units': len(filtered_units), + 'reduction_percentage': dataset["metadata"]["codeql_filter"]["reduction_percentage"] + } + + result = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': self.dataset_file, + 'summary': summary + } + + print(f" Success ({elapsed:.2f}s)") + print(f" CodeQL findings: {len(self.codeql_findings)}") + print(f" Flagged function units: {len(self.codeql_flagged_units)}") + print(f" Units: {original_count} -> {len(filtered_units)} ({summary['reduction_percentage']}% reduction)") + print() + + self.results['stages']['codeql_filter'] = result + return True + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + result = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + self.results['stages']['codeql_filter'] = result + return False + + def run_context_enhancer(self) -> bool: + """Stage 4 (optional): Enhance dataset with LLM context.""" + if not self.dataset_file or not os.path.exists(self.dataset_file): + print("No dataset to enhance") + return False + + mode = "agentic" if self.agentic else "single-shot" + print("=" * 60) + print(f"STAGE: context_enhancer (Rust, {mode} mode)") + print("=" * 60) + print() + + start_time = datetime.now() + + try: + with open(self.dataset_file, 'r') as f: + dataset = json.load(f) + + enhancer = ContextEnhancer() + + if self.agentic: + enhanced = enhancer.enhance_dataset_agentic( + dataset, + analyzer_output_path=self.analyzer_output_file, + repo_path=self.repo_path, + batch_size=5, + verbose=False + ) + agentic_stats = enhanced.get('metadata', {}).get('agentic_stats', {}) + summary = { + 'mode': 'agentic', + 'units_processed': agentic_stats.get('units_processed', 0), + 'units_with_context': agentic_stats.get('units_with_context', 0), + 'functions_added': agentic_stats.get('functions_added', 0), + 'security_controls_found': agentic_stats.get('security_controls_found', 0), + 'vulnerable_found': agentic_stats.get('vulnerable_found', 0), + 'neutral_found': agentic_stats.get('neutral_found', 0) + } + else: + enhanced = enhancer.enhance_dataset(dataset) + summary = { + 'mode': 'single-shot', + 'units_enhanced': enhancer.stats['units_enhanced'], + 'dependencies_added': enhancer.stats['dependencies_added'], + 'callers_added': enhancer.stats['callers_added'], + 'data_flows_extracted': enhancer.stats['data_flows_extracted'] + } + + with open(self.dataset_file, 'w') as f: + json.dump(enhanced, f, indent=2) + + elapsed = (datetime.now() - start_time).total_seconds() + + result = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': self.dataset_file, + 'summary': summary + } + + print() + print(f" Success ({elapsed:.2f}s)") + + self.results['stages']['context_enhancer'] = result + return True + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + result = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + self.results['stages']['context_enhancer'] = result + return False + + def apply_exploitable_filter(self) -> bool: + """Filter dataset to only include units classified as 'exploitable'.""" + if not self.dataset_file or not os.path.exists(self.dataset_file): + print("No dataset to filter") + return False + + print("=" * 60) + print("STAGE: exploitable_filter") + print("=" * 60) + print() + + start_time = datetime.now() + + try: + with open(self.dataset_file, 'r') as f: + dataset = json.load(f) + + units = dataset.get("units", []) + original_count = len(units) + + filtered_units = [] + classification_counts = {} + + for unit in units: + agent_context = unit.get("agent_context", {}) + classification = agent_context.get("security_classification", "unknown") + classification_counts[classification] = classification_counts.get(classification, 0) + 1 + + if classification == "exploitable": + filtered_units.append(unit) + + dataset["units"] = filtered_units + dataset["metadata"] = dataset.get("metadata", {}) + dataset["metadata"]["exploitable_filter"] = { + "original_units": original_count, + "exploitable_units": len(filtered_units), + "filtered_out": original_count - len(filtered_units), + "classification_counts": classification_counts, + "reduction_percentage": round((1 - len(filtered_units) / original_count) * 100, 1) if original_count > 0 else 0 + } + + with open(self.dataset_file, 'w') as f: + json.dump(dataset, f, indent=2) + + elapsed = (datetime.now() - start_time).total_seconds() + + summary = { + 'original_units': original_count, + 'exploitable_units': len(filtered_units), + 'classification_counts': classification_counts, + 'reduction_percentage': dataset["metadata"]["exploitable_filter"]["reduction_percentage"] + } + + result = { + 'success': True, + 'elapsed_seconds': elapsed, + 'output_file': self.dataset_file, + 'summary': summary + } + + print(f" Success ({elapsed:.2f}s)") + print(f" Classification breakdown:") + for cls, count in sorted(classification_counts.items()): + marker = "->" if cls == "exploitable" else " " + print(f" {marker} {cls}: {count}") + print(f" Units: {original_count} -> {len(filtered_units)} ({summary['reduction_percentage']}% reduction)") + print() + + self.results['stages']['exploitable_filter'] = result + return True + + except Exception as e: + elapsed = (datetime.now() - start_time).total_seconds() + print(f" Error: {e}") + import traceback + traceback.print_exc() + result = { + 'success': False, + 'elapsed_seconds': elapsed, + 'error': str(e) + } + self.results['stages']['exploitable_filter'] = result + return False + + def run_full_pipeline(self): + """Run the complete pipeline.""" + print("=" * 60) + print("RUST PARSER PIPELINE") + print("=" * 60) + print(f"Repository: {self.repo_path}") + print(f"Processing Level: {self.processing_level.value}") + print(f"Started: {self.results['test_time']}") + print() + + if not self.setup(): + print("Pipeline stopped: Setup failed") + return self.results + + # Stage 1-4: Run parser pipeline + if not self.run_parser_pipeline(): + print("Pipeline stopped: Parser pipeline failed") + return self.results + + # Stage 3.5 (optional): Reachability Filter + if self.processing_level in (ProcessingLevel.REACHABLE, ProcessingLevel.CODEQL, ProcessingLevel.EXPLOITABLE): + if not self.apply_reachability_filter(): + print("Warning: Reachability filter failed, continuing with all units") + + # Stage 3.6-3.7 (optional): CodeQL Analysis and Filter + if self.processing_level in (ProcessingLevel.CODEQL, ProcessingLevel.EXPLOITABLE): + codeql_success = self.run_codeql_analysis() + if codeql_success: + if not self.apply_codeql_filter(): + print("Warning: CodeQL filter failed, continuing with reachable units") + else: + print("Warning: CodeQL analysis failed, continuing with reachable units only") + + # Stage 4 (optional): Context Enhancer + if self.enable_llm: + if not self.run_context_enhancer(): + print("Warning: Context enhancer failed, continuing with static analysis only") + + # Stage 4.5 (optional): Exploitable Filter + if self.processing_level == ProcessingLevel.EXPLOITABLE: + if self.agentic: + if not self.apply_exploitable_filter(): + print("Warning: Exploitable filter failed") + else: + print() + print("Warning: Exploitable filter requires --agentic mode for classification") + print("Skipping exploitable filter") + else: + print() + print("Skipping LLM enhancement (use --llm to enable)") + if self.processing_level == ProcessingLevel.EXPLOITABLE: + print("Warning: Exploitable level requires --llm --agentic for classification") + + # Summary + print("=" * 60) + print("PIPELINE SUMMARY") + print("=" * 60) + + all_success = all( + stage.get('success', False) + for stage in self.results['stages'].values() + ) + + self.results['success'] = all_success + + if all_success: + print(" All stages completed successfully") + else: + print(" Some stages failed") + + print() + for stage_name, stage_result in self.results['stages'].items(): + status = "OK" if stage_result.get('success') else "FAIL" + elapsed = stage_result.get('elapsed_seconds', 0) + print(f" [{status}] {stage_name}: {elapsed:.2f}s") + + if 'summary' in stage_result: + summary = stage_result['summary'] + if 'total_files' in summary: + print(f" Files: {summary['total_files']}") + if 'total_functions' in summary: + print(f" Functions: {summary['total_functions']}") + if 'total_units' in summary: + print(f" Units: {summary['total_units']}") + edges = summary.get('call_graph_edges', 0) + avg_deg = summary.get('avg_out_degree', 0) + if edges: + print(f" Call graph: {edges} edges, avg degree: {avg_deg:.2f}") + if 'entry_points' in summary: + print(f" Entry points: {summary['entry_points']}") + print(f" Reachable: {summary.get('reachable_units', 0)}") + print(f" Reduction: {summary.get('reduction_percentage', 0)}%") + + print() + print(f"Output files in: {self.output_dir}") + + # Save results summary + results_file = os.path.join(self.output_dir, 'pipeline_results.json') + with open(results_file, 'w') as f: + clean_results = { + 'repository': self.results['repository'], + 'test_time': self.results['test_time'], + 'processing_level': self.results.get('processing_level', 'all'), + 'success': self.results.get('success', False), + 'stages': {} + } + for stage_name, stage_result in self.results['stages'].items(): + clean_results['stages'][stage_name] = { + 'success': stage_result.get('success', False), + 'elapsed_seconds': stage_result.get('elapsed_seconds', 0), + 'output_file': stage_result.get('output_file'), + 'summary': stage_result.get('summary', {}) + } + json.dump(clean_results, f, indent=2) + + print(f"Results summary: {results_file}") + + return self.results + + +def main(): + parser = argparse.ArgumentParser( + description='Run the Rust parser pipeline on a repository', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Processing Levels (cumulative filtering): + all Level 1: Process all units (no filtering, highest cost) + reachable Level 2: Filter to units reachable from entry points + codeql Level 3: Filter to reachable + CodeQL-flagged units (requires CodeQL CLI) + exploitable Level 4: Filter to reachable + CodeQL-flagged + exploitable (requires --llm --agentic) + +Examples: + # Static analysis only (all units) + python test_pipeline.py /path/to/repo + + # With reachability filtering only + python test_pipeline.py /path/to/repo --processing-level reachable + + # With CodeQL pre-filter + agentic classification + python test_pipeline.py /path/to/repo --llm --agentic --processing-level codeql + + # Maximum cost savings: only exploitable units + python test_pipeline.py /path/to/repo --llm --agentic --processing-level exploitable +""" + ) + parser.add_argument( + 'repo_path', + help='Path to the Rust repository to analyze' + ) + parser.add_argument( + '--output', '-o', + help='Output directory for pipeline artifacts', + default=None + ) + parser.add_argument( + '--llm', + action='store_true', + help='Enable LLM context enhancement (uses Claude Sonnet)' + ) + parser.add_argument( + '--agentic', + action='store_true', + help='Use agentic mode with iterative tool use (more accurate, more expensive)' + ) + parser.add_argument( + '--processing-level', + choices=['all', 'reachable', 'codeql', 'exploitable'], + default='all', + help='Processing level: all (L1), reachable (L2), codeql (L3), exploitable (L4)' + ) + parser.add_argument( + '--skip-tests', + action='store_true', + help='Skip test files' + ) + parser.add_argument( + '--depth', '-d', + type=int, + default=3, + help='Max dependency resolution depth (default: 3)' + ) + parser.add_argument( + '--name', '-n', + default=None, + help='Dataset name (default: derived from repo path)' + ) + + args = parser.parse_args() + + if not os.path.exists(args.repo_path): + print(f"Error: Repository not found: {args.repo_path}") + sys.exit(1) + + processing_level = ProcessingLevel(args.processing_level) + + if processing_level == ProcessingLevel.EXPLOITABLE and not (args.llm and args.agentic): + print("Warning: --processing-level exploitable requires --llm --agentic for classification") + print("Units will be filtered by reachability only, not by exploitability") + + pipeline = RustPipelineTest( + args.repo_path, + args.output, + enable_llm=args.llm, + agentic=args.agentic, + processing_level=processing_level, + skip_tests=args.skip_tests, + depth=args.depth, + name=args.name + ) + results = pipeline.run_full_pipeline() + + sys.exit(0 if results.get('success', False) else 1) + + +if __name__ == '__main__': + main() diff --git a/libs/openant-core/parsers/rust/unit_generator.py b/libs/openant-core/parsers/rust/unit_generator.py new file mode 100644 index 0000000..78a4062 --- /dev/null +++ b/libs/openant-core/parsers/rust/unit_generator.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +""" +Unit Generator for Rust Codebases + +Creates self-contained analysis units for ALL functions extracted from a repository. +Each unit includes: +- Primary code (the function itself) +- Upstream dependencies (functions this calls) +- Downstream callers (functions that call this) +- Assembled enhanced code with file boundaries + +This is Phase 4 of the Rust parser - dataset generation. + +Usage: + python unit_generator.py [--output ] [--depth ] + +Output (JSON): + { + "name": "dataset_name", + "repository": "/path/to/repo", + "units": [ ... ], + "statistics": { ... } + } +""" + +import json +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Set + + +# File boundary marker for enhanced code (Rust uses // comments) +FILE_BOUNDARY = '\n\n// ========== File Boundary ==========\n\n' + + +class UnitGenerator: + """ + Generate self-contained analysis units from call graph data. + + This is Stage 4 (final stage) of the Rust parser pipeline. + """ + + def __init__(self, call_graph_data: Dict, options: Optional[Dict] = None): + options = options or {} + + self.functions = call_graph_data.get('functions', {}) + self.classes = call_graph_data.get('classes', {}) # impl blocks + self.call_graph = call_graph_data.get('call_graph', {}) + self.reverse_call_graph = call_graph_data.get('reverse_call_graph', {}) + self.repo_path = call_graph_data.get('repository', '') + + self.max_depth = options.get('max_depth', 3) + self.dataset_name = options.get('dataset_name', Path(self.repo_path).name if self.repo_path else 'dataset') + + self.units: List[Dict] = [] + self.statistics = { + 'total_units': 0, + 'by_type': {}, + 'units_with_upstream': 0, + 'units_with_downstream': 0, + 'units_enhanced': 0, + 'avg_upstream': 0, + 'avg_downstream': 0, + } + + def get_dependencies(self, func_id: str, depth: Optional[int] = None) -> List[str]: + """Get all dependencies (callees) for a function up to max depth.""" + max_d = depth if depth is not None else self.max_depth + dependencies = [] + visited = {func_id} + queue = [(func_id, 0)] + + while queue: + current_id, current_depth = queue.pop(0) + + if current_depth >= max_d: + continue + + calls = self.call_graph.get(current_id, []) + for called_id in calls: + if called_id not in visited: + visited.add(called_id) + dependencies.append(called_id) + queue.append((called_id, current_depth + 1)) + + return dependencies + + def get_callers(self, func_id: str, depth: Optional[int] = None) -> List[str]: + """Get all callers for a function up to max depth.""" + max_d = depth if depth is not None else self.max_depth + callers = [] + visited = {func_id} + queue = [(func_id, 0)] + + while queue: + current_id, current_depth = queue.pop(0) + + if current_depth >= max_d: + continue + + caller_ids = self.reverse_call_graph.get(current_id, []) + for caller_id in caller_ids: + if caller_id not in visited: + visited.add(caller_id) + callers.append(caller_id) + queue.append((caller_id, current_depth + 1)) + + return callers + + def assemble_enhanced_code(self, func_data: Dict, + upstream_deps: List[Dict], + downstream_callers: List[Dict]) -> str: + """Assemble enhanced code with all dependencies using file boundary markers.""" + parts = [] + included_code: Set[str] = set() + + # Add primary code first + primary_code = func_data.get('code', '') + parts.append(primary_code) + included_code.add(primary_code) + + # Add upstream dependencies (functions this calls) + for dep in upstream_deps: + dep_code = dep.get('code', '') + if dep_code and dep_code not in included_code: + parts.append(dep_code) + included_code.add(dep_code) + + # Add downstream callers (functions that call this) + for caller in downstream_callers: + caller_code = caller.get('code', '') + if caller_code and caller_code not in included_code: + parts.append(caller_code) + included_code.add(caller_code) + + return FILE_BOUNDARY.join(parts) + + def collect_files_included(self, primary_file: str, + upstream_deps: List[Dict], + downstream_callers: List[Dict]) -> List[str]: + """Collect unique file paths from primary and all dependencies.""" + files: Set[str] = {primary_file} + + for dep in upstream_deps: + file_path = dep.get('file_path', '') + if file_path: + files.add(file_path) + + for caller in downstream_callers: + file_path = caller.get('file_path', '') + if file_path: + files.add(file_path) + + return sorted(list(files)) + + def create_unit(self, func_id: str, func_data: Dict) -> Dict: + """Create a single analysis unit with full context.""" + file_path = func_data.get('file_path', '') + func_name = func_data.get('name', '') + class_name = func_data.get('class_name') # impl block name in Rust + module_name = func_data.get('module_name') + unit_type = func_data.get('unit_type', 'function') + + # Get upstream dependencies (functions this calls) + upstream_ids = self.get_dependencies(func_id) + upstream_deps = [] + for dep_id in upstream_ids: + dep_func = self.functions.get(dep_id, {}) + if dep_func: + upstream_deps.append({ + 'id': dep_id, + 'name': dep_func.get('name'), + 'code': dep_func.get('code', ''), + 'file_path': dep_func.get('file_path', ''), + 'unit_type': dep_func.get('unit_type', 'function'), + 'class_name': dep_func.get('class_name'), + }) + + # Get downstream callers (functions that call this) + caller_ids = self.get_callers(func_id) + downstream_callers = [] + for caller_id in caller_ids: + caller_func = self.functions.get(caller_id, {}) + if caller_func: + downstream_callers.append({ + 'id': caller_id, + 'name': caller_func.get('name'), + 'code': caller_func.get('code', ''), + 'file_path': caller_func.get('file_path', ''), + 'unit_type': caller_func.get('unit_type', 'function'), + 'class_name': caller_func.get('class_name'), + }) + + # Assemble enhanced code + enhanced_code = self.assemble_enhanced_code(func_data, upstream_deps, downstream_callers) + files_included = self.collect_files_included(file_path, upstream_deps, downstream_callers) + is_enhanced = len(upstream_deps) > 0 or len(downstream_callers) > 0 + + # Get direct calls/callers (depth 1 only) + direct_calls = self.call_graph.get(func_id, []) + direct_callers = self.reverse_call_graph.get(func_id, []) + + # Build the unit + unit = { + 'id': func_id, + 'unit_type': unit_type, + 'code': { + 'primary_code': enhanced_code, + 'primary_origin': { + 'file_path': file_path, + 'start_line': func_data.get('start_line'), + 'end_line': func_data.get('end_line'), + 'function_name': func_name, + 'class_name': class_name, + 'enhanced': is_enhanced, + 'files_included': files_included, + 'original_length': len(func_data.get('code', '')), + 'enhanced_length': len(enhanced_code), + }, + 'dependencies': [], + 'dependency_metadata': { + 'depth': self.max_depth, + 'total_upstream': len(upstream_deps), + 'total_downstream': len(downstream_callers), + 'direct_calls': len(direct_calls), + 'direct_callers': len(direct_callers), + } + }, + 'ground_truth': { + 'status': 'UNKNOWN', + 'vulnerability_types': [], + 'issues': [], + 'annotation_source': None, + 'annotation_key': None, + 'notes': None, + }, + 'metadata': { + 'is_public': func_data.get('is_public', False), + 'is_async': func_data.get('is_async', False), + 'module_name': module_name, + 'parameters': func_data.get('parameters', []), + 'generator': 'rust_unit_generator.py', + 'direct_calls': direct_calls, + 'direct_callers': direct_callers, + } + } + + return unit + + def update_statistics(self, unit: Dict) -> None: + """Update statistics for a unit.""" + self.statistics['total_units'] += 1 + + unit_type = unit.get('unit_type', 'function') + self.statistics['by_type'][unit_type] = self.statistics['by_type'].get(unit_type, 0) + 1 + + dep_meta = unit.get('code', {}).get('dependency_metadata', {}) + if dep_meta.get('total_upstream', 0) > 0: + self.statistics['units_with_upstream'] += 1 + if dep_meta.get('total_downstream', 0) > 0: + self.statistics['units_with_downstream'] += 1 + if unit.get('code', {}).get('primary_origin', {}).get('enhanced', False): + self.statistics['units_enhanced'] += 1 + + def generate_units(self) -> Dict: + """Generate analysis units for all functions.""" + total_upstream = 0 + total_downstream = 0 + + for func_id, func_data in self.functions.items(): + unit = self.create_unit(func_id, func_data) + self.units.append(unit) + self.update_statistics(unit) + + dep_meta = unit.get('code', {}).get('dependency_metadata', {}) + total_upstream += dep_meta.get('total_upstream', 0) + total_downstream += dep_meta.get('total_downstream', 0) + + # Calculate averages + if self.statistics['total_units'] > 0: + self.statistics['avg_upstream'] = round(total_upstream / self.statistics['total_units'], 2) + self.statistics['avg_downstream'] = round(total_downstream / self.statistics['total_units'], 2) + + return { + 'name': self.dataset_name, + 'repository': self.repo_path, + 'units': self.units, + 'statistics': self.statistics, + 'metadata': { + 'generator': 'rust_unit_generator.py', + 'generated_at': datetime.now().isoformat(), + 'dependency_depth': self.max_depth, + } + } + + def generate_analyzer_output(self) -> Dict: + """Generate analyzer_output.json with camelCase fields for compatibility.""" + functions = {} + for func_id, func_data in self.functions.items(): + functions[func_id] = { + 'name': func_data.get('name', ''), + 'unitType': func_data.get('unit_type', 'function'), + 'code': func_data.get('code', ''), + 'filePath': func_data.get('file_path', ''), + 'startLine': func_data.get('start_line', 0), + 'endLine': func_data.get('end_line', 0), + 'isPublic': func_data.get('is_public', False), + 'isAsync': func_data.get('is_async', False), + 'isExported': func_data.get('is_public', False), # Rust: public = exported + 'moduleName': func_data.get('module_name'), + 'parameters': func_data.get('parameters', []), + 'className': func_data.get('class_name'), # impl block name + } + + return { + 'repository': self.repo_path, + 'functions': functions, + 'call_graph': self.call_graph, + 'reverse_call_graph': self.reverse_call_graph, + } + + +def main(): + """Command line interface.""" + import argparse + + parser = argparse.ArgumentParser( + description='Generate analysis units from Rust call graph data', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=''' +Examples: + python unit_generator.py call_graph.json + python unit_generator.py call_graph.json --output dataset.json + python unit_generator.py call_graph.json --depth 2 --name my_dataset + ''' + ) + + parser.add_argument('input_file', help='Call graph JSON file') + parser.add_argument('--output', '-o', help='Output file (default: stdout)') + parser.add_argument('--analyzer-output', help='Path for analyzer_output.json') + parser.add_argument('--depth', '-d', type=int, default=3, + help='Max dependency resolution depth (default: 3)') + parser.add_argument('--name', '-n', help='Dataset name (default: derived from repo path)') + + args = parser.parse_args() + + try: + with open(args.input_file) as f: + call_graph_data = json.load(f) + + options = { + 'max_depth': args.depth, + } + if args.name: + options['dataset_name'] = args.name + + print(f"Processing {len(call_graph_data.get('functions', {}))} functions...", file=sys.stderr) + print(f"Dependency resolution depth: {args.depth}", file=sys.stderr) + + generator = UnitGenerator(call_graph_data, options) + result = generator.generate_units() + + stats = result['statistics'] + print(f"\nDataset generated:", file=sys.stderr) + print(f" Total units: {stats['total_units']}", file=sys.stderr) + print(f" Units with upstream deps: {stats['units_with_upstream']}", file=sys.stderr) + print(f" Units with downstream callers: {stats['units_with_downstream']}", file=sys.stderr) + print(f" Enhanced units: {stats['units_enhanced']}", file=sys.stderr) + print(f" Avg upstream deps: {stats['avg_upstream']}", file=sys.stderr) + print(f" Avg downstream callers: {stats['avg_downstream']}", file=sys.stderr) + print(f"\nBy type:", file=sys.stderr) + for unit_type, count in sorted(stats['by_type'].items()): + print(f" {unit_type}: {count}", file=sys.stderr) + + output = json.dumps(result, indent=2) + + if args.output: + with open(args.output, 'w') as f: + f.write(output) + print(f"\nOutput written to: {args.output}", file=sys.stderr) + else: + print(output) + + # Write analyzer output if requested + if args.analyzer_output: + analyzer = generator.generate_analyzer_output() + with open(args.analyzer_output, 'w') as f: + json.dump(analyzer, f, indent=2) + print(f"Analyzer output written to: {args.analyzer_output}", file=sys.stderr) + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/libs/openant-core/pyproject.toml b/libs/openant-core/pyproject.toml index 266e7db..3666acb 100644 --- a/libs/openant-core/pyproject.toml +++ b/libs/openant-core/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "tree-sitter-cpp>=0.21.0", "tree-sitter-ruby>=0.21.0", "tree-sitter-php>=0.22.0", + "tree-sitter-rust>=0.21.0", ] [project.optional-dependencies] diff --git a/libs/openant-core/requirements.txt b/libs/openant-core/requirements.txt index 966904a..6f94cae 100644 --- a/libs/openant-core/requirements.txt +++ b/libs/openant-core/requirements.txt @@ -22,3 +22,4 @@ tree-sitter-c>=0.21.0 tree-sitter-cpp>=0.21.0 tree-sitter-ruby>=0.21.0 tree-sitter-php>=0.22.0 +tree-sitter-rust>=0.21.0