-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
77 lines (59 loc) · 2.45 KB
/
Copy pathpredict.py
File metadata and controls
77 lines (59 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) 2025 lightning-hydra-boilerplate
# Licensed under the MIT License.
"""Inference script to generate and save model predictions."""
from pathlib import Path
import hydra
import lightning.pytorch as pl
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from utils.logger_utils import log_message, setup_logger
from utils.pred_utils import save_predictions
def predict(cfg: DictConfig) -> None:
"""Run inference on a specified data split using a trained model from a checkpoint.
Args:
cfg (DictConfig): Hydra configuration object containing model, data, and trainer settings.
"""
log_message("info", f"Running prediction with configuration:\n{OmegaConf.to_yaml(cfg)}")
model = instantiate(cfg.model)
datamodule = instantiate(cfg.data)
trainer_params = instantiate(cfg.trainer)
trainer_params["callbacks"] = list(trainer_params["callbacks"].values())
trainer = pl.Trainer(**trainer_params)
datamodule.setup()
# Select appropriate dataloader
if cfg.data_split == "train":
dataloader = datamodule.train_dataloader()
elif cfg.data_split == "val":
dataloader = datamodule.val_dataloader()
elif cfg.data_split == "test":
dataloader = datamodule.test_dataloader()
elif cfg.data_split == "predict":
dataloader = datamodule.predict_dataloader()
else:
error_msg = f"Unsupported data_split: {cfg.data_split}"
raise ValueError(error_msg)
predictions_batches = trainer.predict(
model=model,
dataloaders=dataloader,
ckpt_path=cfg.get("ckpt_path", None),
)
# Flatten and sort predictions by idx
flat_preds = [
{"idx": batch["idx"][i].item(), "pred": batch["pred"][i].item()}
for batch in predictions_batches
for i in range(len(batch["idx"]))
]
flat_preds.sort(key=lambda x: x["idx"])
# Save the predictions
run_dir = HydraConfig.get().run.dir
output_path = Path(run_dir) / f"predictions.{cfg.save_format}"
save_predictions(flat_preds, output_path, cfg.save_format)
log_message("info", f"Predictions of {cfg.data_split} split are saved to {output_path}")
@hydra.main(version_base=None, config_path="../configs", config_name="predict")
def main(cfg: DictConfig) -> None:
"""Main prediction entry point triggered by Hydra config."""
setup_logger()
predict(cfg)
if __name__ == "__main__":
main()