|
1 | | -import json |
2 | | -from click import secho |
3 | | - |
4 | | -from datasets import Dataset |
5 | 1 | from adapters.core import BaseDataset |
6 | 2 |
|
| 3 | +FEATURE_COLUMNS = [ |
| 4 | + "query", |
| 5 | + "product_id", |
| 6 | + "product_name", |
| 7 | + "product_description", |
| 8 | + "product_features", |
| 9 | + "category hierarchy", |
| 10 | + "label", |
| 11 | +] |
| 12 | + |
7 | 13 |
|
8 | 14 | class WayfairDataset(BaseDataset): |
9 | | - def __init__(self, repo_id="bstds/home_depot", sample_size=None, split="train"): |
10 | | - super().__init__(repo_id, sample_size, split) |
11 | | - self.name = "home_depot" |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + repo_id="napsternxg/wands", |
| 18 | + sample_size=None, |
| 19 | + split="train", |
| 20 | + cols=FEATURE_COLUMNS, |
| 21 | + ): |
| 22 | + super().__init__(repo_id, sample_size, split, cols) |
| 23 | + self.name = "wayfair" |
12 | 24 | self.generate_query() |
13 | 25 | self.generate_document() |
| 26 | + self._map_relevance() |
14 | 27 |
|
15 | | - def generate_pairs(self): |
16 | | - self.pairs = self._data |
17 | | - metadata = [{"source": self.name}] * len(self.pairs) |
18 | | - self.pairs = self.pairs.add_column("metadata", metadata) |
19 | | - secho(f"Generated {len(self.pairs)} pairs.", fg="green") |
20 | | - secho(f"First sample: {self.pairs[0]}", fg="yellow") |
21 | | - return self.pairs |
22 | | - |
23 | | - def generate_triplets(self, threshold=2.5): |
24 | | - positives = self._filter_positives(threshold=threshold).to_pandas() |
25 | | - negatives = self._filter_negatives(threshold=threshold).to_pandas() |
26 | | - triplets = positives.merge(negatives, on="anchor", suffixes=("_positive", "_negative")) |
27 | | - triplets["margin"] = round(triplets["relevance_positive"] - triplets["relevance_negative"], 2) |
28 | | - triplets["source"] = self.name |
| 28 | + def _map_relevance(self): |
| 29 | + self._data = self._data.map( |
| 30 | + lambda x: {"relevance": float(x["label"])}, |
| 31 | + num_proc=self._num_procs, |
| 32 | + remove_columns=["label"], |
| 33 | + ) |
29 | 34 |
|
30 | | - include_cols = {"anchor", "positive", "negative", "margin"} |
31 | | - metadata_cols = [col for col in triplets.columns if col not in include_cols] |
32 | | - triplets["metadata"] = triplets[metadata_cols].apply(lambda x: json.dumps(x.to_dict()), axis=1) |
33 | | - triplets = triplets.drop(columns=metadata_cols) |
| 35 | + def _parse_attributes(self, text): |
| 36 | + """Parse pipe-separated key-value pairs into attributes dictionary. |
| 37 | + Example: "color: red | size: large | material: cotton" |
| 38 | + Returns: {"color": "red", "size": "large", "material": "cotton"} |
| 39 | + """ |
| 40 | + if not isinstance(text, str): |
| 41 | + return {} |
34 | 42 |
|
35 | | - self.triplets = Dataset.from_pandas(triplets, preserve_index=False) |
36 | | - secho(f"Generated {len(self.triplets)} triplets.", fg="green") |
37 | | - secho(f"First sample: {self.triplets[0]}", fg="yellow") |
38 | | - return self.triplets |
| 43 | + attributes = {} |
| 44 | + pairs = [pair.strip() for pair in text.split("|")] |
39 | 45 |
|
40 | | - def generate_query(self): |
41 | | - pass |
| 46 | + for pair in pairs: |
| 47 | + try: |
| 48 | + if " : " in pair: |
| 49 | + key, value = pair.split(" : ", 1) |
| 50 | + key = key.strip() |
| 51 | + value = value.strip() |
| 52 | + print(f"key: {key}, value: {value}", fg="green") |
| 53 | + if key and value: |
| 54 | + attributes[key] = value |
| 55 | + except: |
| 56 | + return attributes |
| 57 | + return attributes |
42 | 58 |
|
43 | 59 | def generate_document(self): |
44 | 60 | self._data = self._data.map( |
45 | 61 | lambda row: { |
46 | | - "document": self.format_document( |
47 | | - title=row.get("name"), |
48 | | - category=row.get("category"), |
49 | | - description=row.get("description"), |
50 | | - ) |
| 62 | + "product_attributes": self._parse_attributes(row.get("product_features", "")), |
51 | 63 | }, |
52 | | - remove_columns=["name", "description", "id", "entity_id"], |
53 | 64 | num_proc=self._num_procs, |
54 | 65 | ) |
55 | | - |
56 | | - def _filter_positives(self, threshold): |
57 | | - pos = self._data.filter(lambda x: x["relevance"] >= threshold).map( |
58 | | - lambda x: {"anchor": x["query"], "positive": x["document"]}, |
| 66 | + self._data = self._data.map( |
| 67 | + lambda row: { |
| 68 | + "document": self.format_document( |
| 69 | + title=row.get("product_name"), |
| 70 | + description=row.get("product_description"), |
| 71 | + category=row.get("category hierarchy"), |
| 72 | + attributes=row.get("product_attributes", {}), |
| 73 | + ) |
| 74 | + }, |
| 75 | + remove_columns=[ |
| 76 | + "product_id", |
| 77 | + "product_name", |
| 78 | + "product_description", |
| 79 | + "product_features", |
| 80 | + "category hierarchy", |
| 81 | + "product_attributes", |
| 82 | + ], |
59 | 83 | num_proc=self._num_procs, |
60 | | - remove_columns=["query", "document"], |
61 | 84 | ) |
62 | | - secho(f"Generated {len(pos)} positives.", fg="green") |
63 | | - return pos |
64 | 85 |
|
65 | | - def _filter_negatives(self, threshold): |
66 | | - neg = self._data.filter(lambda x: x["relevance"] < threshold).map( |
67 | | - lambda x: {"anchor": x["query"], "negative": x["document"]}, |
68 | | - num_proc=self._num_procs, |
69 | | - remove_columns=["query", "document"], |
70 | | - ) |
71 | | - secho(f"Generated {len(neg)} negatives.", fg="green") |
72 | | - return neg |
| 86 | + def generate_triplets(self, threshold=2): |
| 87 | + return super().generate_triplets(threshold=threshold) |
0 commit comments