Skip to content

Commit 62a02ed

Browse files
committed
hard miner working
Signed-off-by: Lejin Varghese <lejinsnests@gmail.com>
1 parent 3214806 commit 62a02ed

3 files changed

Lines changed: 63 additions & 1 deletion

File tree

representation_learning/product_search/moe/adapters/aggregator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
splits: list[str] = ["train", "test"],
1515
):
1616
self.sources = [HomeDepotDataset, AmazonDataset, WayfairDataset, GoogleDataset]
17+
self.sources = [GoogleDataset]
1718
self.sample_size = sample_size
1819
self.splits = splits
1920
self.datasets = self.generate_datasets()

representation_learning/product_search/moe/adapters/google.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from click import secho
22
from datasets import load_dataset, Dataset
33
from adapters.core import BaseDataset, RANDOM_STATE
4+
from adapters.negative_miner import HardNegativeMiner
45

56
FEATURE_COLUMNS = [
67
"query",
@@ -26,7 +27,7 @@ def __init__(
2627

2728
def _map_relevance(self):
2829
self._data = self._data.map(
29-
lambda x: {"relevance": round(x.get("score_reciprocal", 0.0), 2)},
30+
lambda x: {"relevance": round(1 + (x.get("score_reciprocal", 0.0) / 100) * 2, 2)},
3031
num_proc=self._num_procs,
3132
remove_columns=["score_reciprocal"],
3233
)
@@ -60,3 +61,16 @@ def generate_document(self):
6061
num_proc=self._num_procs,
6162
)
6263
self._n_documents = len(set(self._data.unique("document")))
64+
65+
def generate_triplets(self, threshold=1.0):
66+
return super().generate_triplets(threshold=threshold)
67+
68+
def generate_negatives(self, threshold=0.8):
69+
neg = self._data.map(
70+
lambda x: {"anchor": x["query"]},
71+
num_proc=self._num_procs,
72+
remove_columns=["query"],
73+
)
74+
neg = HardNegativeMiner(dataset=neg, max_score=threshold).run()
75+
secho(f"Generated {len(neg)} negatives.", fg="green")
76+
return neg
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from sentence_transformers.util import mine_hard_negatives
2+
from sentence_transformers import SentenceTransformer, CrossEncoder
3+
import torch
4+
from multiprocessing import cpu_count
5+
6+
7+
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
8+
9+
10+
class HardNegativeMiner:
11+
def __init__(
12+
self,
13+
dataset,
14+
bi_encoder_name="thenlper/gte-base",
15+
cross_encoder_name="Alibaba-NLP/gte-reranker-modernbert-base",
16+
max_score=0.8,
17+
):
18+
19+
self.dataset = dataset
20+
self.bi_encoder = SentenceTransformer(bi_encoder_name, device=DEVICE)
21+
self.cross_encoder = CrossEncoder(cross_encoder_name, device=DEVICE, model_kwargs={"torch_dtype": "auto"})
22+
self.max_score = max_score
23+
self.num_procs = cpu_count() - 1
24+
25+
def run(self):
26+
dataset = mine_hard_negatives(
27+
dataset=self.dataset,
28+
model=self.bi_encoder,
29+
cross_encoder=self.cross_encoder,
30+
anchor_column_name="anchor",
31+
positive_column_name="document",
32+
range_min=5,
33+
range_max=30,
34+
max_score=self.max_score,
35+
min_score=0.5,
36+
margin=0,
37+
num_negatives=10,
38+
sampling_strategy="random",
39+
batch_size=32,
40+
use_faiss=False,
41+
)
42+
dataset = dataset.map(
43+
{"relevance": 0.9},
44+
num_proc=self.num_procs,
45+
remove_columns=["document"],
46+
)
47+
return dataset

0 commit comments

Comments
 (0)