-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
99 lines (83 loc) · 3.1 KB
/
data_loader.py
File metadata and controls
99 lines (83 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Crear una clase que contenga los datos en formato torch.tensor y que tenga dos métodos
# uno train con los datos del train y el otro test con los datos del test
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
def balance_dataframe(df, label_col="label"):
# Find the minimum class count
# min_count = df[label_col].value_counts().max()
# Sample min_count from each class
# balanced_df = df.groupby(label_col).sample(n=min_count, random_state=42)
# return balanced_df.reset_index(drop=True)
counts = df["label"].value_counts()
n = counts.max()
balanced = (
df.groupby("label", group_keys=False)
.apply(lambda g: g.sample(n=n, replace=True, random_state=42))
.sample(frac=1, random_state=42) # shuffle
.reset_index(drop=True)
)
return balanced
def split_dataframe(df, test_size=0.2, random_state=42):
train_df, test_df = train_test_split(
df, test_size=test_size, stratify=df["label"], random_state=random_state
)
return train_df.reset_index(drop=True), test_df.reset_index(drop=True)
class CustomData(Dataset):
def __init__(
self,
dataframe: pd.DataFrame,
img_dir: str = "data/img",
transform=None,
augment: bool = False,
):
self.labels_df = dataframe
self.img_dir = img_dir
self.transform = transform if transform is not None else transforms.ToTensor()
self.augment = augment
def __len__(self):
return len(self.labels_df) * (2 if self.augment else 1)
def __getitem__(self, idx):
# If augmenting, every other idx is a flipped image
real_idx = idx // 2 if self.augment else idx
flip = self.augment and (idx % 2 == 1)
img_id = self.labels_df.iloc[real_idx, 0]
label = self.labels_df.iloc[real_idx, 1]
img_path = f"{self.img_dir}/{img_id}.png"
image = Image.open(img_path).convert("RGB")
if flip:
image = image.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
def get_train_test_datasets(
csv_file="data/labels.csv", img_dir="data/img", transform=None, test_size=0.2
):
df = pd.read_csv(csv_file)
balanced_df = balance_dataframe(df)
train_df, test_df = split_dataframe(balanced_df, test_size=test_size)
train_dataset = CustomData(
train_df, img_dir=img_dir, transform=transform, augment=True
)
test_dataset = CustomData(
test_df, img_dir=img_dir, transform=transform, augment=True
)
return train_dataset, test_dataset
"""
Example
-------
>>> # Create dataset & loader
>>> dataset = CustomImageDataset(
>>> csv_file="labels.csv",
>>> img_dir="data/img",
>>> transform=transform
>>> )
>>> dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
>>> # Iterate through batches
>>> for images, labels in dataloader:
>>> print(images.shape, labels)
>>> break
"""