Skip to content

Commit 3cd49be

Browse files
committed
added publish
Signed-off-by: Lejin Varghese <lejinsnests@gmail.com>
1 parent e26ed9e commit 3cd49be

5 files changed

Lines changed: 107 additions & 15 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from adapters.core import BaseDataset
22
from adapters.home_depot import HomeDepotDataset
33
from adapters.amazon import AmazonDataset
4+
from adapters.aggregator import DatasetAggregator
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Optional
2+
from datasets import concatenate_datasets, DatasetDict
3+
from click import secho
4+
5+
from adapters.amazon import AmazonDataset
6+
from adapters.home_depot import HomeDepotDataset
7+
8+
DATASET_NAME = "lv12/ProductSearchDataset"
9+
10+
11+
class DatasetAggregator:
12+
def __init__(
13+
self,
14+
sample_size: Optional[int] = None,
15+
split: str = "train",
16+
):
17+
self.sources = [AmazonDataset, HomeDepotDataset]
18+
self.sample_size = sample_size
19+
self.split = split
20+
self.datasets = self.generate_datasets()
21+
22+
def generate_datasets(self):
23+
"""Generate datasets."""
24+
return [
25+
AmazonDataset(
26+
sample_size=self.sample_size,
27+
split=self.split,
28+
),
29+
HomeDepotDataset(
30+
sample_size=self.sample_size,
31+
split=self.split,
32+
),
33+
]
34+
35+
def generate_pairs(self):
36+
"""Generate pairs from all datasets and concatenate them."""
37+
if not self.datasets:
38+
raise ValueError("No datasets added to aggregator")
39+
40+
pairs_list = []
41+
for dataset in self.datasets:
42+
pairs = dataset.generate_pairs()
43+
pairs_list.append(pairs)
44+
45+
combined_pairs = concatenate_datasets(pairs_list)
46+
secho(f"Total combined pairs: {len(combined_pairs)}", fg="blue")
47+
return combined_pairs
48+
49+
def generate_triplets(self):
50+
"""Generate triplets from all datasets and concatenate them."""
51+
if not self.datasets:
52+
raise ValueError("No datasets added to aggregator")
53+
54+
triplets_list = []
55+
for dataset in self.datasets:
56+
triplets = dataset.generate_triplets()
57+
triplets_list.append(triplets)
58+
59+
combined_triplets = concatenate_datasets(triplets_list)
60+
secho(f"Total combined triplets: {len(combined_triplets)}", fg="blue")
61+
return combined_triplets
62+
63+
def push_to_hub(
64+
self,
65+
repo_id: str = DATASET_NAME,
66+
private: bool = False,
67+
overwrite: bool = True,
68+
):
69+
"""Push the combined dataset to HuggingFace Hub."""
70+
secho(f"Pushing combined dataset to {repo_id}", fg=(229, 192, 123))
71+
72+
# Generate combined pairs and triplets
73+
pairs = self.generate_pairs()
74+
triplets = self.generate_triplets()
75+
76+
pairs = DatasetDict({"train": pairs})
77+
triplets = DatasetDict({"train": triplets})
78+
79+
# Push pairs subset
80+
pairs.push_to_hub(
81+
repo_id,
82+
private=private,
83+
config_name="pairs",
84+
)
85+
pairs.push_to_hub(
86+
repo_id,
87+
private=private,
88+
config_name="triplets",
89+
)
90+
91+
secho(f"Successfully pushed combined dataset to {repo_id}", fg="green")

deep_learning/moe/adapters/core.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def __init__(
1717
):
1818
self._repo_id = repo_id
1919
self._sample_size = sample_size
20-
2120
self._num_procs = cpu_count() - 1
2221
self._data = self.load(split, cols)
2322
secho(f"Total records loaded: {len(self._data)}", fg="green")
@@ -70,7 +69,7 @@ def format_document(**kwargs):
7069
def load(self, split: str, cols: list[str] = None):
7170
secho(
7271
f"Loading data from {self._repo_id} using: {self._num_procs} cores",
73-
fg="yellow",
72+
fg=(229, 192, 123),
7473
)
7574
data = load_dataset(self.repo_id, num_proc=self._num_procs, split=split, columns=cols)
7675
if self._sample_size is None:
@@ -79,12 +78,12 @@ def load(self, split: str, cols: list[str] = None):
7978
return data.shuffle(seed=RANDOM_STATE).select(range(self._sample_size))
8079

8180
def generate_pairs(self):
82-
self.pairs = self._data
83-
metadata = [{"source": self.name}] * len(self.pairs)
84-
self.pairs = self.pairs.add_column("metadata", metadata)
85-
secho(f"Generated {len(self.pairs)} pairs.", fg="green")
86-
secho(f"First sample: {self.pairs[0]}", fg="yellow")
87-
return self.pairs
81+
pairs = self._data
82+
metadata = [{"source": self.name}] * len(pairs)
83+
pairs = pairs.add_column("metadata", metadata)
84+
secho(f"Generated {len(pairs)} pairs.", fg="green")
85+
secho(f"First sample: {pairs[0]}", fg=(229, 192, 123))
86+
return pairs
8887

8988
def generate_triplets(self, threshold=3.0):
9089
positives = self.generate_positives(threshold=threshold).to_pandas()
@@ -98,10 +97,10 @@ def generate_triplets(self, threshold=3.0):
9897
triplets["metadata"] = triplets[metadata_cols].apply(lambda x: json.dumps(x.to_dict()), axis=1)
9998
triplets = triplets.drop(columns=metadata_cols)
10099

101-
self.triplets = Dataset.from_pandas(triplets, preserve_index=False)
102-
secho(f"Generated {len(self.triplets)} triplets.", fg="green")
103-
secho(f"First sample: {self.triplets[0]}", fg="yellow")
104-
return self.triplets
100+
triplets = Dataset.from_pandas(triplets, preserve_index=False)
101+
secho(f"Generated {len(triplets)} triplets.", fg="green")
102+
secho(f"First sample: {triplets[0]}", fg=(229, 192, 123))
103+
return triplets
105104

106105
def generate_positives(self, threshold):
107106
pos = self._data.filter(lambda x: x["relevance"] >= threshold).map(

deep_learning/moe/adapters/home_depot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def generate_document(self):
2222
)
2323

2424
def generate_triplets(self, threshold=2.5):
25-
super().generate_triplets(threshold=threshold)
25+
return super().generate_triplets(threshold=threshold)

deep_learning/moe/processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import click
2-
from adapters import HomeDepotDataset, AmazonDataset
2+
from adapters import DatasetAggregator, HomeDepotDataset
33

44

55
@click.command()
66
@click.option("--sample_size", default=None, type=int, help="Number of samples to generate.")
77
def main(sample_size):
8-
ds = AmazonDataset(sample_size=sample_size)
8+
ds = DatasetAggregator(sample_size=sample_size)
99
samples = ds.generate_pairs()
1010
samples = ds.generate_triplets()
11+
ds.push_to_hub()
1112

1213

1314
if __name__ == "__main__":

0 commit comments

Comments
 (0)