-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_stage3.py
More file actions
65 lines (57 loc) · 3.09 KB
/
run_stage3.py
File metadata and controls
65 lines (57 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations
import argparse
from pathlib import Path
from nora_stage3.pipeline import Stage3Config, run_stage3
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run NORA stage 3 (AMIE3 rule mining and constraint scoring).')
parser.add_argument('--stage2-dir', required=True, help='Path to the stage-2 output directory.')
parser.add_argument('--output-dir', required=True, help='Path where stage-3 outputs will be written.')
parser.add_argument(
'--amie-jar',
default=None,
help='Path to amie3.jar. Defaults to AMIE3_JAR or a local amie3.jar if present.',
)
parser.add_argument('--java-bin', default='java')
parser.add_argument('--min-support', type=int, default=2)
parser.add_argument('--min-initial-support', type=int, default=1, help='AMIE -minis value. Defaults to 1 because Stage 1 normalization creates many sparse relations.')
parser.add_argument('--min-pca-confidence', type=float, default=0.1)
parser.add_argument('--min-std-confidence', type=float, default=0.0)
parser.add_argument('--min-head-coverage', type=float, default=0.01)
parser.add_argument('--max-atoms', type=int, default=3)
parser.add_argument('--pca-percentile', type=int, default=None, help='Override auto percentile selection with 50, 75, or 90.')
parser.add_argument('--amie-threads', type=int, default=None)
parser.add_argument('--body-graph', choices=('existing', 'compact'), default='existing', help='Graph used to satisfy rule bodies during application. Paper-faithful default is existing.')
parser.add_argument('--max-filtered-rules', type=int, default=None)
parser.add_argument('--log-rule-interval', type=int, default=25)
parser.add_argument('--no-reuse-amie-output', action='store_true', default=False)
return parser.parse_args()
def main() -> None:
args = parse_args()
config = Stage3Config(
stage2_dir=str(Path(args.stage2_dir).resolve()),
output_dir=str(Path(args.output_dir).resolve()),
amie_jar=args.amie_jar,
java_bin=args.java_bin,
min_support=args.min_support,
min_initial_support=args.min_initial_support,
min_pca_confidence=args.min_pca_confidence,
min_std_confidence=args.min_std_confidence,
min_head_coverage=args.min_head_coverage,
max_atoms=args.max_atoms,
pca_percentile=args.pca_percentile,
amie_threads=args.amie_threads,
body_graph=args.body_graph,
max_filtered_rules=args.max_filtered_rules,
log_rule_interval=args.log_rule_interval,
reuse_amie_output=not args.no_reuse_amie_output,
)
artifacts = run_stage3(config)
summary = artifacts.summary
print(f"Dataset: {summary['dataset_name']}")
print(f"Compact triples: {summary['compact_triple_count']}")
print(f"Raw rules: {summary['raw_rule_count']}")
print(f"Filtered rules: {summary['filtered_rule_count']}")
print(f"Predicted missing edges: {summary['predicted_missing_edge_count']}")
print(f"Outputs written to: {config.output_dir}")
if __name__ == '__main__':
main()