Skip to content

Commit bd6ef7e

Browse files
committed
wayfair working for 100k sample
Signed-off-by: Lejin Varghese <lejinsnests@gmail.com>
1 parent eb4fe73 commit bd6ef7e

6 files changed

Lines changed: 130 additions & 60 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from adapters.core import BaseDataset
22
from adapters.home_depot import HomeDepotDataset
33
from adapters.amazon import AmazonDataset
4+
from adapters.google import GoogleDataset
5+
from adapters.wayfair import WayfairDataset
46
from adapters.aggregator import DatasetAggregator

deep_learning/moe/adapters/aggregator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
from datasets import concatenate_datasets, DatasetDict, Dataset
33
from click import secho
44

5-
from adapters import AmazonDataset
6-
from adapters import HomeDepotDataset
7-
from adapters import BaseDataset
5+
from adapters import BaseDataset, AmazonDataset, HomeDepotDataset, GoogleDataset, WayfairDataset
86

97
DATASET_NAME = "lv12/ProductSearchDataset"
108

@@ -15,7 +13,7 @@ def __init__(
1513
sample_size: Optional[int] = None,
1614
splits: list[str] = ["train", "test"],
1715
):
18-
self.sources = [AmazonDataset, HomeDepotDataset]
16+
self.sources = [HomeDepotDataset, AmazonDataset, WayfairDataset]
1917
self.sample_size = sample_size
2018
self.splits = splits
2119
self.datasets = self.generate_datasets()

deep_learning/moe/adapters/amazon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"esci_label",
1313
]
1414

15-
ESCI_LABEL_MAPPING = {
15+
LABEL_MAPPING = {
1616
"Exact": 3.0,
1717
"Substitute": 2.0,
1818
"Complement": 1.0,
@@ -36,7 +36,7 @@ def __init__(
3636

3737
def _map_relevance(self):
3838
self._data = self._data.map(
39-
lambda x: {"relevance": ESCI_LABEL_MAPPING.get(x["esci_label"], 0.0)},
39+
lambda x: {"relevance": LABEL_MAPPING.get(x["esci_label"], 0.0)},
4040
num_proc=self._num_procs,
4141
remove_columns=["esci_label"],
4242
)

deep_learning/moe/adapters/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def generate_pairs(self):
7979
metadata = [{"source": self.name}] * len(pairs)
8080
pairs = pairs.add_column("metadata", metadata)
8181
secho(f"Generated {len(pairs)} pairs.", fg="green")
82-
secho(f"First sample: {pairs[0]}", fg=(229, 192, 123))
82+
secho(f"Pairs sample: {pairs[0]}", fg=(229, 192, 123))
8383
return pairs
8484

8585
def generate_triplets(self, threshold=3.0):
@@ -96,7 +96,7 @@ def generate_triplets(self, threshold=3.0):
9696

9797
triplets = Dataset.from_pandas(triplets, preserve_index=False)
9898
secho(f"Generated {len(triplets)} triplets.", fg="green")
99-
secho(f"First sample: {triplets[0]}", fg=(229, 192, 123))
99+
secho(f"Triplets sample: {triplets[0]}", fg=(229, 192, 123))
100100
return triplets
101101

102102
def generate_positives(self, threshold):
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from click import secho
2+
from datasets import load_dataset
3+
from adapters.core import BaseDataset, RANDOM_STATE
4+
5+
FEATURE_COLUMNS = [
6+
"query",
7+
"product_id",
8+
"title",
9+
"score_reciprocal",
10+
]
11+
12+
13+
class GoogleDataset(BaseDataset):
14+
def __init__(
15+
self,
16+
repo_id="Marqo/marqo-GS-10M",
17+
sample_size=None,
18+
split="train",
19+
cols=FEATURE_COLUMNS,
20+
):
21+
super().__init__(repo_id, sample_size, split, cols)
22+
self.name = "google"
23+
self.generate_query()
24+
self.generate_document()
25+
self._map_relevance()
26+
27+
def _map_relevance(self):
28+
self._data = self._data.map(
29+
lambda x: {"relevance": x.get("score_reciprocal", 0.0)},
30+
num_proc=self._num_procs,
31+
remove_columns=["score_reciprocal"],
32+
)
33+
34+
def load(self, split: str, cols: list[str] = FEATURE_COLUMNS):
35+
secho(
36+
f"Loading data from {self._repo_id} using: {self._num_procs} cores",
37+
fg=(229, 192, 123),
38+
)
39+
if split == "train":
40+
split = "in_domain"
41+
elif split == "test":
42+
split = "zero_shot"
43+
data = load_dataset(self.repo_id, num_proc=self._num_procs, split=split, columns=cols)
44+
data = data.filter(lambda row: row.get("product_locale") == "us", num_proc=self._num_procs)
45+
if self._sample_size is None:
46+
return data
47+
else:
48+
return data.shuffle(seed=RANDOM_STATE).select(range(self._sample_size))
49+
50+
def generate_document(self):
51+
self._data = self._data.map(
52+
lambda row: {"document": self.format_document(title=row.get("product_title"))},
53+
remove_columns=["product_id"],
54+
num_proc=self._num_procs,
55+
)
Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,87 @@
1-
import json
2-
from click import secho
3-
4-
from datasets import Dataset
51
from adapters.core import BaseDataset
62

3+
FEATURE_COLUMNS = [
4+
"query",
5+
"product_id",
6+
"product_name",
7+
"product_description",
8+
"product_features",
9+
"category hierarchy",
10+
"label",
11+
]
12+
713

814
class WayfairDataset(BaseDataset):
9-
def __init__(self, repo_id="bstds/home_depot", sample_size=None, split="train"):
10-
super().__init__(repo_id, sample_size, split)
11-
self.name = "home_depot"
15+
def __init__(
16+
self,
17+
repo_id="napsternxg/wands",
18+
sample_size=None,
19+
split="train",
20+
cols=FEATURE_COLUMNS,
21+
):
22+
super().__init__(repo_id, sample_size, split, cols)
23+
self.name = "wayfair"
1224
self.generate_query()
1325
self.generate_document()
26+
self._map_relevance()
1427

15-
def generate_pairs(self):
16-
self.pairs = self._data
17-
metadata = [{"source": self.name}] * len(self.pairs)
18-
self.pairs = self.pairs.add_column("metadata", metadata)
19-
secho(f"Generated {len(self.pairs)} pairs.", fg="green")
20-
secho(f"First sample: {self.pairs[0]}", fg="yellow")
21-
return self.pairs
22-
23-
def generate_triplets(self, threshold=2.5):
24-
positives = self._filter_positives(threshold=threshold).to_pandas()
25-
negatives = self._filter_negatives(threshold=threshold).to_pandas()
26-
triplets = positives.merge(negatives, on="anchor", suffixes=("_positive", "_negative"))
27-
triplets["margin"] = round(triplets["relevance_positive"] - triplets["relevance_negative"], 2)
28-
triplets["source"] = self.name
28+
def _map_relevance(self):
29+
self._data = self._data.map(
30+
lambda x: {"relevance": float(x["label"])},
31+
num_proc=self._num_procs,
32+
remove_columns=["label"],
33+
)
2934

30-
include_cols = {"anchor", "positive", "negative", "margin"}
31-
metadata_cols = [col for col in triplets.columns if col not in include_cols]
32-
triplets["metadata"] = triplets[metadata_cols].apply(lambda x: json.dumps(x.to_dict()), axis=1)
33-
triplets = triplets.drop(columns=metadata_cols)
35+
def _parse_attributes(self, text):
36+
"""Parse pipe-separated key-value pairs into attributes dictionary.
37+
Example: "color: red | size: large | material: cotton"
38+
Returns: {"color": "red", "size": "large", "material": "cotton"}
39+
"""
40+
if not isinstance(text, str):
41+
return {}
3442

35-
self.triplets = Dataset.from_pandas(triplets, preserve_index=False)
36-
secho(f"Generated {len(self.triplets)} triplets.", fg="green")
37-
secho(f"First sample: {self.triplets[0]}", fg="yellow")
38-
return self.triplets
43+
attributes = {}
44+
pairs = [pair.strip() for pair in text.split("|")]
3945

40-
def generate_query(self):
41-
pass
46+
for pair in pairs:
47+
try:
48+
if " : " in pair:
49+
key, value = pair.split(" : ", 1)
50+
key = key.strip()
51+
value = value.strip()
52+
print(f"key: {key}, value: {value}", fg="green")
53+
if key and value:
54+
attributes[key] = value
55+
except:
56+
return attributes
57+
return attributes
4258

4359
def generate_document(self):
4460
self._data = self._data.map(
4561
lambda row: {
46-
"document": self.format_document(
47-
title=row.get("name"),
48-
category=row.get("category"),
49-
description=row.get("description"),
50-
)
62+
"product_attributes": self._parse_attributes(row.get("product_features", "")),
5163
},
52-
remove_columns=["name", "description", "id", "entity_id"],
5364
num_proc=self._num_procs,
5465
)
55-
56-
def _filter_positives(self, threshold):
57-
pos = self._data.filter(lambda x: x["relevance"] >= threshold).map(
58-
lambda x: {"anchor": x["query"], "positive": x["document"]},
66+
self._data = self._data.map(
67+
lambda row: {
68+
"document": self.format_document(
69+
title=row.get("product_name"),
70+
description=row.get("product_description"),
71+
category=row.get("category hierarchy"),
72+
attributes=row.get("product_attributes", {}),
73+
)
74+
},
75+
remove_columns=[
76+
"product_id",
77+
"product_name",
78+
"product_description",
79+
"product_features",
80+
"category hierarchy",
81+
"product_attributes",
82+
],
5983
num_proc=self._num_procs,
60-
remove_columns=["query", "document"],
6184
)
62-
secho(f"Generated {len(pos)} positives.", fg="green")
63-
return pos
6485

65-
def _filter_negatives(self, threshold):
66-
neg = self._data.filter(lambda x: x["relevance"] < threshold).map(
67-
lambda x: {"anchor": x["query"], "negative": x["document"]},
68-
num_proc=self._num_procs,
69-
remove_columns=["query", "document"],
70-
)
71-
secho(f"Generated {len(neg)} negatives.", fg="green")
72-
return neg
86+
def generate_triplets(self, threshold=2):
87+
return super().generate_triplets(threshold=threshold)

0 commit comments

Comments
 (0)