55
66from click import secho
77import random
8-
8+
99from datasets import Dataset , load_dataset , Features , Value
10+ from adapters .miners import HardNegativeMiner
1011
1112RANDOM_STATE = 42
1213random .seed (RANDOM_STATE )
1314
14- DATASET_CHUNK_SIZES = { # Define specific chunk sizes here
15- "wayfair" : 100 ,
16- "amazon" : 5000 ,
17- # Add other datasets if needed
18- }
19- DEFAULT_CHUNK_SIZE = 1000
20-
2115class BaseDataset (ABC ):
2216 def __init__ (
2317 self ,
@@ -32,8 +26,7 @@ def __init__(
3226 self ._chunk_size = chunk_size
3327 self ._num_procs = cpu_count () - 1
3428 self ._split = split
35- self ._data = self .load (split , cols )
36- secho (f"Total records loaded: { len (self ._data )} " , fg = "green" )
29+ self ._data = None
3730
3831 @property
3932 def repo_id (self ):
@@ -51,35 +44,49 @@ def n_queries(self):
5144 def n_documents (self ):
5245 return self ._n_documents
5346
54- def generate_query (self , queries_already_sampled = False ):
55- secho (f"Generating queries for { self .name } dataset..." , fg = "blue" )
47+ def generate_query (self ):
48+ secho (f"Generating queries for { self .name } dataset" , fg = "blue" )
49+ secho (f"Initial dataset size: { len (self ._data )} " , fg = "blue" )
5650
51+ self ._data = self ._data .map (
52+ lambda x : {"query" : x ["query" ].lower ()},
53+ num_proc = self ._num_procs ,
54+ )
5755 self ._unique_queries = list (set (self ._data .unique ("query" )))
5856 self ._n_queries = len (self ._unique_queries )
59-
60- if not queries_already_sampled and self ._sample_size is not None and self ._sample_size < self ._n_queries :
61- secho (f"Applying sampling in BaseDataset.generate_query: { self ._sample_size } queries" , fg = "yellow" )
62- sampled_queries = random .sample (self ._unique_queries , self ._sample_size )
63- self ._unique_queries = sampled_queries
64- self ._n_queries = len (self ._unique_queries )
65- self ._data = self ._data .filter (
66- lambda x : x ["query" ] in self ._unique_queries ,
67- num_proc = self ._num_procs
68- )
69-
57+ secho (f"Unique queries before sampling: { self ._n_queries } " , fg = "blue" )
58+
59+ if self ._sample_size is not None and self ._sample_size < self ._n_queries :
60+ secho (f"Sampling { self ._sample_size } queries from { self ._n_queries } total queries" , fg = "green" )
61+ sampled_queries = random .sample (self ._unique_queries , self ._sample_size )
62+ self ._unique_queries = sampled_queries
63+ self ._n_queries = len (self ._unique_queries )
64+
65+ self ._data = self ._data .filter (
66+ lambda x : x ["query" ] in self ._unique_queries ,
67+ num_proc = self ._num_procs
68+ )
69+ secho (f"Filtered dataset to { len (self ._data )} records with sampled queries" , fg = "green" )
70+
71+ # Create chunks for the queries
7072 chunks = {}
71- effective_chunk_size = DATASET_CHUNK_SIZES .get (getattr (self , 'name' , None ), self ._chunk_size or DEFAULT_CHUNK_SIZE )
72-
73- if self ._n_queries > 0 and effective_chunk_size > 0 :
74- for i in range (0 , self ._n_queries , effective_chunk_size ):
75- chunk_index = i // effective_chunk_size
76- chunks [chunk_index ] = self ._unique_queries [i :i + effective_chunk_size ]
73+ if self ._chunk_size is not None :
74+ if self .name == "wayfair" :
75+ self ._chunk_size = 100
76+ elif self .name == "amazon" :
77+ self ._chunk_size = 5000
78+
79+ for i in range (0 , self ._n_queries , self ._chunk_size ):
80+ chunk_index = i // self ._chunk_size
81+ chunks [chunk_index ] = self ._unique_queries [i :i + self ._chunk_size ]
82+ secho (f"Chunk { chunk_index } : { len (chunks [chunk_index ])} queries" , fg = "blue" )
7783 else :
7884 chunks = {0 : self ._unique_queries }
85+ secho (f"Single chunk with { len (chunks [0 ])} queries" , fg = "blue" )
7986
80- self ._max_chunks = len (chunks )
87+ self ._max_chunks = len (chunks . keys () )
8188 self ._query_chunks = chunks
82- secho (f"Total query chunks created : { self ._max_chunks } " , fg = "blue" )
89+ secho (f"Total chunks: { self ._max_chunks } " , fg = "blue" )
8390
8491 def generate_document (self ):
8592 pass
@@ -123,137 +130,80 @@ def generate_pairs(self):
123130 pairs = pairs .add_column ("source" , source )
124131 secho (f"Generated { len (pairs )} pairs." , fg = "green" )
125132 secho (f"Queries: { self .n_queries } , Documents: { self .n_documents } ." , fg = "green" )
133+ # secho(f"Pairs sample: {pairs[0]}", fg=(229, 192, 123))
126134 return pairs
127135
128136 def generate_triplets (self , threshold = 3.0 , chunk_index : int = None ):
137+ secho (f"Generating triplets for { self .name } dataset with threshold { threshold } " , fg = "blue" )
138+ positives = self .generate_positives (threshold = threshold ).to_pandas ()
139+ secho (f"Generated { len (positives )} positives for { self .name } " , fg = "blue" )
140+
141+ negatives = self .generate_negatives (threshold = threshold ).to_pandas ()
142+ secho (f"Generated { len (negatives )} negatives for { self .name } " , fg = "blue" )
143+
129144 if chunk_index is not None :
130- secho (f"Generating triplets for { self .name } chunk { chunk_index } ..." , fg = "blue" )
145+ chunk_queries = self ._query_chunks .get (chunk_index , [])
146+ secho (f"Filtering for chunk { chunk_index } with { len (chunk_queries )} queries" , fg = "blue" )
147+ positives = positives [positives ["anchor" ].isin (chunk_queries )]
148+ negatives = negatives [negatives ["anchor" ].isin (chunk_queries )]
149+ secho (f"After filtering: { len (positives )} positives, { len (negatives )} negatives" , fg = "blue" )
131150
132- chunk_data = self ._data
151+ if len (positives ) == 0 or len (negatives ) == 0 :
152+ secho (f"Not enough data to generate triplets: { len (positives )} positives, { len (negatives )} negatives" , fg = "red" )
153+ return Dataset .from_dict ({
154+ "anchor" : [],
155+ "positive" : [],
156+ "negative" : [],
157+ "margin" : [],
158+ "source" : [],
159+ "metadata" : []
160+ }, features = Features ({
161+ "anchor" : Value ("string" ),
162+ "positive" : Value ("string" ),
163+ "negative" : Value ("string" ),
164+ "margin" : Value ("float64" ),
165+ "source" : Value ("string" ),
166+ "metadata" : Value ("string" )
167+ }))
133168
134- if chunk_index is not None and self ._query_chunks and chunk_index in self ._query_chunks :
135- chunk_queries = set (self ._query_chunks [chunk_index ])
136- if not chunk_queries :
137- return self ._create_empty_triplet_dataset ()
138-
139- chunk_data = self ._data .filter (
140- lambda x : x ["query" ] in chunk_queries ,
141- num_proc = self ._num_procs
142- )
143- elif chunk_index is not None :
144- secho (f"Warning: Chunk index { chunk_index } not found or query chunks empty." , fg = "yellow" )
145- return self ._create_empty_triplet_dataset ()
169+ triplets = positives .merge (negatives , on = "anchor" , suffixes = ("_positive" , "_negative" ))
170+ secho (f"Merged into { len (triplets )} triplets for { self .name } " , fg = "blue" )
146171
147- positives_ds = self .generate_positives (threshold = threshold , data_subset = chunk_data )
148- negatives_ds = self .generate_negatives (threshold = threshold , data_subset = chunk_data )
149-
150- if len (positives_ds ) == 0 or len (negatives_ds ) == 0 :
151- return self ._create_empty_triplet_dataset ()
152-
153- try :
154- positives = positives_ds .to_pandas ()
155- negatives = negatives_ds .to_pandas ()
156- except Exception as e :
157- secho (f"Error converting dataset subset to pandas (chunk { chunk_index } ): { e } " , fg = "red" )
158- return self ._create_empty_triplet_dataset ()
159-
160- positives = positives .rename (columns = {"positive" : "document" , "relevance" : "relevance_positive" })
161- negatives = negatives .rename (columns = {"negative" : "document" , "relevance" : "relevance_negative" })
162-
163- if "anchor" not in positives .columns or "anchor" not in negatives .columns :
164- secho ("Error: 'anchor' column missing before merge." , fg = "red" )
165- return self ._create_empty_triplet_dataset ()
166- if "relevance_positive" not in positives .columns :
167- positives ['relevance_positive' ] = threshold
168- secho ("Warning: 'relevance' column missing in positives, added default." , fg = "yellow" )
169- if "relevance_negative" not in negatives .columns :
170- negatives ['relevance_negative' ] = threshold - 0.1
171- secho ("Warning: 'relevance' column missing in negatives, added default." , fg = "yellow" )
172-
173- try :
174- triplets = positives .merge (negatives , on = "anchor" , suffixes = ("_pos" , "_neg" ))
175- except Exception as e :
176- secho (f"Error merging pandas DataFrames (chunk { chunk_index } ): { e } " , fg = "red" )
177- return self ._create_empty_triplet_dataset ()
178-
179- if triplets .empty :
180- return self ._create_empty_triplet_dataset ()
181-
182172 triplets ["margin" ] = round (triplets ["relevance_positive" ] - triplets ["relevance_negative" ], 2 )
183173 triplets ["source" ] = self .name
184- triplets = triplets .rename (columns = {"document_pos" : "positive" , "document_neg" : "negative" })
185-
186- metadata_cols = [col for col in ['relevance_positive' , 'relevance_negative' ] if col in triplets .columns ]
187- if metadata_cols :
188- try :
189- triplets ["metadata" ] = triplets [metadata_cols ].apply (lambda x : json .dumps (x .to_dict ()), axis = 1 )
190- triplets = triplets .drop (columns = metadata_cols )
191- except Exception as e :
192- secho (f"Error creating metadata JSON (chunk { chunk_index } ): { e } " , fg = "yellow" )
193- triplets ["metadata" ] = "{}"
194- else :
195- triplets ["metadata" ] = "{}"
196-
197- final_cols = ["anchor" , "positive" , "negative" , "margin" , "source" , "metadata" ]
198- missing_cols = [col for col in final_cols if col not in triplets .columns ]
199- if missing_cols :
200- secho (f"Error: Final columns missing before Dataset creation: { missing_cols } " , fg = "red" )
201- return self ._create_empty_triplet_dataset ()
202-
203- triplets_final_df = triplets [final_cols ]
204174
205- try :
206- triplets_dataset = Dataset .from_pandas (triplets_final_df , preserve_index = False , features = self ._get_triplet_features ())
207- secho (f"Generated { len (triplets_dataset )} triplets for chunk { chunk_index } ." , fg = "green" )
208- return triplets_dataset
209- except Exception as e :
210- secho (f"Error converting final DataFrame to Dataset (chunk { chunk_index } ): { e } " , fg = "red" )
211- return self ._create_empty_triplet_dataset ()
175+ include_cols = {"anchor" , "positive" , "negative" , "margin" , "source" }
176+ metadata_cols = [col for col in triplets .columns if col not in include_cols ]
177+ triplets ["metadata" ] = triplets [metadata_cols ].apply (lambda x : json .dumps (x .to_dict ()), axis = 1 )
178+ triplets = triplets .drop (columns = metadata_cols )
212179
213- def _get_triplet_features (self ):
214- return Features ({
215- "anchor" : Value ("string" ),
216- "positive" : Value ("string" ),
217- "negative" : Value ("string" ),
218- "margin" : Value ("float64" ),
219- "source" : Value ("string" ),
220- "metadata" : Value ("string" )
221- })
180+ triplets = Dataset .from_pandas (triplets , preserve_index = False )
181+ secho (f"Generated { len (triplets )} triplets for { self .name } ." , fg = "green" )
182+ # secho(f"Triplets sample: {triplets[0]}", fg=(229, 192, 123))
183+ return triplets
222184
223- def _create_empty_triplet_dataset (self ):
224- return Dataset .from_dict ({
225- "anchor" : [], "positive" : [], "negative" : [],
226- "margin" : [], "source" : [], "metadata" : []
227- }, features = self ._get_triplet_features ())
228-
229- def generate_positives (self , threshold , data_subset = None ):
230- data_to_process = data_subset if data_subset is not None else self ._data
231- if not data_to_process or len (data_to_process ) == 0 :
232- return Dataset .from_dict ({"anchor" : [], "positive" : [], "relevance" : []})
233-
234- if "relevance" not in data_to_process .column_names :
235- secho ("Error: 'relevance' column missing for generate_positives." , fg = "red" )
236- return Dataset .from_dict ({"anchor" : [], "positive" : [], "relevance" : []})
237-
238- pos = data_to_process .filter (lambda x : x ["relevance" ] >= threshold , num_proc = self ._num_procs ).map (
239- lambda x : {"anchor" : x ["query" ], "positive" : x ["document" ], "relevance" : x ["relevance" ]},
185+ def generate_positives (self , threshold ):
186+ pos = self ._data .filter (lambda x : x ["relevance" ] >= threshold ).map (
187+ lambda x : {"anchor" : x ["query" ], "positive" : x ["document" ]},
240188 num_proc = self ._num_procs ,
241- remove_columns = [col for col in data_to_process . column_names if col not in [ "query" , "document" , "relevance" ] ],
189+ remove_columns = ["query" , "document" ],
242190 )
191+ secho (f"Generated { len (pos )} positives." , fg = "green" )
243192 return pos
244193
245- def generate_negatives (self , threshold , data_subset = None ):
246- data_to_process = data_subset if data_subset is not None else self ._data
247- if not data_to_process or len (data_to_process ) == 0 :
248- return Dataset .from_dict ({"anchor" : [], "negative" : [], "relevance" : []})
249-
250- if "relevance" not in data_to_process .column_names :
251- secho ("Error: 'relevance' column missing for generate_negatives (base)." , fg = "red" )
252- return Dataset .from_dict ({"anchor" : [], "negative" : [], "relevance" : []})
253-
254- neg = data_to_process .filter (lambda x : x ["relevance" ] < threshold , num_proc = self ._num_procs ).map (
255- lambda x : {"anchor" : x ["query" ], "negative" : x ["document" ], "relevance" : x ["relevance" ]},
256- num_proc = self ._num_procs ,
257- remove_columns = [col for col in data_to_process .column_names if col not in ["query" , "document" , "relevance" ]],
258- )
194+ def generate_negatives (self , threshold ):
195+ if self .name == "google" :
196+ neg = self ._data .map (
197+ lambda x : {"anchor" : x ["query" ]},
198+ num_proc = self ._num_procs ,
199+ remove_columns = ["query" ],
200+ )
201+ neg = HardNegativeMiner (dataset = neg , max_score = threshold ).run ()
202+ else :
203+ neg = self ._data .filter (lambda x : x ["relevance" ] < threshold ).map (
204+ lambda x : {"anchor" : x ["query" ], "negative" : x ["document" ]},
205+ num_proc = self ._num_procs ,
206+ remove_columns = ["query" , "document" ],
207+ )
208+ secho (f"Generated { len (neg )} negatives." , fg = "green" )
259209 return neg
0 commit comments