Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion natrix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def parse_args() -> argparse.Namespace:
"(e.g., -p /path/to/libs /another/path)."
),
)
exports_parser.add_argument(
"--display-modules",
action="store_true",
help="Include comments showing which module functions come from",
)

# Create the call_graph sub-subcommand
call_graph_parser = codegen_subparsers.add_parser(
Expand Down Expand Up @@ -303,7 +308,11 @@ def main() -> None:
# Get extra paths if provided
extra_paths = tuple(Path(p) for p in args.path) if args.path else ()
# Generate and print exports
exports = generate_exports(Path(args.file_path), extra_paths)
exports = generate_exports(
Path(args.file_path),
extra_paths,
include_module_comments=args.display_modules,
)
print(exports)
sys.exit(0)
elif args.codegen_command == "call_graph":
Expand Down
72 changes: 69 additions & 3 deletions natrix/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from natrix.ast_tools import vyper_compile


def generate_exports(file_path: Path, extra_paths: tuple[Path, ...]) -> str:
def generate_exports(
file_path: Path,
extra_paths: tuple[Path, ...],
include_module_comments: bool = False,
) -> str:
"""Generate explicit exports for a Vyper contract.

Args:
file_path: Path to the Vyper contract file
extra_paths: Additional paths to search for imports
include_module_comments: Whether to include comments
showing which module functions come from

Returns:
A string containing the exports declaration
Expand All @@ -33,10 +39,28 @@ def generate_exports(file_path: Path, extra_paths: tuple[Path, ...]) -> str:
# Convert to sorted list for deterministic output
external_funcs_list = sorted(external_funcs)

# If module comments are requested, parse AST to get module mapping
func_to_module: dict[str, str] = {}
if include_module_comments and external_funcs_list:
func_to_module = _get_function_to_module_mapping(file_path, extra_paths)

# Format the exports
if external_funcs_list:
func_list = [f" {module_name}.{func}" for func in external_funcs_list]
func_names = ",\n".join(func_list)
func_list = []
for i, func in enumerate(external_funcs_list):
func_line = f" {module_name}.{func}"

# Add comma to all items except the last one
if i < len(external_funcs_list) - 1:
func_line += ","

# Add module comment if available
if include_module_comments and func in func_to_module:
func_line += f" # {func_to_module[func]}"

func_list.append(func_line)

func_names = "\n".join(func_list)
return (
f"# NOTE: Always double-check the generated exports\n"
f"exports: (\n{func_names}\n)"
Expand All @@ -45,6 +69,48 @@ def generate_exports(file_path: Path, extra_paths: tuple[Path, ...]) -> str:
return f"# No external functions found in {module_name}"


def _get_function_to_module_mapping(
file_path: Path, extra_paths: tuple[Path, ...]
) -> dict[str, str]:
"""Parse the AST to map function names to their source modules.

Args:
file_path: Path to the Vyper contract file
extra_paths: Additional paths to search for imports

Returns:
A dictionary mapping function names to their source module names
"""
# Get the annotated AST from vyper
full_dict = vyper_compile(file_path, "annotated_ast", extra_paths=extra_paths)
assert isinstance(full_dict, dict)

func_to_module: dict[str, str] = {}

# Extract imported modules and their function names
imports = full_dict.get("imports", [])
for import_dict in imports:
if "path" in import_dict and "body" in import_dict:
# Extract module name from path
import_path = import_dict["path"]
module_name = Path(import_path).stem

# Look through the imported module's AST body
# for function definitions and public variables
for node in import_dict["body"]:
if node.get("ast_type") == "FunctionDef":
func_name = node.get("name")
if func_name:
func_to_module[func_name] = module_name
elif node.get("ast_type") == "VariableDecl" and node.get("is_public"):
# Public variables generate getter functions
var_name = node.get("target", {}).get("id")
if var_name:
func_to_module[var_name] = module_name

return func_to_module


def generate_call_graph(
file_path: Path, extra_paths: tuple[Path, ...], target_function: str | None = None
) -> str:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,43 @@ def test_codegen_exports():
version_dummy.view_external_marked_as_nothing
)"""
assert result.stdout.strip() == expected


def test_codegen_exports_with_module_comments():
"""Test the codegen exports command with module comments enabled."""
# Use the scrvusd_oracle contract which imports from ownable
test_contract = (
Path(__file__).parent / "contracts" / "scrvusd_oracle" / "scrvusd_oracle.vy"
)

# Run the codegen exports command with --display-modules flag
result = subprocess.run(
[
sys.executable,
"-m",
"natrix",
"codegen",
"exports",
str(test_contract),
"--display-modules",
],
capture_output=True,
text=True,
)

# Check that the command succeeded
assert result.returncode == 0

output = result.stdout.strip()

# Should contain the exports section
assert "exports: (" in output

# Should contain functions from ownable module with comments
assert "owner, # ownable" in output
assert "transfer_ownership, # ownable" in output
assert "renounce_ownership, # ownable" in output

# Should contain functions from main contract without comments
assert "scrvusd_oracle.price_v0," in output
assert "scrvusd_oracle.update_price," in output