-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
59 lines (43 loc) · 1.61 KB
/
main.py
File metadata and controls
59 lines (43 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Main file to launch the experiments based on configuration.
Author: Yoshinari Motokawa <yoshinari.moto@fuji.waseda.jp>
"""
import os
import warnings
import hydra
import matplotlib.pyplot as plt
import seaborn as sns
from omegaconf import DictConfig
from configs import config_names
from core.environments import generate_environment
from core.handlers.evaluators import generate_evaluator
from core.handlers.trainers import generate_trainer
from core.maps import generate_map
from core.utils.logging import initialize_logging
from core.utils.seed import set_seed
from core.worlds import generate_world
plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["savefig.facecolor"] = "white"
sns.set()
warnings.simplefilter("ignore")
logger = initialize_logging(__name__)
@hydra.main(config_path="configs", config_name=config_names["IJCNN2025_gpos_eda6_iqn"])
def main(config: DictConfig):
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu)
set_seed(seed=config.seed)
world_map = generate_map(config=config)
world = generate_world(config=config, world_map=world_map)
env = generate_environment(config=config, world=world)
if config.phase == "training":
trainer = generate_trainer(config=config, environment=env)
try:
trainer.run()
finally:
trainer.save_state_dict()
elif config.phase == "evaluation":
evaluator = generate_evaluator(config=config, environment=env)
evaluator.run()
else:
logger.warn(f"Unexpected phase is given. config.phase: {config.phase}")
raise ValueError()
if __name__ == "__main__":
main()