-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcluster.py
More file actions
92 lines (72 loc) · 2.73 KB
/
cluster.py
File metadata and controls
92 lines (72 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from typing import Any, Dict
import yaml
class DBConnect:
def __init__(self, host: str, port: int, user: str, password: str, name: str):
self.host = host
self.port = port
self.user = user
self.password = password
self.name = name
class ClusterManager:
def __init__(self, name: str = None, config_file: str = "config/system.yaml", auto: bool = True):
"""Creates a ClusterManager object that can automatically configure multiple clusters.
Args:
name (str): Name of the cluster as given in YAML.
config_file (str, optional): Path to the YAML config file. Defaults to "config/system.yaml".
auto (bool, optional): Whether the cluster should be identified automatically from the linux environment variables. Defaults to True.
Raises:
OSError: YAML config file not found.
NotImplementedError: Cluster ID not found in YAML config file.
"""
if auto:
sys_name = os.getenv("CLUSTER_NAME")
if sys_name is None:
raise OSError("CLUSTER_NAME not found in environment variables. Autoselecting system failed.")
else:
self.name = sys_name
else:
self.name = name
with open(config_file) as file:
self._configs = yaml.load(file, Loader=yaml.FullLoader)
if self.name not in self._configs.keys():
raise NotImplementedError(f"System {self.name} not implemented in '{config_file}'")
self._configs = self._configs[self.name]
@property
def project_dir(self) -> str:
return self._configs["PROJECT_DIR"]
@property
def num_workers(self) -> int:
return self._configs["NUM_WORKERS"]
@property
def data_dir(self) -> str:
return self._configs["DATA_DIR"]
@property
def log_dir(self) -> str:
return self._configs["LOG_DIR"]
@property
def artifact_dir(self) -> str:
return self._configs["ARTIFACTS_DIR"]
@property
def network(self):
return self._configs["NETWORK"]
@property
def use_GPU(self) -> bool:
return self._configs["USE_GPU"]
@property
def get_pid(self) -> int:
try:
return os.environ["SLURM_JOB_ID"]
except KeyError:
return os.getpid()
@property
def db(self) -> DBConnect:
return DBConnect(
host=self._configs["DB_HOST"],
port=self._configs["DB_PORT"],
user=self._configs["DB_USER"],
password=self._configs["DB_PASSWORD"],
name=self._configs["DB_NAME"],
)
def to_dict(self) -> Dict[str, Any]:
return {k.lower(): v for k, v in self._configs.items()}