Skip to content

Commit e519683

Browse files
committed
added crowdflower
Signed-off-by: Lejin Varghese <lejinsnests@gmail.com>
1 parent 62a02ed commit e519683

7 files changed

Lines changed: 69 additions & 21 deletions

File tree

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from adapters.core import BaseDataset
2-
from adapters.home_depot import HomeDepotDataset
32
from adapters.amazon import AmazonDataset
3+
from adapters.crowdflower import CrowdFlowerDataset
44
from adapters.google import GoogleDataset
5+
from adapters.home_depot import HomeDepotDataset
56
from adapters.wayfair import WayfairDataset
67
from adapters.aggregator import DatasetAggregator

representation_learning/product_search/moe/adapters/aggregator.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Optional
2-
from datasets import concatenate_datasets, DatasetDict, Dataset
32
from click import secho
4-
5-
from adapters import BaseDataset, AmazonDataset, HomeDepotDataset, GoogleDataset, WayfairDataset
3+
from datasets import Dataset, DatasetDict, concatenate_datasets
4+
from adapters import AmazonDataset, BaseDataset, CrowdFlowerDataset, GoogleDataset, HomeDepotDataset, WayfairDataset
65

76
DATASET_NAME = "lv12/ProductSearchDataset"
87

@@ -13,8 +12,8 @@ def __init__(
1312
sample_size: Optional[int] = None,
1413
splits: list[str] = ["train", "test"],
1514
):
16-
self.sources = [HomeDepotDataset, AmazonDataset, WayfairDataset, GoogleDataset]
17-
self.sources = [GoogleDataset]
15+
self.sources = [HomeDepotDataset, AmazonDataset, WayfairDataset, GoogleDataset, CrowdFlowerDataset]
16+
# self.sources = [GoogleDataset]
1817
self.sample_size = sample_size
1918
self.splits = splits
2019
self.datasets = self.generate_datasets()
@@ -74,7 +73,7 @@ def push_to_hub(
7473
self,
7574
repo_id: str = DATASET_NAME,
7675
private: bool = False,
77-
):
76+
) -> None:
7877
"""Push the dataset to HuggingFace Hub."""
7978
secho(f"Pushing the dataset to {repo_id}", fg=(229, 192, 123))
8079

representation_learning/product_search/moe/adapters/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from abc import ABC
2-
from multiprocessing import cpu_count
31
import json
42
import re
3+
from abc import ABC
4+
from multiprocessing import cpu_count
5+
56
from click import secho
6-
from datasets import load_dataset, Dataset
7+
from datasets import Dataset, load_dataset
78

89
RANDOM_STATE = 42
910

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from adapters.core import BaseDataset
2+
3+
4+
FEATURE_COLUMNS = [
5+
"query",
6+
"product_title",
7+
"product_description",
8+
"median_relevance",
9+
]
10+
11+
12+
class CrowdFlowerDataset(BaseDataset):
13+
def __init__(
14+
self,
15+
repo_id="napsternxg/kaggle_crowdflower_ecommerce_search_relevance",
16+
sample_size=None,
17+
split="train",
18+
cols=FEATURE_COLUMNS,
19+
):
20+
super().__init__(repo_id, sample_size, split, cols)
21+
self.name = "crowdflower"
22+
self.generate_query()
23+
self.generate_document()
24+
self._map_relevance()
25+
26+
def _map_relevance(self):
27+
self._data = self._data.map(
28+
lambda x: {"relevance": x.get("median_relevance", 1.0) - 1.0},
29+
num_proc=self._num_procs,
30+
remove_columns=["median_relevance"],
31+
)
32+
33+
def generate_document(self):
34+
self._data = self._data.map(
35+
lambda row: {
36+
"document": self.format_document(
37+
title=row.get("product_title"),
38+
description=row.get("product_description"),
39+
)
40+
},
41+
remove_columns=["product_title", "product_description"],
42+
num_proc=self._num_procs,
43+
)
44+
self._n_documents = len(set(self._data.unique("document")))

representation_learning/product_search/moe/adapters/google.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from click import secho
2-
from datasets import load_dataset, Dataset
3-
from adapters.core import BaseDataset, RANDOM_STATE
4-
from adapters.negative_miner import HardNegativeMiner
2+
from datasets import Dataset, load_dataset
3+
4+
from adapters.core import RANDOM_STATE, BaseDataset
5+
from adapters.miners import HardNegativeMiner
56

67
FEATURE_COLUMNS = [
78
"query",

representation_learning/product_search/moe/adapters/negative_miner.py renamed to representation_learning/product_search/moe/adapters/miners.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from sentence_transformers.util import mine_hard_negatives
2-
from sentence_transformers import SentenceTransformer, CrossEncoder
3-
import torch
41
from multiprocessing import cpu_count
2+
import torch
3+
from sentence_transformers import CrossEncoder, SentenceTransformer
4+
from sentence_transformers.util import mine_hard_negatives
55

6-
6+
DATASET_NAME = "lv12/ProductSearchDataset"
77
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
88

99

@@ -14,12 +14,14 @@ def __init__(
1414
bi_encoder_name="thenlper/gte-base",
1515
cross_encoder_name="Alibaba-NLP/gte-reranker-modernbert-base",
1616
max_score=0.8,
17+
min_score=0.6,
1718
):
1819

1920
self.dataset = dataset
2021
self.bi_encoder = SentenceTransformer(bi_encoder_name, device=DEVICE)
2122
self.cross_encoder = CrossEncoder(cross_encoder_name, device=DEVICE, model_kwargs={"torch_dtype": "auto"})
2223
self.max_score = max_score
24+
self.min_score = min_score
2325
self.num_procs = cpu_count() - 1
2426

2527
def run(self):
@@ -30,17 +32,17 @@ def run(self):
3032
anchor_column_name="anchor",
3133
positive_column_name="document",
3234
range_min=5,
33-
range_max=30,
35+
range_max=20,
3436
max_score=self.max_score,
35-
min_score=0.5,
37+
min_score=self.min_score,
3638
margin=0,
3739
num_negatives=10,
3840
sampling_strategy="random",
3941
batch_size=32,
4042
use_faiss=False,
4143
)
4244
dataset = dataset.map(
43-
{"relevance": 0.9},
45+
lambda x: {"relevance": 0.6},
4446
num_proc=self.num_procs,
4547
remove_columns=["document"],
4648
)

representation_learning/product_search/moe/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import click
2-
from adapters import DatasetAggregator, HomeDepotDataset
2+
from adapters import DatasetAggregator
33

44

55
@click.command()

0 commit comments

Comments
 (0)