From c75925f9575ca3eda9d40ee577c90b4963e2bcc6 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Tue, 3 Feb 2026 15:48:58 +0300 Subject: [PATCH 1/6] Composed loss --- replay/nn/loss/__init__.py | 2 + replay/nn/loss/base.py | 6 ++- replay/nn/loss/composed.py | 96 ++++++++++++++++++++++++++++++++++++++ replay/nn/loss/seen.py | 0 4 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 replay/nn/loss/composed.py create mode 100644 replay/nn/loss/seen.py diff --git a/replay/nn/loss/__init__.py b/replay/nn/loss/__init__.py index 2116cac83..33721b35e 100644 --- a/replay/nn/loss/__init__.py +++ b/replay/nn/loss/__init__.py @@ -3,6 +3,7 @@ from .ce import CE, CESampled, CESampledWeighted, CEWeighted from .login_ce import LogInCE, LogInCESampled from .logout_ce import LogOutCE, LogOutCEWeighted +from .composed import ComposedLoss LogOutCESampled = CE @@ -10,6 +11,7 @@ "BCE", "CE", "BCESampled", + "ComposedLoss", "CESampled", "CESampledWeighted", "CEWeighted", diff --git a/replay/nn/loss/base.py b/replay/nn/loss/base.py index ef9bd88bb..b5b776bd6 100644 --- a/replay/nn/loss/base.py +++ b/replay/nn/loss/base.py @@ -4,6 +4,8 @@ from replay.data.nn import TensorMap +LogitsCallback = Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor] + class LossProto(Protocol): """Class-protocol for working with losses inside models""" @@ -11,10 +13,10 @@ class LossProto(Protocol): @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ... + ) -> LogitsCallback: ... @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: ... + def logits_callback(self, func: LogitsCallback) -> None: ... def forward( self, diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py new file mode 100644 index 000000000..7de0a0371 --- /dev/null +++ b/replay/nn/loss/composed.py @@ -0,0 +1,96 @@ +import warnings +from typing import Optional, Self, cast + +import torch + +from replay.data.nn import TensorMap + +from .base import LogitsCallback, LossProto + +Weights = dict[str, torch.Tensor | float] + + +class ComposedLoss(torch.nn.Module): + def __init__( + self: Self, losses: dict[str, torch.nn.Module] | torch.nn.ModuleDict, weights: Weights | None = None + ) -> None: + super().__init__() + + if isinstance(losses, dict): + for loss in cast(dict, losses.values()): + if not isinstance(loss, torch.nn.Module): + msg: str = f"Unsupported type of loss. Must be `Module`. Got: {type(loss)=}." + raise TypeError(msg) + losses = torch.nn.ModuleDict(losses) + + if not isinstance(losses, torch.nn.ModuleDict): + msg: str = f"Unsupported type of `losses`. Must be `dict` or `ModuleDict`. Got {type(losses)=}." + raise TypeError(msg) + + if len(losses) < 1: + msg: str = "Empty losses are not supported." + raise ValueError(msg) + + self.losses: torch.nn.ModuleDict = cast(torch.nn.ModuleDict, losses) + + if weights is None: + weights = {} + + if not isinstance(weights, dict): + msg: str = f"Unsupported type of `weights`. Must be `dict`. Got: {type(weights)=}." + + for name, weight in cast(dict, weights): + if name not in self.losses: + msg: str = f"Unknown name of weight: {name}." + warnings.warn(msg, stacklevel=2) + continue + if isinstance(weight, float): + continue + elif isinstance(weight, torch.Tensor): + assert torch.is_tensor(weight) + continue + else: + msg: str = f"Unsupported type of weight value. Must be `float` or `Tensor`. Got: {type(weight)=}." + raise TypeError(msg) + + self.weights: dict[str, torch.Tensor | float] = cast(Weights, weights) + + self._logits_callback: Optional[LogitsCallback] = None + + @property + def logits_callback(self: Self) -> LogitsCallback: + if self._logits_callback is None: + msg: str = "No `logits_callback` provided" + raise NotImplementedError(msg) + return self._logits_callback + + @logits_callback.setter + def logits_callback(self: Self, func: LogitsCallback) -> None: + self._logits_callback = func + + for loss in self.losses.values(): + casted = cast(LossProto, loss) + casted.logits_callback = func + + def forward( + self: Self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + ) -> torch.Tensor: + losses = 0.0 + for name, loss in self.losses.items(): + loss_weight = self.weights.get(name, 1.0) + loss_value: torch.Tensor = loss( + model_embeddings, + feature_tensors, + positive_labels, + negative_labels, + padding_mask, + target_padding_mask, + ) + losses = losses + loss_weight * loss_value + return cast(torch.Tensor, losses) diff --git a/replay/nn/loss/seen.py b/replay/nn/loss/seen.py new file mode 100644 index 000000000..e69de29bb From 7655974ab280fda8e21db7c5c1feb351db837765 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Wed, 4 Feb 2026 17:47:25 +0300 Subject: [PATCH 2/6] Working logging --- examples/09_sasrec_example.ipynb | 581 +++++++++++++++---------- examples/15_twotower_example.ipynb | 459 +++++++++++-------- replay/nn/lightning/module.py | 17 +- replay/nn/loss/__init__.py | 8 +- replay/nn/loss/base.py | 18 +- replay/nn/loss/bce.py | 23 +- replay/nn/loss/ce.py | 48 +- replay/nn/loss/composed.py | 26 +- replay/nn/loss/login_ce.py | 22 +- replay/nn/loss/logout_ce.py | 20 +- replay/nn/output.py | 3 + replay/nn/sequential/sasrec/agg.py | 6 +- replay/nn/sequential/sasrec/model.py | 15 +- replay/nn/sequential/twotower/model.py | 6 +- 14 files changed, 786 insertions(+), 466 deletions(-) diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb index 14ab84808..ecc849764 100644 --- a/examples/09_sasrec_example.ipynb +++ b/examples/09_sasrec_example.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -19,9 +19,21 @@ "text": [ "Seed set to 42\n" ] + }, + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "from typing import Optional\n", + "\n", "import lightning as L\n", "import pandas as pd\n", "\n", @@ -46,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -105,15 +117,15 @@ " 2\n", " \n", " \n", - " 1000192\n", + " 1000007\n", " 6040\n", - " 2019\n", + " 1961\n", " 3\n", " \n", " \n", - " 1000007\n", + " 1000192\n", " 6040\n", - " 1961\n", + " 2019\n", " 4\n", " \n", " \n", @@ -135,15 +147,15 @@ " 447\n", " \n", " \n", - " 825731\n", + " 825724\n", " 4958\n", - " 2634\n", + " 3264\n", " 448\n", " \n", " \n", - " 825724\n", + " 825731\n", " 4958\n", - " 3264\n", + " 2634\n", " 449\n", " \n", " \n", @@ -162,19 +174,19 @@ "1000138 6040 858 0\n", "1000153 6040 2384 1\n", "999873 6040 593 2\n", - "1000192 6040 2019 3\n", - "1000007 6040 1961 4\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", "... ... ... ...\n", "825793 4958 2399 446\n", "825438 4958 1407 447\n", - "825731 4958 2634 448\n", - "825724 4958 3264 449\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", "825603 4958 1924 450\n", "\n", "[1000209 rows x 3 columns]" ] }, - "execution_count": 30, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -196,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -228,32 +240,32 @@ " \n", " \n", " 0\n", - " 32\n", - " 0\n", + " 12\n", + " 2011\n", " 0\n", " \n", " \n", " 1\n", - " 10\n", - " 1\n", + " 68\n", + " 4078\n", " 0\n", " \n", " \n", " 2\n", - " 12\n", - " 2\n", + " 67\n", + " 4123\n", " 0\n", " \n", " \n", " 3\n", - " 339\n", - " 3\n", + " 12\n", + " 983\n", " 0\n", " \n", " \n", " 4\n", - " 144\n", - " 4\n", + " 140\n", + " 2270\n", " 0\n", " \n", " \n", @@ -264,32 +276,32 @@ " \n", " \n", " 1000204\n", - " 281\n", - " 796\n", + " 14\n", + " 855\n", " 3705\n", " \n", " \n", " 1000205\n", - " 209\n", - " 1297\n", + " 90\n", + " 1700\n", " 3705\n", " \n", " \n", " 1000206\n", - " 748\n", - " 1883\n", + " 70\n", + " 936\n", " 3705\n", " \n", " \n", " 1000207\n", - " 71\n", - " 4449\n", + " 25\n", + " 360\n", " 3705\n", " \n", " \n", " 1000208\n", - " 287\n", - " 2473\n", + " 380\n", + " 1388\n", " 3705\n", " \n", " \n", @@ -299,35 +311,36 @@ ], "text/plain": [ " timestamp user_id item_id\n", - "0 32 0 0\n", - "1 10 1 0\n", - "2 12 2 0\n", - "3 339 3 0\n", - "4 144 4 0\n", + "0 12 2011 0\n", + "1 68 4078 0\n", + "2 67 4123 0\n", + "3 12 983 0\n", + "4 140 2270 0\n", "... ... ... ...\n", - "1000204 281 796 3705\n", - "1000205 209 1297 3705\n", - "1000206 748 1883 3705\n", - "1000207 71 4449 3705\n", - "1000208 287 2473 3705\n", + "1000204 14 855 3705\n", + "1000205 90 1700 3705\n", + "1000206 70 936 3705\n", + "1000207 25 360 3705\n", + "1000208 380 1388 3705\n", "\n", "[1000209 rows x 3 columns]" ] }, - "execution_count": 31, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from replay.preprocessing import LabelEncoder, LabelEncodingRule\n", + "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", "\n", "encoder = LabelEncoder(\n", " [\n", - " LabelEncodingRule(\"user_id\"),\n", - " LabelEncodingRule(\"item_id\"),\n", + " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", + " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", " ]\n", ")\n", + "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", "encoded_interactions = encoder.fit_transform(interactions)\n", "encoded_interactions" ] @@ -343,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -374,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -388,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -422,31 +435,31 @@ " 0\n", " 0\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7...\n", + " [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117...\n", " \n", " \n", " 1\n", " 1\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8...\n", + " [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110...\n", " \n", " \n", " 2\n", " 2\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107...\n", + " [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429...\n", " \n", " \n", " 3\n", " 3\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [346, 2190, 670, 802, 323, 661, 2480, 2501, 19...\n", + " [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,...\n", " \n", " \n", " 4\n", " 4\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502...\n", + " [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ...\n", " \n", " \n", " ...\n", @@ -458,31 +471,31 @@ " 6035\n", " 6035\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248...\n", + " [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250...\n", " \n", " \n", " 6036\n", " 6036\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200...\n", + " [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ...\n", " \n", " \n", " 6037\n", " 6037\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278...\n", + " [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,...\n", " \n", " \n", " 6038\n", " 6038\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...\n", + " [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2...\n", " \n", " \n", " 6039\n", " 6039\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317...\n", + " [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ...\n", " \n", " \n", "\n", @@ -504,22 +517,22 @@ "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", "\n", " item_id \n", - "0 [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7... \n", - "1 [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8... \n", - "2 [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107... \n", - "3 [346, 2190, 670, 802, 323, 661, 2480, 2501, 19... \n", - "4 [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502... \n", + "0 [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117... \n", + "1 [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110... \n", + "2 [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429... \n", + "3 [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,... \n", + "4 [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ... \n", "... ... \n", - "6035 [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248... \n", - "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200... \n", - "6037 [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278... \n", - "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", - "6039 [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317... \n", + "6035 [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250... \n", + "6036 [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ... \n", + "6037 [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,... \n", + "6038 [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2... \n", + "6039 [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ... \n", "\n", "[6040 rows x 3 columns]" ] }, - "execution_count": 34, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -545,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -561,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -579,9 +592,20 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:964: UserWarning: There is already LabelEncoder object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n", + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:537: UserWarning: There is already LabelEncodingRule object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n" + ] + } + ], "source": [ "train_events.to_parquet(TRAIN_PATH)\n", "validation_events.to_parquet(VAL_PATH)\n", @@ -604,7 +628,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -659,7 +683,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -670,12 +694,10 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "from typing import Optional\n", - "\n", "MAX_SEQ_LEN = 50\n", "\n", "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", @@ -696,9 +718,18 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_752239/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "from replay.data.nn import ParquetModule\n", "\n", @@ -730,7 +761,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -759,13 +790,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ + "from replay.nn.lightning.optimizer import OptimizerFactory\n", + "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", "from replay.nn.lightning import LightningModule\n", "\n", - "model = LightningModule(sasrec)" + "model = LightningModule(\n", + " sasrec,\n", + " optimizer_factory=OptimizerFactory(),\n", + " lr_scheduler_factory=LRSchedulerFactory(),\n", + ")" ] }, { @@ -779,33 +816,36 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------\n", - "0 | model | SasRec | 291 K | train\n", - "-----------------------------------------\n", + " | Name | Type | Params | Mode | FLOPs\n", + "-------------------------------------------------\n", + "0 | model | SasRec | 291 K | train | 0 \n", + "-------------------------------------------------\n", "291 K Trainable params\n", "0 Non-trainable params\n", "291 K Total params\n", "1.164 Total estimated model params size (MB)\n", "39 Modules in train mode\n", - "0 Modules in eval mode\n" + "0 Modules in eval mode\n", + "0 Total Flops\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a36bec8cdb384569a56b01dda4a8dce3", + "model_id": "92d2cf620e634e478ee28b385cacc445", "version_major": 2, "version_minor": 0 }, @@ -816,10 +856,19 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3788b1f40c8648e58741b90e70a66259", + "model_id": "e9a6ad3f3b8d45c0895b0a9f12e32f7f", "version_major": 2, "version_minor": 0 }, @@ -833,7 +882,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cdc8dd8ed34441bc82f9d868f8cac7d3", + "model_id": "9716d4189b4a49b6b2cf5bfe81bec5a5", "version_major": 2, "version_minor": 0 }, @@ -848,7 +897,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0, global step 189: 'recall@10' reached 0.03576 (best 0.03576), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" + "Epoch 0, global step 189: 'recall@10' reached 0.03643 (best 0.03643), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" ] }, { @@ -856,16 +905,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.003808 0.010609 0.012859 0.008209\n", - "ndcg 0.003808 0.016343 0.024711 0.010422\n", - "recall 0.003808 0.035762 0.069205 0.017219\n", + "map 0.003643 0.011142 0.013291 0.009014\n", + "ndcg 0.003643 0.016971 0.024719 0.011761\n", + "recall 0.003643 0.036430 0.066898 0.020202\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6c8e2d8667a64bba9e4c05337d9a934c", + "model_id": "e1289d4fdf454f32b41c289ee7e1c2da", "version_major": 2, "version_minor": 0 }, @@ -880,7 +937,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1, global step 378: 'recall@10' reached 0.08626 (best 0.08626), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" + "Epoch 1, global step 378: 'recall@10' reached 0.08843 (best 0.08843), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" ] }, { @@ -888,16 +945,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.011258 0.028874 0.032790 0.024324\n", - "ndcg 0.011258 0.042089 0.056579 0.030685\n", - "recall 0.011258 0.086258 0.144040 0.050166\n", + "map 0.011095 0.028992 0.033167 0.024110\n", + "ndcg 0.011095 0.042700 0.058069 0.030709\n", + "recall 0.011095 0.088425 0.149528 0.051002\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4f29f2aa51ba4cf68af42cc95a453be5", + "model_id": "8a58c575bc3d4132915191f7e0621fdc", "version_major": 2, "version_minor": 0 }, @@ -912,7 +977,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 2, global step 567: 'recall@10' reached 0.12136 (best 0.12136), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" + "Epoch 2, global step 567: 'recall@10' reached 0.12171 (best 0.12171), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" ] }, { @@ -920,16 +985,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.015894 0.040155 0.045144 0.033278\n", - "ndcg 0.015894 0.058849 0.077380 0.041941\n", - "recall 0.015894 0.121358 0.195364 0.068543\n", + "map 0.013413 0.038355 0.043417 0.031186\n", + "ndcg 0.013413 0.057557 0.076368 0.040041\n", + "recall 0.013413 0.121709 0.196887 0.067230\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a8c62495083f4b108e679d86d496e3e9", + "model_id": "2276b9086a564859a2af69c0a7eb4c1a", "version_major": 2, "version_minor": 0 }, @@ -944,24 +1017,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 3, global step 756: 'recall@10' reached 0.14106 (best 0.14106), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1\n" + "Epoch 3, global step 756: 'recall@10' reached 0.13562 (best 0.13562), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 10 20 5\n", - "map 0.015232 0.045004 0.050748 0.037318\n", - "ndcg 0.015232 0.067207 0.088327 0.048297\n", - "recall 0.015232 0.141060 0.225000 0.081954\n", + "k 1 10 20 5\n", + "map 0.01275 0.041396 0.047141 0.033433\n", + "ndcg 0.01275 0.063121 0.084165 0.043618\n", + "recall 0.01275 0.135618 0.219076 0.074847\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e40e9de8673b44e3b451ccb9bdd84699", + "model_id": "3626b1e4f6f14561b885dac0bfe3bf56", "version_major": 2, "version_minor": 0 }, @@ -976,7 +1057,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.15331 (best 0.15331), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.14804 (best 0.14804), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -985,9 +1066,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.015728 0.048114 0.054519 0.039854\n", - "ndcg 0.015728 0.072437 0.095845 0.052167\n", - "recall 0.015728 0.153311 0.246026 0.090066\n", + "map 0.012916 0.044320 0.050978 0.035577\n", + "ndcg 0.012916 0.068219 0.092799 0.046770\n", + "recall 0.012916 0.148038 0.245902 0.081139\n", "\n" ] } @@ -1034,16 +1115,16 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt'" + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt'" ] }, - "execution_count": 45, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1065,10 +1146,13 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ + "import torch\n", + "import replay\n", + "\n", "sasrec = SasRec.from_params(\n", " schema=tensor_schema,\n", " embedding_dim=EMBEDDING_DIM,\n", @@ -1079,6 +1163,11 @@ " excluded_features=None,\n", ")\n", "\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory,\n", + "])\n", + "\n", "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", "best_model.eval();" ] @@ -1092,9 +1181,18 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 23, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_752239/1373759190.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", "\n", @@ -1118,22 +1216,30 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d93a289ce7f74b0f8b470cd48cb6882a", + "model_id": "5aa1ce20ed684f9b8b8a6b1ca5b9c62c", "version_major": 2, "version_minor": 0 }, @@ -1143,6 +1249,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -1167,7 +1281,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1200,32 +1314,32 @@ " \n", " 0\n", " 0\n", - " 3383\n", - " 6.106367\n", + " 354\n", + " 5.776358\n", " \n", " \n", " 0\n", " 0\n", - " 3509\n", - " 5.956742\n", + " 740\n", + " 5.754249\n", " \n", " \n", " 0\n", " 0\n", - " 3341\n", - " 5.944553\n", + " 574\n", + " 5.660694\n", " \n", " \n", " 0\n", " 0\n", - " 3510\n", - " 5.776233\n", + " 1439\n", + " 5.478026\n", " \n", " \n", " 0\n", " 0\n", - " 3512\n", - " 5.587171\n", + " 1551\n", + " 5.333538\n", " \n", " \n", " ...\n", @@ -1236,32 +1350,32 @@ " \n", " 6037\n", " 6039\n", - " 2941\n", - " 5.740072\n", + " 1656\n", + " 3.195577\n", " \n", " \n", " 6037\n", " 6039\n", - " 3049\n", - " 5.596299\n", + " 327\n", + " 3.120715\n", " \n", " \n", " 6037\n", " 6039\n", - " 2750\n", - " 5.548656\n", + " 1820\n", + " 3.099917\n", " \n", " \n", " 6037\n", " 6039\n", - " 2968\n", - " 5.302012\n", + " 515\n", + " 3.07936\n", " \n", " \n", " 6037\n", " 6039\n", - " 2202\n", - " 5.136523\n", + " 1306\n", + " 3.031433\n", " \n", " \n", "\n", @@ -1270,22 +1384,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 0 3383 6.106367\n", - "0 0 3509 5.956742\n", - "0 0 3341 5.944553\n", - "0 0 3510 5.776233\n", - "0 0 3512 5.587171\n", + "0 0 354 5.776358\n", + "0 0 740 5.754249\n", + "0 0 574 5.660694\n", + "0 0 1439 5.478026\n", + "0 0 1551 5.333538\n", "... ... ... ...\n", - "6037 6039 2941 5.740072\n", - "6037 6039 3049 5.596299\n", - "6037 6039 2750 5.548656\n", - "6037 6039 2968 5.302012\n", - "6037 6039 2202 5.136523\n", + "6037 6039 1656 3.195577\n", + "6037 6039 327 3.120715\n", + "6037 6039 1820 3.099917\n", + "6037 6039 515 3.07936\n", + "6037 6039 1306 3.031433\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 49, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -1306,7 +1420,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1316,7 +1430,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1329,7 +1443,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1362,24 +1476,24 @@ " \n", " \n", " MAP\n", - " 0.013912\n", - " 0.046539\n", - " 0.052342\n", - " 0.038912\n", + " 0.015734\n", + " 0.047696\n", + " 0.053510\n", + " 0.039428\n", " \n", " \n", " Precision\n", - " 0.013912\n", - " 0.014492\n", - " 0.011494\n", - " 0.017456\n", + " 0.015734\n", + " 0.014790\n", + " 0.011651\n", + " 0.016959\n", " \n", " \n", " Recall\n", - " 0.013912\n", - " 0.144916\n", - " 0.229877\n", - " 0.087281\n", + " 0.015734\n", + " 0.147897\n", + " 0.233024\n", + " 0.084796\n", " \n", " \n", "\n", @@ -1387,12 +1501,12 @@ ], "text/plain": [ "k 1 10 20 5\n", - "MAP 0.013912 0.046539 0.052342 0.038912\n", - "Precision 0.013912 0.014492 0.011494 0.017456\n", - "Recall 0.013912 0.144916 0.229877 0.087281" + "MAP 0.015734 0.047696 0.053510 0.039428\n", + "Precision 0.015734 0.014790 0.011651 0.016959\n", + "Recall 0.015734 0.147897 0.233024 0.084796" ] }, - "execution_count": 52, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1410,7 +1524,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1442,33 +1556,33 @@ " \n", " \n", " 0\n", - " 2011\n", - " 3623\n", - " 6.106367\n", + " 1\n", + " 364\n", + " 5.776358\n", " \n", " \n", " 0\n", - " 2011\n", - " 3752\n", - " 5.956742\n", + " 1\n", + " 783\n", + " 5.754249\n", " \n", " \n", " 0\n", - " 2011\n", - " 3578\n", - " 5.944553\n", + " 1\n", + " 588\n", + " 5.660694\n", " \n", " \n", " 0\n", - " 2011\n", - " 3753\n", - " 5.776233\n", + " 1\n", + " 1566\n", + " 5.478026\n", " \n", " \n", " 0\n", - " 2011\n", - " 3755\n", - " 5.587171\n", + " 1\n", + " 1688\n", + " 5.333538\n", " \n", " \n", " ...\n", @@ -1478,33 +1592,33 @@ " \n", " \n", " 6037\n", - " 5727\n", - " 3157\n", - " 5.740072\n", + " 6040\n", + " 1834\n", + " 3.195577\n", " \n", " \n", " 6037\n", - " 5727\n", - " 3273\n", - " 5.596299\n", + " 6040\n", + " 337\n", + " 3.120715\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2961\n", - " 5.548656\n", + " 6040\n", + " 2000\n", + " 3.099917\n", " \n", " \n", " 6037\n", - " 5727\n", - " 3185\n", - " 5.302012\n", + " 6040\n", + " 529\n", + " 3.07936\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2395\n", - " 5.136523\n", + " 6040\n", + " 1408\n", + " 3.031433\n", " \n", " \n", "\n", @@ -1513,22 +1627,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 2011 3623 6.106367\n", - "0 2011 3752 5.956742\n", - "0 2011 3578 5.944553\n", - "0 2011 3753 5.776233\n", - "0 2011 3755 5.587171\n", + "0 1 364 5.776358\n", + "0 1 783 5.754249\n", + "0 1 588 5.660694\n", + "0 1 1566 5.478026\n", + "0 1 1688 5.333538\n", "... ... ... ...\n", - "6037 5727 3157 5.740072\n", - "6037 5727 3273 5.596299\n", - "6037 5727 2961 5.548656\n", - "6037 5727 3185 5.302012\n", - "6037 5727 2395 5.136523\n", + "6037 6040 1834 3.195577\n", + "6037 6040 337 3.120715\n", + "6037 6040 2000 3.099917\n", + "6037 6040 529 3.07936\n", + "6037 6040 1408 3.031433\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 53, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1536,11 +1650,18 @@ "source": [ "encoder.inverse_transform(pandas_res)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "rep", + "display_name": "new_venv", "language": "python", "name": "python3" }, @@ -1554,7 +1675,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.12.3" }, "orig_nbformat": 4 }, diff --git a/examples/15_twotower_example.ipynb b/examples/15_twotower_example.ipynb index b189eebb3..9164b0a36 100644 --- a/examples/15_twotower_example.ipynb +++ b/examples/15_twotower_example.ipynb @@ -118,15 +118,15 @@ " 2\n", " \n", " \n", - " 1000192\n", + " 1000007\n", " 6040\n", - " 2019\n", + " 1961\n", " 3\n", " \n", " \n", - " 1000007\n", + " 1000192\n", " 6040\n", - " 1961\n", + " 2019\n", " 4\n", " \n", " \n", @@ -148,15 +148,15 @@ " 447\n", " \n", " \n", - " 825731\n", + " 825724\n", " 4958\n", - " 2634\n", + " 3264\n", " 448\n", " \n", " \n", - " 825724\n", + " 825731\n", " 4958\n", - " 3264\n", + " 2634\n", " 449\n", " \n", " \n", @@ -175,13 +175,13 @@ "1000138 6040 858 0\n", "1000153 6040 2384 1\n", "999873 6040 593 2\n", - "1000192 6040 2019 3\n", - "1000007 6040 1961 4\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", "... ... ... ...\n", "825793 4958 2399 446\n", "825438 4958 1407 447\n", - "825731 4958 2634 448\n", - "825724 4958 3264 449\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", "825603 4958 1924 450\n", "\n", "[1000209 rows x 3 columns]" @@ -211,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -263,13 +263,13 @@ " 3\n", " 3\n", " 6039\n", - " 1950\n", + " 1892\n", " \n", " \n", " 4\n", " 4\n", " 6039\n", - " 1892\n", + " 1950\n", " \n", " \n", " ...\n", @@ -293,13 +293,13 @@ " 1000206\n", " 448\n", " 4957\n", - " 2565\n", + " 3195\n", " \n", " \n", " 1000207\n", " 449\n", " 4957\n", - " 3195\n", + " 2565\n", " \n", " \n", " 1000208\n", @@ -317,13 +317,13 @@ "0 0 6039 847\n", "1 1 6039 2315\n", "2 2 6039 589\n", - "3 3 6039 1950\n", - "4 4 6039 1892\n", + "3 3 6039 1892\n", + "4 4 6039 1950\n", "... ... ... ...\n", "1000204 446 4957 2330\n", "1000205 447 4957 1384\n", - "1000206 448 4957 2565\n", - "1000207 449 4957 3195\n", + "1000206 448 4957 3195\n", + "1000207 449 4957 2565\n", "1000208 450 4957 1855\n", "\n", "[1000209 rows x 3 columns]" @@ -531,25 +531,25 @@ " 0\n", " 0\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118...\n", + " [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118...\n", " \n", " \n", " 1\n", " 1\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315...\n", + " [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117...\n", " \n", " \n", " 2\n", " 2\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358...\n", + " [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602...\n", " \n", " \n", " 3\n", " 3\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,...\n", + " [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,...\n", " \n", " \n", " 4\n", @@ -567,13 +567,13 @@ " 6035\n", " 6035\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278...\n", + " [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263...\n", " \n", " \n", " 6036\n", " 6036\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ...\n", + " [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ...\n", " \n", " \n", " 6037\n", @@ -585,13 +585,13 @@ " 6038\n", " 6038\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2...\n", + " [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3...\n", " \n", " \n", " 6039\n", " 6039\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ...\n", + " [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ...\n", " \n", " \n", "\n", @@ -613,17 +613,17 @@ "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", "\n", " item_id \n", - "0 [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118... \n", - "1 [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315... \n", - "2 [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358... \n", - "3 [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,... \n", + "0 [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118... \n", + "1 [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117... \n", + "2 [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602... \n", + "3 [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,... \n", "4 [2648, 907, 896, 352, 1230, 2119, 2789, 1111, ... \n", "... ... \n", - "6035 [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278... \n", - "6036 [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ... \n", + "6035 [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263... \n", + "6036 [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ... \n", "6037 [3327, 908, 1192, 2077, 1366, 352, 1063, 1132,... \n", - "6038 [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2... \n", - "6039 [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ... \n", + "6038 [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3... \n", + "6039 [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ... \n", "\n", "[6040 rows x 3 columns]" ] @@ -700,7 +700,18 @@ "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:964: UserWarning: There is already LabelEncoder object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n", + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:537: UserWarning: There is already LabelEncodingRule object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n" + ] + } + ], "source": [ "train_events.to_parquet(TRAIN_PATH)\n", "validation_events.to_parquet(VAL_PATH)\n", @@ -725,7 +736,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -781,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -792,7 +803,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -820,7 +831,16 @@ "cell_type": "code", "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_766094/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "from replay.data.nn import ParquetModule\n", "\n", @@ -853,7 +873,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -888,7 +908,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -915,26 +935,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", - " | Name | Type | Params | Mode \n", - "-------------------------------------------\n", - "0 | model | TwoTower | 340 K | train\n", - "-------------------------------------------\n", - "340 K Trainable params\n", + " | Name | Type | Params | Mode | FLOPs\n", + "---------------------------------------------------\n", + "0 | model | TwoTower | 352 K | train | 0 \n", + "---------------------------------------------------\n", + "352 K Trainable params\n", "0 Non-trainable params\n", - "340 K Total params\n", - "1.364 Total estimated model params size (MB)\n", + "352 K Total params\n", + "1.409 Total estimated model params size (MB)\n", "52 Modules in train mode\n", - "0 Modules in eval mode\n" + "0 Modules in eval mode\n", + "0 Total Flops\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "32da5d7c8a1a4a2abd200e4c44472b61", + "model_id": "108862bbd4de48d2804f302e5cf4e997", "version_major": 2, "version_minor": 0 }, @@ -945,10 +967,19 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "959fabb6e3ca4ed38097cd7128944bc8", + "model_id": "fb965f6b1006499fbfa5de8c8ed96e5b", "version_major": 2, "version_minor": 0 }, @@ -962,7 +993,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3462811aac8b4003aaa61c4cf9fcc4a0", + "model_id": "bc8d696ab44d4a87ac4dcbf29e2250f2", "version_major": 2, "version_minor": 0 }, @@ -977,24 +1008,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0, global step 189: 'recall@10' reached 0.03526 (best 0.03526), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n" + "Epoch 0, global step 189: 'recall@10' reached 0.03958 (best 0.03958), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.004801 0.009509 0.011596 0.013299\n", - "ndcg 0.004801 0.011889 0.017026 0.023337\n", - "recall 0.004801 0.019205 0.035265 0.060430\n", + "k 1 10 20 5\n", + "map 0.003643 0.011629 0.013674 0.009331\n", + "ndcg 0.003643 0.018055 0.025700 0.012422\n", + "recall 0.003643 0.039576 0.070210 0.022024\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f8dacd44d48b486c9d1217c937b2584d", + "model_id": "4d8c41219aac45d39de0bb26313a9239", "version_major": 2, "version_minor": 0 }, @@ -1009,24 +1048,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1, global step 378: 'recall@10' reached 0.09603 (best 0.09603), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n" + "Epoch 1, global step 378: 'recall@10' reached 0.10184 (best 0.10184), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.012086 0.026440 0.031724 0.036210\n", - "ndcg 0.012086 0.033518 0.046537 0.063113\n", - "recall 0.012086 0.055298 0.096026 0.162086\n", + "k 1 10 20 5\n", + "map 0.010763 0.033238 0.037375 0.028159\n", + "ndcg 0.010763 0.049178 0.064472 0.036708\n", + "recall 0.010763 0.101838 0.162775 0.062924\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "421acec60c4347ae98e0df5157d69ead", + "model_id": "2ac5ed9b53d144e1a776759aa104d5cd", "version_major": 2, "version_minor": 0 }, @@ -1041,24 +1088,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 2, global step 567: 'recall@10' reached 0.13576 (best 0.13576), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n" + "Epoch 2, global step 567: 'recall@10' reached 0.13280 (best 0.13280), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.012252 0.034048 0.041749 0.047579\n", - "ndcg 0.012252 0.044626 0.063492 0.084913\n", - "recall 0.012252 0.076987 0.135762 0.220861\n", + "k 1 10 20 5\n", + "map 0.012916 0.041789 0.047532 0.034280\n", + "ndcg 0.012916 0.062866 0.084155 0.044635\n", + "recall 0.012916 0.132803 0.217751 0.076337\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "55449d2ab210406b89fea8f15deb28f9", + "model_id": "3ae8fabaec204088a33a973536c4c366", "version_major": 2, "version_minor": 0 }, @@ -1073,24 +1128,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 3, global step 756: 'recall@10' reached 0.15795 (best 0.15795), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n" + "Epoch 3, global step 756: 'recall@10' reached 0.14920 (best 0.14920), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.014735 0.040450 0.049221 0.055306\n", - "ndcg 0.014735 0.052765 0.074364 0.096740\n", - "recall 0.014735 0.090397 0.157947 0.246854\n", + "k 1 10 20 5\n", + "map 0.01391 0.046216 0.052732 0.037647\n", + "ndcg 0.01391 0.070051 0.093930 0.049280\n", + "recall 0.01391 0.149197 0.243915 0.084948\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5bbb1d30778b42a6ace5f22c50fdec88", + "model_id": "4f25edcbbd5842349605a5dbfa0275ed", "version_major": 2, "version_minor": 0 }, @@ -1105,7 +1168,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.17053 (best 0.17053), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.16079 (best 0.16079), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -1113,10 +1176,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.019205 0.045563 0.054963 0.061259\n", - "ndcg 0.019205 0.058695 0.081677 0.104973\n", - "recall 0.019205 0.099007 0.170530 0.263411\n", + "k 1 10 20 5\n", + "map 0.013247 0.047682 0.054643 0.038307\n", + "ndcg 0.013247 0.073782 0.099453 0.050853\n", + "recall 0.013247 0.160788 0.262957 0.089419\n", "\n" ] } @@ -1169,7 +1232,7 @@ { "data": { "text/plain": [ - "'/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'" + "'/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'" ] }, "execution_count": 18, @@ -1194,10 +1257,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ + "import torch\n", + "import replay\n", + "\n", "twotower = TwoTower.from_params(\n", " schema=tensor_schema,\n", " embedding_dim=EMBEDDING_DIM,\n", @@ -1212,6 +1278,11 @@ " )\n", ")\n", "\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory,\n", + "])\n", + "\n", "best_model = LightningModule.load_from_checkpoint(best_model_path, model=twotower)\n", "best_model.eval();" ] @@ -1227,7 +1298,16 @@ "cell_type": "code", "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_766094/2227877614.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", "\n", @@ -1258,15 +1338,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "61636a98f63042a4bfce791a7dd2d01b", + "model_id": "c03aa242b0b3403fb6405c310f7d505d", "version_major": 2, "version_minor": 0 }, @@ -1276,6 +1358,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -1333,32 +1423,32 @@ " \n", " 0\n", " 0\n", - " 3383\n", - " 41.500851\n", + " 773\n", + " 26.849968\n", " \n", " \n", " 0\n", " 0\n", - " 3509\n", - " 40.856842\n", + " 360\n", + " 26.444904\n", " \n", " \n", " 0\n", " 0\n", - " 3475\n", - " 40.794147\n", + " 1526\n", + " 26.433756\n", " \n", " \n", " 0\n", " 0\n", - " 3512\n", - " 40.674751\n", + " 2618\n", + " 26.282482\n", " \n", " \n", " 0\n", " 0\n", - " 3341\n", - " 40.665249\n", + " 1838\n", + " 26.202333\n", " \n", " \n", " ...\n", @@ -1369,32 +1459,32 @@ " \n", " 6037\n", " 6039\n", - " 3550\n", - " 26.971941\n", + " 1680\n", + " 26.977921\n", " \n", " \n", " 6037\n", " 6039\n", - " 1490\n", - " 26.874283\n", + " 1375\n", + " 26.976624\n", " \n", " \n", " 6037\n", " 6039\n", - " 2489\n", - " 26.785303\n", + " 2125\n", + " 26.967751\n", " \n", " \n", " 6037\n", " 6039\n", - " 2634\n", - " 26.725142\n", + " 1439\n", + " 26.963537\n", " \n", " \n", " 6037\n", " 6039\n", - " 2203\n", - " 26.696005\n", + " 623\n", + " 26.949759\n", " \n", " \n", "\n", @@ -1403,17 +1493,17 @@ ], "text/plain": [ " user_id item_id score\n", - "0 0 3383 41.500851\n", - "0 0 3509 40.856842\n", - "0 0 3475 40.794147\n", - "0 0 3512 40.674751\n", - "0 0 3341 40.665249\n", + "0 0 773 26.849968\n", + "0 0 360 26.444904\n", + "0 0 1526 26.433756\n", + "0 0 2618 26.282482\n", + "0 0 1838 26.202333\n", "... ... ... ...\n", - "6037 6039 3550 26.971941\n", - "6037 6039 1490 26.874283\n", - "6037 6039 2489 26.785303\n", - "6037 6039 2634 26.725142\n", - "6037 6039 2203 26.696005\n", + "6037 6039 1680 26.977921\n", + "6037 6039 1375 26.976624\n", + "6037 6039 2125 26.967751\n", + "6037 6039 1439 26.963537\n", + "6037 6039 623 26.949759\n", "\n", "[120760 rows x 3 columns]" ] @@ -1487,42 +1577,42 @@ " \n", " k\n", " 1\n", - " 5\n", " 10\n", " 20\n", + " 5\n", " \n", " \n", " \n", " \n", " MAP\n", - " 0.016727\n", - " 0.041714\n", - " 0.050679\n", - " 0.056809\n", + " 0.016893\n", + " 0.050191\n", + " 0.056355\n", + " 0.041200\n", " \n", " \n", " Precision\n", - " 0.016727\n", - " 0.018152\n", - " 0.015982\n", - " 0.012421\n", + " 0.016893\n", + " 0.015916\n", + " 0.012479\n", + " 0.018085\n", " \n", " \n", " Recall\n", - " 0.016727\n", - " 0.090759\n", - " 0.159821\n", - " 0.248427\n", + " 0.016893\n", + " 0.159159\n", + " 0.249586\n", + " 0.090427\n", " \n", " \n", "\n", "" ], "text/plain": [ - "k 1 5 10 20\n", - "MAP 0.016727 0.041714 0.050679 0.056809\n", - "Precision 0.016727 0.018152 0.015982 0.012421\n", - "Recall 0.016727 0.090759 0.159821 0.248427" + "k 1 10 20 5\n", + "MAP 0.016893 0.050191 0.056355 0.041200\n", + "Precision 0.016893 0.015916 0.012479 0.018085\n", + "Recall 0.016893 0.159159 0.249586 0.090427" ] }, "execution_count": 25, @@ -1575,33 +1665,33 @@ " \n", " \n", " 0\n", - " 2011\n", - " 3623\n", - " 41.500851\n", + " 1\n", + " 783\n", + " 26.849968\n", " \n", " \n", " 0\n", - " 2011\n", - " 3752\n", - " 40.856842\n", + " 1\n", + " 364\n", + " 26.444904\n", " \n", " \n", " 0\n", - " 2011\n", - " 3717\n", - " 40.794147\n", + " 1\n", + " 1566\n", + " 26.433756\n", " \n", " \n", " 0\n", - " 2011\n", - " 3755\n", - " 40.674751\n", + " 1\n", + " 2687\n", + " 26.282482\n", " \n", " \n", " 0\n", - " 2011\n", - " 3578\n", - " 40.665249\n", + " 1\n", + " 1907\n", + " 26.202333\n", " \n", " \n", " ...\n", @@ -1611,33 +1701,33 @@ " \n", " \n", " 6037\n", - " 5727\n", - " 3793\n", - " 26.971941\n", + " 6040\n", + " 1729\n", + " 26.977921\n", " \n", " \n", " 6037\n", - " 5727\n", - " 1623\n", - " 26.874283\n", + " 6040\n", + " 1396\n", + " 26.976624\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2693\n", - " 26.785303\n", + " 6040\n", + " 2194\n", + " 26.967751\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2841\n", - " 26.725142\n", + " 6040\n", + " 1466\n", + " 26.963537\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2396\n", - " 26.696005\n", + " 6040\n", + " 628\n", + " 26.949759\n", " \n", " \n", "\n", @@ -1646,17 +1736,17 @@ ], "text/plain": [ " user_id item_id score\n", - "0 2011 3623 41.500851\n", - "0 2011 3752 40.856842\n", - "0 2011 3717 40.794147\n", - "0 2011 3755 40.674751\n", - "0 2011 3578 40.665249\n", + "0 1 783 26.849968\n", + "0 1 364 26.444904\n", + "0 1 1566 26.433756\n", + "0 1 2687 26.282482\n", + "0 1 1907 26.202333\n", "... ... ... ...\n", - "6037 5727 3793 26.971941\n", - "6037 5727 1623 26.874283\n", - "6037 5727 2693 26.785303\n", - "6037 5727 2841 26.725142\n", - "6037 5727 2396 26.696005\n", + "6037 6040 1729 26.977921\n", + "6037 6040 1396 26.976624\n", + "6037 6040 2194 26.967751\n", + "6037 6040 1466 26.963537\n", + "6037 6040 628 26.949759\n", "\n", "[120760 rows x 3 columns]" ] @@ -1669,11 +1759,18 @@ "source": [ "encoder.inverse_transform(pandas_res)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "rep", + "display_name": "new_venv", "language": "python", "name": "python3" }, @@ -1687,7 +1784,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.12.3" }, "orig_nbformat": 4 }, diff --git a/replay/nn/lightning/module.py b/replay/nn/lightning/module.py index ba6d237b6..ad5f837cc 100644 --- a/replay/nn/lightning/module.py +++ b/replay/nn/lightning/module.py @@ -40,7 +40,7 @@ def __init__( self._lr_scheduler_factory = lr_scheduler_factory self.candidates_to_score = None - def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]: + def forward(self, batch: dict, return_info: bool = False) -> Union[TrainOutput, InferenceOutput]: """ Implementation of the forward function. @@ -57,12 +57,21 @@ def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]: batch["candidates_to_score"] = self.candidates_to_score # select only args for model.forward modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters} - return self.model(**modified_batch) + return self.model(**modified_batch, return_info=return_info) def training_step(self, batch: dict) -> torch.Tensor: - model_output: TrainOutput = self(batch) - loss = model_output["loss"] + model_output: TrainOutput = self(batch, return_info=True) + loss, info = model_output["loss"], model_output.get("info", None) lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate + if info is not None: + assert isinstance(info, dict) + self.log_dict( + dictionary=info, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True) self.log( "train_loss", diff --git a/replay/nn/loss/__init__.py b/replay/nn/loss/__init__.py index 33721b35e..eed615375 100644 --- a/replay/nn/loss/__init__.py +++ b/replay/nn/loss/__init__.py @@ -1,9 +1,9 @@ -from .base import LossProto +from .base import LossInfo, LossOutput, LossProto from .bce import BCE, BCESampled from .ce import CE, CESampled, CESampledWeighted, CEWeighted +from .composed import ComposedLoss from .login_ce import LogInCE, LogInCESampled from .logout_ce import LogOutCE, LogOutCEWeighted -from .composed import ComposedLoss LogOutCESampled = CE @@ -11,14 +11,16 @@ "BCE", "CE", "BCESampled", - "ComposedLoss", "CESampled", "CESampledWeighted", "CEWeighted", + "ComposedLoss", "LogInCE", "LogInCESampled", "LogOutCE", "LogOutCESampled", "LogOutCEWeighted", + "LossInfo", + "LossOutput", "LossProto", ] diff --git a/replay/nn/loss/base.py b/replay/nn/loss/base.py index b5b776bd6..841c45d60 100644 --- a/replay/nn/loss/base.py +++ b/replay/nn/loss/base.py @@ -4,6 +4,8 @@ from replay.data.nn import TensorMap +LossInfo = dict[str, torch.Tensor | float] +LossOutput = tuple[torch.Tensor, None] | tuple[torch.Tensor, LossInfo] LogitsCallback = Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor] @@ -26,7 +28,19 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: ... + return_info: bool = False, + ) -> LossOutput: ... + + def __call__( + self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + return_info: bool = False, + ) -> LossOutput: ... class SampledLossOutput(TypedDict): @@ -41,6 +55,8 @@ class SampledLossOutput(TypedDict): class SampledLossBase(torch.nn.Module): """The base class for calculating sampled losses""" + _logits_callback: LogitsCallback | None + @property def logits_callback( self, diff --git a/replay/nn/loss/bce.py b/replay/nn/loss/bce.py index 2f5901d76..7e354646c 100644 --- a/replay/nn/loss/bce.py +++ b/replay/nn/loss/bce.py @@ -4,7 +4,7 @@ from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class BCE(torch.nn.Module): @@ -28,7 +28,7 @@ def __init__(self, **kwargs): @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -46,7 +46,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -57,7 +57,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -92,7 +93,11 @@ def forward( ) loss = self._loss(logits, bce_labels) / logits.size(0) - return loss + + if return_info: + return (loss, {"BCE": loss.detach()}) + else: + return (loss, None) class BCESampled(SampledLossBase): @@ -159,7 +164,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -213,4 +219,7 @@ def forward( loss = -(positive_loss + negative_loss) loss /= positive_logits.size(0) - return loss + if return_info: + return (loss, {"BCESampled": loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/ce.py b/replay/nn/loss/ce.py index 07f6e7b17..6286b9038 100644 --- a/replay/nn/loss/ce.py +++ b/replay/nn/loss/ce.py @@ -4,7 +4,7 @@ from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class CE(torch.nn.Module): @@ -20,12 +20,12 @@ def __init__(self, **kwargs): """ super().__init__() self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LossOutput: """ Property for calling a function for the logits computation.\n @@ -43,7 +43,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LossOutput) -> None: self._logits_callback = func def forward( @@ -54,7 +54,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -78,7 +79,11 @@ def forward( # [batch_size, seq_len, 1] -> [batch_size * seq_len] labels_flat: torch.LongTensor = labels.view(-1) loss = self._loss(logits_flat, labels_flat) - return loss + + if return_info: + return (loss, {"CE": loss.detach()}) + else: + return (loss, None) class CEWeighted(CE): @@ -116,7 +121,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask) :param feature_tensors: a dictionary of tensors from dataloader. @@ -140,7 +146,11 @@ def forward( ) sample_weight = feature_tensors[self.feature_name] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {"CEWeighted": loss.detach()}) + else: + return (loss, None) class CESampled(SampledLossBase): @@ -170,7 +180,7 @@ def __init__( super().__init__() self.negative_labels_ignore_index = negative_labels_ignore_index self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None @property def logits_callback( @@ -193,7 +203,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -204,7 +214,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -246,7 +257,11 @@ def forward( target = torch.zeros(positive_logits.size(0), dtype=torch.long, device=logits.device) # [masked_batch_size] - loss for all recommendation points loss = self._loss(logits, target) - return loss + + if return_info: + return (loss, {"BCE": loss.detach()}) + else: + return (loss, None) class CESampledWeighted(CESampled): @@ -293,7 +308,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, negative_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -314,4 +330,8 @@ def forward( sample_weight = feature_tensors[self.feature_name] sample_weight = sample_weight[target_padding_mask] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {"CESampledWeighted": loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py index 7de0a0371..d18577244 100644 --- a/replay/nn/loss/composed.py +++ b/replay/nn/loss/composed.py @@ -5,7 +5,7 @@ from replay.data.nn import TensorMap -from .base import LogitsCallback, LossProto +from .base import LogitsCallback, LossInfo, LossOutput, LossProto Weights = dict[str, torch.Tensor | float] @@ -48,6 +48,9 @@ def __init__( continue elif isinstance(weight, torch.Tensor): assert torch.is_tensor(weight) + if torch.numel(weight) > 1: + msg: str = f"Too many values in weight: {torch.numel(weight)=}." + raise ValueError(msg) continue else: msg: str = f"Unsupported type of weight value. Must be `float` or `Tensor`. Got: {type(weight)=}." @@ -80,11 +83,11 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: - losses = 0.0 + return_info: bool = False, + ) -> LossOutput: + raw_losses: dict[str, torch.Tensor] = {} for name, loss in self.losses.items(): - loss_weight = self.weights.get(name, 1.0) - loss_value: torch.Tensor = loss( + raw_losses[name] = loss( model_embeddings, feature_tensors, positive_labels, @@ -92,5 +95,14 @@ def forward( padding_mask, target_padding_mask, ) - losses = losses + loss_weight * loss_value - return cast(torch.Tensor, losses) + + loss = cast(torch.Tensor, sum(value * self.weights.get(name, 1.0) for name, value in raw_losses.items())) + + if return_info: + info: LossInfo = { + "ComposedLoss": loss.detach(), + **{name: value.detach() for name, value in raw_losses.items()}, + } + return (loss, info) + else: + return (loss, None) diff --git a/replay/nn/loss/login_ce.py b/replay/nn/loss/login_ce.py index c6702c229..44d1bdf8c 100644 --- a/replay/nn/loss/login_ce.py +++ b/replay/nn/loss/login_ce.py @@ -4,7 +4,7 @@ from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LossOutput, SampledLossBase, mask_negative_logits class LogInCESampledOutput(TypedDict): @@ -174,7 +174,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -235,7 +236,12 @@ def forward( -self.clamp_border, self.clamp_border, ) - return loss.mean() + loss = loss.mean() + + if return_info: + return (loss, {"LogInCE": loss.detach()}) + else: + return (loss, None) class LogInCESampled(LogInCEBase): @@ -310,7 +316,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -370,4 +377,9 @@ def forward( -self.clamp_border, self.clamp_border, ) - return loss.mean() + loss = loss.mean() + + if return_info: + return (loss, {"LogInCESampled": loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/logout_ce.py b/replay/nn/loss/logout_ce.py index d0d5bba32..b5a24b482 100644 --- a/replay/nn/loss/logout_ce.py +++ b/replay/nn/loss/logout_ce.py @@ -4,7 +4,7 @@ from replay.data.nn import TensorMap -from .base import mask_negative_logits +from .base import LossOutput, mask_negative_logits class LogOutCE(torch.nn.Module): @@ -80,7 +80,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -144,7 +145,11 @@ def forward( target = torch.zeros(logits.size(0), dtype=torch.long, device=positive_labels.device) # [masked_batch_size] - loss for all recommendation points loss = self._loss(logits, target) - return loss + + if return_info: + return (loss, {"LogOutCE": loss.detach()}) + else: + return (loss, None) class LogOutCEWeighted(LogOutCE): @@ -205,7 +210,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -227,4 +233,8 @@ def forward( sample_weight = feature_tensors[self.feature_name] sample_weight = sample_weight[target_padding_mask] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {"LogOutCEWeighted": loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/output.py b/replay/nn/output.py index 868a6e2cd..e378f3f0d 100644 --- a/replay/nn/output.py +++ b/replay/nn/output.py @@ -3,6 +3,8 @@ import torch from typing_extensions import NotRequired +from .loss import LossInfo + class TrainOutput(TypedDict): """ @@ -17,6 +19,7 @@ class TrainOutput(TypedDict): """ loss: torch.Tensor + info: NotRequired[LossInfo | None] hidden_states: NotRequired[tuple[torch.Tensor, ...]] diff --git a/replay/nn/sequential/sasrec/agg.py b/replay/nn/sequential/sasrec/agg.py index bc2f42d5c..73e31556b 100644 --- a/replay/nn/sequential/sasrec/agg.py +++ b/replay/nn/sequential/sasrec/agg.py @@ -43,9 +43,9 @@ def forward(self, feature_tensors: TensorMap) -> torch.Tensor: seqs: torch.Tensor = self.embedding_aggregator(feature_tensors) assert seqs.dim() == 3 batch_size, seq_len, embedding_dim = seqs.size() - assert ( - seq_len <= self.pe.num_embeddings - ), f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}" + assert seq_len <= self.pe.num_embeddings, ( + f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}" + ) seqs *= embedding_dim**0.5 seqs += self.pe.weight[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/replay/nn/sequential/sasrec/model.py b/replay/nn/sequential/sasrec/model.py index ed2539dd2..ea4ff2d9e 100644 --- a/replay/nn/sequential/sasrec/model.py +++ b/replay/nn/sequential/sasrec/model.py @@ -270,23 +270,26 @@ def forward_train( positive_labels: torch.LongTensor, negative_labels: torch.LongTensor, target_padding_mask: torch.BoolTensor, + return_info: bool = False, ) -> TrainOutput: hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask) assert hidden_states.dim() == 3 - loss: torch.Tensor = self.loss( + loss, info = self.loss( model_embeddings=hidden_states, feature_tensors=feature_tensors, positive_labels=positive_labels, negative_labels=negative_labels, padding_mask=padding_mask, target_padding_mask=target_padding_mask, + return_info=return_info, ) - return { - "loss": loss, - "hidden_states": (hidden_states,), - } + return TrainOutput( + loss=loss, + info=info, + hidden_states=(hidden_states,), + ) def forward_inference( self, @@ -313,6 +316,7 @@ def forward( positive_labels: Optional[torch.LongTensor] = None, negative_labels: Optional[torch.LongTensor] = None, target_padding_mask: Optional[torch.BoolTensor] = None, + return_info: bool = False, ) -> Union[TrainOutput, InferenceOutput]: """ :param feature_tensors: a dictionary of tensors to generate embeddings. @@ -358,6 +362,7 @@ def forward( positive_labels=positive_labels, negative_labels=negative_labels, target_padding_mask=target_padding_mask, + return_info=return_info, ) all( diff --git a/replay/nn/sequential/twotower/model.py b/replay/nn/sequential/twotower/model.py index 6016ae94f..8e580d8f7 100644 --- a/replay/nn/sequential/twotower/model.py +++ b/replay/nn/sequential/twotower/model.py @@ -542,6 +542,7 @@ def forward_train( positive_labels: torch.LongTensor, negative_labels: torch.LongTensor, target_padding_mask: torch.BoolTensor, + return_info: bool = False, ) -> TrainOutput: hidden_states = () query_hidden_states: torch.Tensor = self.body.query_tower( @@ -559,17 +560,19 @@ def forward_train( assert query_hidden_states.dim() == 3 hidden_states += (query_hidden_states,) - loss: torch.Tensor = self.loss( + loss, info = self.loss( model_embeddings=query_hidden_states, feature_tensors=feature_tensors, positive_labels=positive_labels, negative_labels=negative_labels, padding_mask=padding_mask, target_padding_mask=target_padding_mask, + return_info=return_info, ) return TrainOutput( loss=loss, + info=info, hidden_states=hidden_states, ) @@ -612,6 +615,7 @@ def forward( positive_labels: Optional[torch.LongTensor] = None, negative_labels: Optional[torch.LongTensor] = None, target_padding_mask: Optional[torch.BoolTensor] = None, + return_info: bool = False, ) -> Union[TrainOutput, InferenceOutput]: """ :param feature_tensors: a dictionary of tensors to generate embeddings. From e6f18a1b3efbbbe63ef47eea3e834860e9cb3b3e Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Wed, 4 Feb 2026 17:48:28 +0300 Subject: [PATCH 3/6] Ruff applied --- replay/nn/loss/ce.py | 4 ++-- replay/nn/sequential/twotower/model.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/replay/nn/loss/ce.py b/replay/nn/loss/ce.py index 6286b9038..8f380cf64 100644 --- a/replay/nn/loss/ce.py +++ b/replay/nn/loss/ce.py @@ -54,7 +54,7 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - return_info: bool = False + return_info: bool = False, ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) @@ -79,7 +79,7 @@ def forward( # [batch_size, seq_len, 1] -> [batch_size * seq_len] labels_flat: torch.LongTensor = labels.view(-1) loss = self._loss(logits_flat, labels_flat) - + if return_info: return (loss, {"CE": loss.detach()}) else: diff --git a/replay/nn/sequential/twotower/model.py b/replay/nn/sequential/twotower/model.py index 8e580d8f7..eaa4874cf 100644 --- a/replay/nn/sequential/twotower/model.py +++ b/replay/nn/sequential/twotower/model.py @@ -661,6 +661,7 @@ def forward( positive_labels=positive_labels, negative_labels=negative_labels, target_padding_mask=target_padding_mask, + return_info=return_info, ) all( From 4d383dbd2e927ea4cb80be90b8d57ed14b038def Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Thu, 5 Feb 2026 08:21:04 +0300 Subject: [PATCH 4/6] Minor refactoring --- examples/09_sasrec_example.ipynb | 98 ++++++++++++++---------------- examples/15_twotower_example.ipynb | 33 ++++------ replay/nn/loss/base.py | 2 + replay/nn/loss/bce.py | 21 ++++--- replay/nn/loss/ce.py | 36 ++++++----- replay/nn/loss/composed.py | 94 +++++++++++++++++++++------- replay/nn/loss/login_ce.py | 26 ++++---- replay/nn/loss/logout_ce.py | 18 +++--- 8 files changed, 188 insertions(+), 140 deletions(-) diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb index ecc849764..4038021fd 100644 --- a/examples/09_sasrec_example.ipynb +++ b/examples/09_sasrec_example.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -26,7 +26,7 @@ "42" ] }, - "execution_count": 1, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -186,7 +186,7 @@ "[1000209 rows x 3 columns]" ] }, - "execution_count": 3, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -208,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -326,7 +326,7 @@ "[1000209 rows x 3 columns]" ] }, - "execution_count": 4, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -387,7 +387,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -401,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -532,7 +532,7 @@ "[6040 rows x 3 columns]" ] }, - "execution_count": 7, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -558,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -574,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -592,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -628,7 +628,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -683,7 +683,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -694,7 +694,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -718,7 +718,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -761,7 +761,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -790,7 +790,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -816,7 +816,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -825,8 +825,6 @@ "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode | FLOPs\n", @@ -845,7 +843,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "92d2cf620e634e478ee28b385cacc445", + "model_id": "3d4b2a25f75e41faba84295680abc9be", "version_major": 2, "version_minor": 0 }, @@ -868,7 +866,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e9a6ad3f3b8d45c0895b0a9f12e32f7f", + "model_id": "874cf01cadc74a3abab656fbdc19121c", "version_major": 2, "version_minor": 0 }, @@ -882,7 +880,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9716d4189b4a49b6b2cf5bfe81bec5a5", + "model_id": "fc3bd1a21dff4d6d9de53f54971cbaa4", "version_major": 2, "version_minor": 0 }, @@ -922,7 +920,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e1289d4fdf454f32b41c289ee7e1c2da", + "model_id": "a4040344f9cd47cdb6dc44e8ed527b9a", "version_major": 2, "version_minor": 0 }, @@ -962,7 +960,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8a58c575bc3d4132915191f7e0621fdc", + "model_id": "41e5ebf1a38849a0bf447d6f669f867a", "version_major": 2, "version_minor": 0 }, @@ -1002,7 +1000,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2276b9086a564859a2af69c0a7eb4c1a", + "model_id": "621212bec8064e978a66731bc87a3eba", "version_major": 2, "version_minor": 0 }, @@ -1042,7 +1040,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3626b1e4f6f14561b885dac0bfe3bf56", + "model_id": "f8c8ac6854d84543ae66b1249643668f", "version_major": 2, "version_minor": 0 }, @@ -1057,7 +1055,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.14804 (best 0.14804), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.14804 (best 0.14804), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -1115,16 +1113,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt'" + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt'" ] }, - "execution_count": 18, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } @@ -1146,7 +1144,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -1181,7 +1179,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -1216,20 +1214,14 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", @@ -1239,7 +1231,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5aa1ce20ed684f9b8b8a6b1ca5b9c62c", + "model_id": "4c9752b93ff34c15a3caeb8291efce50", "version_major": 2, "version_minor": 0 }, @@ -1281,7 +1273,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 51, "metadata": {}, "outputs": [ { @@ -1399,7 +1391,7 @@ "[120760 rows x 3 columns]" ] }, - "execution_count": 25, + "execution_count": 51, "metadata": {}, "output_type": "execute_result" } @@ -1420,7 +1412,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -1430,7 +1422,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -1443,7 +1435,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 54, "metadata": {}, "outputs": [ { @@ -1506,7 +1498,7 @@ "Recall 0.015734 0.147897 0.233024 0.084796" ] }, - "execution_count": 28, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -1524,7 +1516,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 55, "metadata": {}, "outputs": [ { @@ -1642,7 +1634,7 @@ "[120760 rows x 3 columns]" ] }, - "execution_count": 29, + "execution_count": 55, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/15_twotower_example.ipynb b/examples/15_twotower_example.ipynb index 9164b0a36..bac925a97 100644 --- a/examples/15_twotower_example.ipynb +++ b/examples/15_twotower_example.ipynb @@ -700,18 +700,7 @@ "cell_type": "code", "execution_count": 10, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:964: UserWarning: There is already LabelEncoder object saved at the given path. File will be overwrited.\n", - " warnings.warn(msg)\n", - "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:537: UserWarning: There is already LabelEncodingRule object saved at the given path. File will be overwrited.\n", - " warnings.warn(msg)\n" - ] - } - ], + "outputs": [], "source": [ "train_events.to_parquet(TRAIN_PATH)\n", "validation_events.to_parquet(VAL_PATH)\n", @@ -836,7 +825,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_766094/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + "/tmp/ipykernel_775501/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", " parquet_module = ParquetModule(\n" ] } @@ -956,7 +945,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "108862bbd4de48d2804f302e5cf4e997", + "model_id": "30be06dd73e4450f8fc37ee83bbd1746", "version_major": 2, "version_minor": 0 }, @@ -979,7 +968,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fb965f6b1006499fbfa5de8c8ed96e5b", + "model_id": "dd8c6b547ea84fd4abcbe132b7da42cb", "version_major": 2, "version_minor": 0 }, @@ -993,7 +982,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bc8d696ab44d4a87ac4dcbf29e2250f2", + "model_id": "301d175de5f24790b75dc4943e83e687", "version_major": 2, "version_minor": 0 }, @@ -1033,7 +1022,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4d8c41219aac45d39de0bb26313a9239", + "model_id": "01468999f3674e389b8416c5e8f43a46", "version_major": 2, "version_minor": 0 }, @@ -1073,7 +1062,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2ac5ed9b53d144e1a776759aa104d5cd", + "model_id": "148c7a181eda40d98249e05b50e4bfa1", "version_major": 2, "version_minor": 0 }, @@ -1113,7 +1102,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3ae8fabaec204088a33a973536c4c366", + "model_id": "fced5bd80eb342909431d34449b9ca88", "version_major": 2, "version_minor": 0 }, @@ -1153,7 +1142,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4f25edcbbd5842349605a5dbfa0275ed", + "model_id": "3a21e049ada84d73b3aef467158ee42d", "version_major": 2, "version_minor": 0 }, @@ -1303,7 +1292,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_766094/2227877614.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + "/tmp/ipykernel_775501/2227877614.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", " parquet_module = ParquetModule(\n" ] } @@ -1348,7 +1337,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c03aa242b0b3403fb6405c310f7d505d", + "model_id": "527412a8a14545399dfad9ab127e3db3", "version_major": 2, "version_minor": 0 }, diff --git a/replay/nn/loss/base.py b/replay/nn/loss/base.py index 841c45d60..996fcde9c 100644 --- a/replay/nn/loss/base.py +++ b/replay/nn/loss/base.py @@ -12,6 +12,8 @@ class LossProto(Protocol): """Class-protocol for working with losses inside models""" + loss_name: str + @property def logits_callback( self, diff --git a/replay/nn/loss/bce.py b/replay/nn/loss/bce.py index 7e354646c..62e38edc3 100644 --- a/replay/nn/loss/bce.py +++ b/replay/nn/loss/bce.py @@ -1,5 +1,3 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap @@ -16,14 +14,15 @@ class BCE(torch.nn.Module): (there are several labels for each position in the sequence). """ - def __init__(self, **kwargs): + def __init__(self, loss_name: str = "BCELoss", **kwargs): """ To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``. You can pass all other parameters for initializing the object via kwargs. """ super().__init__() self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( @@ -95,7 +94,7 @@ def forward( loss = self._loss(logits, bce_labels) / logits.size(0) if return_info: - return (loss, {"BCE": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) @@ -114,7 +113,8 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, - ): + loss_name: str = "BCESampledLoss", + ) -> None: """ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating. Default: ``1e-6``. @@ -130,12 +130,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -153,7 +154,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -220,6 +221,6 @@ def forward( loss /= positive_logits.size(0) if return_info: - return (loss, {"BCESampled": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) diff --git a/replay/nn/loss/ce.py b/replay/nn/loss/ce.py index 8f380cf64..152afc563 100644 --- a/replay/nn/loss/ce.py +++ b/replay/nn/loss/ce.py @@ -1,5 +1,3 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap @@ -13,19 +11,20 @@ class CE(torch.nn.Module): Calculates loss over all items catalog. """ - def __init__(self, **kwargs): + def __init__(self, loss_name: str = "CELoss", **kwargs): """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used. You can pass all parameters for initializing the object via kwargs. """ super().__init__() + self.loss_name: str = loss_name self._loss = torch.nn.CrossEntropyLoss(**kwargs) self._logits_callback: LogitsCallback | None = None @property def logits_callback( self, - ) -> LossOutput: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -43,7 +42,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: LossOutput) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -97,11 +96,14 @@ class CEWeighted(CE): which is fed into the model. """ + loss_name: str = "CEWeightedLoss" + def __init__( self, feature_name: str, + loss_name: str = "CEWEightedLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``. You can pass all other parameters for initializing the object via kwargs. @@ -110,7 +112,8 @@ def __init__( The tensor is expected to contain sample weights. """ super().__init__() - self.feature_name = feature_name + self.loss_name: str = loss_name + self.feature_name: str = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) def forward( @@ -148,7 +151,7 @@ def forward( loss = (loss * sample_weight).mean() if return_info: - return (loss, {"CEWeighted": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) @@ -165,8 +168,9 @@ class CESampled(SampledLossBase): def __init__( self, negative_labels_ignore_index: int = -100, + loss_name: str = "CESampledLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used. You can pass all parameters for initializing the object via kwargs. @@ -178,6 +182,7 @@ def __init__( Default: ``-100``. """ super().__init__() + self.loss_name: str = loss_name self.negative_labels_ignore_index = negative_labels_ignore_index self._loss = torch.nn.CrossEntropyLoss(**kwargs) self._logits_callback: LogitsCallback | None = None @@ -185,7 +190,7 @@ def __init__( @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -259,7 +264,7 @@ def forward( loss = self._loss(logits, target) if return_info: - return (loss, {"BCE": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) @@ -282,8 +287,9 @@ def __init__( self, feature_name: str, negative_labels_ignore_index: int = -100, + loss_name: str = "CESampledWeighedLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``. You can pass all other parameters for initializing the object via kwargs. @@ -297,7 +303,9 @@ def __init__( Default: ``-100``. """ super().__init__(negative_labels_ignore_index=negative_labels_ignore_index) - self.feature_name = feature_name + + self.loss_name: str = loss_name + self.feature_name: str = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) def forward( @@ -332,6 +340,6 @@ def forward( loss = (loss * sample_weight).mean() if return_info: - return (loss, {"CESampledWeighted": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py index d18577244..aff5dbe37 100644 --- a/replay/nn/loss/composed.py +++ b/replay/nn/loss/composed.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Self, cast +from typing import Iterable, Self, cast import torch @@ -8,20 +8,41 @@ from .base import LogitsCallback, LossInfo, LossOutput, LossProto Weights = dict[str, torch.Tensor | float] +Losses = Iterable[torch.nn.Module] | dict[str, torch.nn.Module] | torch.nn.ModuleDict class ComposedLoss(torch.nn.Module): def __init__( - self: Self, losses: dict[str, torch.nn.Module] | torch.nn.ModuleDict, weights: Weights | None = None + self: Self, + losses: Losses, + weights: Weights | None = None, + loss_name: str = "ComposedLoss", ) -> None: super().__init__() - if isinstance(losses, dict): - for loss in cast(dict, losses.values()): - if not isinstance(loss, torch.nn.Module): - msg: str = f"Unsupported type of loss. Must be `Module`. Got: {type(loss)=}." - raise TypeError(msg) - losses = torch.nn.ModuleDict(losses) + self.losses: torch.nn.ModuleDict = self._handle_losses(losses) + self.weights: Weights = self._handle_weights(weights) + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name + + def _handle_losses(self: Self, losses: Losses) -> torch.nn.ModuleDict: + if not isinstance(losses, torch.nn.ModuleDict): + if isinstance(losses, dict): + for loss in cast(dict, losses.values()): + if not isinstance(loss, torch.nn.Module): + msg: str = f"Unsupported type of loss. Must be `Module`. Got: {type(loss)=}." + raise TypeError(msg) + losses_dict: dict[str, torch.nn.Module] = cast(dict[str, torch.nn.Module], losses) + else: + losses_dict: dict[str, torch.nn.Module] = {} + for loss in iter(losses): + casted: LossProto = cast(LossProto, loss) + name: str = casted.loss_name + if name in losses_dict: + msg: str = f"Loss names must be unique. Got {name} twice." + raise KeyError(name) + losses_dict[name] = loss + losses = torch.nn.ModuleDict(losses_dict) if not isinstance(losses, torch.nn.ModuleDict): msg: str = f"Unsupported type of `losses`. Must be `dict` or `ModuleDict`. Got {type(losses)=}." @@ -31,13 +52,14 @@ def __init__( msg: str = "Empty losses are not supported." raise ValueError(msg) - self.losses: torch.nn.ModuleDict = cast(torch.nn.ModuleDict, losses) + return cast(torch.nn.ModuleDict, losses) + def _handle_weights(self: Self, weights: Weights | None) -> Weights: if weights is None: weights = {} - - if not isinstance(weights, dict): + elif not isinstance(weights, dict): msg: str = f"Unsupported type of `weights`. Must be `dict`. Got: {type(weights)=}." + raise TypeError(msg) for name, weight in cast(dict, weights): if name not in self.losses: @@ -56,9 +78,7 @@ def __init__( msg: str = f"Unsupported type of weight value. Must be `float` or `Tensor`. Got: {type(weight)=}." raise TypeError(msg) - self.weights: dict[str, torch.Tensor | float] = cast(Weights, weights) - - self._logits_callback: Optional[LogitsCallback] = None + return cast(Weights, weights) @property def logits_callback(self: Self) -> LogitsCallback: @@ -75,7 +95,7 @@ def logits_callback(self: Self, func: LogitsCallback) -> None: casted = cast(LossProto, loss) casted.logits_callback = func - def forward( + def _compute_raw_losses( self: Self, model_embeddings: torch.Tensor, feature_tensors: TensorMap, @@ -83,8 +103,7 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, target_padding_mask: torch.BoolTensor, - return_info: bool = False, - ) -> LossOutput: + ) -> dict[str, torch.Tensor]: raw_losses: dict[str, torch.Tensor] = {} for name, loss in self.losses.items(): raw_losses[name] = loss( @@ -95,14 +114,45 @@ def forward( padding_mask, target_padding_mask, ) + return raw_losses + + def _apply_weights(self: Self, raw_losses: dict[str, torch.Tensor]) -> torch.Tensor: + losses_list: list[torch.Tensor] = [] + + for name, value in raw_losses.items(): + weight = self.weights.get(name, 1.0) + losses_list.append(weight * value) + + losses: torch.Tensor = torch.cat(losses_list) + return torch.sum(losses) + + def _detach_dict(self: Self, raw_losses: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {name: value.detach() for name, value in raw_losses.items()} + + def forward( + self: Self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + return_info: bool = False, + ) -> LossOutput: + raw_losses: dict[str, torch.Tensor] = self._compute_raw_losses( + model_embeddings=model_embeddings, + feature_tensors=feature_tensors, + positive_labels=positive_labels, + negative_labels=negative_labels, + padding_mask=padding_mask, + target_padding_mask=target_padding_mask, + ) - loss = cast(torch.Tensor, sum(value * self.weights.get(name, 1.0) for name, value in raw_losses.items())) + loss: torch.Tensor = self._apply_weights(raw_losses) if return_info: - info: LossInfo = { - "ComposedLoss": loss.detach(), - **{name: value.detach() for name, value in raw_losses.items()}, - } + base_info: LossInfo = self._detach_dict(raw_losses) + info: LossInfo = {self.loss_name: loss.detach(), **base_info} return (loss, info) else: return (loss, None) diff --git a/replay/nn/loss/login_ce.py b/replay/nn/loss/login_ce.py index 44d1bdf8c..cf78312eb 100644 --- a/replay/nn/loss/login_ce.py +++ b/replay/nn/loss/login_ce.py @@ -1,10 +1,10 @@ -from typing import Callable, Optional, TypedDict +from typing import TypedDict import torch from replay.data.nn import TensorMap -from .base import LossOutput, SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class LogInCESampledOutput(TypedDict): @@ -121,7 +121,8 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, - ): + loss_name: str = "LogInCELoss", + ) -> None: """ :param cardinality: number of unique items in vocabulary (catalog). The specified cardinality value must not take into account the padding value. @@ -140,12 +141,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -163,7 +165,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -239,7 +241,7 @@ def forward( loss = loss.mean() if return_info: - return (loss, {"LogInCE": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) @@ -266,6 +268,7 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, + loss_name: str = "LogInCESampledLoss", ): """ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating. @@ -282,12 +285,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -305,7 +309,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -380,6 +384,6 @@ def forward( loss = loss.mean() if return_info: - return (loss, {"LogInCESampled": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) diff --git a/replay/nn/loss/logout_ce.py b/replay/nn/loss/logout_ce.py index b5a24b482..9299af070 100644 --- a/replay/nn/loss/logout_ce.py +++ b/replay/nn/loss/logout_ce.py @@ -1,10 +1,8 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap -from .base import LossOutput, mask_negative_logits +from .base import LogitsCallback, LossOutput, mask_negative_logits class LogOutCE(torch.nn.Module): @@ -28,6 +26,7 @@ def __init__( self, cardinality: int, negative_labels_ignore_index: int = -100, + loss_name: str = "LogOutCELoss", **kwargs, ): """ @@ -46,12 +45,13 @@ def __init__( self.cardinality = cardinality self.negative_labels_ignore_index = negative_labels_ignore_index self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -69,7 +69,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -147,7 +147,7 @@ def forward( loss = self._loss(logits, target) if return_info: - return (loss, {"LogOutCE": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) @@ -179,6 +179,7 @@ def __init__( cardinality: int, feature_name: str, negative_labels_ignore_index: int = -100, + loss_name: str = "LogOutCEWeightedLoss", **kwargs, ): """ @@ -201,6 +202,7 @@ def __init__( ) self.feature_name = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) + self.loss_name: str = loss_name def forward( self, @@ -235,6 +237,6 @@ def forward( loss = (loss * sample_weight).mean() if return_info: - return (loss, {"LogOutCEWeighted": loss.detach()}) + return (loss, {self.loss_name: loss.detach()}) else: return (loss, None) From baa10247a8fd5e574091b3243607254cca6ca693 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Thu, 5 Feb 2026 08:24:42 +0300 Subject: [PATCH 5/6] Minor updates --- replay/nn/loss/composed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py index aff5dbe37..7d93b06cc 100644 --- a/replay/nn/loss/composed.py +++ b/replay/nn/loss/composed.py @@ -151,7 +151,7 @@ def forward( loss: torch.Tensor = self._apply_weights(raw_losses) if return_info: - base_info: LossInfo = self._detach_dict(raw_losses) + base_info: dict[str, torch.Tensor] = self._detach_dict(raw_losses) info: LossInfo = {self.loss_name: loss.detach(), **base_info} return (loss, info) else: From aea482c84991946d2454705342be81351e34ff18 Mon Sep 17 00:00:00 2001 From: Nikita Kulikov Date: Thu, 5 Feb 2026 09:14:40 +0300 Subject: [PATCH 6/6] Changes in example to show comosable loss --- examples/09_sasrec_example.ipynb | 309 +++++++++++++++++-------------- replay/nn/loss/composed.py | 5 +- replay/nn/loss/seen.py | 0 3 files changed, 168 insertions(+), 146 deletions(-) delete mode 100644 replay/nn/loss/seen.py diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb index 4038021fd..46de00ac6 100644 --- a/examples/09_sasrec_example.ipynb +++ b/examples/09_sasrec_example.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -10,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -26,7 +36,7 @@ "42" ] }, - "execution_count": 30, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -58,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -186,7 +196,7 @@ "[1000209 rows x 3 columns]" ] }, - "execution_count": 32, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -208,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -326,7 +336,7 @@ "[1000209 rows x 3 columns]" ] }, - "execution_count": 33, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -387,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -401,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -532,7 +542,7 @@ "[6040 rows x 3 columns]" ] }, - "execution_count": 36, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -558,7 +568,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -574,7 +584,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -592,7 +602,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -628,7 +638,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -683,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -694,7 +704,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -718,14 +728,14 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_752239/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + "/tmp/ipykernel_37022/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", " parquet_module = ParquetModule(\n" ] } @@ -761,11 +771,12 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from replay.nn.sequential import SasRec\n", + "from replay.nn.loss import ComposedLoss, CE, BCE\n", "\n", "NUM_BLOCKS = 2\n", "NUM_HEADS = 2\n", @@ -778,7 +789,15 @@ " num_heads=NUM_HEADS,\n", " num_blocks=NUM_BLOCKS,\n", " dropout=DROPOUT,\n", - ")" + ")\n", + "\n", + "sasrec.loss = ComposedLoss(\n", + " losses = [\n", + " CE(ignore_index=tensor_schema.item_id_features.item().padding_value),\n", + " BCE()\n", + " ]\n", + ")\n", + "sasrec.loss.logits_callback = sasrec.get_logits" ] }, { @@ -790,7 +809,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -816,7 +835,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -825,6 +844,8 @@ "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params | Mode | FLOPs\n", @@ -835,7 +856,7 @@ "0 Non-trainable params\n", "291 K Total params\n", "1.164 Total estimated model params size (MB)\n", - "39 Modules in train mode\n", + "43 Modules in train mode\n", "0 Modules in eval mode\n", "0 Total Flops\n" ] @@ -843,7 +864,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3d4b2a25f75e41faba84295680abc9be", + "model_id": "b6dfa4dd9a2a41c4acdcfeac89157c0f", "version_major": 2, "version_minor": 0 }, @@ -866,7 +887,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "874cf01cadc74a3abab656fbdc19121c", + "model_id": "68ad24eb0bb14809b761afba56ea080d", "version_major": 2, "version_minor": 0 }, @@ -880,7 +901,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fc3bd1a21dff4d6d9de53f54971cbaa4", + "model_id": "ef579d347f4f47f986a64875ecccadee", "version_major": 2, "version_minor": 0 }, @@ -895,7 +916,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0, global step 189: 'recall@10' reached 0.03643 (best 0.03643), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" + "Epoch 0, global step 189: 'recall@10' reached 0.01954 (best 0.01954), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" ] }, { @@ -903,9 +924,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.003643 0.011142 0.013291 0.009014\n", - "ndcg 0.003643 0.016971 0.024719 0.011761\n", - "recall 0.003643 0.036430 0.066898 0.020202\n", + "map 0.003809 0.006980 0.008672 0.005729\n", + "ndcg 0.003809 0.009822 0.016091 0.006684\n", + "recall 0.003809 0.019540 0.044544 0.009604\n", "\n" ] }, @@ -920,7 +941,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a4040344f9cd47cdb6dc44e8ed527b9a", + "model_id": "3448560f23764ea6a4bedcc0349a529b", "version_major": 2, "version_minor": 0 }, @@ -935,7 +956,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1, global step 378: 'recall@10' reached 0.08843 (best 0.08843), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" + "Epoch 1, global step 378: 'recall@10' reached 0.02500 (best 0.02500), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" ] }, { @@ -943,9 +964,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.011095 0.028992 0.033167 0.024110\n", - "ndcg 0.011095 0.042700 0.058069 0.030709\n", - "recall 0.011095 0.088425 0.149528 0.051002\n", + "map 0.003809 0.009115 0.010547 0.007866\n", + "ndcg 0.003809 0.012794 0.018238 0.009690\n", + "recall 0.003809 0.025004 0.047028 0.015234\n", "\n" ] }, @@ -960,7 +981,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "41e5ebf1a38849a0bf447d6f669f867a", + "model_id": "754ca088536d48099db8255917dfea00", "version_major": 2, "version_minor": 0 }, @@ -975,7 +996,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 2, global step 567: 'recall@10' reached 0.12171 (best 0.12171), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" + "Epoch 2, global step 567: 'recall@10' reached 0.02534 (best 0.02534), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" ] }, { @@ -983,9 +1004,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.013413 0.038355 0.043417 0.031186\n", - "ndcg 0.013413 0.057557 0.076368 0.040041\n", - "recall 0.013413 0.121709 0.196887 0.067230\n", + "map 0.003809 0.009250 0.010724 0.007827\n", + "ndcg 0.003809 0.012979 0.018452 0.009511\n", + "recall 0.003809 0.025335 0.047193 0.014572\n", "\n" ] }, @@ -1000,7 +1021,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "621212bec8064e978a66731bc87a3eba", + "model_id": "0f2fe27d36e64ee0889cdade17c286de", "version_major": 2, "version_minor": 0 }, @@ -1015,17 +1036,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 3, global step 756: 'recall@10' reached 0.13562 (best 0.13562), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1\n" + "Epoch 3, global step 756: 'recall@10' was not in top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 10 20 5\n", - "map 0.01275 0.041396 0.047141 0.033433\n", - "ndcg 0.01275 0.063121 0.084165 0.043618\n", - "recall 0.01275 0.135618 0.219076 0.074847\n", + "k 1 10 20 5\n", + "map 0.003809 0.007385 0.009081 0.006154\n", + "ndcg 0.003809 0.010590 0.016888 0.007441\n", + "recall 0.003809 0.021527 0.046696 0.011426\n", "\n" ] }, @@ -1040,7 +1061,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f8c8ac6854d84543ae66b1249643668f", + "model_id": "d3f48b662b444c5baa3644f1eff468c7", "version_major": 2, "version_minor": 0 }, @@ -1055,7 +1076,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.14804 (best 0.14804), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.03279 (best 0.03279), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -1064,9 +1085,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.012916 0.044320 0.050978 0.035577\n", - "ndcg 0.012916 0.068219 0.092799 0.046770\n", - "recall 0.012916 0.148038 0.245902 0.081139\n", + "map 0.003146 0.009925 0.012081 0.007874\n", + "ndcg 0.003146 0.015168 0.023088 0.010099\n", + "recall 0.003146 0.032787 0.064249 0.016890\n", "\n" ] } @@ -1113,16 +1134,16 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt'" + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt'" ] }, - "execution_count": 47, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -1144,7 +1165,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1179,14 +1200,14 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_752239/1373759190.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + "/tmp/ipykernel_37022/1373759190.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", " parquet_module = ParquetModule(\n" ] } @@ -1214,7 +1235,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1231,7 +1252,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4c9752b93ff34c15a3caeb8291efce50", + "model_id": "8c6243a9d5a64606864f5dba6caf3018", "version_major": 2, "version_minor": 0 }, @@ -1273,7 +1294,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1306,32 +1327,32 @@ " \n", " 0\n", " 0\n", - " 354\n", - " 5.776358\n", + " 2426\n", + " -5.474451\n", " \n", " \n", " 0\n", " 0\n", - " 740\n", - " 5.754249\n", + " 2557\n", + " -5.80117\n", " \n", " \n", " 0\n", " 0\n", - " 574\n", - " 5.660694\n", + " 466\n", + " -5.805716\n", " \n", " \n", " 0\n", " 0\n", - " 1439\n", - " 5.478026\n", + " 627\n", + " -5.844583\n", " \n", " \n", " 0\n", " 0\n", - " 1551\n", - " 5.333538\n", + " 574\n", + " -5.853208\n", " \n", " \n", " ...\n", @@ -1342,32 +1363,32 @@ " \n", " 6037\n", " 6039\n", - " 1656\n", - " 3.195577\n", + " 2775\n", + " -5.961242\n", " \n", " \n", " 6037\n", " 6039\n", - " 327\n", - " 3.120715\n", + " 3341\n", + " -6.006827\n", " \n", " \n", " 6037\n", " 6039\n", - " 1820\n", - " 3.099917\n", + " 2708\n", + " -6.026894\n", " \n", " \n", " 6037\n", " 6039\n", - " 515\n", - " 3.07936\n", + " 2959\n", + " -6.033692\n", " \n", " \n", " 6037\n", " 6039\n", - " 1306\n", - " 3.031433\n", + " 1618\n", + " -6.041715\n", " \n", " \n", "\n", @@ -1376,22 +1397,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 0 354 5.776358\n", - "0 0 740 5.754249\n", - "0 0 574 5.660694\n", - "0 0 1439 5.478026\n", - "0 0 1551 5.333538\n", + "0 0 2426 -5.474451\n", + "0 0 2557 -5.80117\n", + "0 0 466 -5.805716\n", + "0 0 627 -5.844583\n", + "0 0 574 -5.853208\n", "... ... ... ...\n", - "6037 6039 1656 3.195577\n", - "6037 6039 327 3.120715\n", - "6037 6039 1820 3.099917\n", - "6037 6039 515 3.07936\n", - "6037 6039 1306 3.031433\n", + "6037 6039 2775 -5.961242\n", + "6037 6039 3341 -6.006827\n", + "6037 6039 2708 -6.026894\n", + "6037 6039 2959 -6.033692\n", + "6037 6039 1618 -6.041715\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 51, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1412,7 +1433,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -1422,7 +1443,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1435,7 +1456,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1468,37 +1489,37 @@ " \n", " \n", " MAP\n", - " 0.015734\n", - " 0.047696\n", - " 0.053510\n", - " 0.039428\n", + " 0.00265\n", + " 0.008525\n", + " 0.010557\n", + " 0.006558\n", " \n", " \n", " Precision\n", - " 0.015734\n", - " 0.014790\n", - " 0.011651\n", - " 0.016959\n", + " 0.00265\n", + " 0.002865\n", + " 0.002956\n", + " 0.002816\n", " \n", " \n", " Recall\n", - " 0.015734\n", - " 0.147897\n", - " 0.233024\n", - " 0.084796\n", + " 0.00265\n", + " 0.028652\n", + " 0.059126\n", + " 0.014078\n", " \n", " \n", "\n", "" ], "text/plain": [ - "k 1 10 20 5\n", - "MAP 0.015734 0.047696 0.053510 0.039428\n", - "Precision 0.015734 0.014790 0.011651 0.016959\n", - "Recall 0.015734 0.147897 0.233024 0.084796" + "k 1 10 20 5\n", + "MAP 0.00265 0.008525 0.010557 0.006558\n", + "Precision 0.00265 0.002865 0.002956 0.002816\n", + "Recall 0.00265 0.028652 0.059126 0.014078" ] }, - "execution_count": 54, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1516,7 +1537,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1549,32 +1570,32 @@ " \n", " 0\n", " 1\n", - " 364\n", - " 5.776358\n", + " 2628\n", + " -5.474451\n", " \n", " \n", " 0\n", " 1\n", - " 783\n", - " 5.754249\n", + " 2762\n", + " -5.80117\n", " \n", " \n", " 0\n", " 1\n", - " 588\n", - " 5.660694\n", + " 480\n", + " -5.805716\n", " \n", " \n", " 0\n", " 1\n", - " 1566\n", - " 5.478026\n", + " 648\n", + " -5.844583\n", " \n", " \n", " 0\n", " 1\n", - " 1688\n", - " 5.333538\n", + " 588\n", + " -5.853208\n", " \n", " \n", " ...\n", @@ -1585,32 +1606,32 @@ " \n", " 6037\n", " 6040\n", - " 1834\n", - " 3.195577\n", + " 2987\n", + " -5.961242\n", " \n", " \n", " 6037\n", " 6040\n", - " 337\n", - " 3.120715\n", + " 3578\n", + " -6.006827\n", " \n", " \n", " 6037\n", " 6040\n", - " 2000\n", - " 3.099917\n", + " 2916\n", + " -6.026894\n", " \n", " \n", " 6037\n", " 6040\n", - " 529\n", - " 3.07936\n", + " 3176\n", + " -6.033692\n", " \n", " \n", " 6037\n", " 6040\n", - " 1408\n", - " 3.031433\n", + " 1784\n", + " -6.041715\n", " \n", " \n", "\n", @@ -1619,22 +1640,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 1 364 5.776358\n", - "0 1 783 5.754249\n", - "0 1 588 5.660694\n", - "0 1 1566 5.478026\n", - "0 1 1688 5.333538\n", + "0 1 2628 -5.474451\n", + "0 1 2762 -5.80117\n", + "0 1 480 -5.805716\n", + "0 1 648 -5.844583\n", + "0 1 588 -5.853208\n", "... ... ... ...\n", - "6037 6040 1834 3.195577\n", - "6037 6040 337 3.120715\n", - "6037 6040 2000 3.099917\n", - "6037 6040 529 3.07936\n", - "6037 6040 1408 3.031433\n", + "6037 6040 2987 -5.961242\n", + "6037 6040 3578 -6.006827\n", + "6037 6040 2916 -6.026894\n", + "6037 6040 3176 -6.033692\n", + "6037 6040 1784 -6.041715\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 55, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py index 7d93b06cc..c52ebffdd 100644 --- a/replay/nn/loss/composed.py +++ b/replay/nn/loss/composed.py @@ -106,7 +106,7 @@ def _compute_raw_losses( ) -> dict[str, torch.Tensor]: raw_losses: dict[str, torch.Tensor] = {} for name, loss in self.losses.items(): - raw_losses[name] = loss( + value, _ = loss( model_embeddings, feature_tensors, positive_labels, @@ -114,6 +114,7 @@ def _compute_raw_losses( padding_mask, target_padding_mask, ) + raw_losses[name] = value return raw_losses def _apply_weights(self: Self, raw_losses: dict[str, torch.Tensor]) -> torch.Tensor: @@ -121,7 +122,7 @@ def _apply_weights(self: Self, raw_losses: dict[str, torch.Tensor]) -> torch.Ten for name, value in raw_losses.items(): weight = self.weights.get(name, 1.0) - losses_list.append(weight * value) + losses_list.append(weight * value[None]) losses: torch.Tensor = torch.cat(losses_list) return torch.sum(losses) diff --git a/replay/nn/loss/seen.py b/replay/nn/loss/seen.py deleted file mode 100644 index e69de29bb..000000000