-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathhabitat_evaluation.py
More file actions
576 lines (473 loc) · 19.6 KB
/
habitat_evaluation.py
File metadata and controls
576 lines (473 loc) · 19.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
"""
Habitat ObjectNav Evaluation Script for HM3D/MP3D Datasets
This script evaluates object navigation performance using the Habitat simulator
with support for HM3D-v1, HM3D-v2, and MP3D datasets. It communicates with ROS for
real-time planning and decision making, incorporates vision-language models
for object detection and image-text matching, and generates comprehensive
evaluation metrics.
Usage:
# Run with HM3D-v1 dataset
python habitat_evaluation.py --dataset hm3dv1
# Run with HM3D-v2 dataset (default)
python habitat_evaluation.py --dataset hm3dv2
# Run with MP3D dataset
python habitat_evaluation.py --dataset mp3d
# Test specific episode
python habitat_evaluation.py --dataset hm3dv2 test_epi_num=10
Author: Zager-Zhang
"""
# Standard library imports
import argparse
import gzip
import json
import os
import signal
import time
from copy import deepcopy
# Third-party library imports
from hydra import initialize, compose
import numpy as np
import rospy
from geometry_msgs.msg import PoseStamped
from omegaconf import DictConfig
from prettytable import PrettyTable
from sensor_msgs.msg import PointCloud2
from std_msgs.msg import Int32, Int32MultiArray, Float32MultiArray, Float64
import tqdm
# Habitat-related imports
import habitat
from habitat.config.default import patch_config
from habitat.config.default_structured_configs import (
CollisionsMeasurementConfig,
FogOfWarConfig,
TopDownMapMeasurementConfig,
)
from habitat.sims.habitat_simulator.actions import HabitatSimActions
from habitat.utils.visualizations.utils import (
images_to_video,
observations_to_image,
overlay_frame,
)
# ROS message imports
from plan_env.msg import MultipleMasksWithConfidence
# Local project imports
from basic_utils.failure_check.count_files import count_files_in_directory
from basic_utils.failure_check.failure_check import check_failure, is_on_same_floor
from basic_utils.object_point_cloud_utils.object_point_cloud import (
get_object_point_cloud,
)
from basic_utils.record_episode.read_record import read_record
from basic_utils.record_episode.write_record import write_record
from habitat2ros import habitat_publisher
from llm.answer_reader.answer_reader import read_answer
from params import HABITAT_STATE, ROS_STATE, ACTION, RESULT_TYPES
from vlm.Labels import MP3D_ID_TO_NAME
from vlm.utils.get_itm_message import get_itm_message_cosine
from vlm.utils.get_object_utils import get_object
def publish_int32(publisher, data):
msg = Int32()
msg.data = data
publisher.publish(msg)
def publish_float64(publisher, data):
msg = Float64()
msg.data = data
publisher.publish(msg)
def publish_int32_array(publisher, data_list):
msg = Int32MultiArray()
msg.data = data_list
publisher.publish(msg)
def publish_float32_array(publisher, data_list):
msg = Float32MultiArray()
msg.data = data_list
publisher.publish(msg)
def signal_handler(sig, frame):
"""Handle Ctrl+C signal for graceful shutdown"""
print("Ctrl+C detected! Shutting down...")
rospy.signal_shutdown("Manual shutdown")
os._exit(0)
def transform_rgb_bgr(image):
"""Convert RGB image to BGR format"""
return image[:, :, [2, 1, 0]]
def publish_observations(event):
"""Timer callback to publish habitat observations and trigger messages"""
global msg_observations, fusion_threshold
global ros_pub, trigger_pub, confidence_threshold_pub
tmp = deepcopy(msg_observations)
ros_pub.habitat_publish_ros_topic(tmp)
publish_float64(confidence_threshold_pub, fusion_threshold)
trigger = PoseStamped()
trigger_pub.publish(trigger)
def ros_action_callback(msg):
global global_action
global_action = msg.data
def ros_state_callback(msg):
global ros_state
ros_state = msg.data
def ros_final_state_callback(msg):
global final_state
final_state = msg.data
def ros_expl_result_callback(msg):
global expl_result
expl_result = msg.data
def _parse_dataset_arg():
"""Parse CLI to choose dataset and capture remaining Hydra overrides."""
parser = argparse.ArgumentParser(
description="Habitat ObjectNav Evaluation", add_help=True
)
parser.add_argument(
"--dataset",
type=str,
choices=["hm3dv1", "hm3dv2", "mp3d"],
default="hm3dv2",
help="Choose dataset: hm3dv1, hm3dv2 or mp3d (default: hm3dv2)",
)
# Keep unknown so users can still pass Hydra-style overrides (e.g., key=value)
args, unknown = parser.parse_known_args()
return args.dataset, unknown
def main(cfg: DictConfig) -> None:
global msg_observations, global_action, ros_state, fusion_threshold
global ros_pub, trigger_pub, obj_point_cloud_pub, confidence_threshold_pub
global final_state, expl_result
# Load MP3D validation data for object category mapping
with gzip.open(
"data/datasets/objectnav/mp3d/v1/val/val.json.gz", "rt", encoding="utf-8"
) as f:
val_data = json.load(f)
category_to_coco = val_data.get("category_to_mp3d_category_id", {})
id_to_name = {
category_to_coco[cat]: MP3D_ID_TO_NAME[idx]
for idx, cat in enumerate(category_to_coco)
}
start_time = time.time()
final_state = 0
expl_result = 0
result_list = [0] * len(RESULT_TYPES)
cfg = patch_config(cfg)
# Extract configuration parameters
video_output_path = cfg.video_output_path.format(split=cfg.habitat.dataset.split)
need_video = cfg.need_video
record_file_path = os.path.join(video_output_path, cfg.record_file_name)
continue_path = os.path.join(video_output_path, cfg.continue_file_name)
max_episode_steps = cfg.habitat.environment.max_episode_steps
success_distance = cfg.habitat.task.measurements.success.success_distance
detector_cfg = cfg.detector
llm_cfg = cfg.llm
llm_client = llm_cfg.llm_client
llm_answer_path = llm_cfg.llm_answer_path
llm_response_path = llm_cfg.llm_response_path
# Single test parameters
env_num_once = cfg.test_epi_num # Which episode to test for single run
flag_once = env_num_once != -1 # Whether to run single test
# Create directories if they don't exist
os.makedirs(os.path.dirname(llm_answer_path), exist_ok=True)
os.makedirs(video_output_path, exist_ok=True)
# Add top_down_map and collisions visualization
with habitat.config.read_write(cfg):
cfg.habitat.task.measurements.update(
{
"top_down_map": TopDownMapMeasurementConfig(
map_padding=3,
map_resolution=256,
draw_source=True,
draw_border=True,
draw_shortest_path=True,
draw_view_points=True,
draw_goal_positions=True,
draw_goal_aabbs=False,
fog_of_war=FogOfWarConfig(
draw=True,
visibility_dist=5.0,
fov=79,
),
),
"collisions": CollisionsMeasurementConfig(),
}
)
env = habitat.Env(cfg)
print("Environment creation successful")
number_of_episodes = env.number_of_episodes
# Read previous records and set initial values
(
num_total,
num_success,
spl_all,
soft_spl_all,
distance_to_goal_all,
distance_to_goal_reward_all,
last_time,
) = read_record(continue_path, flag_once)
if num_total >= number_of_episodes:
raise ValueError("Already finished all episodes.")
pbar = tqdm.tqdm(total=env.number_of_episodes)
env_count = num_total if not flag_once else env_num_once
while env_count:
pbar.update()
env.current_episode = next(env.episode_iterator)
env_count -= 1
# Initialize ROS publishers, subscribers, and timers
obj_point_cloud_pub = rospy.Publisher(
"habitat/object_point_cloud", PointCloud2, queue_size=10
)
ros_pub = habitat_publisher.ROSPublisher()
rospy.Subscriber("/habitat/plan_action", Int32, ros_action_callback, queue_size=10)
rospy.Subscriber("/ros/state", Int32, ros_state_callback, queue_size=10)
rospy.Subscriber("/ros/expl_state", Int32, ros_final_state_callback, queue_size=10)
rospy.Subscriber("/ros/expl_result", Int32, ros_expl_result_callback, queue_size=10)
state_pub = rospy.Publisher("/habitat/state", Int32, queue_size=10)
trigger_pub = rospy.Publisher("/move_base_simple/goal", PoseStamped, queue_size=10)
itm_score_pub = rospy.Publisher("/blip2/cosine_score", Float64, queue_size=10)
confidence_threshold_pub = rospy.Publisher(
"/detector/confidence_threshold", Float64, queue_size=10
)
cld_with_score_pub = rospy.Publisher(
"/detector/clouds_with_scores", MultipleMasksWithConfidence, queue_size=10
)
progress_pub = rospy.Publisher("/habitat/progress", Int32MultiArray, queue_size=10)
record_pub = rospy.Publisher("/habitat/record", Float32MultiArray, queue_size=10)
for epi in range(number_of_episodes - num_total):
# Publish progress information
publish_int32_array(progress_pub, [num_total, number_of_episodes])
if flag_once:
while env_count:
env.current_episode = next(env.episode_iterator)
env_count -= 1
# Initialize episode variables
pass_object = 0.0
near_object = 0.0
global_action = None
cld_with_score_msg = MultipleMasksWithConfidence()
count_steps = 0
camera_pitch = 0.0
observations = env.reset()
observations["camera_pitch"] = camera_pitch
msg_observations = deepcopy(observations)
del observations["camera_pitch"]
label = env.current_episode.object_category
# Convert object category to coco name format
if label in category_to_coco:
coco_id = category_to_coco[label]
label = id_to_name.get(coco_id, label)
# Get LLM answer and fusion threshold for the target object
llm_answer, room, fusion_threshold = read_answer(
llm_answer_path, llm_response_path, label, llm_client
)
# Initialize video frame collection
vis_frames = []
info = env.get_metrics()
if need_video:
frame = observations_to_image(observations, info)
info.pop("top_down_map")
frame = overlay_frame(frame, info)
vis_frames = [frame]
# Start publishing basic information and trigger messages
pub_timer = rospy.Timer(rospy.Duration(0.25), publish_observations)
print("Agent is waiting in the environment!!!")
# Wait for ROS system to be ready
rate = rospy.Rate(10)
ros_state = ROS_STATE.INIT
while ros_state == ROS_STATE.INIT or ros_state == ROS_STATE.WAIT_TRIGGER:
if ros_state == ROS_STATE.INIT:
print("Waiting for ROS to get odometry...")
elif ros_state == ROS_STATE.WAIT_TRIGGER:
print("Waiting for ROS trigger...")
rate.sleep()
# Stop timer publishing when starting action execution
pub_timer.shutdown()
print("Agent is ready to go!!!!")
rate = rospy.Rate(10)
while not rospy.is_shutdown() and not env.episode_over:
# Skip episode if target is not on the same floor
is_feasible = 0
for goal in env.current_episode.goals:
height = goal.position[1]
is_feasible += is_on_same_floor(
height=height, episode=env.current_episode
)
if not is_feasible:
break
# Parse action from decision system
action = None
if global_action is not None:
if count_steps == max_episode_steps - 1:
global_action = ACTION.STOP
if global_action == ACTION.MOVE_FORWARD:
action = HabitatSimActions.move_forward
elif global_action == ACTION.TURN_LEFT:
action = HabitatSimActions.turn_left
elif global_action == ACTION.TURN_RIGHT:
action = HabitatSimActions.turn_right
elif global_action == ACTION.TURN_DOWN:
action = HabitatSimActions.look_down
camera_pitch = camera_pitch - np.pi / 6.0
elif global_action == ACTION.TURN_UP:
action = HabitatSimActions.look_up
camera_pitch = camera_pitch + np.pi / 6.0
elif global_action == ACTION.STOP:
action = HabitatSimActions.stop
global_action = None
if action is None:
continue
count_steps += 1
print(f"\n--------------Step: {count_steps}--------------")
print(f"Finding [{label}]; Action: {action};")
# Notify ROS system that action execution is starting
publish_int32(state_pub, HABITAT_STATE.ACTION_EXEC)
observations = env.step(action)
# Calculate ITM cosine similarity score
cosine = get_itm_message_cosine(observations["rgb"], label, room)
print(f"Target related room: {room}")
print(f"ITM cosine similarity: {cosine:.3f}")
publish_float64(itm_score_pub, cosine)
# Detect objects in the current observation
observations["rgb"], score_list, object_masks_list, label_list = get_object(
label, observations["rgb"], detector_cfg, llm_answer
)
# Publish habitat observations to ROS
observations["camera_pitch"] = camera_pitch
msg_observations = deepcopy(observations)
del observations["camera_pitch"]
ros_pub.habitat_publish_ros_topic(msg_observations)
# Generate and publish object point clouds
obj_point_cloud_list = get_object_point_cloud(
cfg, observations, object_masks_list
)
# Publish detection-related information
cld_with_score_msg.point_clouds = obj_point_cloud_list
cld_with_score_msg.confidence_scores = score_list
cld_with_score_msg.label_indices = label_list
cld_with_score_pub.publish(cld_with_score_msg)
# Generate video frame
info = env.get_metrics()
if need_video:
frame = observations_to_image(observations, info)
info.pop("top_down_map")
frame = overlay_frame(frame, info)
vis_frames.append(frame)
# Track if agent has passed close to the target
distance_to_goal = info["distance_to_goal"]
if distance_to_goal <= success_distance and pass_object == 0:
pass_object = 1
# Notify ROS system that action execution is complete
publish_int32(state_pub, HABITAT_STATE.ACTION_FINISH)
rate.sleep()
# Notify ROS system that current episode evaluation is complete
publish_int32(state_pub, HABITAT_STATE.EPISODE_FINISH)
# Collect evaluation metrics
info = env.get_metrics()
spl = info["spl"]
soft_spl = info["soft_spl"]
distance_to_goal = info["distance_to_goal"]
distance_to_goal_reward = info["distance_to_goal_reward"]
success = info["success"]
# Check if agent got close to the target object
if distance_to_goal <= success_distance:
near_object = 1
# Determine episode result
if success == 1:
num_success += 1
result_text = "success"
else:
result_text = check_failure(
env.current_episode,
final_state,
expl_result,
count_steps,
max_episode_steps,
pass_object,
near_object,
)
# Update cumulative statistics
num_total += 1
spl_all += spl
soft_spl_all += soft_spl
distance_to_goal_all += distance_to_goal
distance_to_goal_reward_all += distance_to_goal_reward
# Generate video file
scene_id = env.current_episode.scene_id
episode_id = env.current_episode.episode_id
video_name = f"{os.path.basename(scene_id)}_{episode_id}"
time_spend = time.time() - start_time + last_time
img2video_output_path = os.path.join(video_output_path, result_text)
if flag_once:
img2video_output_path = "videos"
video_name = "video_once"
if need_video:
images_to_video(
vis_frames, img2video_output_path, video_name, fps=6, quality=9
)
vis_frames.clear()
# Display average performance metrics
table1 = PrettyTable(["Metric", "Average"])
table1.add_row(["Average Success", f"{num_success/num_total * 100:.2f}%"])
table1.add_row(["Average SPL", f"{spl_all/num_total * 100:.2f}%"])
table1.add_row(["Average Soft SPL", f"{soft_spl_all/num_total * 100:.2f}%"])
table1.add_row(
["Average Distance to Goal", f"{distance_to_goal_all/num_total:.4f}"]
)
print(table1)
print(f"Episode {num_total} data written to {record_file_path}")
print(f"Result: {result_text}")
# Display total performance metrics
table2 = PrettyTable(["Metric", "Total"])
table2.add_row(["Total Success", f"{num_success}"])
table2.add_row(["Total SPL", f"{spl_all:.2f}"])
table2.add_row(["Total Soft SPL", f"{soft_spl_all:.2f}"])
table2.add_row(["Total Distance to Goal", f"{distance_to_goal_all:.4f}"])
if flag_once:
break
# Write results to record file
write_record(
scene_id,
episode_id,
table1,
result_text,
label,
num_total,
time_spend,
record_file_path,
)
# Write results to continue file
write_record(
scene_id,
episode_id,
table2,
result_text,
label,
num_total,
time_spend,
continue_path,
)
# Count files in each result category folder
for i in range(len(RESULT_TYPES)):
folder = RESULT_TYPES[i] # Get current category (folder name)
folder_path = os.path.join(video_output_path, folder) # Build folder path
file_count = count_files_in_directory(folder_path) # Count files in folder
result_list[i] = file_count
# Publish comprehensive record data
record_data = [
num_success / num_total * 100,
spl_all / num_total * 100,
soft_spl_all / num_total * 100,
distance_to_goal_all / num_total,
]
record_data.extend(result_list)
publish_float32_array(record_pub, record_data)
pbar.update()
env.current_episode = next(env.episode_iterator)
rospy.sleep(0.1) # wait a moment
env.close()
pbar.close()
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
rospy.init_node("habitat_eval_node", anonymous=True)
try:
dataset, overrides = _parse_dataset_arg()
cfg_name = f"habitat_eval_{dataset}"
# Compose the chosen config and pass through extra Hydra overrides
with initialize(version_base=None, config_path="config"):
cfg = compose(config_name=cfg_name, overrides=overrides)
main(cfg)
except Exception as e:
print(f"Unexpected error occurred: {e}")
rospy.signal_shutdown("Shutdown due to error")
os._exit(1)