-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimize.py
More file actions
57 lines (43 loc) · 2.23 KB
/
optimize.py
File metadata and controls
57 lines (43 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import json
import os
from run import run_sac
from UnitCell_Environment.unitcell_environment.env.utils import load_compositions
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Optional app description')
parser.add_argument("--algo_config", type=str, default="config.json",
help="The config file for tunable parameters of the chosen RL method")
parser.add_argument("--env_config", type=str, default="env_config.json",
help="The config file for the env parameters")
parser.add_argument("--policy", type=str, default=os.path.abspath("policy"),
help="the folder containing the policy to be used")
parser.add_argument("--comp_config", type=str, default="comp_config.json",
help="the config file with compositions and structure sizes supported")
parser.add_argument("--comp", type=str, default="SrTiO3x8",
help="specific composition/structure size to optimize")
parser.add_argument("--input_dir", type=str, default="starting_structures/SrTiO3x8",
help="input dir with starting structures")
parser.add_argument("--output_dir", type=str, default=None,
help="The output dir to save the optimized structures and the energy history")
args = parser.parse_args()
with open(args.algo_config, 'r') as f:
algo_config = json.load(f)
algo_config["num_workers"] = 0
with open(args.env_config, "r") as file:
env_config = json.load(file)
env_config["gamma"] = algo_config.get("gamma", 0.99)
env_config["gpu"] = algo_config.get("num_gpus", 0) > 0
env_config["agent_levels"] = ["2_stepsize_"]
if args.output_dir != None:
env_config["output_dir"] = args.output_dir
if args.comp is not None:
env_config["comp"] = args.comp
env_config["store_trajectory"] = False
with open(args.comp_config, 'r') as f:
comp_config = json.load(f)
load_compositions(comp_config)
checkpoint = args.policy
checkpoint = run_sac(env_config,
algo_config,
input_dir=args.input_dir,
checkpoint=checkpoint)