-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrunner.py
More file actions
83 lines (73 loc) · 3.65 KB
/
runner.py
File metadata and controls
83 lines (73 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from datetime import datetime
from typing import Union
from typing_extensions import TypedDict
from traffic.environment import TrafficEnvironment
from traffic.agent import TrafficAgent
from utils.configs import TrafficAgentConfig, LearningAgentConfig, CanvasConfig
from utils.plotter import MultiPlotter
class RunsConfig(TypedDict):
""" TypedDict for one or more training runs of a TrafficAgent. """
cls: type[TrafficAgent]
""" Class of the TrafficAgent to use for all training runs. """
configs: list[Union[TrafficAgentConfig, LearningAgentConfig]]
""" List of configs to use to configure the TrafficAgent at each training run. """
class Runner():
""" Runner for several TrafficAgents with possibly different hyperparameters configuration. """
def __init__(self, canvas_config: CanvasConfig, traffic_env: TrafficEnvironment, runs_configs: list[RunsConfig]) -> None:
"""
Runner for several TrafficAgents with possibly different hyperparameters configuration.
:param canvas_config: PlotData to instantiate Plotters.
:type canvas_config: CanvasConfig
:param traffic_env: TrafficEnvironment to perform each run in.
:type traffic_env: TrafficEnvironment
:param agents: Dictionary of agents names along with their runs configurations.
:type agents: list[RunsConfig]
"""
self._traffic_env: TrafficEnvironment = traffic_env
self._agents: dict[str, TrafficAgent] = {config['name']: agent['cls'](config, traffic_env, canvas_config) for agent in runs_configs for config in agent['configs']}
self._multi_plotter: MultiPlotter = MultiPlotter([agent.config for agent in self._agents.values()], canvas_config)
@property
def timestamp(self) -> str:
""" Current timestamp formatted. """
return str(datetime.now()).split('.')[0]
def learn(self) -> dict[str, list[str]]:
"""
Trains all TrafficAgents, each run for each agent with a different hyperparameters configuration.
:return: List of all saved models for each agent.
:rtype: dict[str, list[str]]
"""
models: dict[str, list[str]] = {}
for agent in self._agents.values():
models[agent.name] = []
print(f'{self.timestamp} Learning for agent {agent.name}.')
while agent.config['repeat']:
print(f'{self.timestamp} Run #{agent.current_run + 1}.')
models[agent.name].append(agent.run())
self._multi_plotter.add_run(agent.means, agent.name)
self._multi_plotter.save(True)
self._multi_plotter.close()
return models
def run(self, models: dict[str, list[str]], seconds: Union[int, None] = None, use_gui: bool = True) -> None:
"""
Resets the hyperparameters configuration for all TrafficAgents, then runs all specified TrafficAgents loading each specified model for each run.
:param models: Dictionary of agent names paired with a list of models to load.
:type models: dict[str, list[str]]
:param seconds: Amount of simulation seconds to run, if None the same amount of simulation seconds used during learning will be used.
:type seconds: Union[int, None]
:param use_gui: Whether to show SUMO GUI.
:type use_gui: bool
"""
self._multi_plotter.clear()
if seconds is not None:
self._traffic_env.set_seconds(seconds)
for model in models:
if model in self._agents:
agent = self._agents[model]
agent.reset()
print(f'{self.timestamp} Running agent {agent.name}.')
while agent.config['repeat']:
print(f'{self.timestamp} Run #{agent.current_run + 1}.')
agent.run(use_gui, models[model][agent._runs])
self._multi_plotter.add_run(agent.means, agent.name)
self._multi_plotter.save(False)
self._multi_plotter.close()