diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb index 14ab84808..46de00ac6 100644 --- a/examples/09_sasrec_example.ipynb +++ b/examples/09_sasrec_example.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -10,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -19,9 +29,21 @@ "text": [ "Seed set to 42\n" ] + }, + { + "data": { + "text/plain": [ + "42" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "from typing import Optional\n", + "\n", "import lightning as L\n", "import pandas as pd\n", "\n", @@ -46,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -56,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -105,15 +127,15 @@ " 2\n", " \n", " \n", - " 1000192\n", + " 1000007\n", " 6040\n", - " 2019\n", + " 1961\n", " 3\n", " \n", " \n", - " 1000007\n", + " 1000192\n", " 6040\n", - " 1961\n", + " 2019\n", " 4\n", " \n", " \n", @@ -135,15 +157,15 @@ " 447\n", " \n", " \n", - " 825731\n", + " 825724\n", " 4958\n", - " 2634\n", + " 3264\n", " 448\n", " \n", " \n", - " 825724\n", + " 825731\n", " 4958\n", - " 3264\n", + " 2634\n", " 449\n", " \n", " \n", @@ -162,19 +184,19 @@ "1000138 6040 858 0\n", "1000153 6040 2384 1\n", "999873 6040 593 2\n", - "1000192 6040 2019 3\n", - "1000007 6040 1961 4\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", "... ... ... ...\n", "825793 4958 2399 446\n", "825438 4958 1407 447\n", - "825731 4958 2634 448\n", - "825724 4958 3264 449\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", "825603 4958 1924 450\n", "\n", "[1000209 rows x 3 columns]" ] }, - "execution_count": 30, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -196,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -228,32 +250,32 @@ " \n", " \n", " 0\n", - " 32\n", - " 0\n", + " 12\n", + " 2011\n", " 0\n", " \n", " \n", " 1\n", - " 10\n", - " 1\n", + " 68\n", + " 4078\n", " 0\n", " \n", " \n", " 2\n", - " 12\n", - " 2\n", + " 67\n", + " 4123\n", " 0\n", " \n", " \n", " 3\n", - " 339\n", - " 3\n", + " 12\n", + " 983\n", " 0\n", " \n", " \n", " 4\n", - " 144\n", - " 4\n", + " 140\n", + " 2270\n", " 0\n", " \n", " \n", @@ -264,32 +286,32 @@ " \n", " \n", " 1000204\n", - " 281\n", - " 796\n", + " 14\n", + " 855\n", " 3705\n", " \n", " \n", " 1000205\n", - " 209\n", - " 1297\n", + " 90\n", + " 1700\n", " 3705\n", " \n", " \n", " 1000206\n", - " 748\n", - " 1883\n", + " 70\n", + " 936\n", " 3705\n", " \n", " \n", " 1000207\n", - " 71\n", - " 4449\n", + " 25\n", + " 360\n", " 3705\n", " \n", " \n", " 1000208\n", - " 287\n", - " 2473\n", + " 380\n", + " 1388\n", " 3705\n", " \n", " \n", @@ -299,35 +321,36 @@ ], "text/plain": [ " timestamp user_id item_id\n", - "0 32 0 0\n", - "1 10 1 0\n", - "2 12 2 0\n", - "3 339 3 0\n", - "4 144 4 0\n", + "0 12 2011 0\n", + "1 68 4078 0\n", + "2 67 4123 0\n", + "3 12 983 0\n", + "4 140 2270 0\n", "... ... ... ...\n", - "1000204 281 796 3705\n", - "1000205 209 1297 3705\n", - "1000206 748 1883 3705\n", - "1000207 71 4449 3705\n", - "1000208 287 2473 3705\n", + "1000204 14 855 3705\n", + "1000205 90 1700 3705\n", + "1000206 70 936 3705\n", + "1000207 25 360 3705\n", + "1000208 380 1388 3705\n", "\n", "[1000209 rows x 3 columns]" ] }, - "execution_count": 31, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from replay.preprocessing import LabelEncoder, LabelEncodingRule\n", + "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n", "\n", "encoder = LabelEncoder(\n", " [\n", - " LabelEncodingRule(\"user_id\"),\n", - " LabelEncodingRule(\"item_id\"),\n", + " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n", + " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n", " ]\n", ")\n", + "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n", "encoded_interactions = encoder.fit_transform(interactions)\n", "encoded_interactions" ] @@ -343,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -374,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -388,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -422,31 +445,31 @@ " 0\n", " 0\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7...\n", + " [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117...\n", " \n", " \n", " 1\n", " 1\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8...\n", + " [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110...\n", " \n", " \n", " 2\n", " 2\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107...\n", + " [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429...\n", " \n", " \n", " 3\n", " 3\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [346, 2190, 670, 802, 323, 661, 2480, 2501, 19...\n", + " [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,...\n", " \n", " \n", " 4\n", " 4\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502...\n", + " [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ...\n", " \n", " \n", " ...\n", @@ -458,31 +481,31 @@ " 6035\n", " 6035\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248...\n", + " [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250...\n", " \n", " \n", " 6036\n", " 6036\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200...\n", + " [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ...\n", " \n", " \n", " 6037\n", " 6037\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278...\n", + " [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,...\n", " \n", " \n", " 6038\n", " 6038\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330...\n", + " [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2...\n", " \n", " \n", " 6039\n", " 6039\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317...\n", + " [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ...\n", " \n", " \n", "\n", @@ -504,22 +527,22 @@ "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", "\n", " item_id \n", - "0 [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7... \n", - "1 [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8... \n", - "2 [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107... \n", - "3 [346, 2190, 670, 802, 323, 661, 2480, 2501, 19... \n", - "4 [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502... \n", + "0 [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117... \n", + "1 [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110... \n", + "2 [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429... \n", + "3 [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,... \n", + "4 [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ... \n", "... ... \n", - "6035 [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248... \n", - "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200... \n", - "6037 [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278... \n", - "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n", - "6039 [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317... \n", + "6035 [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250... \n", + "6036 [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ... \n", + "6037 [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,... \n", + "6038 [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2... \n", + "6039 [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ... \n", "\n", "[6040 rows x 3 columns]" ] }, - "execution_count": 34, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -545,7 +568,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -561,7 +584,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -579,9 +602,20 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:964: UserWarning: There is already LabelEncoder object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n", + "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:537: UserWarning: There is already LabelEncodingRule object saved at the given path. File will be overwrited.\n", + " warnings.warn(msg)\n" + ] + } + ], "source": [ "train_events.to_parquet(TRAIN_PATH)\n", "validation_events.to_parquet(VAL_PATH)\n", @@ -604,7 +638,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -659,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -670,12 +704,10 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "from typing import Optional\n", - "\n", "MAX_SEQ_LEN = 50\n", "\n", "def create_meta(shape: int, gt_shape: Optional[int] = None):\n", @@ -696,9 +728,18 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_37022/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "from replay.data.nn import ParquetModule\n", "\n", @@ -730,11 +771,12 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from replay.nn.sequential import SasRec\n", + "from replay.nn.loss import ComposedLoss, CE, BCE\n", "\n", "NUM_BLOCKS = 2\n", "NUM_HEADS = 2\n", @@ -747,7 +789,15 @@ " num_heads=NUM_HEADS,\n", " num_blocks=NUM_BLOCKS,\n", " dropout=DROPOUT,\n", - ")" + ")\n", + "\n", + "sasrec.loss = ComposedLoss(\n", + " losses = [\n", + " CE(ignore_index=tensor_schema.item_id_features.item().padding_value),\n", + " BCE()\n", + " ]\n", + ")\n", + "sasrec.loss.logits_callback = sasrec.get_logits" ] }, { @@ -759,13 +809,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ + "from replay.nn.lightning.optimizer import OptimizerFactory\n", + "from replay.nn.lightning.scheduler import LRSchedulerFactory\n", "from replay.nn.lightning import LightningModule\n", "\n", - "model = LightningModule(sasrec)" + "model = LightningModule(\n", + " sasrec,\n", + " optimizer_factory=OptimizerFactory(),\n", + " lr_scheduler_factory=LRSchedulerFactory(),\n", + ")" ] }, { @@ -779,33 +835,36 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------\n", - "0 | model | SasRec | 291 K | train\n", - "-----------------------------------------\n", + " | Name | Type | Params | Mode | FLOPs\n", + "-------------------------------------------------\n", + "0 | model | SasRec | 291 K | train | 0 \n", + "-------------------------------------------------\n", "291 K Trainable params\n", "0 Non-trainable params\n", "291 K Total params\n", "1.164 Total estimated model params size (MB)\n", - "39 Modules in train mode\n", - "0 Modules in eval mode\n" + "43 Modules in train mode\n", + "0 Modules in eval mode\n", + "0 Total Flops\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a36bec8cdb384569a56b01dda4a8dce3", + "model_id": "b6dfa4dd9a2a41c4acdcfeac89157c0f", "version_major": 2, "version_minor": 0 }, @@ -816,10 +875,19 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3788b1f40c8648e58741b90e70a66259", + "model_id": "68ad24eb0bb14809b761afba56ea080d", "version_major": 2, "version_minor": 0 }, @@ -833,7 +901,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cdc8dd8ed34441bc82f9d868f8cac7d3", + "model_id": "ef579d347f4f47f986a64875ecccadee", "version_major": 2, "version_minor": 0 }, @@ -848,7 +916,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0, global step 189: 'recall@10' reached 0.03576 (best 0.03576), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" + "Epoch 0, global step 189: 'recall@10' reached 0.01954 (best 0.01954), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n" ] }, { @@ -856,16 +924,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.003808 0.010609 0.012859 0.008209\n", - "ndcg 0.003808 0.016343 0.024711 0.010422\n", - "recall 0.003808 0.035762 0.069205 0.017219\n", + "map 0.003809 0.006980 0.008672 0.005729\n", + "ndcg 0.003809 0.009822 0.016091 0.006684\n", + "recall 0.003809 0.019540 0.044544 0.009604\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6c8e2d8667a64bba9e4c05337d9a934c", + "model_id": "3448560f23764ea6a4bedcc0349a529b", "version_major": 2, "version_minor": 0 }, @@ -880,7 +956,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1, global step 378: 'recall@10' reached 0.08626 (best 0.08626), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" + "Epoch 1, global step 378: 'recall@10' reached 0.02500 (best 0.02500), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n" ] }, { @@ -888,16 +964,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.011258 0.028874 0.032790 0.024324\n", - "ndcg 0.011258 0.042089 0.056579 0.030685\n", - "recall 0.011258 0.086258 0.144040 0.050166\n", + "map 0.003809 0.009115 0.010547 0.007866\n", + "ndcg 0.003809 0.012794 0.018238 0.009690\n", + "recall 0.003809 0.025004 0.047028 0.015234\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4f29f2aa51ba4cf68af42cc95a453be5", + "model_id": "754ca088536d48099db8255917dfea00", "version_major": 2, "version_minor": 0 }, @@ -912,7 +996,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 2, global step 567: 'recall@10' reached 0.12136 (best 0.12136), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" + "Epoch 2, global step 567: 'recall@10' reached 0.02534 (best 0.02534), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n" ] }, { @@ -920,16 +1004,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.015894 0.040155 0.045144 0.033278\n", - "ndcg 0.015894 0.058849 0.077380 0.041941\n", - "recall 0.015894 0.121358 0.195364 0.068543\n", + "map 0.003809 0.009250 0.010724 0.007827\n", + "ndcg 0.003809 0.012979 0.018452 0.009511\n", + "recall 0.003809 0.025335 0.047193 0.014572\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a8c62495083f4b108e679d86d496e3e9", + "model_id": "0f2fe27d36e64ee0889cdade17c286de", "version_major": 2, "version_minor": 0 }, @@ -944,7 +1036,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 3, global step 756: 'recall@10' reached 0.14106 (best 0.14106), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1\n" + "Epoch 3, global step 756: 'recall@10' was not in top 1\n" ] }, { @@ -952,16 +1044,24 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.015232 0.045004 0.050748 0.037318\n", - "ndcg 0.015232 0.067207 0.088327 0.048297\n", - "recall 0.015232 0.141060 0.225000 0.081954\n", + "map 0.003809 0.007385 0.009081 0.006154\n", + "ndcg 0.003809 0.010590 0.016888 0.007441\n", + "recall 0.003809 0.021527 0.046696 0.011426\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e40e9de8673b44e3b451ccb9bdd84699", + "model_id": "d3f48b662b444c5baa3644f1eff468c7", "version_major": 2, "version_minor": 0 }, @@ -976,7 +1076,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.15331 (best 0.15331), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.03279 (best 0.03279), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -985,9 +1085,9 @@ "output_type": "stream", "text": [ "k 1 10 20 5\n", - "map 0.015728 0.048114 0.054519 0.039854\n", - "ndcg 0.015728 0.072437 0.095845 0.052167\n", - "recall 0.015728 0.153311 0.246026 0.090066\n", + "map 0.003146 0.009925 0.012081 0.007874\n", + "ndcg 0.003146 0.015168 0.023088 0.010099\n", + "recall 0.003146 0.032787 0.064249 0.016890\n", "\n" ] } @@ -1034,16 +1134,16 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt'" + "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt'" ] }, - "execution_count": 45, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -1065,10 +1165,13 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ + "import torch\n", + "import replay\n", + "\n", "sasrec = SasRec.from_params(\n", " schema=tensor_schema,\n", " embedding_dim=EMBEDDING_DIM,\n", @@ -1079,6 +1182,11 @@ " excluded_features=None,\n", ")\n", "\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory,\n", + "])\n", + "\n", "best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n", "best_model.eval();" ] @@ -1092,9 +1200,18 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_37022/1373759190.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", "\n", @@ -1118,22 +1235,24 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d93a289ce7f74b0f8b470cd48cb6882a", + "model_id": "8c6243a9d5a64606864f5dba6caf3018", "version_major": 2, "version_minor": 0 }, @@ -1143,6 +1262,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -1167,7 +1294,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1200,32 +1327,32 @@ " \n", " 0\n", " 0\n", - " 3383\n", - " 6.106367\n", + " 2426\n", + " -5.474451\n", " \n", " \n", " 0\n", " 0\n", - " 3509\n", - " 5.956742\n", + " 2557\n", + " -5.80117\n", " \n", " \n", " 0\n", " 0\n", - " 3341\n", - " 5.944553\n", + " 466\n", + " -5.805716\n", " \n", " \n", " 0\n", " 0\n", - " 3510\n", - " 5.776233\n", + " 627\n", + " -5.844583\n", " \n", " \n", " 0\n", " 0\n", - " 3512\n", - " 5.587171\n", + " 574\n", + " -5.853208\n", " \n", " \n", " ...\n", @@ -1236,32 +1363,32 @@ " \n", " 6037\n", " 6039\n", - " 2941\n", - " 5.740072\n", + " 2775\n", + " -5.961242\n", " \n", " \n", " 6037\n", " 6039\n", - " 3049\n", - " 5.596299\n", + " 3341\n", + " -6.006827\n", " \n", " \n", " 6037\n", " 6039\n", - " 2750\n", - " 5.548656\n", + " 2708\n", + " -6.026894\n", " \n", " \n", " 6037\n", " 6039\n", - " 2968\n", - " 5.302012\n", + " 2959\n", + " -6.033692\n", " \n", " \n", " 6037\n", " 6039\n", - " 2202\n", - " 5.136523\n", + " 1618\n", + " -6.041715\n", " \n", " \n", "\n", @@ -1270,22 +1397,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 0 3383 6.106367\n", - "0 0 3509 5.956742\n", - "0 0 3341 5.944553\n", - "0 0 3510 5.776233\n", - "0 0 3512 5.587171\n", + "0 0 2426 -5.474451\n", + "0 0 2557 -5.80117\n", + "0 0 466 -5.805716\n", + "0 0 627 -5.844583\n", + "0 0 574 -5.853208\n", "... ... ... ...\n", - "6037 6039 2941 5.740072\n", - "6037 6039 3049 5.596299\n", - "6037 6039 2750 5.548656\n", - "6037 6039 2968 5.302012\n", - "6037 6039 2202 5.136523\n", + "6037 6039 2775 -5.961242\n", + "6037 6039 3341 -6.006827\n", + "6037 6039 2708 -6.026894\n", + "6037 6039 2959 -6.033692\n", + "6037 6039 1618 -6.041715\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 49, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1306,7 +1433,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -1316,7 +1443,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1329,7 +1456,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1362,37 +1489,37 @@ " \n", " \n", " MAP\n", - " 0.013912\n", - " 0.046539\n", - " 0.052342\n", - " 0.038912\n", + " 0.00265\n", + " 0.008525\n", + " 0.010557\n", + " 0.006558\n", " \n", " \n", " Precision\n", - " 0.013912\n", - " 0.014492\n", - " 0.011494\n", - " 0.017456\n", + " 0.00265\n", + " 0.002865\n", + " 0.002956\n", + " 0.002816\n", " \n", " \n", " Recall\n", - " 0.013912\n", - " 0.144916\n", - " 0.229877\n", - " 0.087281\n", + " 0.00265\n", + " 0.028652\n", + " 0.059126\n", + " 0.014078\n", " \n", " \n", "\n", "" ], "text/plain": [ - "k 1 10 20 5\n", - "MAP 0.013912 0.046539 0.052342 0.038912\n", - "Precision 0.013912 0.014492 0.011494 0.017456\n", - "Recall 0.013912 0.144916 0.229877 0.087281" + "k 1 10 20 5\n", + "MAP 0.00265 0.008525 0.010557 0.006558\n", + "Precision 0.00265 0.002865 0.002956 0.002816\n", + "Recall 0.00265 0.028652 0.059126 0.014078" ] }, - "execution_count": 52, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1410,7 +1537,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1442,33 +1569,33 @@ " \n", " \n", " 0\n", - " 2011\n", - " 3623\n", - " 6.106367\n", + " 1\n", + " 2628\n", + " -5.474451\n", " \n", " \n", " 0\n", - " 2011\n", - " 3752\n", - " 5.956742\n", + " 1\n", + " 2762\n", + " -5.80117\n", " \n", " \n", " 0\n", - " 2011\n", - " 3578\n", - " 5.944553\n", + " 1\n", + " 480\n", + " -5.805716\n", " \n", " \n", " 0\n", - " 2011\n", - " 3753\n", - " 5.776233\n", + " 1\n", + " 648\n", + " -5.844583\n", " \n", " \n", " 0\n", - " 2011\n", - " 3755\n", - " 5.587171\n", + " 1\n", + " 588\n", + " -5.853208\n", " \n", " \n", " ...\n", @@ -1478,33 +1605,33 @@ " \n", " \n", " 6037\n", - " 5727\n", - " 3157\n", - " 5.740072\n", + " 6040\n", + " 2987\n", + " -5.961242\n", " \n", " \n", " 6037\n", - " 5727\n", - " 3273\n", - " 5.596299\n", + " 6040\n", + " 3578\n", + " -6.006827\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2961\n", - " 5.548656\n", + " 6040\n", + " 2916\n", + " -6.026894\n", " \n", " \n", " 6037\n", - " 5727\n", - " 3185\n", - " 5.302012\n", + " 6040\n", + " 3176\n", + " -6.033692\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2395\n", - " 5.136523\n", + " 6040\n", + " 1784\n", + " -6.041715\n", " \n", " \n", "\n", @@ -1513,22 +1640,22 @@ ], "text/plain": [ " user_id item_id score\n", - "0 2011 3623 6.106367\n", - "0 2011 3752 5.956742\n", - "0 2011 3578 5.944553\n", - "0 2011 3753 5.776233\n", - "0 2011 3755 5.587171\n", + "0 1 2628 -5.474451\n", + "0 1 2762 -5.80117\n", + "0 1 480 -5.805716\n", + "0 1 648 -5.844583\n", + "0 1 588 -5.853208\n", "... ... ... ...\n", - "6037 5727 3157 5.740072\n", - "6037 5727 3273 5.596299\n", - "6037 5727 2961 5.548656\n", - "6037 5727 3185 5.302012\n", - "6037 5727 2395 5.136523\n", + "6037 6040 2987 -5.961242\n", + "6037 6040 3578 -6.006827\n", + "6037 6040 2916 -6.026894\n", + "6037 6040 3176 -6.033692\n", + "6037 6040 1784 -6.041715\n", "\n", "[120760 rows x 3 columns]" ] }, - "execution_count": 53, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -1536,11 +1663,18 @@ "source": [ "encoder.inverse_transform(pandas_res)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "rep", + "display_name": "new_venv", "language": "python", "name": "python3" }, @@ -1554,7 +1688,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.12.3" }, "orig_nbformat": 4 }, diff --git a/examples/15_twotower_example.ipynb b/examples/15_twotower_example.ipynb index b189eebb3..bac925a97 100644 --- a/examples/15_twotower_example.ipynb +++ b/examples/15_twotower_example.ipynb @@ -118,15 +118,15 @@ " 2\n", " \n", " \n", - " 1000192\n", + " 1000007\n", " 6040\n", - " 2019\n", + " 1961\n", " 3\n", " \n", " \n", - " 1000007\n", + " 1000192\n", " 6040\n", - " 1961\n", + " 2019\n", " 4\n", " \n", " \n", @@ -148,15 +148,15 @@ " 447\n", " \n", " \n", - " 825731\n", + " 825724\n", " 4958\n", - " 2634\n", + " 3264\n", " 448\n", " \n", " \n", - " 825724\n", + " 825731\n", " 4958\n", - " 3264\n", + " 2634\n", " 449\n", " \n", " \n", @@ -175,13 +175,13 @@ "1000138 6040 858 0\n", "1000153 6040 2384 1\n", "999873 6040 593 2\n", - "1000192 6040 2019 3\n", - "1000007 6040 1961 4\n", + "1000007 6040 1961 3\n", + "1000192 6040 2019 4\n", "... ... ... ...\n", "825793 4958 2399 446\n", "825438 4958 1407 447\n", - "825731 4958 2634 448\n", - "825724 4958 3264 449\n", + "825724 4958 3264 448\n", + "825731 4958 2634 449\n", "825603 4958 1924 450\n", "\n", "[1000209 rows x 3 columns]" @@ -211,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -263,13 +263,13 @@ " 3\n", " 3\n", " 6039\n", - " 1950\n", + " 1892\n", " \n", " \n", " 4\n", " 4\n", " 6039\n", - " 1892\n", + " 1950\n", " \n", " \n", " ...\n", @@ -293,13 +293,13 @@ " 1000206\n", " 448\n", " 4957\n", - " 2565\n", + " 3195\n", " \n", " \n", " 1000207\n", " 449\n", " 4957\n", - " 3195\n", + " 2565\n", " \n", " \n", " 1000208\n", @@ -317,13 +317,13 @@ "0 0 6039 847\n", "1 1 6039 2315\n", "2 2 6039 589\n", - "3 3 6039 1950\n", - "4 4 6039 1892\n", + "3 3 6039 1892\n", + "4 4 6039 1950\n", "... ... ... ...\n", "1000204 446 4957 2330\n", "1000205 447 4957 1384\n", - "1000206 448 4957 2565\n", - "1000207 449 4957 3195\n", + "1000206 448 4957 3195\n", + "1000207 449 4957 2565\n", "1000208 450 4957 1855\n", "\n", "[1000209 rows x 3 columns]" @@ -531,25 +531,25 @@ " 0\n", " 0\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118...\n", + " [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118...\n", " \n", " \n", " 1\n", " 1\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315...\n", + " [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117...\n", " \n", " \n", " 2\n", " 2\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358...\n", + " [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602...\n", " \n", " \n", " 3\n", " 3\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,...\n", + " [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,...\n", " \n", " \n", " 4\n", @@ -567,13 +567,13 @@ " 6035\n", " 6035\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278...\n", + " [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263...\n", " \n", " \n", " 6036\n", " 6036\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ...\n", + " [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ...\n", " \n", " \n", " 6037\n", @@ -585,13 +585,13 @@ " 6038\n", " 6038\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2...\n", + " [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3...\n", " \n", " \n", " 6039\n", " 6039\n", " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n", - " [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ...\n", + " [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ...\n", " \n", " \n", "\n", @@ -613,17 +613,17 @@ "6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n", "\n", " item_id \n", - "0 [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118... \n", - "1 [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315... \n", - "2 [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358... \n", - "3 [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,... \n", + "0 [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118... \n", + "1 [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117... \n", + "2 [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602... \n", + "3 [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,... \n", "4 [2648, 907, 896, 352, 1230, 2119, 2789, 1111, ... \n", "... ... \n", - "6035 [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278... \n", - "6036 [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ... \n", + "6035 [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263... \n", + "6036 [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ... \n", "6037 [3327, 908, 1192, 2077, 1366, 352, 1063, 1132,... \n", - "6038 [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2... \n", - "6039 [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ... \n", + "6038 [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3... \n", + "6039 [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ... \n", "\n", "[6040 rows x 3 columns]" ] @@ -725,7 +725,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -781,7 +781,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -792,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -820,7 +820,16 @@ "cell_type": "code", "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_775501/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "from replay.data.nn import ParquetModule\n", "\n", @@ -853,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -888,7 +897,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -915,26 +924,28 @@ "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", - " | Name | Type | Params | Mode \n", - "-------------------------------------------\n", - "0 | model | TwoTower | 340 K | train\n", - "-------------------------------------------\n", - "340 K Trainable params\n", + " | Name | Type | Params | Mode | FLOPs\n", + "---------------------------------------------------\n", + "0 | model | TwoTower | 352 K | train | 0 \n", + "---------------------------------------------------\n", + "352 K Trainable params\n", "0 Non-trainable params\n", - "340 K Total params\n", - "1.364 Total estimated model params size (MB)\n", + "352 K Total params\n", + "1.409 Total estimated model params size (MB)\n", "52 Modules in train mode\n", - "0 Modules in eval mode\n" + "0 Modules in eval mode\n", + "0 Total Flops\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "32da5d7c8a1a4a2abd200e4c44472b61", + "model_id": "30be06dd73e4450f8fc37ee83bbd1746", "version_major": 2, "version_minor": 0 }, @@ -945,10 +956,19 @@ "metadata": {}, "output_type": "display_data" }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "959fabb6e3ca4ed38097cd7128944bc8", + "model_id": "dd8c6b547ea84fd4abcbe132b7da42cb", "version_major": 2, "version_minor": 0 }, @@ -962,7 +982,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3462811aac8b4003aaa61c4cf9fcc4a0", + "model_id": "301d175de5f24790b75dc4943e83e687", "version_major": 2, "version_minor": 0 }, @@ -977,24 +997,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 0, global step 189: 'recall@10' reached 0.03526 (best 0.03526), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n" + "Epoch 0, global step 189: 'recall@10' reached 0.03958 (best 0.03958), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.004801 0.009509 0.011596 0.013299\n", - "ndcg 0.004801 0.011889 0.017026 0.023337\n", - "recall 0.004801 0.019205 0.035265 0.060430\n", + "k 1 10 20 5\n", + "map 0.003643 0.011629 0.013674 0.009331\n", + "ndcg 0.003643 0.018055 0.025700 0.012422\n", + "recall 0.003643 0.039576 0.070210 0.022024\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f8dacd44d48b486c9d1217c937b2584d", + "model_id": "01468999f3674e389b8416c5e8f43a46", "version_major": 2, "version_minor": 0 }, @@ -1009,24 +1037,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 1, global step 378: 'recall@10' reached 0.09603 (best 0.09603), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n" + "Epoch 1, global step 378: 'recall@10' reached 0.10184 (best 0.10184), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.012086 0.026440 0.031724 0.036210\n", - "ndcg 0.012086 0.033518 0.046537 0.063113\n", - "recall 0.012086 0.055298 0.096026 0.162086\n", + "k 1 10 20 5\n", + "map 0.010763 0.033238 0.037375 0.028159\n", + "ndcg 0.010763 0.049178 0.064472 0.036708\n", + "recall 0.010763 0.101838 0.162775 0.062924\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "421acec60c4347ae98e0df5157d69ead", + "model_id": "148c7a181eda40d98249e05b50e4bfa1", "version_major": 2, "version_minor": 0 }, @@ -1041,24 +1077,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 2, global step 567: 'recall@10' reached 0.13576 (best 0.13576), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n" + "Epoch 2, global step 567: 'recall@10' reached 0.13280 (best 0.13280), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.012252 0.034048 0.041749 0.047579\n", - "ndcg 0.012252 0.044626 0.063492 0.084913\n", - "recall 0.012252 0.076987 0.135762 0.220861\n", + "k 1 10 20 5\n", + "map 0.012916 0.041789 0.047532 0.034280\n", + "ndcg 0.012916 0.062866 0.084155 0.044635\n", + "recall 0.012916 0.132803 0.217751 0.076337\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "55449d2ab210406b89fea8f15deb28f9", + "model_id": "fced5bd80eb342909431d34449b9ca88", "version_major": 2, "version_minor": 0 }, @@ -1073,24 +1117,32 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 3, global step 756: 'recall@10' reached 0.15795 (best 0.15795), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n" + "Epoch 3, global step 756: 'recall@10' reached 0.14920 (best 0.14920), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.014735 0.040450 0.049221 0.055306\n", - "ndcg 0.014735 0.052765 0.074364 0.096740\n", - "recall 0.014735 0.090397 0.157947 0.246854\n", + "k 1 10 20 5\n", + "map 0.01391 0.046216 0.052732 0.037647\n", + "ndcg 0.01391 0.070051 0.093930 0.049280\n", + "recall 0.01391 0.149197 0.243915 0.084948\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] + }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5bbb1d30778b42a6ace5f22c50fdec88", + "model_id": "3a21e049ada84d73b3aef467158ee42d", "version_major": 2, "version_minor": 0 }, @@ -1105,7 +1157,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Epoch 4, global step 945: 'recall@10' reached 0.17053 (best 0.17053), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n", + "Epoch 4, global step 945: 'recall@10' reached 0.16079 (best 0.16079), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n", "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, @@ -1113,10 +1165,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "k 1 5 10 20\n", - "map 0.019205 0.045563 0.054963 0.061259\n", - "ndcg 0.019205 0.058695 0.081677 0.104973\n", - "recall 0.019205 0.099007 0.170530 0.263411\n", + "k 1 10 20 5\n", + "map 0.013247 0.047682 0.054643 0.038307\n", + "ndcg 0.013247 0.073782 0.099453 0.050853\n", + "recall 0.013247 0.160788 0.262957 0.089419\n", "\n" ] } @@ -1169,7 +1221,7 @@ { "data": { "text/plain": [ - "'/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'" + "'/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'" ] }, "execution_count": 18, @@ -1194,10 +1246,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ + "import torch\n", + "import replay\n", + "\n", "twotower = TwoTower.from_params(\n", " schema=tensor_schema,\n", " embedding_dim=EMBEDDING_DIM,\n", @@ -1212,6 +1267,11 @@ " )\n", ")\n", "\n", + "torch.serialization.add_safe_globals([\n", + " replay.nn.lightning.optimizer.OptimizerFactory,\n", + " replay.nn.lightning.scheduler.LRSchedulerFactory,\n", + "])\n", + "\n", "best_model = LightningModule.load_from_checkpoint(best_model_path, model=twotower)\n", "best_model.eval();" ] @@ -1227,7 +1287,16 @@ "cell_type": "code", "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_775501/2227877614.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n", + " parquet_module = ParquetModule(\n" + ] + } + ], "source": [ "inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n", "\n", @@ -1258,15 +1327,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "61636a98f63042a4bfce791a7dd2d01b", + "model_id": "527412a8a14545399dfad9ab127e3db3", "version_major": 2, "version_minor": 0 }, @@ -1276,6 +1347,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -1333,32 +1412,32 @@ " \n", " 0\n", " 0\n", - " 3383\n", - " 41.500851\n", + " 773\n", + " 26.849968\n", " \n", " \n", " 0\n", " 0\n", - " 3509\n", - " 40.856842\n", + " 360\n", + " 26.444904\n", " \n", " \n", " 0\n", " 0\n", - " 3475\n", - " 40.794147\n", + " 1526\n", + " 26.433756\n", " \n", " \n", " 0\n", " 0\n", - " 3512\n", - " 40.674751\n", + " 2618\n", + " 26.282482\n", " \n", " \n", " 0\n", " 0\n", - " 3341\n", - " 40.665249\n", + " 1838\n", + " 26.202333\n", " \n", " \n", " ...\n", @@ -1369,32 +1448,32 @@ " \n", " 6037\n", " 6039\n", - " 3550\n", - " 26.971941\n", + " 1680\n", + " 26.977921\n", " \n", " \n", " 6037\n", " 6039\n", - " 1490\n", - " 26.874283\n", + " 1375\n", + " 26.976624\n", " \n", " \n", " 6037\n", " 6039\n", - " 2489\n", - " 26.785303\n", + " 2125\n", + " 26.967751\n", " \n", " \n", " 6037\n", " 6039\n", - " 2634\n", - " 26.725142\n", + " 1439\n", + " 26.963537\n", " \n", " \n", " 6037\n", " 6039\n", - " 2203\n", - " 26.696005\n", + " 623\n", + " 26.949759\n", " \n", " \n", "\n", @@ -1403,17 +1482,17 @@ ], "text/plain": [ " user_id item_id score\n", - "0 0 3383 41.500851\n", - "0 0 3509 40.856842\n", - "0 0 3475 40.794147\n", - "0 0 3512 40.674751\n", - "0 0 3341 40.665249\n", + "0 0 773 26.849968\n", + "0 0 360 26.444904\n", + "0 0 1526 26.433756\n", + "0 0 2618 26.282482\n", + "0 0 1838 26.202333\n", "... ... ... ...\n", - "6037 6039 3550 26.971941\n", - "6037 6039 1490 26.874283\n", - "6037 6039 2489 26.785303\n", - "6037 6039 2634 26.725142\n", - "6037 6039 2203 26.696005\n", + "6037 6039 1680 26.977921\n", + "6037 6039 1375 26.976624\n", + "6037 6039 2125 26.967751\n", + "6037 6039 1439 26.963537\n", + "6037 6039 623 26.949759\n", "\n", "[120760 rows x 3 columns]" ] @@ -1487,42 +1566,42 @@ " \n", " k\n", " 1\n", - " 5\n", " 10\n", " 20\n", + " 5\n", " \n", " \n", " \n", " \n", " MAP\n", - " 0.016727\n", - " 0.041714\n", - " 0.050679\n", - " 0.056809\n", + " 0.016893\n", + " 0.050191\n", + " 0.056355\n", + " 0.041200\n", " \n", " \n", " Precision\n", - " 0.016727\n", - " 0.018152\n", - " 0.015982\n", - " 0.012421\n", + " 0.016893\n", + " 0.015916\n", + " 0.012479\n", + " 0.018085\n", " \n", " \n", " Recall\n", - " 0.016727\n", - " 0.090759\n", - " 0.159821\n", - " 0.248427\n", + " 0.016893\n", + " 0.159159\n", + " 0.249586\n", + " 0.090427\n", " \n", " \n", "\n", "" ], "text/plain": [ - "k 1 5 10 20\n", - "MAP 0.016727 0.041714 0.050679 0.056809\n", - "Precision 0.016727 0.018152 0.015982 0.012421\n", - "Recall 0.016727 0.090759 0.159821 0.248427" + "k 1 10 20 5\n", + "MAP 0.016893 0.050191 0.056355 0.041200\n", + "Precision 0.016893 0.015916 0.012479 0.018085\n", + "Recall 0.016893 0.159159 0.249586 0.090427" ] }, "execution_count": 25, @@ -1575,33 +1654,33 @@ " \n", " \n", " 0\n", - " 2011\n", - " 3623\n", - " 41.500851\n", + " 1\n", + " 783\n", + " 26.849968\n", " \n", " \n", " 0\n", - " 2011\n", - " 3752\n", - " 40.856842\n", + " 1\n", + " 364\n", + " 26.444904\n", " \n", " \n", " 0\n", - " 2011\n", - " 3717\n", - " 40.794147\n", + " 1\n", + " 1566\n", + " 26.433756\n", " \n", " \n", " 0\n", - " 2011\n", - " 3755\n", - " 40.674751\n", + " 1\n", + " 2687\n", + " 26.282482\n", " \n", " \n", " 0\n", - " 2011\n", - " 3578\n", - " 40.665249\n", + " 1\n", + " 1907\n", + " 26.202333\n", " \n", " \n", " ...\n", @@ -1611,33 +1690,33 @@ " \n", " \n", " 6037\n", - " 5727\n", - " 3793\n", - " 26.971941\n", + " 6040\n", + " 1729\n", + " 26.977921\n", " \n", " \n", " 6037\n", - " 5727\n", - " 1623\n", - " 26.874283\n", + " 6040\n", + " 1396\n", + " 26.976624\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2693\n", - " 26.785303\n", + " 6040\n", + " 2194\n", + " 26.967751\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2841\n", - " 26.725142\n", + " 6040\n", + " 1466\n", + " 26.963537\n", " \n", " \n", " 6037\n", - " 5727\n", - " 2396\n", - " 26.696005\n", + " 6040\n", + " 628\n", + " 26.949759\n", " \n", " \n", "\n", @@ -1646,17 +1725,17 @@ ], "text/plain": [ " user_id item_id score\n", - "0 2011 3623 41.500851\n", - "0 2011 3752 40.856842\n", - "0 2011 3717 40.794147\n", - "0 2011 3755 40.674751\n", - "0 2011 3578 40.665249\n", + "0 1 783 26.849968\n", + "0 1 364 26.444904\n", + "0 1 1566 26.433756\n", + "0 1 2687 26.282482\n", + "0 1 1907 26.202333\n", "... ... ... ...\n", - "6037 5727 3793 26.971941\n", - "6037 5727 1623 26.874283\n", - "6037 5727 2693 26.785303\n", - "6037 5727 2841 26.725142\n", - "6037 5727 2396 26.696005\n", + "6037 6040 1729 26.977921\n", + "6037 6040 1396 26.976624\n", + "6037 6040 2194 26.967751\n", + "6037 6040 1466 26.963537\n", + "6037 6040 628 26.949759\n", "\n", "[120760 rows x 3 columns]" ] @@ -1669,11 +1748,18 @@ "source": [ "encoder.inverse_transform(pandas_res)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "rep", + "display_name": "new_venv", "language": "python", "name": "python3" }, @@ -1687,7 +1773,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.12.3" }, "orig_nbformat": 4 }, diff --git a/replay/nn/lightning/module.py b/replay/nn/lightning/module.py index ba6d237b6..ad5f837cc 100644 --- a/replay/nn/lightning/module.py +++ b/replay/nn/lightning/module.py @@ -40,7 +40,7 @@ def __init__( self._lr_scheduler_factory = lr_scheduler_factory self.candidates_to_score = None - def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]: + def forward(self, batch: dict, return_info: bool = False) -> Union[TrainOutput, InferenceOutput]: """ Implementation of the forward function. @@ -57,12 +57,21 @@ def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]: batch["candidates_to_score"] = self.candidates_to_score # select only args for model.forward modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters} - return self.model(**modified_batch) + return self.model(**modified_batch, return_info=return_info) def training_step(self, batch: dict) -> torch.Tensor: - model_output: TrainOutput = self(batch) - loss = model_output["loss"] + model_output: TrainOutput = self(batch, return_info=True) + loss, info = model_output["loss"], model_output.get("info", None) lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate + if info is not None: + assert isinstance(info, dict) + self.log_dict( + dictionary=info, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True) self.log( "train_loss", diff --git a/replay/nn/loss/__init__.py b/replay/nn/loss/__init__.py index 2116cac83..eed615375 100644 --- a/replay/nn/loss/__init__.py +++ b/replay/nn/loss/__init__.py @@ -1,6 +1,7 @@ -from .base import LossProto +from .base import LossInfo, LossOutput, LossProto from .bce import BCE, BCESampled from .ce import CE, CESampled, CESampledWeighted, CEWeighted +from .composed import ComposedLoss from .login_ce import LogInCE, LogInCESampled from .logout_ce import LogOutCE, LogOutCEWeighted @@ -13,10 +14,13 @@ "CESampled", "CESampledWeighted", "CEWeighted", + "ComposedLoss", "LogInCE", "LogInCESampled", "LogOutCE", "LogOutCESampled", "LogOutCEWeighted", + "LossInfo", + "LossOutput", "LossProto", ] diff --git a/replay/nn/loss/base.py b/replay/nn/loss/base.py index ef9bd88bb..996fcde9c 100644 --- a/replay/nn/loss/base.py +++ b/replay/nn/loss/base.py @@ -4,17 +4,23 @@ from replay.data.nn import TensorMap +LossInfo = dict[str, torch.Tensor | float] +LossOutput = tuple[torch.Tensor, None] | tuple[torch.Tensor, LossInfo] +LogitsCallback = Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor] + class LossProto(Protocol): """Class-protocol for working with losses inside models""" + loss_name: str + @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ... + ) -> LogitsCallback: ... @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: ... + def logits_callback(self, func: LogitsCallback) -> None: ... def forward( self, @@ -24,7 +30,19 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: ... + return_info: bool = False, + ) -> LossOutput: ... + + def __call__( + self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + return_info: bool = False, + ) -> LossOutput: ... class SampledLossOutput(TypedDict): @@ -39,6 +57,8 @@ class SampledLossOutput(TypedDict): class SampledLossBase(torch.nn.Module): """The base class for calculating sampled losses""" + _logits_callback: LogitsCallback | None + @property def logits_callback( self, diff --git a/replay/nn/loss/bce.py b/replay/nn/loss/bce.py index 2f5901d76..62e38edc3 100644 --- a/replay/nn/loss/bce.py +++ b/replay/nn/loss/bce.py @@ -1,10 +1,8 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class BCE(torch.nn.Module): @@ -16,19 +14,20 @@ class BCE(torch.nn.Module): (there are several labels for each position in the sequence). """ - def __init__(self, **kwargs): + def __init__(self, loss_name: str = "BCELoss", **kwargs): """ To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``. You can pass all other parameters for initializing the object via kwargs. """ super().__init__() self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -46,7 +45,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -57,7 +56,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -92,7 +92,11 @@ def forward( ) loss = self._loss(logits, bce_labels) / logits.size(0) - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) class BCESampled(SampledLossBase): @@ -109,7 +113,8 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, - ): + loss_name: str = "BCESampledLoss", + ) -> None: """ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating. Default: ``1e-6``. @@ -125,12 +130,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -148,7 +154,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -159,7 +165,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -213,4 +220,7 @@ def forward( loss = -(positive_loss + negative_loss) loss /= positive_logits.size(0) - return loss + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/ce.py b/replay/nn/loss/ce.py index 07f6e7b17..152afc563 100644 --- a/replay/nn/loss/ce.py +++ b/replay/nn/loss/ce.py @@ -1,10 +1,8 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class CE(torch.nn.Module): @@ -13,19 +11,20 @@ class CE(torch.nn.Module): Calculates loss over all items catalog. """ - def __init__(self, **kwargs): + def __init__(self, loss_name: str = "CELoss", **kwargs): """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used. You can pass all parameters for initializing the object via kwargs. """ super().__init__() + self.loss_name: str = loss_name self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -43,7 +42,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -54,7 +53,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -78,7 +78,11 @@ def forward( # [batch_size, seq_len, 1] -> [batch_size * seq_len] labels_flat: torch.LongTensor = labels.view(-1) loss = self._loss(logits_flat, labels_flat) - return loss + + if return_info: + return (loss, {"CE": loss.detach()}) + else: + return (loss, None) class CEWeighted(CE): @@ -92,11 +96,14 @@ class CEWeighted(CE): which is fed into the model. """ + loss_name: str = "CEWeightedLoss" + def __init__( self, feature_name: str, + loss_name: str = "CEWEightedLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``. You can pass all other parameters for initializing the object via kwargs. @@ -105,7 +112,8 @@ def __init__( The tensor is expected to contain sample weights. """ super().__init__() - self.feature_name = feature_name + self.loss_name: str = loss_name + self.feature_name: str = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) def forward( @@ -116,7 +124,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask) :param feature_tensors: a dictionary of tensors from dataloader. @@ -140,7 +149,11 @@ def forward( ) sample_weight = feature_tensors[self.feature_name] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) class CESampled(SampledLossBase): @@ -155,8 +168,9 @@ class CESampled(SampledLossBase): def __init__( self, negative_labels_ignore_index: int = -100, + loss_name: str = "CESampledLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used. You can pass all parameters for initializing the object via kwargs. @@ -168,14 +182,15 @@ def __init__( Default: ``-100``. """ super().__init__() + self.loss_name: str = loss_name self.negative_labels_ignore_index = negative_labels_ignore_index self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -193,7 +208,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -204,7 +219,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -246,7 +262,11 @@ def forward( target = torch.zeros(positive_logits.size(0), dtype=torch.long, device=logits.device) # [masked_batch_size] - loss for all recommendation points loss = self._loss(logits, target) - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) class CESampledWeighted(CESampled): @@ -267,8 +287,9 @@ def __init__( self, feature_name: str, negative_labels_ignore_index: int = -100, + loss_name: str = "CESampledWeighedLoss", **kwargs, - ): + ) -> None: """ To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``. You can pass all other parameters for initializing the object via kwargs. @@ -282,7 +303,9 @@ def __init__( Default: ``-100``. """ super().__init__(negative_labels_ignore_index=negative_labels_ignore_index) - self.feature_name = feature_name + + self.loss_name: str = loss_name + self.feature_name: str = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) def forward( @@ -293,7 +316,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, negative_labels, target_padding_mask) :param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``. @@ -314,4 +338,8 @@ def forward( sample_weight = feature_tensors[self.feature_name] sample_weight = sample_weight[target_padding_mask] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py new file mode 100644 index 000000000..c52ebffdd --- /dev/null +++ b/replay/nn/loss/composed.py @@ -0,0 +1,159 @@ +import warnings +from typing import Iterable, Self, cast + +import torch + +from replay.data.nn import TensorMap + +from .base import LogitsCallback, LossInfo, LossOutput, LossProto + +Weights = dict[str, torch.Tensor | float] +Losses = Iterable[torch.nn.Module] | dict[str, torch.nn.Module] | torch.nn.ModuleDict + + +class ComposedLoss(torch.nn.Module): + def __init__( + self: Self, + losses: Losses, + weights: Weights | None = None, + loss_name: str = "ComposedLoss", + ) -> None: + super().__init__() + + self.losses: torch.nn.ModuleDict = self._handle_losses(losses) + self.weights: Weights = self._handle_weights(weights) + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name + + def _handle_losses(self: Self, losses: Losses) -> torch.nn.ModuleDict: + if not isinstance(losses, torch.nn.ModuleDict): + if isinstance(losses, dict): + for loss in cast(dict, losses.values()): + if not isinstance(loss, torch.nn.Module): + msg: str = f"Unsupported type of loss. Must be `Module`. Got: {type(loss)=}." + raise TypeError(msg) + losses_dict: dict[str, torch.nn.Module] = cast(dict[str, torch.nn.Module], losses) + else: + losses_dict: dict[str, torch.nn.Module] = {} + for loss in iter(losses): + casted: LossProto = cast(LossProto, loss) + name: str = casted.loss_name + if name in losses_dict: + msg: str = f"Loss names must be unique. Got {name} twice." + raise KeyError(name) + losses_dict[name] = loss + losses = torch.nn.ModuleDict(losses_dict) + + if not isinstance(losses, torch.nn.ModuleDict): + msg: str = f"Unsupported type of `losses`. Must be `dict` or `ModuleDict`. Got {type(losses)=}." + raise TypeError(msg) + + if len(losses) < 1: + msg: str = "Empty losses are not supported." + raise ValueError(msg) + + return cast(torch.nn.ModuleDict, losses) + + def _handle_weights(self: Self, weights: Weights | None) -> Weights: + if weights is None: + weights = {} + elif not isinstance(weights, dict): + msg: str = f"Unsupported type of `weights`. Must be `dict`. Got: {type(weights)=}." + raise TypeError(msg) + + for name, weight in cast(dict, weights): + if name not in self.losses: + msg: str = f"Unknown name of weight: {name}." + warnings.warn(msg, stacklevel=2) + continue + if isinstance(weight, float): + continue + elif isinstance(weight, torch.Tensor): + assert torch.is_tensor(weight) + if torch.numel(weight) > 1: + msg: str = f"Too many values in weight: {torch.numel(weight)=}." + raise ValueError(msg) + continue + else: + msg: str = f"Unsupported type of weight value. Must be `float` or `Tensor`. Got: {type(weight)=}." + raise TypeError(msg) + + return cast(Weights, weights) + + @property + def logits_callback(self: Self) -> LogitsCallback: + if self._logits_callback is None: + msg: str = "No `logits_callback` provided" + raise NotImplementedError(msg) + return self._logits_callback + + @logits_callback.setter + def logits_callback(self: Self, func: LogitsCallback) -> None: + self._logits_callback = func + + for loss in self.losses.values(): + casted = cast(LossProto, loss) + casted.logits_callback = func + + def _compute_raw_losses( + self: Self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + ) -> dict[str, torch.Tensor]: + raw_losses: dict[str, torch.Tensor] = {} + for name, loss in self.losses.items(): + value, _ = loss( + model_embeddings, + feature_tensors, + positive_labels, + negative_labels, + padding_mask, + target_padding_mask, + ) + raw_losses[name] = value + return raw_losses + + def _apply_weights(self: Self, raw_losses: dict[str, torch.Tensor]) -> torch.Tensor: + losses_list: list[torch.Tensor] = [] + + for name, value in raw_losses.items(): + weight = self.weights.get(name, 1.0) + losses_list.append(weight * value[None]) + + losses: torch.Tensor = torch.cat(losses_list) + return torch.sum(losses) + + def _detach_dict(self: Self, raw_losses: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return {name: value.detach() for name, value in raw_losses.items()} + + def forward( + self: Self, + model_embeddings: torch.Tensor, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + negative_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + return_info: bool = False, + ) -> LossOutput: + raw_losses: dict[str, torch.Tensor] = self._compute_raw_losses( + model_embeddings=model_embeddings, + feature_tensors=feature_tensors, + positive_labels=positive_labels, + negative_labels=negative_labels, + padding_mask=padding_mask, + target_padding_mask=target_padding_mask, + ) + + loss: torch.Tensor = self._apply_weights(raw_losses) + + if return_info: + base_info: dict[str, torch.Tensor] = self._detach_dict(raw_losses) + info: LossInfo = {self.loss_name: loss.detach(), **base_info} + return (loss, info) + else: + return (loss, None) diff --git a/replay/nn/loss/login_ce.py b/replay/nn/loss/login_ce.py index c6702c229..cf78312eb 100644 --- a/replay/nn/loss/login_ce.py +++ b/replay/nn/loss/login_ce.py @@ -1,10 +1,10 @@ -from typing import Callable, Optional, TypedDict +from typing import TypedDict import torch from replay.data.nn import TensorMap -from .base import SampledLossBase, mask_negative_logits +from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits class LogInCESampledOutput(TypedDict): @@ -121,7 +121,8 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, - ): + loss_name: str = "LogInCELoss", + ) -> None: """ :param cardinality: number of unique items in vocabulary (catalog). The specified cardinality value must not take into account the padding value. @@ -140,12 +141,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -163,7 +165,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -174,7 +176,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -235,7 +238,12 @@ def forward( -self.clamp_border, self.clamp_border, ) - return loss.mean() + loss = loss.mean() + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) class LogInCESampled(LogInCEBase): @@ -260,6 +268,7 @@ def __init__( log_epsilon: float = 1e-6, clamp_border: float = 100.0, negative_labels_ignore_index: int = -100, + loss_name: str = "LogInCESampledLoss", ): """ :param log_epsilon: correction to avoid zero in the logarithm during loss calculating. @@ -276,12 +285,13 @@ def __init__( self.log_epsilon = log_epsilon self.clamp_border = clamp_border self.negative_labels_ignore_index = negative_labels_ignore_index - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -299,7 +309,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -310,7 +320,8 @@ def forward( negative_labels: torch.LongTensor, padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, negative_labels, target_padding_mask) @@ -370,4 +381,9 @@ def forward( -self.clamp_border, self.clamp_border, ) - return loss.mean() + loss = loss.mean() + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/loss/logout_ce.py b/replay/nn/loss/logout_ce.py index d0d5bba32..9299af070 100644 --- a/replay/nn/loss/logout_ce.py +++ b/replay/nn/loss/logout_ce.py @@ -1,10 +1,8 @@ -from typing import Callable, Optional - import torch from replay.data.nn import TensorMap -from .base import mask_negative_logits +from .base import LogitsCallback, LossOutput, mask_negative_logits class LogOutCE(torch.nn.Module): @@ -28,6 +26,7 @@ def __init__( self, cardinality: int, negative_labels_ignore_index: int = -100, + loss_name: str = "LogOutCELoss", **kwargs, ): """ @@ -46,12 +45,13 @@ def __init__( self.cardinality = cardinality self.negative_labels_ignore_index = negative_labels_ignore_index self._loss = torch.nn.CrossEntropyLoss(**kwargs) - self._logits_callback = None + self._logits_callback: LogitsCallback | None = None + self.loss_name: str = loss_name @property def logits_callback( self, - ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: + ) -> LogitsCallback: """ Property for calling a function for the logits computation.\n @@ -69,7 +69,7 @@ def logits_callback( return self._logits_callback @logits_callback.setter - def logits_callback(self, func: Optional[Callable]) -> None: + def logits_callback(self, func: LogitsCallback) -> None: self._logits_callback = func def forward( @@ -80,7 +80,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -144,7 +145,11 @@ def forward( target = torch.zeros(logits.size(0), dtype=torch.long, device=positive_labels.device) # [masked_batch_size] - loss for all recommendation points loss = self._loss(logits, target) - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) class LogOutCEWeighted(LogOutCE): @@ -174,6 +179,7 @@ def __init__( cardinality: int, feature_name: str, negative_labels_ignore_index: int = -100, + loss_name: str = "LogOutCEWeightedLoss", **kwargs, ): """ @@ -196,6 +202,7 @@ def __init__( ) self.feature_name = feature_name self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs) + self.loss_name: str = loss_name def forward( self, @@ -205,7 +212,8 @@ def forward( negative_labels: torch.LongTensor, # noqa: ARG002 padding_mask: torch.BoolTensor, # noqa: ARG002 target_padding_mask: torch.BoolTensor, - ) -> torch.Tensor: + return_info: bool = False, + ) -> LossOutput: """ forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask) **Note**: At forward pass, the whole catalog of items is used as negatives. @@ -227,4 +235,8 @@ def forward( sample_weight = feature_tensors[self.feature_name] sample_weight = sample_weight[target_padding_mask] loss = (loss * sample_weight).mean() - return loss + + if return_info: + return (loss, {self.loss_name: loss.detach()}) + else: + return (loss, None) diff --git a/replay/nn/output.py b/replay/nn/output.py index 868a6e2cd..e378f3f0d 100644 --- a/replay/nn/output.py +++ b/replay/nn/output.py @@ -3,6 +3,8 @@ import torch from typing_extensions import NotRequired +from .loss import LossInfo + class TrainOutput(TypedDict): """ @@ -17,6 +19,7 @@ class TrainOutput(TypedDict): """ loss: torch.Tensor + info: NotRequired[LossInfo | None] hidden_states: NotRequired[tuple[torch.Tensor, ...]] diff --git a/replay/nn/sequential/sasrec/agg.py b/replay/nn/sequential/sasrec/agg.py index bc2f42d5c..73e31556b 100644 --- a/replay/nn/sequential/sasrec/agg.py +++ b/replay/nn/sequential/sasrec/agg.py @@ -43,9 +43,9 @@ def forward(self, feature_tensors: TensorMap) -> torch.Tensor: seqs: torch.Tensor = self.embedding_aggregator(feature_tensors) assert seqs.dim() == 3 batch_size, seq_len, embedding_dim = seqs.size() - assert ( - seq_len <= self.pe.num_embeddings - ), f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}" + assert seq_len <= self.pe.num_embeddings, ( + f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}" + ) seqs *= embedding_dim**0.5 seqs += self.pe.weight[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/replay/nn/sequential/sasrec/model.py b/replay/nn/sequential/sasrec/model.py index ed2539dd2..ea4ff2d9e 100644 --- a/replay/nn/sequential/sasrec/model.py +++ b/replay/nn/sequential/sasrec/model.py @@ -270,23 +270,26 @@ def forward_train( positive_labels: torch.LongTensor, negative_labels: torch.LongTensor, target_padding_mask: torch.BoolTensor, + return_info: bool = False, ) -> TrainOutput: hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask) assert hidden_states.dim() == 3 - loss: torch.Tensor = self.loss( + loss, info = self.loss( model_embeddings=hidden_states, feature_tensors=feature_tensors, positive_labels=positive_labels, negative_labels=negative_labels, padding_mask=padding_mask, target_padding_mask=target_padding_mask, + return_info=return_info, ) - return { - "loss": loss, - "hidden_states": (hidden_states,), - } + return TrainOutput( + loss=loss, + info=info, + hidden_states=(hidden_states,), + ) def forward_inference( self, @@ -313,6 +316,7 @@ def forward( positive_labels: Optional[torch.LongTensor] = None, negative_labels: Optional[torch.LongTensor] = None, target_padding_mask: Optional[torch.BoolTensor] = None, + return_info: bool = False, ) -> Union[TrainOutput, InferenceOutput]: """ :param feature_tensors: a dictionary of tensors to generate embeddings. @@ -358,6 +362,7 @@ def forward( positive_labels=positive_labels, negative_labels=negative_labels, target_padding_mask=target_padding_mask, + return_info=return_info, ) all( diff --git a/replay/nn/sequential/twotower/model.py b/replay/nn/sequential/twotower/model.py index 6016ae94f..eaa4874cf 100644 --- a/replay/nn/sequential/twotower/model.py +++ b/replay/nn/sequential/twotower/model.py @@ -542,6 +542,7 @@ def forward_train( positive_labels: torch.LongTensor, negative_labels: torch.LongTensor, target_padding_mask: torch.BoolTensor, + return_info: bool = False, ) -> TrainOutput: hidden_states = () query_hidden_states: torch.Tensor = self.body.query_tower( @@ -559,17 +560,19 @@ def forward_train( assert query_hidden_states.dim() == 3 hidden_states += (query_hidden_states,) - loss: torch.Tensor = self.loss( + loss, info = self.loss( model_embeddings=query_hidden_states, feature_tensors=feature_tensors, positive_labels=positive_labels, negative_labels=negative_labels, padding_mask=padding_mask, target_padding_mask=target_padding_mask, + return_info=return_info, ) return TrainOutput( loss=loss, + info=info, hidden_states=hidden_states, ) @@ -612,6 +615,7 @@ def forward( positive_labels: Optional[torch.LongTensor] = None, negative_labels: Optional[torch.LongTensor] = None, target_padding_mask: Optional[torch.BoolTensor] = None, + return_info: bool = False, ) -> Union[TrainOutput, InferenceOutput]: """ :param feature_tensors: a dictionary of tensors to generate embeddings. @@ -657,6 +661,7 @@ def forward( positive_labels=positive_labels, negative_labels=negative_labels, target_padding_mask=target_padding_mask, + return_info=return_info, ) all(