diff --git a/deeptuner/datagenerators/triplet_data_generator.py b/deeptuner/datagenerators/triplet_data_generator.py index 18523e7..665ac71 100644 --- a/deeptuner/datagenerators/triplet_data_generator.py +++ b/deeptuner/datagenerators/triplet_data_generator.py @@ -13,6 +13,15 @@ def __init__(self, image_paths, labels, batch_size, image_size, num_classes): self.num_classes = num_classes self.label_encoder = LabelEncoder() self.encoded_labels = self.label_encoder.fit_transform(labels) + + # Pre-compute label to paths mapping for O(1) sampling + self.label_to_paths = {} + for path, label in zip(self.image_paths, self.encoded_labels): + if label not in self.label_to_paths: + self.label_to_paths[label] = [] + self.label_to_paths[label].append(path) + self.unique_labels = np.array(list(self.label_to_paths.keys())) + self.image_data_generator = ImageDataGenerator(preprocessing_function=resnet.preprocess_input) self.on_epoch_end() print(f"Initialized TripletDataGenerator with {len(self.image_paths)} images") @@ -40,12 +49,13 @@ def _generate_triplet_batch(self, batch_image_paths, batch_labels): anchor_path = batch_image_paths[i] anchor_label = batch_labels[i] - positive_path = np.random.choice( - [p for p, l in zip(self.image_paths, self.encoded_labels) if l == anchor_label] - ) - negative_path = np.random.choice( - [p for p, l in zip(self.image_paths, self.encoded_labels) if l != anchor_label] - ) + positive_path = np.random.choice(self.label_to_paths[anchor_label]) + + while True: + neg_label = np.random.choice(self.unique_labels) + if neg_label != anchor_label: + break + negative_path = np.random.choice(self.label_to_paths[neg_label]) anchor_image = load_img(anchor_path, target_size=self.image_size) positive_image = load_img(positive_path, target_size=self.image_size)