-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize.py
More file actions
83 lines (68 loc) · 3.09 KB
/
visualize.py
File metadata and controls
83 lines (68 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torch.utils.data import DataLoader
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from model import PointNetSeg
from data_loader import ShapeNetPartDataset
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
# Create visualization directory
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
# --- Data Loading ---
test_dataset = ShapeNetPartDataset(file_path=args.data_path, split='test')
# Get a specific sample
if args.index >= len(test_dataset):
print(f"Error: Index {args.index} is out of bounds for the test set of size {len(test_dataset)}.")
return
points, _, seg_labels = test_dataset[args.index]
points, seg_labels = points.unsqueeze(0).to(device), seg_labels.unsqueeze(0).to(device)
# --- Model ---
num_part_classes = 50
model = PointNetSeg(num_part_classes=num_part_classes).to(device)
model.load_state_dict(torch.load(args.checkpoint_path, map_location=device))
model.eval()
# --- Prediction ---
with torch.no_grad():
preds, _ = model(points)
preds = preds.view(-1, num_part_classes)
pred_choice = preds.data.max(1)[1]
# --- Visualization ---
points_np = points.cpu().numpy().squeeze()
gt_np = seg_labels.cpu().numpy().squeeze()
pred_np = pred_choice.cpu().numpy()
# Create a color map
# Using a part of the 'tab20' colormap which has distinct colors
colors = plt.cm.get_cmap('tab20', num_part_classes)
fig = plt.figure(figsize=(12, 6))
# Plot Ground Truth
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=gt_np, cmap=colors, s=10)
ax1.set_title('Ground Truth Segmentation')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')
ax1.view_init(elev=20, azim=-45)
# Plot Prediction
ax2 = fig.add_subplot(122, projection='3d')
ax2.scatter(points_np[:, 0], points_np[:, 1], points_np[:, 2], c=pred_np, cmap=colors, s=10)
ax2.set_title('Predicted Segmentation')
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')
ax2.view_init(elev=20, azim=-45)
plt.tight_layout()
plt.savefig(args.output_file)
print(f"Visualization saved to {args.output_file}")
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PointNet Part Segmentation Visualization')
parser.add_argument('--data_path', type=str, required=True, help='Path to the HDF5 dataset file')
parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to the trained model checkpoint')
parser.add_argument('--index', type=int, default=0, help='Index of the test sample to visualize')
parser.add_argument('--output_file', type=str, default='./results/visualizations/segmentation_comparison.png', help='Path to save the output visualization image')
parser.add_argument('--no_cuda', action='store_true', help='disables CUDA')
args = parser.parse_args()
main(args)