Skip to content

Commit b8f47c9

Browse files
committed
fix most
Signed-off-by: Lejin Varghese <lejinsnests@gmail.com>
1 parent 31e630c commit b8f47c9

7 files changed

Lines changed: 131 additions & 268 deletions

File tree

representation_learning/product_search/moe/adapters/amazon.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from adapters.core import BaseDataset
2+
from click import secho
23

34
FEATURE_COLUMNS = [
45
"query",
@@ -31,6 +32,7 @@ def __init__(
3132
):
3233
super().__init__(repo_id, sample_size, chunk_size, split, cols)
3334
self.name = "amazon"
35+
self._data = self.load(split, cols)
3436
self._map_relevance()
3537
self.generate_query()
3638
self.generate_document()

representation_learning/product_search/moe/adapters/core.py

Lines changed: 98 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,13 @@
55

66
from click import secho
77
import random
8-
8+
99
from datasets import Dataset, load_dataset, Features, Value
10+
from adapters.miners import HardNegativeMiner
1011

1112
RANDOM_STATE = 42
1213
random.seed(RANDOM_STATE)
1314

14-
DATASET_CHUNK_SIZES = { # Define specific chunk sizes here
15-
"wayfair": 100,
16-
"amazon": 5000,
17-
# Add other datasets if needed
18-
}
19-
DEFAULT_CHUNK_SIZE = 1000
20-
2115
class BaseDataset(ABC):
2216
def __init__(
2317
self,
@@ -32,8 +26,7 @@ def __init__(
3226
self._chunk_size = chunk_size
3327
self._num_procs = cpu_count() - 1
3428
self._split = split
35-
self._data = self.load(split, cols)
36-
secho(f"Total records loaded: {len(self._data)}", fg="green")
29+
self._data = None
3730

3831
@property
3932
def repo_id(self):
@@ -51,35 +44,49 @@ def n_queries(self):
5144
def n_documents(self):
5245
return self._n_documents
5346

54-
def generate_query(self, queries_already_sampled=False):
55-
secho(f"Generating queries for {self.name} dataset...", fg="blue")
47+
def generate_query(self):
48+
secho(f"Generating queries for {self.name} dataset", fg="blue")
49+
secho(f"Initial dataset size: {len(self._data)}", fg="blue")
5650

51+
self._data = self._data.map(
52+
lambda x: {"query": x["query"].lower()},
53+
num_proc=self._num_procs,
54+
)
5755
self._unique_queries = list(set(self._data.unique("query")))
5856
self._n_queries = len(self._unique_queries)
59-
60-
if not queries_already_sampled and self._sample_size is not None and self._sample_size < self._n_queries:
61-
secho(f"Applying sampling in BaseDataset.generate_query: {self._sample_size} queries", fg="yellow")
62-
sampled_queries = random.sample(self._unique_queries, self._sample_size)
63-
self._unique_queries = sampled_queries
64-
self._n_queries = len(self._unique_queries)
65-
self._data = self._data.filter(
66-
lambda x: x["query"] in self._unique_queries,
67-
num_proc=self._num_procs
68-
)
69-
57+
secho(f"Unique queries before sampling: {self._n_queries}", fg="blue")
58+
59+
if self._sample_size is not None and self._sample_size < self._n_queries:
60+
secho(f"Sampling {self._sample_size} queries from {self._n_queries} total queries", fg="green")
61+
sampled_queries = random.sample(self._unique_queries, self._sample_size)
62+
self._unique_queries = sampled_queries
63+
self._n_queries = len(self._unique_queries)
64+
65+
self._data = self._data.filter(
66+
lambda x: x["query"] in self._unique_queries,
67+
num_proc=self._num_procs
68+
)
69+
secho(f"Filtered dataset to {len(self._data)} records with sampled queries", fg="green")
70+
71+
# Create chunks for the queries
7072
chunks = {}
71-
effective_chunk_size = DATASET_CHUNK_SIZES.get(getattr(self, 'name', None), self._chunk_size or DEFAULT_CHUNK_SIZE)
72-
73-
if self._n_queries > 0 and effective_chunk_size > 0:
74-
for i in range(0, self._n_queries, effective_chunk_size):
75-
chunk_index = i // effective_chunk_size
76-
chunks[chunk_index] = self._unique_queries[i:i + effective_chunk_size]
73+
if self._chunk_size is not None:
74+
if self.name == "wayfair":
75+
self._chunk_size = 100
76+
elif self.name == "amazon":
77+
self._chunk_size = 5000
78+
79+
for i in range(0, self._n_queries, self._chunk_size):
80+
chunk_index = i // self._chunk_size
81+
chunks[chunk_index] = self._unique_queries[i:i + self._chunk_size]
82+
secho(f"Chunk {chunk_index}: {len(chunks[chunk_index])} queries", fg="blue")
7783
else:
7884
chunks = {0: self._unique_queries}
85+
secho(f"Single chunk with {len(chunks[0])} queries", fg="blue")
7986

80-
self._max_chunks = len(chunks)
87+
self._max_chunks = len(chunks.keys())
8188
self._query_chunks = chunks
82-
secho(f"Total query chunks created: {self._max_chunks}", fg="blue")
89+
secho(f"Total chunks: {self._max_chunks}", fg="blue")
8390

8491
def generate_document(self):
8592
pass
@@ -123,137 +130,80 @@ def generate_pairs(self):
123130
pairs = pairs.add_column("source", source)
124131
secho(f"Generated {len(pairs)} pairs.", fg="green")
125132
secho(f"Queries: {self.n_queries}, Documents: {self.n_documents}.", fg="green")
133+
# secho(f"Pairs sample: {pairs[0]}", fg=(229, 192, 123))
126134
return pairs
127135

128136
def generate_triplets(self, threshold=3.0, chunk_index: int = None):
137+
secho(f"Generating triplets for {self.name} dataset with threshold {threshold}", fg="blue")
138+
positives = self.generate_positives(threshold=threshold).to_pandas()
139+
secho(f"Generated {len(positives)} positives for {self.name}", fg="blue")
140+
141+
negatives = self.generate_negatives(threshold=threshold).to_pandas()
142+
secho(f"Generated {len(negatives)} negatives for {self.name}", fg="blue")
143+
129144
if chunk_index is not None:
130-
secho(f"Generating triplets for {self.name} chunk {chunk_index}...", fg="blue")
145+
chunk_queries = self._query_chunks.get(chunk_index, [])
146+
secho(f"Filtering for chunk {chunk_index} with {len(chunk_queries)} queries", fg="blue")
147+
positives = positives[positives["anchor"].isin(chunk_queries)]
148+
negatives = negatives[negatives["anchor"].isin(chunk_queries)]
149+
secho(f"After filtering: {len(positives)} positives, {len(negatives)} negatives", fg="blue")
131150

132-
chunk_data = self._data
151+
if len(positives) == 0 or len(negatives) == 0:
152+
secho(f"Not enough data to generate triplets: {len(positives)} positives, {len(negatives)} negatives", fg="red")
153+
return Dataset.from_dict({
154+
"anchor": [],
155+
"positive": [],
156+
"negative": [],
157+
"margin": [],
158+
"source": [],
159+
"metadata": []
160+
}, features=Features({
161+
"anchor": Value("string"),
162+
"positive": Value("string"),
163+
"negative": Value("string"),
164+
"margin": Value("float64"),
165+
"source": Value("string"),
166+
"metadata": Value("string")
167+
}))
133168

134-
if chunk_index is not None and self._query_chunks and chunk_index in self._query_chunks:
135-
chunk_queries = set(self._query_chunks[chunk_index])
136-
if not chunk_queries:
137-
return self._create_empty_triplet_dataset()
138-
139-
chunk_data = self._data.filter(
140-
lambda x: x["query"] in chunk_queries,
141-
num_proc=self._num_procs
142-
)
143-
elif chunk_index is not None:
144-
secho(f"Warning: Chunk index {chunk_index} not found or query chunks empty.", fg="yellow")
145-
return self._create_empty_triplet_dataset()
169+
triplets = positives.merge(negatives, on="anchor", suffixes=("_positive", "_negative"))
170+
secho(f"Merged into {len(triplets)} triplets for {self.name}", fg="blue")
146171

147-
positives_ds = self.generate_positives(threshold=threshold, data_subset=chunk_data)
148-
negatives_ds = self.generate_negatives(threshold=threshold, data_subset=chunk_data)
149-
150-
if len(positives_ds) == 0 or len(negatives_ds) == 0:
151-
return self._create_empty_triplet_dataset()
152-
153-
try:
154-
positives = positives_ds.to_pandas()
155-
negatives = negatives_ds.to_pandas()
156-
except Exception as e:
157-
secho(f"Error converting dataset subset to pandas (chunk {chunk_index}): {e}", fg="red")
158-
return self._create_empty_triplet_dataset()
159-
160-
positives = positives.rename(columns={"positive": "document", "relevance": "relevance_positive"})
161-
negatives = negatives.rename(columns={"negative": "document", "relevance": "relevance_negative"})
162-
163-
if "anchor" not in positives.columns or "anchor" not in negatives.columns:
164-
secho("Error: 'anchor' column missing before merge.", fg="red")
165-
return self._create_empty_triplet_dataset()
166-
if "relevance_positive" not in positives.columns:
167-
positives['relevance_positive'] = threshold
168-
secho("Warning: 'relevance' column missing in positives, added default.", fg="yellow")
169-
if "relevance_negative" not in negatives.columns:
170-
negatives['relevance_negative'] = threshold - 0.1
171-
secho("Warning: 'relevance' column missing in negatives, added default.", fg="yellow")
172-
173-
try:
174-
triplets = positives.merge(negatives, on="anchor", suffixes=("_pos", "_neg"))
175-
except Exception as e:
176-
secho(f"Error merging pandas DataFrames (chunk {chunk_index}): {e}", fg="red")
177-
return self._create_empty_triplet_dataset()
178-
179-
if triplets.empty:
180-
return self._create_empty_triplet_dataset()
181-
182172
triplets["margin"] = round(triplets["relevance_positive"] - triplets["relevance_negative"], 2)
183173
triplets["source"] = self.name
184-
triplets = triplets.rename(columns={"document_pos": "positive", "document_neg": "negative"})
185-
186-
metadata_cols = [col for col in ['relevance_positive', 'relevance_negative'] if col in triplets.columns]
187-
if metadata_cols:
188-
try:
189-
triplets["metadata"] = triplets[metadata_cols].apply(lambda x: json.dumps(x.to_dict()), axis=1)
190-
triplets = triplets.drop(columns=metadata_cols)
191-
except Exception as e:
192-
secho(f"Error creating metadata JSON (chunk {chunk_index}): {e}", fg="yellow")
193-
triplets["metadata"] = "{}"
194-
else:
195-
triplets["metadata"] = "{}"
196-
197-
final_cols = ["anchor", "positive", "negative", "margin", "source", "metadata"]
198-
missing_cols = [col for col in final_cols if col not in triplets.columns]
199-
if missing_cols:
200-
secho(f"Error: Final columns missing before Dataset creation: {missing_cols}", fg="red")
201-
return self._create_empty_triplet_dataset()
202-
203-
triplets_final_df = triplets[final_cols]
204174

205-
try:
206-
triplets_dataset = Dataset.from_pandas(triplets_final_df, preserve_index=False, features=self._get_triplet_features())
207-
secho(f"Generated {len(triplets_dataset)} triplets for chunk {chunk_index}.", fg="green")
208-
return triplets_dataset
209-
except Exception as e:
210-
secho(f"Error converting final DataFrame to Dataset (chunk {chunk_index}): {e}", fg="red")
211-
return self._create_empty_triplet_dataset()
175+
include_cols = {"anchor", "positive", "negative", "margin", "source"}
176+
metadata_cols = [col for col in triplets.columns if col not in include_cols]
177+
triplets["metadata"] = triplets[metadata_cols].apply(lambda x: json.dumps(x.to_dict()), axis=1)
178+
triplets = triplets.drop(columns=metadata_cols)
212179

213-
def _get_triplet_features(self):
214-
return Features({
215-
"anchor": Value("string"),
216-
"positive": Value("string"),
217-
"negative": Value("string"),
218-
"margin": Value("float64"),
219-
"source": Value("string"),
220-
"metadata": Value("string")
221-
})
180+
triplets = Dataset.from_pandas(triplets, preserve_index=False)
181+
secho(f"Generated {len(triplets)} triplets for {self.name}.", fg="green")
182+
# secho(f"Triplets sample: {triplets[0]}", fg=(229, 192, 123))
183+
return triplets
222184

223-
def _create_empty_triplet_dataset(self):
224-
return Dataset.from_dict({
225-
"anchor": [], "positive": [], "negative": [],
226-
"margin": [], "source": [], "metadata": []
227-
}, features=self._get_triplet_features())
228-
229-
def generate_positives(self, threshold, data_subset=None):
230-
data_to_process = data_subset if data_subset is not None else self._data
231-
if not data_to_process or len(data_to_process) == 0:
232-
return Dataset.from_dict({"anchor": [], "positive": [], "relevance": []})
233-
234-
if "relevance" not in data_to_process.column_names:
235-
secho("Error: 'relevance' column missing for generate_positives.", fg="red")
236-
return Dataset.from_dict({"anchor": [], "positive": [], "relevance": []})
237-
238-
pos = data_to_process.filter(lambda x: x["relevance"] >= threshold, num_proc=self._num_procs).map(
239-
lambda x: {"anchor": x["query"], "positive": x["document"], "relevance": x["relevance"]},
185+
def generate_positives(self, threshold):
186+
pos = self._data.filter(lambda x: x["relevance"] >= threshold).map(
187+
lambda x: {"anchor": x["query"], "positive": x["document"]},
240188
num_proc=self._num_procs,
241-
remove_columns=[col for col in data_to_process.column_names if col not in ["query", "document", "relevance"]],
189+
remove_columns=["query", "document"],
242190
)
191+
secho(f"Generated {len(pos)} positives.", fg="green")
243192
return pos
244193

245-
def generate_negatives(self, threshold, data_subset=None):
246-
data_to_process = data_subset if data_subset is not None else self._data
247-
if not data_to_process or len(data_to_process) == 0:
248-
return Dataset.from_dict({"anchor": [], "negative": [], "relevance": []})
249-
250-
if "relevance" not in data_to_process.column_names:
251-
secho("Error: 'relevance' column missing for generate_negatives (base).", fg="red")
252-
return Dataset.from_dict({"anchor": [], "negative": [], "relevance": []})
253-
254-
neg = data_to_process.filter(lambda x: x["relevance"] < threshold, num_proc=self._num_procs).map(
255-
lambda x: {"anchor": x["query"], "negative": x["document"], "relevance": x["relevance"]},
256-
num_proc=self._num_procs,
257-
remove_columns=[col for col in data_to_process.column_names if col not in ["query", "document", "relevance"]],
258-
)
194+
def generate_negatives(self, threshold):
195+
if self.name == "google":
196+
neg = self._data.map(
197+
lambda x: {"anchor": x["query"]},
198+
num_proc=self._num_procs,
199+
remove_columns=["query"],
200+
)
201+
neg = HardNegativeMiner(dataset=neg, max_score=threshold).run()
202+
else:
203+
neg = self._data.filter(lambda x: x["relevance"] < threshold).map(
204+
lambda x: {"anchor": x["query"], "negative": x["document"]},
205+
num_proc=self._num_procs,
206+
remove_columns=["query", "document"],
207+
)
208+
secho(f"Generated {len(neg)} negatives.", fg="green")
259209
return neg

representation_learning/product_search/moe/adapters/crowdflower.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
):
2222
super().__init__(repo_id, sample_size, chunk_size, split, cols)
2323
self.name = "crowdflower"
24+
self._data = self.load(split, cols)
2425
self.generate_query()
2526
self.generate_document()
2627
self._map_relevance()

0 commit comments

Comments
 (0)