-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_stage1.py
More file actions
61 lines (50 loc) · 2.17 KB
/
run_stage1.py
File metadata and controls
61 lines (50 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import annotations
import argparse
from pathlib import Path
from nora_stage1.datasets import get_dataset_adapter
from nora_stage1.exporters import export_stage1_outputs
from nora_stage1.normalization import normalize_stage1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run NORA stage 1 preprocessing.")
parser.add_argument("--dataset", required=True, help="Dataset adapter name, for example: nell-995")
parser.add_argument("--source-dir", required=True, help="Path to the raw dataset directory")
parser.add_argument("--output-dir", required=True, help="Path where stage-1 outputs will be written")
parser.add_argument(
"--ambiguity-threshold",
type=float,
default=0.5,
help="3KG-NF ambiguity threshold. Relations above this multi-object subject ratio are specialized.",
)
parser.add_argument(
"--ambiguity-stats-split",
default="train",
choices=("train", "valid", "test"),
help="Split used to estimate the 3KG-NF ambiguity ratio.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
adapter = get_dataset_adapter(args.dataset)
source_dir = Path(args.source_dir).resolve()
output_dir = Path(args.output_dir).resolve()
raw_bundle = adapter.load(source_dir)
stage1_bundle = normalize_stage1(
raw_bundle,
ambiguity_threshold=args.ambiguity_threshold,
ambiguity_stats_split=args.ambiguity_stats_split,
)
export_stage1_outputs(stage1_bundle, output_dir)
stage1_meta = stage1_bundle.metadata["stage1"]
print(f"Dataset: {stage1_bundle.dataset_name}")
print(f"Entities: {len(stage1_bundle.entities)}")
print(f"Relations: {len(stage1_bundle.relations)}")
print(f"Split sizes: {stage1_meta['final_split_sizes']}")
print(
"Stage 1 summary: "
f"NF1 blank nodes={stage1_meta['nf1']['blank_entity_count']}, "
f"NF2 split relations={stage1_meta['nf2']['split_relation_count']}, "
f"NF3 split relations={stage1_meta['nf3']['split_relation_count']}"
)
print(f"Outputs written to: {output_dir}")
if __name__ == "__main__":
main()