-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_stage2.py
More file actions
93 lines (85 loc) · 4.32 KB
/
run_stage2.py
File metadata and controls
93 lines (85 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations
import argparse
from pathlib import Path
from nora_stage2.pipeline import Stage2Config, run_stage2
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Run NORA stage 2 (RAE compact-set extraction).')
parser.add_argument('--stage1-dir', required=True, help='Path to the stage-1 output directory.')
parser.add_argument('--output-dir', required=True, help='Path where stage-2 outputs will be written.')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=512)
parser.add_argument('--neg-num', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--l2', type=float, default=0.0)
parser.add_argument('--init-dim', type=int, default=100)
parser.add_argument('--hidden-dim', type=int, default=200)
parser.add_argument('--gcn-layers', type=int, default=2)
parser.add_argument('--num-bases', type=int, default=100)
parser.add_argument('--num-blocks', type=int, default=None)
parser.add_argument('--hidden-dropout', type=float, default=0.1)
parser.add_argument('--sparsity-gamma', type=float, default=0.5)
parser.add_argument('--mcp-alpha', type=float, default=10.0)
parser.add_argument('--mcp-lambda', type=float, default=1.0)
parser.add_argument('--tau', type=float, default=0.1)
parser.add_argument('--no-gumbel', action='store_true', default=False)
parser.add_argument('--mask-threshold', type=float, default=0.5)
parser.add_argument('--compact-export-threshold', type=float, default=None)
parser.add_argument('--reconstruction-threshold', type=float, default=0.5)
parser.add_argument('--device', default='auto')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--max-train-batches', type=int, default=None)
parser.add_argument('--backward-chunk-batches', type=int, default=16)
parser.add_argument('--log-batch-interval', type=int, default=10)
parser.add_argument('--checkpoint-every-epochs', type=int, default=1)
parser.add_argument('--resume-checkpoint', default=None)
parser.add_argument('--best-checkpoint-metric', choices=('loss', 'balanced'), default='balanced')
parser.add_argument('--edge-mask-batch-size', type=int, default=32768)
parser.add_argument('--edge-message-batch-size', type=int, default=1024)
return parser.parse_args()
def main() -> None:
args = parse_args()
config = Stage2Config(
stage1_dir=str(Path(args.stage1_dir).resolve()),
output_dir=str(Path(args.output_dir).resolve()),
epochs=args.epochs,
batch_size=args.batch_size,
neg_num=args.neg_num,
lr=args.lr,
l2=args.l2,
init_dim=args.init_dim,
hidden_dim=args.hidden_dim,
gcn_layers=args.gcn_layers,
num_bases=args.num_bases,
num_blocks=args.num_blocks,
hidden_dropout=args.hidden_dropout,
sparsity_gamma=args.sparsity_gamma,
mcp_alpha=args.mcp_alpha,
mcp_lambda=args.mcp_lambda,
tau=args.tau,
use_gumbel=not args.no_gumbel,
mask_threshold=args.mask_threshold,
compact_export_threshold=args.compact_export_threshold,
reconstruction_threshold=args.reconstruction_threshold,
device=args.device,
seed=args.seed,
max_train_batches=args.max_train_batches,
backward_chunk_batches=args.backward_chunk_batches,
log_batch_interval=args.log_batch_interval,
checkpoint_every_epochs=args.checkpoint_every_epochs,
resume_checkpoint=args.resume_checkpoint,
best_checkpoint_metric=args.best_checkpoint_metric,
edge_mask_batch_size=args.edge_mask_batch_size,
edge_message_batch_size=args.edge_message_batch_size,
)
artifacts = run_stage2(config)
summary = artifacts.summary
print(f"Dataset: {summary['dataset_name']}")
print(f"Device: {summary['device']}")
print(f"Directed train edges: {summary['directed_train_edge_count']}")
print(f"Compact counts: {summary['compact_counts']}")
print(f"Selected epoch: {summary['best_epoch']}")
print(f"Selection metric: {summary['selected_checkpoint_metric']}")
print(f"Best loss: {summary['best_loss']}")
print(f"Outputs written to: {config.output_dir}")
if __name__ == '__main__':
main()