-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
96 lines (78 loc) · 3.42 KB
/
data_loader.py
File metadata and controls
96 lines (78 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import random
import h5py
import torch
from torch.utils.data import Dataset
import numpy as np
# --- Data Augmentation and Normalization Functions ---
def rotate_point_cloud_z(pc):
""" Randomly rotate the point clouds to augment the dataset """
rotation_angle = np.random.uniform() * 2 * np.pi
cosval, sinval = np.cos(rotation_angle), np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, -sinval, 0],
[sinval, cosval, 0],
[0, 0, 1]])
return np.dot(pc, rotation_matrix)
def jitter_point_cloud(pc, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """
N, C = pc.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
return pc + jittered_data
def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per shape. """
scale = np.random.uniform(scale_low, scale_high)
return pc * scale
def pc_normalize(pc):
""" Statically normalize the point cloud to a unit sphere """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / (m + 1e-9)
return pc
class ShapeNetPartDataset(Dataset):
"""
Dataloader for the HDF5 version of ShapeNetPart.
"""
def __init__(self, file_path, split='train', num_points=2048, augment=False):
self.root = file_path
self.npoints = num_points
self.split = split
self.augment = augment
self.all_points = []
self.all_seg_labels = []
self.all_cls_labels = []
h5_files = [f for f in os.listdir(self.root) if f.endswith('.h5') and self.split in f]
if not h5_files:
raise FileNotFoundError(f"No H5 files found for split '{self.split}' in '{self.root}'")
print(f"Loading H5 files for '{self.split}' split: {sorted(h5_files)}")
for h5_filename in sorted(h5_files):
with h5py.File(os.path.join(self.root, h5_filename), 'r') as f:
points = f['data'][:]
# --- THIS IS THE FIX ---
# Changed 'pid' back to 'seg' to match this specific Kaggle dataset
seg_labels = f['seg'][:]
cls_labels = f['label'][:]
self.all_points.append(points)
self.all_seg_labels.append(seg_labels)
self.all_cls_labels.append(cls_labels)
self.all_points = np.concatenate(self.all_points, axis=0)
self.all_seg_labels = np.concatenate(self.all_seg_labels, axis=0)
self.all_cls_labels = np.concatenate(self.all_cls_labels, axis=0).squeeze()
print(f'The size of {self.split} data is {len(self.all_points)}')
def __len__(self):
return len(self.all_points)
def __getitem__(self, index):
points = self.all_points[index][:self.npoints].copy()
seg_labels = self.all_seg_labels[index][:self.npoints].copy()
cls_label = self.all_cls_labels[index].copy()
if self.augment:
points = rotate_point_cloud_z(points)
points = jitter_point_cloud(points)
points = random_scale_point_cloud(points)
points = pc_normalize(points)
return (
torch.from_numpy(points).float(),
torch.tensor(cls_label, dtype=torch.long),
torch.from_numpy(seg_labels).long()
)