forked from ml-research/blendrl-dev
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate.py
More file actions
45 lines (40 loc) · 1.11 KB
/
Copy pathevaluate.py
File metadata and controls
45 lines (40 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import tyro
from nudge.evaluator import Evaluator
from dataclasses import dataclass
import tyro
def main(
env_name: str = "seaquest",
agent_path: str = "out/runs/models/kangaroo_demo",
fps: int = 5,
episodes: int = 2,
model: str = "blendrl",
device: str = "cuda:0",
seed: int = 0,
modified_env: str|None = None,
detect_all_enemies: bool = False,
) -> None:
"""
Evaluation script. This script evaluates the performance of the blendrl on new episodes.
"""
env_kwargs = {
"modified_env": modified_env,
"detect_all_enemies": detect_all_enemies,
}
print("env_kwargs for Evaluator:", env_kwargs)
evaluator = Evaluator(
episodes=episodes,
agent_path=agent_path,
env_name=env_name,
fps=fps,
deterministic=False,
device=device,
# env_kwargs=dict(render_oc_overlay=True),
seed=seed,
env_kwargs=env_kwargs,
render_predicate_probs=True,
)
return evaluator.run()
if __name__ == "__main__":
tyro.cli(main)