-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenrich_planet_catalog.py
More file actions
115 lines (94 loc) · 3.62 KB
/
enrich_planet_catalog.py
File metadata and controls
115 lines (94 loc) · 3.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
"""
Enrich a planet CSV using NASA Exoplanet Archive (pscomppars).
Adds columns:
- orbital_period_days (pl_orbper)
- tic_id (tic_id)
- M_star (st_mass) [Msun]
- M_planet (pl_bmassj) [Mjup]
- R_planet_jup (pl_radj) [Rjup]
- rho_planet (pl_dens) [g/cm^3]
Usage:
python enrich_planet_catalog.py --input ranked_transiting_planets.csv --output ranked_transiting_planets_enriched.csv
"""
import argparse
import time
import pandas as pd
from tqdm import tqdm
from astroquery.ipac.nexsci.nasa_exoplanet_archive import NasaExoplanetArchive
SELECT_COLUMNS = [
"pl_name",
"pl_orbper",
"pl_tranmid",
"tic_id",
"st_mass",
"pl_bmassj",
"pl_radj",
"pl_dens",
]
def chunked(seq, size):
for i in range(0, len(seq), size):
yield seq[i:i + size]
def fetch_batch(names):
escaped = [n.replace("'", "''") for n in names]
quoted = ",".join([f"'{n}'" for n in escaped])
where = f"pl_name in ({quoted})"
tbl = NasaExoplanetArchive.query_criteria(
table="pscomppars",
select=",".join(SELECT_COLUMNS),
where=where,
)
if len(tbl) == 0:
return pd.DataFrame(columns=SELECT_COLUMNS)
return tbl.to_pandas()
def main():
parser = argparse.ArgumentParser(description="Enrich planet CSV with ephemeris/physical parameters")
parser.add_argument("--input", required=True, help="Input CSV with planet_name column")
parser.add_argument("--output", required=True, help="Output CSV path")
parser.add_argument("--batch-size", type=int, default=50, help="Batch size for archive queries")
parser.add_argument("--sleep", type=float, default=0.0, help="Seconds to sleep between batches")
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing columns")
args = parser.parse_args()
df = pd.read_csv(args.input)
if "planet_name" not in df.columns:
raise ValueError("Input CSV must include planet_name column")
names = sorted(df["planet_name"].dropna().astype(str).unique().tolist())
rows = []
for batch in tqdm(list(chunked(names, args.batch_size)), desc="Querying archive"):
try:
batch_df = fetch_batch(batch)
except Exception:
batch_df = pd.DataFrame(columns=SELECT_COLUMNS)
rows.append(batch_df)
if args.sleep > 0:
time.sleep(args.sleep)
enrich = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame(columns=SELECT_COLUMNS)
enrich = enrich.drop_duplicates(subset=["pl_name"]).rename(columns={
"pl_name": "planet_name",
"pl_orbper": "orbital_period_days",
"st_mass": "M_star",
"pl_bmassj": "M_planet",
"pl_radj": "R_planet_jup",
"pl_dens": "rho_planet",
})
merged = df.merge(enrich, on="planet_name", how="left", suffixes=("", "_new"))
def _merge_col(col):
new_col = f"{col}_new"
if new_col not in merged.columns:
return
if args.overwrite:
merged[col] = merged[new_col]
else:
merged[col] = merged[col].where(merged[col].notna(), merged[new_col])
merged.drop(columns=[new_col], inplace=True)
for col in ["orbital_period_days", "tic_id", "M_star", "M_planet", "R_planet_jup", "rho_planet"]:
if col in merged.columns:
_merge_col(col)
else:
merged[col] = merged.get(f"{col}_new")
if f"{col}_new" in merged.columns:
merged.drop(columns=[f"{col}_new"], inplace=True)
merged.to_csv(args.output, index=False)
print(f"Wrote enriched catalog to {args.output} (rows: {len(merged)})")
if __name__ == "__main__":
main()