-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdatasetWrapper.py
More file actions
128 lines (100 loc) · 4.99 KB
/
datasetWrapper.py
File metadata and controls
128 lines (100 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from datasets import load_dataset, Dataset
from src.text2image import AbstractText2Image, FixedSizeText2Image
from src.log import Log
from src.text2image import FilteredFixedSizeText2Image
from src.filters import PushFrontTextFilter
CACHE_NAME = "cache"
class AbstractDatasetWrapper:
"""
Abstract base class for dataset wrappers.
Ensures that each dataset wrapper implements a standardized dataset structure.
"""
def __init__(self, text2image: AbstractText2Image):
self.dataset: Dataset = None # Should be implemented in subclasses
self.dataset_id = None
self._text2image = text2image
class GSM8KWrapper(AbstractDatasetWrapper):
def __init__(self, text2image: AbstractText2Image = FixedSizeText2Image(), cache_filename: str=""):
self.dataset_id = "GSM8k"
self._text2image = text2image
Log().logger.info(f"Loading {self.dataset_id} dataset...")
if not os.path.exists(CACHE_NAME):
os.makedirs(CACHE_NAME)
if not cache_filename:
cache_filename = f"dataset_{self.dataset_id}"
cache_path = os.path.join(CACHE_NAME, cache_filename)
if os.path.exists(cache_path):
Log().logger.info(f"Found cached dataset at {cache_path}. Loading from cache...")
self.dataset = Dataset.load_from_disk(str(cache_path))
return
else:
try:
self.dataset = load_dataset("gsm8k", "main")["test"]
except Exception as e:
Log().logger.error(f"Error loading dataset: {e}")
raise e
try:
longest_question_sample = max(self.dataset, key=lambda x: len(x['question']))
longest_question = longest_question_sample['question']
if isinstance(self._text2image, FilteredFixedSizeText2Image) and isinstance(self._text2image.filter, PushFrontTextFilter):
longest_question = self._text2image.filter.apply_filter(longest_question)
self._text2image.set_font_size(longest_question)
Log().logger.info(f"Font size for longest question: {self._text2image.font_size}")
self.dataset = self.dataset.map(self._map_sample)
except Exception as e:
Log().logger.error(f"Error mapping dataset: {e}")
raise e
try:
self.dataset.save_to_disk(cache_path)
Log().logger.info(f"Cached dataset at {cache_path}")
except Exception as e:
Log().logger.error(f"Error caching dataset: {e}")
raise e
Log().logger.info(f"Loaded {self.dataset_id} dataset with {len(self.dataset)} samples.")
def _map_sample(self, sample):
sample["answer"] = sample["answer"].split("####")[-1].strip()
try:
sample["question_image"] = self._text2image.create_image(sample["question"])
except Exception as e:
Log().logger.error(f"Error creating image: {e}")
raise e
return sample
# class GSM8kWrapper_GSM8k_5_samples(AbstractDatasetWrapper):
# """
# A slimmed-down GSM8k wrapper that loads only 5 examples.
# """
# def __init__(self, text2image: AbstractText2Image = FixedSizeText2Image()):
# super().__init__(text2image)
# self.dataset_id = "GSM8k_5_samples"
# self._text2image = text2image
# Log().logger.info(f"Loading 5 examples from {self.dataset_id}...")
# try:
# # load only the first 5 test samples
# raw = load_dataset("gsm8k", "main", split="test[:5]")
# except Exception as e:
# Log().logger.error(f"Error loading 5-sample GSM8k: {e}")
# raise
# # find the question with the maximum length among these 5
# questions = [ex["question"] for ex in raw]
# longest_question = max(questions, key=len)
# Log().logger.info(
# f"Longest question (of 5): '{longest_question}' (len={len(longest_question)})"
# )
# self._text2image.set_font_size(longest_question)
# Log().logger.info(f"Font size for longest question: {self._text2image.font_size}")
# # apply mapping (answer cleanup + image creation)
# try:
# self.dataset = raw.map(self._map_sample,load_from_cache_file=False,keep_in_memory=True)
# except Exception as e:
# Log().logger.error(f"Error mapping 5-sample GSM8k: {e}")
# raise
# Log().logger.info(f"Loaded {len(self.dataset)} samples from {self.dataset_id}.")
# def _map_sample(self, sample):
# sample["answer"] = sample["answer"].split("####")[-1].strip()
# try:
# sample["question_image"] = self._text2image.create_image(sample["question"])
# except Exception as e:
# Log().logger.error(f"Error creating image: {e}")
# raise e
# return sample