Skip to content

Commit de0a511

Browse files
authored
Implement basic unit tests (#11)
1 parent e5dce41 commit de0a511

3 files changed

Lines changed: 228 additions & 2 deletions

File tree

artefacts.yaml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@ project: artefacts/go2-isaac-example
22
version: 0.1.0
33

44
jobs:
5+
unit_tests:
6+
type: test
7+
timeout: 15 # minutes
8+
package:
9+
docker:
10+
build:
11+
dockerfile: empty
12+
runtime:
13+
framework: other
14+
simulator: isaac_sim
15+
16+
scenarios:
17+
settings:
18+
- name: navigator_unit_tests
19+
pytest_file: "nodes/navigator/tests"
20+
21+
- name: policy_controller_unit_tests
22+
pytest_file: "nodes/policy_controller/tests"
23+
524
waypoint_missions:
625
type: test
726
timeout: 20 # minutes
@@ -14,8 +33,6 @@ jobs:
1433
simulator: isaac_sim
1534

1635
scenarios:
17-
# Reinstall artefacts-click as a temporary workaround for conflicting click versions
18-
# see https://github.com/art-e-fact/artefacts-client/issues/370
1936
settings:
2037
- name: report_based_waypoint_mission_test
2138
run: "rm -rf outputs/artefacts && uv run dataflow --test-waypoint-report"

nodes/navigator/tests/test_navigator.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
"""Test module for navigator package."""
22

3+
import queue
4+
import threading
5+
from unittest.mock import MagicMock, patch
6+
7+
import numpy as np
8+
import pyarrow as pa
39
import pytest
10+
from msgs import Transform, Twist2D, Waypoint, WaypointList, WaypointStatus
411

512

613
def test_import_main():
@@ -11,3 +18,159 @@ def test_import_main():
1118
# as we're not running in a Dora dataflow.
1219
with pytest.raises(RuntimeError):
1320
main()
21+
22+
23+
def test_navigator_logic():
24+
"""Test the main navigator logic with mocked inputs."""
25+
# Patch dora.Node in the context of the navigator.main module
26+
with patch("navigator.main.Node") as MockNode:
27+
mock_node_instance = MagicMock()
28+
MockNode.return_value = mock_node_instance
29+
30+
# Define the inputs that the node will receive.
31+
robot_pose = Transform.from_position_and_quaternion(
32+
np.array([0.0, 0.0, 0.0]),
33+
np.array([1.0, 0.0, 0.0, 0.0]), # scalar-first quaternion (w, x, y, z)
34+
)
35+
waypoints = WaypointList(
36+
[
37+
Waypoint(
38+
transform=Transform.from_position_and_quaternion(
39+
np.array([10.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0, 0.0])
40+
),
41+
status=WaypointStatus.ACTIVE,
42+
)
43+
]
44+
)
45+
46+
# The mocked node will yield these events when iterated.
47+
mock_node_instance.__iter__.return_value = [
48+
{"type": "INPUT", "id": "robot_pose", "value": robot_pose.to_arrow()},
49+
{"type": "INPUT", "id": "waypoints", "value": waypoints.to_arrow()},
50+
{
51+
"type": "INPUT",
52+
"id": "tick",
53+
"value": pa.array([]), # The tick value is not used.
54+
},
55+
{
56+
"type": "INPUT",
57+
"id": "stop", # Stop the loop
58+
},
59+
]
60+
61+
from navigator.main import main
62+
63+
main()
64+
65+
# Check that send_output was called correctly.
66+
# The robot is at the origin, facing the goal at (10, 0).
67+
# It should command maximum forward velocity and no angular velocity.
68+
expected_command = Twist2D(linear_x=1.0, linear_y=0.0, angular_z=0.0)
69+
70+
# mock_node_instance.send_output.assert_called_once() # Fails if called more than once
71+
mock_node_instance.send_output.assert_called_with(
72+
"command_2d", expected_command.to_arrow()
73+
)
74+
75+
76+
def test_navigator_logic_stateful():
77+
"""Test if sends command output after every tick input."""
78+
# A queue to send events to the node
79+
event_queue = queue.Queue()
80+
81+
with patch("navigator.main.Node") as MockNode:
82+
mock_node_instance = MagicMock()
83+
MockNode.return_value = mock_node_instance
84+
mock_node_instance.__iter__.return_value = iter(event_queue.get, None)
85+
86+
from navigator.main import main
87+
88+
# Run the main function in a separate thread
89+
main_thread = threading.Thread(target=main, daemon=True)
90+
main_thread.start()
91+
92+
# Set initial robot pose and waypoints
93+
robot_pose = Transform.from_position_and_quaternion(
94+
np.array([0.0, 0.0, 0.0]),
95+
np.array([1.0, 0.0, 0.0, 0.0]), # scalar-first (w, x, y, z)
96+
)
97+
waypoints = WaypointList(
98+
[
99+
Waypoint(
100+
transform=Transform.from_position_and_quaternion(
101+
np.array([10.0, 0.0, 0.0]), np.array([1.0, 0.0, 0.0, 0.0])
102+
),
103+
status=WaypointStatus.ACTIVE,
104+
)
105+
]
106+
)
107+
event_queue.put(
108+
{"type": "INPUT", "id": "robot_pose", "value": robot_pose.to_arrow()}
109+
)
110+
event_queue.put(
111+
{"type": "INPUT", "id": "waypoints", "value": waypoints.to_arrow()}
112+
)
113+
114+
# Send a tick to trigger a command calculation
115+
event_queue.put({"type": "INPUT", "id": "tick", "value": pa.array([])})
116+
117+
threading.Event().wait(0.1) # Wait for async operations
118+
mock_node_instance.send_output.assert_called_once()
119+
args, _ = mock_node_instance.send_output.call_args
120+
assert args[0] == "command_2d"
121+
command_output = Twist2D.from_arrow(args[1])
122+
assert command_output.linear_x > 0, (
123+
f"Expected to go towards X direction, got {command_output}"
124+
)
125+
assert command_output.angular_z == 0.0, (
126+
f"Expected to go straight, got {command_output}"
127+
)
128+
129+
# Reset mock to check for the next call
130+
mock_node_instance.send_output.reset_mock()
131+
132+
# Set a new robot pose
133+
new_robot_pose = Transform.from_position_and_quaternion(
134+
np.array([10.0, 10.0, 0.0]),
135+
np.array([0.707, 0.0, 0.0, 0.707]), # Rotated 90 degrees (facing +y)
136+
)
137+
event_queue.put(
138+
{"type": "INPUT", "id": "robot_pose", "value": new_robot_pose.to_arrow()}
139+
)
140+
141+
# Send another tick
142+
event_queue.put({"type": "INPUT", "id": "tick", "value": pa.array([])})
143+
144+
# Check the new output
145+
threading.Event().wait(0.1) # Wait for async operations
146+
mock_node_instance.send_output.assert_called_once()
147+
args, _ = mock_node_instance.send_output.call_args
148+
assert args[0] == "command_2d"
149+
command_output = Twist2D.from_arrow(args[1])
150+
assert command_output.angular_z > 0.0, f"Expected to turn, got {command_output}"
151+
152+
event_queue.put(None) # Signal the iterator to end
153+
main_thread.join(timeout=1) # Wait for the thread to finish
154+
155+
156+
def test_stop_after_stop_signal():
157+
"""Test that the navigator stops after receiving a stop signal."""
158+
event_queue = queue.Queue()
159+
160+
with patch("navigator.main.Node") as MockNode:
161+
mock_node_instance = MagicMock()
162+
MockNode.return_value = mock_node_instance
163+
mock_node_instance.__iter__.return_value = iter(event_queue.get, None)
164+
165+
from navigator.main import main
166+
167+
# Start the main function in a separate thread
168+
main_thread = threading.Thread(target=main, daemon=True)
169+
main_thread.start()
170+
171+
# Send stop signal
172+
event_queue.put({"type": "INPUT", "id": "stop"})
173+
threading.Event().wait(0.1)
174+
175+
# Check if the thread has finished
176+
assert not main_thread.is_alive(), "Main thread did not stop after stop signal"

nodes/policy_controller/tests/test_policy_controller.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Test module for policy_controller package."""
22

3+
from unittest.mock import MagicMock, patch
4+
5+
import numpy as np
36
import pytest
7+
from msgs import JointCommands, Observations, Timestamp, Twist2D
48

59

610
def test_import_main():
@@ -11,3 +15,45 @@ def test_import_main():
1115
# as we're not running in a Dora dataflow.
1216
with pytest.raises(RuntimeError):
1317
main()
18+
19+
20+
def test_generates_commands():
21+
"""Check if the policy runs and generates joint commands."""
22+
with patch("policy_controller.main.Node") as MockNode:
23+
mock_node_instance = MagicMock()
24+
MockNode.return_value = mock_node_instance
25+
26+
# Create mock inputs
27+
command_2d = Twist2D(linear_x=0.5, linear_y=0.0, angular_z=0.0)
28+
observations = Observations(
29+
lin_vel=np.zeros(3),
30+
ang_vel=np.zeros(3),
31+
gravity=np.array([0.0, 0.0, -9.81]),
32+
joint_positions=np.zeros(12),
33+
joint_velocities=np.zeros(12),
34+
height_scan=np.zeros(154),
35+
)
36+
clock = Timestamp.now()
37+
38+
# The mocked node will yield these events when iterated.
39+
mock_node_instance.__iter__.return_value = [
40+
{"type": "INPUT", "id": "command_2d", "value": command_2d.to_arrow()},
41+
{"type": "INPUT", "id": "clock", "value": clock.to_arrow()},
42+
{
43+
"type": "INPUT",
44+
"id": "observations",
45+
"value": observations.to_arrow(),
46+
},
47+
]
48+
49+
from policy_controller.main import main
50+
51+
main()
52+
53+
# Check that send_output was called with joint_commands
54+
mock_node_instance.send_output.assert_called()
55+
args, _ = mock_node_instance.send_output.call_args
56+
assert args[0] == "joint_commands"
57+
joint_commands = JointCommands.from_arrow(args[1])
58+
assert joint_commands.positions is not None
59+
assert len(joint_commands.positions) == 12

0 commit comments

Comments
 (0)