diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb
index 14ab84808..46de00ac6 100644
--- a/examples/09_sasrec_example.ipynb
+++ b/examples/09_sasrec_example.ipynb
@@ -1,5 +1,15 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
{
"attachments": {},
"cell_type": "markdown",
@@ -10,7 +20,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {},
"outputs": [
{
@@ -19,9 +29,21 @@
"text": [
"Seed set to 42\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "42"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
+ "from typing import Optional\n",
+ "\n",
"import lightning as L\n",
"import pandas as pd\n",
"\n",
@@ -46,7 +68,7 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -56,7 +78,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -105,15 +127,15 @@
"
2 | \n",
" \n",
" \n",
- " | 1000192 | \n",
+ " 1000007 | \n",
" 6040 | \n",
- " 2019 | \n",
+ " 1961 | \n",
" 3 | \n",
"
\n",
" \n",
- " | 1000007 | \n",
+ " 1000192 | \n",
" 6040 | \n",
- " 1961 | \n",
+ " 2019 | \n",
" 4 | \n",
"
\n",
" \n",
@@ -135,15 +157,15 @@
" | 447 | \n",
"
\n",
" \n",
- " | 825731 | \n",
+ " 825724 | \n",
" 4958 | \n",
- " 2634 | \n",
+ " 3264 | \n",
" 448 | \n",
"
\n",
" \n",
- " | 825724 | \n",
+ " 825731 | \n",
" 4958 | \n",
- " 3264 | \n",
+ " 2634 | \n",
" 449 | \n",
"
\n",
" \n",
@@ -162,19 +184,19 @@
"1000138 6040 858 0\n",
"1000153 6040 2384 1\n",
"999873 6040 593 2\n",
- "1000192 6040 2019 3\n",
- "1000007 6040 1961 4\n",
+ "1000007 6040 1961 3\n",
+ "1000192 6040 2019 4\n",
"... ... ... ...\n",
"825793 4958 2399 446\n",
"825438 4958 1407 447\n",
- "825731 4958 2634 448\n",
- "825724 4958 3264 449\n",
+ "825724 4958 3264 448\n",
+ "825731 4958 2634 449\n",
"825603 4958 1924 450\n",
"\n",
"[1000209 rows x 3 columns]"
]
},
- "execution_count": 30,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -196,7 +218,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -228,32 +250,32 @@
"
\n",
" \n",
" | 0 | \n",
- " 32 | \n",
- " 0 | \n",
+ " 12 | \n",
+ " 2011 | \n",
" 0 | \n",
"
\n",
" \n",
" | 1 | \n",
- " 10 | \n",
- " 1 | \n",
+ " 68 | \n",
+ " 4078 | \n",
" 0 | \n",
"
\n",
" \n",
" | 2 | \n",
- " 12 | \n",
- " 2 | \n",
+ " 67 | \n",
+ " 4123 | \n",
" 0 | \n",
"
\n",
" \n",
" | 3 | \n",
- " 339 | \n",
- " 3 | \n",
+ " 12 | \n",
+ " 983 | \n",
" 0 | \n",
"
\n",
" \n",
" | 4 | \n",
- " 144 | \n",
- " 4 | \n",
+ " 140 | \n",
+ " 2270 | \n",
" 0 | \n",
"
\n",
" \n",
@@ -264,32 +286,32 @@
"
\n",
" \n",
" | 1000204 | \n",
- " 281 | \n",
- " 796 | \n",
+ " 14 | \n",
+ " 855 | \n",
" 3705 | \n",
"
\n",
" \n",
" | 1000205 | \n",
- " 209 | \n",
- " 1297 | \n",
+ " 90 | \n",
+ " 1700 | \n",
" 3705 | \n",
"
\n",
" \n",
" | 1000206 | \n",
- " 748 | \n",
- " 1883 | \n",
+ " 70 | \n",
+ " 936 | \n",
" 3705 | \n",
"
\n",
" \n",
" | 1000207 | \n",
- " 71 | \n",
- " 4449 | \n",
+ " 25 | \n",
+ " 360 | \n",
" 3705 | \n",
"
\n",
" \n",
" | 1000208 | \n",
- " 287 | \n",
- " 2473 | \n",
+ " 380 | \n",
+ " 1388 | \n",
" 3705 | \n",
"
\n",
" \n",
@@ -299,35 +321,36 @@
],
"text/plain": [
" timestamp user_id item_id\n",
- "0 32 0 0\n",
- "1 10 1 0\n",
- "2 12 2 0\n",
- "3 339 3 0\n",
- "4 144 4 0\n",
+ "0 12 2011 0\n",
+ "1 68 4078 0\n",
+ "2 67 4123 0\n",
+ "3 12 983 0\n",
+ "4 140 2270 0\n",
"... ... ... ...\n",
- "1000204 281 796 3705\n",
- "1000205 209 1297 3705\n",
- "1000206 748 1883 3705\n",
- "1000207 71 4449 3705\n",
- "1000208 287 2473 3705\n",
+ "1000204 14 855 3705\n",
+ "1000205 90 1700 3705\n",
+ "1000206 70 936 3705\n",
+ "1000207 25 360 3705\n",
+ "1000208 380 1388 3705\n",
"\n",
"[1000209 rows x 3 columns]"
]
},
- "execution_count": 31,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "from replay.preprocessing import LabelEncoder, LabelEncodingRule\n",
+ "from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule\n",
"\n",
"encoder = LabelEncoder(\n",
" [\n",
- " LabelEncodingRule(\"user_id\"),\n",
- " LabelEncodingRule(\"item_id\"),\n",
+ " LabelEncodingRule(\"user_id\", default_value=\"last\"),\n",
+ " LabelEncodingRule(\"item_id\", default_value=\"last\"),\n",
" ]\n",
")\n",
+ "interactions = interactions.sort_values(by=\"item_id\", ascending=True)\n",
"encoded_interactions = encoder.fit_transform(interactions)\n",
"encoded_interactions"
]
@@ -343,7 +366,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -374,7 +397,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@@ -388,7 +411,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -422,31 +445,31 @@
" 0 | \n",
" 0 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7... | \n",
+ " [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117... | \n",
" \n",
" \n",
" | 1 | \n",
" 1 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8... | \n",
+ " [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110... | \n",
"
\n",
" \n",
" | 2 | \n",
" 2 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107... | \n",
+ " [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429... | \n",
"
\n",
" \n",
" | 3 | \n",
" 3 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [346, 2190, 670, 802, 323, 661, 2480, 2501, 19... | \n",
+ " [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,... | \n",
"
\n",
" \n",
" | 4 | \n",
" 4 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502... | \n",
+ " [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ... | \n",
"
\n",
" \n",
" | ... | \n",
@@ -458,31 +481,31 @@
" 6035 | \n",
" 6035 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248... | \n",
+ " [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250... | \n",
"
\n",
" \n",
" | 6036 | \n",
" 6036 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200... | \n",
+ " [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ... | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6037 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278... | \n",
+ " [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,... | \n",
"
\n",
" \n",
" | 6038 | \n",
" 6038 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... | \n",
+ " [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2... | \n",
"
\n",
" \n",
" | 6039 | \n",
" 6039 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317... | \n",
+ " [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ... | \n",
"
\n",
" \n",
"\n",
@@ -504,22 +527,22 @@
"6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n",
"\n",
" item_id \n",
- "0 [859, 309, 2371, 3442, 1108, 329, 367, 3279, 7... \n",
- "1 [354, 2426, 253, 1371, 513, 1184, 3131, 309, 8... \n",
- "2 [1008, 1120, 2439, 1066, 3197, 253, 1108, 1107... \n",
- "3 [346, 2190, 670, 802, 323, 661, 2480, 2501, 19... \n",
- "4 [1120, 758, 2426, 1838, 2621, 3341, 3377, 3502... \n",
+ "0 [2969, 1574, 1178, 957, 2147, 1658, 3177, 1117... \n",
+ "1 [1108, 1127, 1120, 2512, 1201, 2735, 1135, 110... \n",
+ "2 [579, 2651, 3301, 1788, 1781, 1327, 1174, 3429... \n",
+ "3 [1120, 1025, 466, 3235, 3294, 1106, 253, 1108,... \n",
+ "4 [2512, 858, 847, 346, 1158, 2007, 2651, 1050, ... \n",
"... ... \n",
- "6035 [2426, 1279, 3321, 3151, 1178, 2501, 3301, 248... \n",
- "6036 [1592, 2302, 1633, 1813, 2879, 1482, 2651, 200... \n",
- "6037 [1971, 3500, 1666, 2077, 1399, 2748, 2958, 278... \n",
- "6038 [1486, 1485, 3384, 3512, 3302, 3126, 3650, 330... \n",
- "6039 [2432, 2960, 2114, 1848, 2142, 3248, 3091, 317... \n",
+ "6035 [1574, 1703, 3206, 2183, 2235, 2480, 2375, 250... \n",
+ "6036 [1702, 672, 1175, 1848, 3275, 2932, 548, 802, ... \n",
+ "6037 [3165, 859, 1120, 1965, 1288, 346, 1007, 1066,... \n",
+ "6038 [107, 275, 1886, 1139, 869, 886, 2872, 2809, 2... \n",
+ "6039 [802, 2191, 579, 1781, 1839, 1316, 207, 2895, ... \n",
"\n",
"[6040 rows x 3 columns]"
]
},
- "execution_count": 34,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -545,7 +568,7 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -561,7 +584,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -579,9 +602,20 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 11,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:964: UserWarning: There is already LabelEncoder object saved at the given path. File will be overwrited.\n",
+ " warnings.warn(msg)\n",
+ "/home/nkulikov/RePlay/replay/preprocessing/label_encoder.py:537: UserWarning: There is already LabelEncodingRule object saved at the given path. File will be overwrited.\n",
+ " warnings.warn(msg)\n"
+ ]
+ }
+ ],
"source": [
"train_events.to_parquet(TRAIN_PATH)\n",
"validation_events.to_parquet(VAL_PATH)\n",
@@ -604,7 +638,7 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -659,7 +693,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -670,12 +704,10 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
- "from typing import Optional\n",
- "\n",
"MAX_SEQ_LEN = 50\n",
"\n",
"def create_meta(shape: int, gt_shape: Optional[int] = None):\n",
@@ -696,9 +728,18 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 15,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_37022/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n",
+ " parquet_module = ParquetModule(\n"
+ ]
+ }
+ ],
"source": [
"from replay.data.nn import ParquetModule\n",
"\n",
@@ -730,11 +771,12 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from replay.nn.sequential import SasRec\n",
+ "from replay.nn.loss import ComposedLoss, CE, BCE\n",
"\n",
"NUM_BLOCKS = 2\n",
"NUM_HEADS = 2\n",
@@ -747,7 +789,15 @@
" num_heads=NUM_HEADS,\n",
" num_blocks=NUM_BLOCKS,\n",
" dropout=DROPOUT,\n",
- ")"
+ ")\n",
+ "\n",
+ "sasrec.loss = ComposedLoss(\n",
+ " losses = [\n",
+ " CE(ignore_index=tensor_schema.item_id_features.item().padding_value),\n",
+ " BCE()\n",
+ " ]\n",
+ ")\n",
+ "sasrec.loss.logits_callback = sasrec.get_logits"
]
},
{
@@ -759,13 +809,19 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
+ "from replay.nn.lightning.optimizer import OptimizerFactory\n",
+ "from replay.nn.lightning.scheduler import LRSchedulerFactory\n",
"from replay.nn.lightning import LightningModule\n",
"\n",
- "model = LightningModule(sasrec)"
+ "model = LightningModule(\n",
+ " sasrec,\n",
+ " optimizer_factory=OptimizerFactory(),\n",
+ " lr_scheduler_factory=LRSchedulerFactory(),\n",
+ ")"
]
},
{
@@ -779,33 +835,36 @@
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "GPU available: False, used: False\n",
+ "GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
- "HPU available: False, using: 0 HPUs\n",
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /home/nkulikov/RePlay/examples/sasrec/checkpoints exists and is not empty.\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
- " | Name | Type | Params | Mode \n",
- "-----------------------------------------\n",
- "0 | model | SasRec | 291 K | train\n",
- "-----------------------------------------\n",
+ " | Name | Type | Params | Mode | FLOPs\n",
+ "-------------------------------------------------\n",
+ "0 | model | SasRec | 291 K | train | 0 \n",
+ "-------------------------------------------------\n",
"291 K Trainable params\n",
"0 Non-trainable params\n",
"291 K Total params\n",
"1.164 Total estimated model params size (MB)\n",
- "39 Modules in train mode\n",
- "0 Modules in eval mode\n"
+ "43 Modules in train mode\n",
+ "0 Modules in eval mode\n",
+ "0 Total Flops\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "a36bec8cdb384569a56b01dda4a8dce3",
+ "model_id": "b6dfa4dd9a2a41c4acdcfeac89157c0f",
"version_major": 2,
"version_minor": 0
},
@@ -816,10 +875,19 @@
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n",
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3788b1f40c8648e58741b90e70a66259",
+ "model_id": "68ad24eb0bb14809b761afba56ea080d",
"version_major": 2,
"version_minor": 0
},
@@ -833,7 +901,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "cdc8dd8ed34441bc82f9d868f8cac7d3",
+ "model_id": "ef579d347f4f47f986a64875ecccadee",
"version_major": 2,
"version_minor": 0
},
@@ -848,7 +916,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 0, global step 189: 'recall@10' reached 0.03576 (best 0.03576), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n"
+ "Epoch 0, global step 189: 'recall@10' reached 0.01954 (best 0.01954), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189.ckpt' as top 1\n"
]
},
{
@@ -856,16 +924,24 @@
"output_type": "stream",
"text": [
"k 1 10 20 5\n",
- "map 0.003808 0.010609 0.012859 0.008209\n",
- "ndcg 0.003808 0.016343 0.024711 0.010422\n",
- "recall 0.003808 0.035762 0.069205 0.017219\n",
+ "map 0.003809 0.006980 0.008672 0.005729\n",
+ "ndcg 0.003809 0.009822 0.016091 0.006684\n",
+ "recall 0.003809 0.019540 0.044544 0.009604\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "6c8e2d8667a64bba9e4c05337d9a934c",
+ "model_id": "3448560f23764ea6a4bedcc0349a529b",
"version_major": 2,
"version_minor": 0
},
@@ -880,7 +956,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 1, global step 378: 'recall@10' reached 0.08626 (best 0.08626), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n"
+ "Epoch 1, global step 378: 'recall@10' reached 0.02500 (best 0.02500), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378.ckpt' as top 1\n"
]
},
{
@@ -888,16 +964,24 @@
"output_type": "stream",
"text": [
"k 1 10 20 5\n",
- "map 0.011258 0.028874 0.032790 0.024324\n",
- "ndcg 0.011258 0.042089 0.056579 0.030685\n",
- "recall 0.011258 0.086258 0.144040 0.050166\n",
+ "map 0.003809 0.009115 0.010547 0.007866\n",
+ "ndcg 0.003809 0.012794 0.018238 0.009690\n",
+ "recall 0.003809 0.025004 0.047028 0.015234\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "4f29f2aa51ba4cf68af42cc95a453be5",
+ "model_id": "754ca088536d48099db8255917dfea00",
"version_major": 2,
"version_minor": 0
},
@@ -912,7 +996,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 2, global step 567: 'recall@10' reached 0.12136 (best 0.12136), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n"
+ "Epoch 2, global step 567: 'recall@10' reached 0.02534 (best 0.02534), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1\n"
]
},
{
@@ -920,16 +1004,24 @@
"output_type": "stream",
"text": [
"k 1 10 20 5\n",
- "map 0.015894 0.040155 0.045144 0.033278\n",
- "ndcg 0.015894 0.058849 0.077380 0.041941\n",
- "recall 0.015894 0.121358 0.195364 0.068543\n",
+ "map 0.003809 0.009250 0.010724 0.007827\n",
+ "ndcg 0.003809 0.012979 0.018452 0.009511\n",
+ "recall 0.003809 0.025335 0.047193 0.014572\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "a8c62495083f4b108e679d86d496e3e9",
+ "model_id": "0f2fe27d36e64ee0889cdade17c286de",
"version_major": 2,
"version_minor": 0
},
@@ -944,7 +1036,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 3, global step 756: 'recall@10' reached 0.14106 (best 0.14106), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=3-step=756.ckpt' as top 1\n"
+ "Epoch 3, global step 756: 'recall@10' was not in top 1\n"
]
},
{
@@ -952,16 +1044,24 @@
"output_type": "stream",
"text": [
"k 1 10 20 5\n",
- "map 0.015232 0.045004 0.050748 0.037318\n",
- "ndcg 0.015232 0.067207 0.088327 0.048297\n",
- "recall 0.015232 0.141060 0.225000 0.081954\n",
+ "map 0.003809 0.007385 0.009081 0.006154\n",
+ "ndcg 0.003809 0.010590 0.016888 0.007441\n",
+ "recall 0.003809 0.021527 0.046696 0.011426\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "e40e9de8673b44e3b451ccb9bdd84699",
+ "model_id": "d3f48b662b444c5baa3644f1eff468c7",
"version_major": 2,
"version_minor": 0
},
@@ -976,7 +1076,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 4, global step 945: 'recall@10' reached 0.15331 (best 0.15331), saving model to '/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt' as top 1\n",
+ "Epoch 4, global step 945: 'recall@10' reached 0.03279 (best 0.03279), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt' as top 1\n",
"`Trainer.fit` stopped: `max_epochs=5` reached.\n"
]
},
@@ -985,9 +1085,9 @@
"output_type": "stream",
"text": [
"k 1 10 20 5\n",
- "map 0.015728 0.048114 0.054519 0.039854\n",
- "ndcg 0.015728 0.072437 0.095845 0.052167\n",
- "recall 0.015728 0.153311 0.246026 0.090066\n",
+ "map 0.003146 0.009925 0.012081 0.007874\n",
+ "ndcg 0.003146 0.015168 0.023088 0.010099\n",
+ "recall 0.003146 0.032787 0.064249 0.016890\n",
"\n"
]
}
@@ -1034,16 +1134,16 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "'/home/RePlay/examples/sasrec/checkpoints/epoch=4-step=945.ckpt'"
+ "'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v1.ckpt'"
]
},
- "execution_count": 45,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -1065,10 +1165,13 @@
},
{
"cell_type": "code",
- "execution_count": 46,
+ "execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
+ "import torch\n",
+ "import replay\n",
+ "\n",
"sasrec = SasRec.from_params(\n",
" schema=tensor_schema,\n",
" embedding_dim=EMBEDDING_DIM,\n",
@@ -1079,6 +1182,11 @@
" excluded_features=None,\n",
")\n",
"\n",
+ "torch.serialization.add_safe_globals([\n",
+ " replay.nn.lightning.optimizer.OptimizerFactory,\n",
+ " replay.nn.lightning.scheduler.LRSchedulerFactory,\n",
+ "])\n",
+ "\n",
"best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)\n",
"best_model.eval();"
]
@@ -1092,9 +1200,18 @@
},
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": 21,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_37022/1373759190.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n",
+ " parquet_module = ParquetModule(\n"
+ ]
+ }
+ ],
"source": [
"inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n",
"\n",
@@ -1118,22 +1235,24 @@
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "GPU available: False, used: False\n",
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
+ "GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
- "HPU available: False, using: 0 HPUs\n"
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "d93a289ce7f74b0f8b470cd48cb6882a",
+ "model_id": "8c6243a9d5a64606864f5dba6caf3018",
"version_major": 2,
"version_minor": 0
},
@@ -1143,6 +1262,14 @@
},
"metadata": {},
"output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
}
],
"source": [
@@ -1167,7 +1294,7 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 23,
"metadata": {},
"outputs": [
{
@@ -1200,32 +1327,32 @@
" \n",
" | 0 | \n",
" 0 | \n",
- " 3383 | \n",
- " 6.106367 | \n",
+ " 2426 | \n",
+ " -5.474451 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3509 | \n",
- " 5.956742 | \n",
+ " 2557 | \n",
+ " -5.80117 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3341 | \n",
- " 5.944553 | \n",
+ " 466 | \n",
+ " -5.805716 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3510 | \n",
- " 5.776233 | \n",
+ " 627 | \n",
+ " -5.844583 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3512 | \n",
- " 5.587171 | \n",
+ " 574 | \n",
+ " -5.853208 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -1236,32 +1363,32 @@
"
\n",
" | 6037 | \n",
" 6039 | \n",
- " 2941 | \n",
- " 5.740072 | \n",
+ " 2775 | \n",
+ " -5.961242 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 3049 | \n",
- " 5.596299 | \n",
+ " 3341 | \n",
+ " -6.006827 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2750 | \n",
- " 5.548656 | \n",
+ " 2708 | \n",
+ " -6.026894 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2968 | \n",
- " 5.302012 | \n",
+ " 2959 | \n",
+ " -6.033692 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2202 | \n",
- " 5.136523 | \n",
+ " 1618 | \n",
+ " -6.041715 | \n",
"
\n",
" \n",
"\n",
@@ -1270,22 +1397,22 @@
],
"text/plain": [
" user_id item_id score\n",
- "0 0 3383 6.106367\n",
- "0 0 3509 5.956742\n",
- "0 0 3341 5.944553\n",
- "0 0 3510 5.776233\n",
- "0 0 3512 5.587171\n",
+ "0 0 2426 -5.474451\n",
+ "0 0 2557 -5.80117\n",
+ "0 0 466 -5.805716\n",
+ "0 0 627 -5.844583\n",
+ "0 0 574 -5.853208\n",
"... ... ... ...\n",
- "6037 6039 2941 5.740072\n",
- "6037 6039 3049 5.596299\n",
- "6037 6039 2750 5.548656\n",
- "6037 6039 2968 5.302012\n",
- "6037 6039 2202 5.136523\n",
+ "6037 6039 2775 -5.961242\n",
+ "6037 6039 3341 -6.006827\n",
+ "6037 6039 2708 -6.026894\n",
+ "6037 6039 2959 -6.033692\n",
+ "6037 6039 1618 -6.041715\n",
"\n",
"[120760 rows x 3 columns]"
]
},
- "execution_count": 49,
+ "execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
@@ -1306,7 +1433,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -1316,7 +1443,7 @@
},
{
"cell_type": "code",
- "execution_count": 51,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -1329,7 +1456,7 @@
},
{
"cell_type": "code",
- "execution_count": 52,
+ "execution_count": 26,
"metadata": {},
"outputs": [
{
@@ -1362,37 +1489,37 @@
" \n",
" \n",
" | MAP | \n",
- " 0.013912 | \n",
- " 0.046539 | \n",
- " 0.052342 | \n",
- " 0.038912 | \n",
+ " 0.00265 | \n",
+ " 0.008525 | \n",
+ " 0.010557 | \n",
+ " 0.006558 | \n",
"
\n",
" \n",
" | Precision | \n",
- " 0.013912 | \n",
- " 0.014492 | \n",
- " 0.011494 | \n",
- " 0.017456 | \n",
+ " 0.00265 | \n",
+ " 0.002865 | \n",
+ " 0.002956 | \n",
+ " 0.002816 | \n",
"
\n",
" \n",
" | Recall | \n",
- " 0.013912 | \n",
- " 0.144916 | \n",
- " 0.229877 | \n",
- " 0.087281 | \n",
+ " 0.00265 | \n",
+ " 0.028652 | \n",
+ " 0.059126 | \n",
+ " 0.014078 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- "k 1 10 20 5\n",
- "MAP 0.013912 0.046539 0.052342 0.038912\n",
- "Precision 0.013912 0.014492 0.011494 0.017456\n",
- "Recall 0.013912 0.144916 0.229877 0.087281"
+ "k 1 10 20 5\n",
+ "MAP 0.00265 0.008525 0.010557 0.006558\n",
+ "Precision 0.00265 0.002865 0.002956 0.002816\n",
+ "Recall 0.00265 0.028652 0.059126 0.014078"
]
},
- "execution_count": 52,
+ "execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
@@ -1410,7 +1537,7 @@
},
{
"cell_type": "code",
- "execution_count": 53,
+ "execution_count": 27,
"metadata": {},
"outputs": [
{
@@ -1442,33 +1569,33 @@
" \n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3623 | \n",
- " 6.106367 | \n",
+ " 1 | \n",
+ " 2628 | \n",
+ " -5.474451 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3752 | \n",
- " 5.956742 | \n",
+ " 1 | \n",
+ " 2762 | \n",
+ " -5.80117 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3578 | \n",
- " 5.944553 | \n",
+ " 1 | \n",
+ " 480 | \n",
+ " -5.805716 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3753 | \n",
- " 5.776233 | \n",
+ " 1 | \n",
+ " 648 | \n",
+ " -5.844583 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3755 | \n",
- " 5.587171 | \n",
+ " 1 | \n",
+ " 588 | \n",
+ " -5.853208 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -1478,33 +1605,33 @@
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 3157 | \n",
- " 5.740072 | \n",
+ " 6040 | \n",
+ " 2987 | \n",
+ " -5.961242 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 3273 | \n",
- " 5.596299 | \n",
+ " 6040 | \n",
+ " 3578 | \n",
+ " -6.006827 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 2961 | \n",
- " 5.548656 | \n",
+ " 6040 | \n",
+ " 2916 | \n",
+ " -6.026894 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 3185 | \n",
- " 5.302012 | \n",
+ " 6040 | \n",
+ " 3176 | \n",
+ " -6.033692 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 2395 | \n",
- " 5.136523 | \n",
+ " 6040 | \n",
+ " 1784 | \n",
+ " -6.041715 | \n",
"
\n",
" \n",
"\n",
@@ -1513,22 +1640,22 @@
],
"text/plain": [
" user_id item_id score\n",
- "0 2011 3623 6.106367\n",
- "0 2011 3752 5.956742\n",
- "0 2011 3578 5.944553\n",
- "0 2011 3753 5.776233\n",
- "0 2011 3755 5.587171\n",
+ "0 1 2628 -5.474451\n",
+ "0 1 2762 -5.80117\n",
+ "0 1 480 -5.805716\n",
+ "0 1 648 -5.844583\n",
+ "0 1 588 -5.853208\n",
"... ... ... ...\n",
- "6037 5727 3157 5.740072\n",
- "6037 5727 3273 5.596299\n",
- "6037 5727 2961 5.548656\n",
- "6037 5727 3185 5.302012\n",
- "6037 5727 2395 5.136523\n",
+ "6037 6040 2987 -5.961242\n",
+ "6037 6040 3578 -6.006827\n",
+ "6037 6040 2916 -6.026894\n",
+ "6037 6040 3176 -6.033692\n",
+ "6037 6040 1784 -6.041715\n",
"\n",
"[120760 rows x 3 columns]"
]
},
- "execution_count": 53,
+ "execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
@@ -1536,11 +1663,18 @@
"source": [
"encoder.inverse_transform(pandas_res)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
"kernelspec": {
- "display_name": "rep",
+ "display_name": "new_venv",
"language": "python",
"name": "python3"
},
@@ -1554,7 +1688,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.14"
+ "version": "3.12.3"
},
"orig_nbformat": 4
},
diff --git a/examples/15_twotower_example.ipynb b/examples/15_twotower_example.ipynb
index b189eebb3..bac925a97 100644
--- a/examples/15_twotower_example.ipynb
+++ b/examples/15_twotower_example.ipynb
@@ -118,15 +118,15 @@
" 2 | \n",
" \n",
" \n",
- " | 1000192 | \n",
+ " 1000007 | \n",
" 6040 | \n",
- " 2019 | \n",
+ " 1961 | \n",
" 3 | \n",
"
\n",
" \n",
- " | 1000007 | \n",
+ " 1000192 | \n",
" 6040 | \n",
- " 1961 | \n",
+ " 2019 | \n",
" 4 | \n",
"
\n",
" \n",
@@ -148,15 +148,15 @@
" | 447 | \n",
"
\n",
" \n",
- " | 825731 | \n",
+ " 825724 | \n",
" 4958 | \n",
- " 2634 | \n",
+ " 3264 | \n",
" 448 | \n",
"
\n",
" \n",
- " | 825724 | \n",
+ " 825731 | \n",
" 4958 | \n",
- " 3264 | \n",
+ " 2634 | \n",
" 449 | \n",
"
\n",
" \n",
@@ -175,13 +175,13 @@
"1000138 6040 858 0\n",
"1000153 6040 2384 1\n",
"999873 6040 593 2\n",
- "1000192 6040 2019 3\n",
- "1000007 6040 1961 4\n",
+ "1000007 6040 1961 3\n",
+ "1000192 6040 2019 4\n",
"... ... ... ...\n",
"825793 4958 2399 446\n",
"825438 4958 1407 447\n",
- "825731 4958 2634 448\n",
- "825724 4958 3264 449\n",
+ "825724 4958 3264 448\n",
+ "825731 4958 2634 449\n",
"825603 4958 1924 450\n",
"\n",
"[1000209 rows x 3 columns]"
@@ -211,7 +211,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -263,13 +263,13 @@
" | 3 | \n",
" 3 | \n",
" 6039 | \n",
- " 1950 | \n",
+ " 1892 | \n",
"
\n",
" \n",
" | 4 | \n",
" 4 | \n",
" 6039 | \n",
- " 1892 | \n",
+ " 1950 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -293,13 +293,13 @@
" 1000206 | \n",
" 448 | \n",
" 4957 | \n",
- " 2565 | \n",
+ " 3195 | \n",
"
\n",
" \n",
" | 1000207 | \n",
" 449 | \n",
" 4957 | \n",
- " 3195 | \n",
+ " 2565 | \n",
"
\n",
" \n",
" | 1000208 | \n",
@@ -317,13 +317,13 @@
"0 0 6039 847\n",
"1 1 6039 2315\n",
"2 2 6039 589\n",
- "3 3 6039 1950\n",
- "4 4 6039 1892\n",
+ "3 3 6039 1892\n",
+ "4 4 6039 1950\n",
"... ... ... ...\n",
"1000204 446 4957 2330\n",
"1000205 447 4957 1384\n",
- "1000206 448 4957 2565\n",
- "1000207 449 4957 3195\n",
+ "1000206 448 4957 3195\n",
+ "1000207 449 4957 2565\n",
"1000208 450 4957 1855\n",
"\n",
"[1000209 rows x 3 columns]"
@@ -531,25 +531,25 @@
" 0 | \n",
" 0 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118... | \n",
+ " [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118... | \n",
"
\n",
" \n",
" | 1 | \n",
" 1 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315... | \n",
+ " [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117... | \n",
"
\n",
" \n",
" | 2 | \n",
" 2 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358... | \n",
+ " [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602... | \n",
"
\n",
" \n",
" | 3 | \n",
" 3 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,... | \n",
+ " [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,... | \n",
"
\n",
" \n",
" | 4 | \n",
@@ -567,13 +567,13 @@
" 6035 | \n",
" 6035 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278... | \n",
+ " [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263... | \n",
"
\n",
" \n",
" | 6036 | \n",
" 6036 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ... | \n",
+ " [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ... | \n",
"
\n",
" \n",
" | 6037 | \n",
@@ -585,13 +585,13 @@
" 6038 | \n",
" 6038 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2... | \n",
+ " [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3... | \n",
"
\n",
" \n",
" | 6039 | \n",
" 6039 | \n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... | \n",
- " [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ... | \n",
+ " [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ... | \n",
"
\n",
" \n",
"\n",
@@ -613,17 +613,17 @@
"6039 6039 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... \n",
"\n",
" item_id \n",
- "0 [3117, 1250, 1009, 1672, 2271, 1768, 3339, 118... \n",
- "1 [1180, 1192, 1199, 2648, 1273, 2874, 1207, 315... \n",
- "2 [589, 2789, 1899, 3465, 1407, 1892, 1246, 1358... \n",
- "3 [1192, 1081, 3458, 476, 3399, 257, 1180, 1178,... \n",
+ "0 [3117, 1672, 1250, 1009, 2271, 1768, 3339, 118... \n",
+ "1 [1180, 1199, 1192, 2648, 1273, 2874, 1207, 117... \n",
+ "2 [589, 2789, 3465, 1899, 1892, 1407, 1246, 3602... \n",
+ "3 [1192, 1081, 476, 3399, 3458, 1178, 257, 1180,... \n",
"4 [2648, 907, 896, 352, 1230, 2119, 2789, 1111, ... \n",
"... ... \n",
- "6035 [1672, 1814, 3369, 2307, 2359, 2503, 2423, 278... \n",
- "6036 [693, 1813, 3439, 1959, 1247, 558, 847, 3079, ... \n",
+ "6035 [1672, 1814, 3369, 2307, 2359, 2614, 2503, 263... \n",
+ "6036 [1813, 693, 1247, 1959, 3439, 3079, 558, 847, ... \n",
"6037 [3327, 908, 1192, 2077, 1366, 352, 1063, 1132,... \n",
- "6038 [109, 279, 1998, 1211, 918, 3064, 935, 3019, 2... \n",
- "6039 [847, 2315, 589, 1950, 1892, 3042, 211, 3436, ... \n",
+ "6038 [109, 279, 1998, 1211, 918, 935, 3019, 2953, 3... \n",
+ "6039 [847, 2315, 589, 1892, 1950, 1395, 211, 3042, ... \n",
"\n",
"[6040 rows x 3 columns]"
]
@@ -725,7 +725,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@@ -781,7 +781,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -792,7 +792,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@@ -820,7 +820,16 @@
"cell_type": "code",
"execution_count": 14,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_775501/1761019238.py:5: UserWarning: The following dataset paths aren't provided: test,predict.Make sure to disable these stages in your Lightning Trainer configuration.\n",
+ " parquet_module = ParquetModule(\n"
+ ]
+ }
+ ],
"source": [
"from replay.data.nn import ParquetModule\n",
"\n",
@@ -853,7 +862,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@@ -888,7 +897,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@@ -915,26 +924,28 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "GPU available: False, used: False\n",
+ "GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
- "HPU available: False, using: 0 HPUs\n",
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
- " | Name | Type | Params | Mode \n",
- "-------------------------------------------\n",
- "0 | model | TwoTower | 340 K | train\n",
- "-------------------------------------------\n",
- "340 K Trainable params\n",
+ " | Name | Type | Params | Mode | FLOPs\n",
+ "---------------------------------------------------\n",
+ "0 | model | TwoTower | 352 K | train | 0 \n",
+ "---------------------------------------------------\n",
+ "352 K Trainable params\n",
"0 Non-trainable params\n",
- "340 K Total params\n",
- "1.364 Total estimated model params size (MB)\n",
+ "352 K Total params\n",
+ "1.409 Total estimated model params size (MB)\n",
"52 Modules in train mode\n",
- "0 Modules in eval mode\n"
+ "0 Modules in eval mode\n",
+ "0 Total Flops\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "32da5d7c8a1a4a2abd200e4c44472b61",
+ "model_id": "30be06dd73e4450f8fc37ee83bbd1746",
"version_major": 2,
"version_minor": 0
},
@@ -945,10 +956,19 @@
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n",
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "959fabb6e3ca4ed38097cd7128944bc8",
+ "model_id": "dd8c6b547ea84fd4abcbe132b7da42cb",
"version_major": 2,
"version_minor": 0
},
@@ -962,7 +982,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3462811aac8b4003aaa61c4cf9fcc4a0",
+ "model_id": "301d175de5f24790b75dc4943e83e687",
"version_major": 2,
"version_minor": 0
},
@@ -977,24 +997,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 0, global step 189: 'recall@10' reached 0.03526 (best 0.03526), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n"
+ "Epoch 0, global step 189: 'recall@10' reached 0.03958 (best 0.03958), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "k 1 5 10 20\n",
- "map 0.004801 0.009509 0.011596 0.013299\n",
- "ndcg 0.004801 0.011889 0.017026 0.023337\n",
- "recall 0.004801 0.019205 0.035265 0.060430\n",
+ "k 1 10 20 5\n",
+ "map 0.003643 0.011629 0.013674 0.009331\n",
+ "ndcg 0.003643 0.018055 0.025700 0.012422\n",
+ "recall 0.003643 0.039576 0.070210 0.022024\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "f8dacd44d48b486c9d1217c937b2584d",
+ "model_id": "01468999f3674e389b8416c5e8f43a46",
"version_major": 2,
"version_minor": 0
},
@@ -1009,24 +1037,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 1, global step 378: 'recall@10' reached 0.09603 (best 0.09603), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n"
+ "Epoch 1, global step 378: 'recall@10' reached 0.10184 (best 0.10184), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "k 1 5 10 20\n",
- "map 0.012086 0.026440 0.031724 0.036210\n",
- "ndcg 0.012086 0.033518 0.046537 0.063113\n",
- "recall 0.012086 0.055298 0.096026 0.162086\n",
+ "k 1 10 20 5\n",
+ "map 0.010763 0.033238 0.037375 0.028159\n",
+ "ndcg 0.010763 0.049178 0.064472 0.036708\n",
+ "recall 0.010763 0.101838 0.162775 0.062924\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "421acec60c4347ae98e0df5157d69ead",
+ "model_id": "148c7a181eda40d98249e05b50e4bfa1",
"version_major": 2,
"version_minor": 0
},
@@ -1041,24 +1077,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 2, global step 567: 'recall@10' reached 0.13576 (best 0.13576), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n"
+ "Epoch 2, global step 567: 'recall@10' reached 0.13280 (best 0.13280), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "k 1 5 10 20\n",
- "map 0.012252 0.034048 0.041749 0.047579\n",
- "ndcg 0.012252 0.044626 0.063492 0.084913\n",
- "recall 0.012252 0.076987 0.135762 0.220861\n",
+ "k 1 10 20 5\n",
+ "map 0.012916 0.041789 0.047532 0.034280\n",
+ "ndcg 0.012916 0.062866 0.084155 0.044635\n",
+ "recall 0.012916 0.132803 0.217751 0.076337\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "55449d2ab210406b89fea8f15deb28f9",
+ "model_id": "fced5bd80eb342909431d34449b9ca88",
"version_major": 2,
"version_minor": 0
},
@@ -1073,24 +1117,32 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 3, global step 756: 'recall@10' reached 0.15795 (best 0.15795), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n"
+ "Epoch 3, global step 756: 'recall@10' reached 0.14920 (best 0.14920), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "k 1 5 10 20\n",
- "map 0.014735 0.040450 0.049221 0.055306\n",
- "ndcg 0.014735 0.052765 0.074364 0.096740\n",
- "recall 0.014735 0.090397 0.157947 0.246854\n",
+ "k 1 10 20 5\n",
+ "map 0.01391 0.046216 0.052732 0.037647\n",
+ "ndcg 0.01391 0.070051 0.093930 0.049280\n",
+ "recall 0.01391 0.149197 0.243915 0.084948\n",
"\n"
]
},
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "5bbb1d30778b42a6ace5f22c50fdec88",
+ "model_id": "3a21e049ada84d73b3aef467158ee42d",
"version_major": 2,
"version_minor": 0
},
@@ -1105,7 +1157,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Epoch 4, global step 945: 'recall@10' reached 0.17053 (best 0.17053), saving model to '/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n",
+ "Epoch 4, global step 945: 'recall@10' reached 0.16079 (best 0.16079), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1\n",
"`Trainer.fit` stopped: `max_epochs=5` reached.\n"
]
},
@@ -1113,10 +1165,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "k 1 5 10 20\n",
- "map 0.019205 0.045563 0.054963 0.061259\n",
- "ndcg 0.019205 0.058695 0.081677 0.104973\n",
- "recall 0.019205 0.099007 0.170530 0.263411\n",
+ "k 1 10 20 5\n",
+ "map 0.013247 0.047682 0.054643 0.038307\n",
+ "ndcg 0.013247 0.073782 0.099453 0.050853\n",
+ "recall 0.013247 0.160788 0.262957 0.089419\n",
"\n"
]
}
@@ -1169,7 +1221,7 @@
{
"data": {
"text/plain": [
- "'/home/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'"
+ "'/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'"
]
},
"execution_count": 18,
@@ -1194,10 +1246,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
+ "import torch\n",
+ "import replay\n",
+ "\n",
"twotower = TwoTower.from_params(\n",
" schema=tensor_schema,\n",
" embedding_dim=EMBEDDING_DIM,\n",
@@ -1212,6 +1267,11 @@
" )\n",
")\n",
"\n",
+ "torch.serialization.add_safe_globals([\n",
+ " replay.nn.lightning.optimizer.OptimizerFactory,\n",
+ " replay.nn.lightning.scheduler.LRSchedulerFactory,\n",
+ "])\n",
+ "\n",
"best_model = LightningModule.load_from_checkpoint(best_model_path, model=twotower)\n",
"best_model.eval();"
]
@@ -1227,7 +1287,16 @@
"cell_type": "code",
"execution_count": 20,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/tmp/ipykernel_775501/2227877614.py:3: UserWarning: The following dataset paths aren't provided: train,validate,test.Make sure to disable these stages in your Lightning Trainer configuration.\n",
+ " parquet_module = ParquetModule(\n"
+ ]
+ }
+ ],
"source": [
"inference_metadata = {\"predict\": create_meta(shape=MAX_SEQ_LEN)}\n",
"\n",
@@ -1258,15 +1327,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "GPU available: False, used: False\n",
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
+ "GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
- "HPU available: False, using: 0 HPUs\n"
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "61636a98f63042a4bfce791a7dd2d01b",
+ "model_id": "527412a8a14545399dfad9ab127e3db3",
"version_major": 2,
"version_minor": 0
},
@@ -1276,6 +1347,14 @@
},
"metadata": {},
"output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/nkulikov/new_venv/lib/python3.12/site-packages/torch/nn/functional.py:6044: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
+ " warnings.warn(\n"
+ ]
}
],
"source": [
@@ -1333,32 +1412,32 @@
" \n",
" | 0 | \n",
" 0 | \n",
- " 3383 | \n",
- " 41.500851 | \n",
+ " 773 | \n",
+ " 26.849968 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3509 | \n",
- " 40.856842 | \n",
+ " 360 | \n",
+ " 26.444904 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3475 | \n",
- " 40.794147 | \n",
+ " 1526 | \n",
+ " 26.433756 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3512 | \n",
- " 40.674751 | \n",
+ " 2618 | \n",
+ " 26.282482 | \n",
"
\n",
" \n",
" | 0 | \n",
" 0 | \n",
- " 3341 | \n",
- " 40.665249 | \n",
+ " 1838 | \n",
+ " 26.202333 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -1369,32 +1448,32 @@
"
\n",
" | 6037 | \n",
" 6039 | \n",
- " 3550 | \n",
- " 26.971941 | \n",
+ " 1680 | \n",
+ " 26.977921 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 1490 | \n",
- " 26.874283 | \n",
+ " 1375 | \n",
+ " 26.976624 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2489 | \n",
- " 26.785303 | \n",
+ " 2125 | \n",
+ " 26.967751 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2634 | \n",
- " 26.725142 | \n",
+ " 1439 | \n",
+ " 26.963537 | \n",
"
\n",
" \n",
" | 6037 | \n",
" 6039 | \n",
- " 2203 | \n",
- " 26.696005 | \n",
+ " 623 | \n",
+ " 26.949759 | \n",
"
\n",
" \n",
"\n",
@@ -1403,17 +1482,17 @@
],
"text/plain": [
" user_id item_id score\n",
- "0 0 3383 41.500851\n",
- "0 0 3509 40.856842\n",
- "0 0 3475 40.794147\n",
- "0 0 3512 40.674751\n",
- "0 0 3341 40.665249\n",
+ "0 0 773 26.849968\n",
+ "0 0 360 26.444904\n",
+ "0 0 1526 26.433756\n",
+ "0 0 2618 26.282482\n",
+ "0 0 1838 26.202333\n",
"... ... ... ...\n",
- "6037 6039 3550 26.971941\n",
- "6037 6039 1490 26.874283\n",
- "6037 6039 2489 26.785303\n",
- "6037 6039 2634 26.725142\n",
- "6037 6039 2203 26.696005\n",
+ "6037 6039 1680 26.977921\n",
+ "6037 6039 1375 26.976624\n",
+ "6037 6039 2125 26.967751\n",
+ "6037 6039 1439 26.963537\n",
+ "6037 6039 623 26.949759\n",
"\n",
"[120760 rows x 3 columns]"
]
@@ -1487,42 +1566,42 @@
" \n",
" | k | \n",
" 1 | \n",
- " 5 | \n",
" 10 | \n",
" 20 | \n",
+ " 5 | \n",
"
\n",
" \n",
" \n",
" \n",
" | MAP | \n",
- " 0.016727 | \n",
- " 0.041714 | \n",
- " 0.050679 | \n",
- " 0.056809 | \n",
+ " 0.016893 | \n",
+ " 0.050191 | \n",
+ " 0.056355 | \n",
+ " 0.041200 | \n",
"
\n",
" \n",
" | Precision | \n",
- " 0.016727 | \n",
- " 0.018152 | \n",
- " 0.015982 | \n",
- " 0.012421 | \n",
+ " 0.016893 | \n",
+ " 0.015916 | \n",
+ " 0.012479 | \n",
+ " 0.018085 | \n",
"
\n",
" \n",
" | Recall | \n",
- " 0.016727 | \n",
- " 0.090759 | \n",
- " 0.159821 | \n",
- " 0.248427 | \n",
+ " 0.016893 | \n",
+ " 0.159159 | \n",
+ " 0.249586 | \n",
+ " 0.090427 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- "k 1 5 10 20\n",
- "MAP 0.016727 0.041714 0.050679 0.056809\n",
- "Precision 0.016727 0.018152 0.015982 0.012421\n",
- "Recall 0.016727 0.090759 0.159821 0.248427"
+ "k 1 10 20 5\n",
+ "MAP 0.016893 0.050191 0.056355 0.041200\n",
+ "Precision 0.016893 0.015916 0.012479 0.018085\n",
+ "Recall 0.016893 0.159159 0.249586 0.090427"
]
},
"execution_count": 25,
@@ -1575,33 +1654,33 @@
" \n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3623 | \n",
- " 41.500851 | \n",
+ " 1 | \n",
+ " 783 | \n",
+ " 26.849968 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3752 | \n",
- " 40.856842 | \n",
+ " 1 | \n",
+ " 364 | \n",
+ " 26.444904 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3717 | \n",
- " 40.794147 | \n",
+ " 1 | \n",
+ " 1566 | \n",
+ " 26.433756 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3755 | \n",
- " 40.674751 | \n",
+ " 1 | \n",
+ " 2687 | \n",
+ " 26.282482 | \n",
"
\n",
" \n",
" | 0 | \n",
- " 2011 | \n",
- " 3578 | \n",
- " 40.665249 | \n",
+ " 1 | \n",
+ " 1907 | \n",
+ " 26.202333 | \n",
"
\n",
" \n",
" | ... | \n",
@@ -1611,33 +1690,33 @@
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 3793 | \n",
- " 26.971941 | \n",
+ " 6040 | \n",
+ " 1729 | \n",
+ " 26.977921 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 1623 | \n",
- " 26.874283 | \n",
+ " 6040 | \n",
+ " 1396 | \n",
+ " 26.976624 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 2693 | \n",
- " 26.785303 | \n",
+ " 6040 | \n",
+ " 2194 | \n",
+ " 26.967751 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 2841 | \n",
- " 26.725142 | \n",
+ " 6040 | \n",
+ " 1466 | \n",
+ " 26.963537 | \n",
"
\n",
" \n",
" | 6037 | \n",
- " 5727 | \n",
- " 2396 | \n",
- " 26.696005 | \n",
+ " 6040 | \n",
+ " 628 | \n",
+ " 26.949759 | \n",
"
\n",
" \n",
"\n",
@@ -1646,17 +1725,17 @@
],
"text/plain": [
" user_id item_id score\n",
- "0 2011 3623 41.500851\n",
- "0 2011 3752 40.856842\n",
- "0 2011 3717 40.794147\n",
- "0 2011 3755 40.674751\n",
- "0 2011 3578 40.665249\n",
+ "0 1 783 26.849968\n",
+ "0 1 364 26.444904\n",
+ "0 1 1566 26.433756\n",
+ "0 1 2687 26.282482\n",
+ "0 1 1907 26.202333\n",
"... ... ... ...\n",
- "6037 5727 3793 26.971941\n",
- "6037 5727 1623 26.874283\n",
- "6037 5727 2693 26.785303\n",
- "6037 5727 2841 26.725142\n",
- "6037 5727 2396 26.696005\n",
+ "6037 6040 1729 26.977921\n",
+ "6037 6040 1396 26.976624\n",
+ "6037 6040 2194 26.967751\n",
+ "6037 6040 1466 26.963537\n",
+ "6037 6040 628 26.949759\n",
"\n",
"[120760 rows x 3 columns]"
]
@@ -1669,11 +1748,18 @@
"source": [
"encoder.inverse_transform(pandas_res)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
"kernelspec": {
- "display_name": "rep",
+ "display_name": "new_venv",
"language": "python",
"name": "python3"
},
@@ -1687,7 +1773,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.14"
+ "version": "3.12.3"
},
"orig_nbformat": 4
},
diff --git a/replay/nn/lightning/module.py b/replay/nn/lightning/module.py
index ba6d237b6..ad5f837cc 100644
--- a/replay/nn/lightning/module.py
+++ b/replay/nn/lightning/module.py
@@ -40,7 +40,7 @@ def __init__(
self._lr_scheduler_factory = lr_scheduler_factory
self.candidates_to_score = None
- def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
+ def forward(self, batch: dict, return_info: bool = False) -> Union[TrainOutput, InferenceOutput]:
"""
Implementation of the forward function.
@@ -57,12 +57,21 @@ def forward(self, batch: dict) -> Union[TrainOutput, InferenceOutput]:
batch["candidates_to_score"] = self.candidates_to_score
# select only args for model.forward
modified_batch = {k: v for k, v in batch.items() if k in inspect.signature(self.model.forward).parameters}
- return self.model(**modified_batch)
+ return self.model(**modified_batch, return_info=return_info)
def training_step(self, batch: dict) -> torch.Tensor:
- model_output: TrainOutput = self(batch)
- loss = model_output["loss"]
+ model_output: TrainOutput = self(batch, return_info=True)
+ loss, info = model_output["loss"], model_output.get("info", None)
lr = self.optimizers().param_groups[0]["lr"] # Get current learning rate
+ if info is not None:
+ assert isinstance(info, dict)
+ self.log_dict(
+ dictionary=info,
+ on_step=True,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ )
self.log("learning_rate", lr, on_step=True, on_epoch=True, prog_bar=True)
self.log(
"train_loss",
diff --git a/replay/nn/loss/__init__.py b/replay/nn/loss/__init__.py
index 2116cac83..eed615375 100644
--- a/replay/nn/loss/__init__.py
+++ b/replay/nn/loss/__init__.py
@@ -1,6 +1,7 @@
-from .base import LossProto
+from .base import LossInfo, LossOutput, LossProto
from .bce import BCE, BCESampled
from .ce import CE, CESampled, CESampledWeighted, CEWeighted
+from .composed import ComposedLoss
from .login_ce import LogInCE, LogInCESampled
from .logout_ce import LogOutCE, LogOutCEWeighted
@@ -13,10 +14,13 @@
"CESampled",
"CESampledWeighted",
"CEWeighted",
+ "ComposedLoss",
"LogInCE",
"LogInCESampled",
"LogOutCE",
"LogOutCESampled",
"LogOutCEWeighted",
+ "LossInfo",
+ "LossOutput",
"LossProto",
]
diff --git a/replay/nn/loss/base.py b/replay/nn/loss/base.py
index ef9bd88bb..996fcde9c 100644
--- a/replay/nn/loss/base.py
+++ b/replay/nn/loss/base.py
@@ -4,17 +4,23 @@
from replay.data.nn import TensorMap
+LossInfo = dict[str, torch.Tensor | float]
+LossOutput = tuple[torch.Tensor, None] | tuple[torch.Tensor, LossInfo]
+LogitsCallback = Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]
+
class LossProto(Protocol):
"""Class-protocol for working with losses inside models"""
+ loss_name: str
+
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: ...
+ ) -> LogitsCallback: ...
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None: ...
+ def logits_callback(self, func: LogitsCallback) -> None: ...
def forward(
self,
@@ -24,7 +30,19 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor: ...
+ return_info: bool = False,
+ ) -> LossOutput: ...
+
+ def __call__(
+ self,
+ model_embeddings: torch.Tensor,
+ feature_tensors: TensorMap,
+ positive_labels: torch.LongTensor,
+ negative_labels: torch.LongTensor,
+ padding_mask: torch.BoolTensor,
+ target_padding_mask: torch.BoolTensor,
+ return_info: bool = False,
+ ) -> LossOutput: ...
class SampledLossOutput(TypedDict):
@@ -39,6 +57,8 @@ class SampledLossOutput(TypedDict):
class SampledLossBase(torch.nn.Module):
"""The base class for calculating sampled losses"""
+ _logits_callback: LogitsCallback | None
+
@property
def logits_callback(
self,
diff --git a/replay/nn/loss/bce.py b/replay/nn/loss/bce.py
index 2f5901d76..62e38edc3 100644
--- a/replay/nn/loss/bce.py
+++ b/replay/nn/loss/bce.py
@@ -1,10 +1,8 @@
-from typing import Callable, Optional
-
import torch
from replay.data.nn import TensorMap
-from .base import SampledLossBase, mask_negative_logits
+from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits
class BCE(torch.nn.Module):
@@ -16,19 +14,20 @@ class BCE(torch.nn.Module):
(there are several labels for each position in the sequence).
"""
- def __init__(self, **kwargs):
+ def __init__(self, loss_name: str = "BCELoss", **kwargs):
"""
To calculate the loss, ``torch.nn.BCEWithLogitsLoss`` is used with the parameter ``reduction="sum"``.
You can pass all other parameters for initializing the object via kwargs.
"""
super().__init__()
self._loss = torch.nn.BCEWithLogitsLoss(reduction="sum", **kwargs)
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -46,7 +45,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -57,7 +56,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, target_padding_mask)
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
@@ -92,7 +92,11 @@ def forward(
)
loss = self._loss(logits, bce_labels) / logits.size(0)
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
class BCESampled(SampledLossBase):
@@ -109,7 +113,8 @@ def __init__(
log_epsilon: float = 1e-6,
clamp_border: float = 100.0,
negative_labels_ignore_index: int = -100,
- ):
+ loss_name: str = "BCESampledLoss",
+ ) -> None:
"""
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
Default: ``1e-6``.
@@ -125,12 +130,13 @@ def __init__(
self.log_epsilon = log_epsilon
self.clamp_border = clamp_border
self.negative_labels_ignore_index = negative_labels_ignore_index
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -148,7 +154,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -159,7 +165,8 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
@@ -213,4 +220,7 @@ def forward(
loss = -(positive_loss + negative_loss)
loss /= positive_logits.size(0)
- return loss
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
diff --git a/replay/nn/loss/ce.py b/replay/nn/loss/ce.py
index 07f6e7b17..152afc563 100644
--- a/replay/nn/loss/ce.py
+++ b/replay/nn/loss/ce.py
@@ -1,10 +1,8 @@
-from typing import Callable, Optional
-
import torch
from replay.data.nn import TensorMap
-from .base import SampledLossBase, mask_negative_logits
+from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits
class CE(torch.nn.Module):
@@ -13,19 +11,20 @@ class CE(torch.nn.Module):
Calculates loss over all items catalog.
"""
- def __init__(self, **kwargs):
+ def __init__(self, loss_name: str = "CELoss", **kwargs):
"""
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
You can pass all parameters for initializing the object via kwargs.
"""
super().__init__()
+ self.loss_name: str = loss_name
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -43,7 +42,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -54,7 +53,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, target_padding_mask)
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
@@ -78,7 +78,11 @@ def forward(
# [batch_size, seq_len, 1] -> [batch_size * seq_len]
labels_flat: torch.LongTensor = labels.view(-1)
loss = self._loss(logits_flat, labels_flat)
- return loss
+
+ if return_info:
+ return (loss, {"CE": loss.detach()})
+ else:
+ return (loss, None)
class CEWeighted(CE):
@@ -92,11 +96,14 @@ class CEWeighted(CE):
which is fed into the model.
"""
+ loss_name: str = "CEWeightedLoss"
+
def __init__(
self,
feature_name: str,
+ loss_name: str = "CEWEightedLoss",
**kwargs,
- ):
+ ) -> None:
"""
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
You can pass all other parameters for initializing the object via kwargs.
@@ -105,7 +112,8 @@ def __init__(
The tensor is expected to contain sample weights.
"""
super().__init__()
- self.feature_name = feature_name
+ self.loss_name: str = loss_name
+ self.feature_name: str = feature_name
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
def forward(
@@ -116,7 +124,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
:param feature_tensors: a dictionary of tensors from dataloader.
@@ -140,7 +149,11 @@ def forward(
)
sample_weight = feature_tensors[self.feature_name]
loss = (loss * sample_weight).mean()
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
class CESampled(SampledLossBase):
@@ -155,8 +168,9 @@ class CESampled(SampledLossBase):
def __init__(
self,
negative_labels_ignore_index: int = -100,
+ loss_name: str = "CESampledLoss",
**kwargs,
- ):
+ ) -> None:
"""
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used.
You can pass all parameters for initializing the object via kwargs.
@@ -168,14 +182,15 @@ def __init__(
Default: ``-100``.
"""
super().__init__()
+ self.loss_name: str = loss_name
self.negative_labels_ignore_index = negative_labels_ignore_index
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -193,7 +208,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -204,7 +219,8 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
@@ -246,7 +262,11 @@ def forward(
target = torch.zeros(positive_logits.size(0), dtype=torch.long, device=logits.device)
# [masked_batch_size] - loss for all recommendation points
loss = self._loss(logits, target)
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
class CESampledWeighted(CESampled):
@@ -267,8 +287,9 @@ def __init__(
self,
feature_name: str,
negative_labels_ignore_index: int = -100,
+ loss_name: str = "CESampledWeighedLoss",
**kwargs,
- ):
+ ) -> None:
"""
To calculate the loss, ``torch.nn.CrossEntropyLoss`` is used with the parameter ``reduction="none"``.
You can pass all other parameters for initializing the object via kwargs.
@@ -282,7 +303,9 @@ def __init__(
Default: ``-100``.
"""
super().__init__(negative_labels_ignore_index=negative_labels_ignore_index)
- self.feature_name = feature_name
+
+ self.loss_name: str = loss_name
+ self.feature_name: str = feature_name
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
def forward(
@@ -293,7 +316,8 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, feature_tensors, positive_labels, negative_labels, target_padding_mask)
:param model_embeddings: model output of shape ``(batch_size, sequence_length, embedding_dim)``.
@@ -314,4 +338,8 @@ def forward(
sample_weight = feature_tensors[self.feature_name]
sample_weight = sample_weight[target_padding_mask]
loss = (loss * sample_weight).mean()
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
diff --git a/replay/nn/loss/composed.py b/replay/nn/loss/composed.py
new file mode 100644
index 000000000..c52ebffdd
--- /dev/null
+++ b/replay/nn/loss/composed.py
@@ -0,0 +1,159 @@
+import warnings
+from typing import Iterable, Self, cast
+
+import torch
+
+from replay.data.nn import TensorMap
+
+from .base import LogitsCallback, LossInfo, LossOutput, LossProto
+
+Weights = dict[str, torch.Tensor | float]
+Losses = Iterable[torch.nn.Module] | dict[str, torch.nn.Module] | torch.nn.ModuleDict
+
+
+class ComposedLoss(torch.nn.Module):
+ def __init__(
+ self: Self,
+ losses: Losses,
+ weights: Weights | None = None,
+ loss_name: str = "ComposedLoss",
+ ) -> None:
+ super().__init__()
+
+ self.losses: torch.nn.ModuleDict = self._handle_losses(losses)
+ self.weights: Weights = self._handle_weights(weights)
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
+
+ def _handle_losses(self: Self, losses: Losses) -> torch.nn.ModuleDict:
+ if not isinstance(losses, torch.nn.ModuleDict):
+ if isinstance(losses, dict):
+ for loss in cast(dict, losses.values()):
+ if not isinstance(loss, torch.nn.Module):
+ msg: str = f"Unsupported type of loss. Must be `Module`. Got: {type(loss)=}."
+ raise TypeError(msg)
+ losses_dict: dict[str, torch.nn.Module] = cast(dict[str, torch.nn.Module], losses)
+ else:
+ losses_dict: dict[str, torch.nn.Module] = {}
+ for loss in iter(losses):
+ casted: LossProto = cast(LossProto, loss)
+ name: str = casted.loss_name
+ if name in losses_dict:
+ msg: str = f"Loss names must be unique. Got {name} twice."
+ raise KeyError(name)
+ losses_dict[name] = loss
+ losses = torch.nn.ModuleDict(losses_dict)
+
+ if not isinstance(losses, torch.nn.ModuleDict):
+ msg: str = f"Unsupported type of `losses`. Must be `dict` or `ModuleDict`. Got {type(losses)=}."
+ raise TypeError(msg)
+
+ if len(losses) < 1:
+ msg: str = "Empty losses are not supported."
+ raise ValueError(msg)
+
+ return cast(torch.nn.ModuleDict, losses)
+
+ def _handle_weights(self: Self, weights: Weights | None) -> Weights:
+ if weights is None:
+ weights = {}
+ elif not isinstance(weights, dict):
+ msg: str = f"Unsupported type of `weights`. Must be `dict`. Got: {type(weights)=}."
+ raise TypeError(msg)
+
+ for name, weight in cast(dict, weights):
+ if name not in self.losses:
+ msg: str = f"Unknown name of weight: {name}."
+ warnings.warn(msg, stacklevel=2)
+ continue
+ if isinstance(weight, float):
+ continue
+ elif isinstance(weight, torch.Tensor):
+ assert torch.is_tensor(weight)
+ if torch.numel(weight) > 1:
+ msg: str = f"Too many values in weight: {torch.numel(weight)=}."
+ raise ValueError(msg)
+ continue
+ else:
+ msg: str = f"Unsupported type of weight value. Must be `float` or `Tensor`. Got: {type(weight)=}."
+ raise TypeError(msg)
+
+ return cast(Weights, weights)
+
+ @property
+ def logits_callback(self: Self) -> LogitsCallback:
+ if self._logits_callback is None:
+ msg: str = "No `logits_callback` provided"
+ raise NotImplementedError(msg)
+ return self._logits_callback
+
+ @logits_callback.setter
+ def logits_callback(self: Self, func: LogitsCallback) -> None:
+ self._logits_callback = func
+
+ for loss in self.losses.values():
+ casted = cast(LossProto, loss)
+ casted.logits_callback = func
+
+ def _compute_raw_losses(
+ self: Self,
+ model_embeddings: torch.Tensor,
+ feature_tensors: TensorMap,
+ positive_labels: torch.LongTensor,
+ negative_labels: torch.LongTensor,
+ padding_mask: torch.BoolTensor,
+ target_padding_mask: torch.BoolTensor,
+ ) -> dict[str, torch.Tensor]:
+ raw_losses: dict[str, torch.Tensor] = {}
+ for name, loss in self.losses.items():
+ value, _ = loss(
+ model_embeddings,
+ feature_tensors,
+ positive_labels,
+ negative_labels,
+ padding_mask,
+ target_padding_mask,
+ )
+ raw_losses[name] = value
+ return raw_losses
+
+ def _apply_weights(self: Self, raw_losses: dict[str, torch.Tensor]) -> torch.Tensor:
+ losses_list: list[torch.Tensor] = []
+
+ for name, value in raw_losses.items():
+ weight = self.weights.get(name, 1.0)
+ losses_list.append(weight * value[None])
+
+ losses: torch.Tensor = torch.cat(losses_list)
+ return torch.sum(losses)
+
+ def _detach_dict(self: Self, raw_losses: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ return {name: value.detach() for name, value in raw_losses.items()}
+
+ def forward(
+ self: Self,
+ model_embeddings: torch.Tensor,
+ feature_tensors: TensorMap,
+ positive_labels: torch.LongTensor,
+ negative_labels: torch.LongTensor,
+ padding_mask: torch.BoolTensor,
+ target_padding_mask: torch.BoolTensor,
+ return_info: bool = False,
+ ) -> LossOutput:
+ raw_losses: dict[str, torch.Tensor] = self._compute_raw_losses(
+ model_embeddings=model_embeddings,
+ feature_tensors=feature_tensors,
+ positive_labels=positive_labels,
+ negative_labels=negative_labels,
+ padding_mask=padding_mask,
+ target_padding_mask=target_padding_mask,
+ )
+
+ loss: torch.Tensor = self._apply_weights(raw_losses)
+
+ if return_info:
+ base_info: dict[str, torch.Tensor] = self._detach_dict(raw_losses)
+ info: LossInfo = {self.loss_name: loss.detach(), **base_info}
+ return (loss, info)
+ else:
+ return (loss, None)
diff --git a/replay/nn/loss/login_ce.py b/replay/nn/loss/login_ce.py
index c6702c229..cf78312eb 100644
--- a/replay/nn/loss/login_ce.py
+++ b/replay/nn/loss/login_ce.py
@@ -1,10 +1,10 @@
-from typing import Callable, Optional, TypedDict
+from typing import TypedDict
import torch
from replay.data.nn import TensorMap
-from .base import SampledLossBase, mask_negative_logits
+from .base import LogitsCallback, LossOutput, SampledLossBase, mask_negative_logits
class LogInCESampledOutput(TypedDict):
@@ -121,7 +121,8 @@ def __init__(
log_epsilon: float = 1e-6,
clamp_border: float = 100.0,
negative_labels_ignore_index: int = -100,
- ):
+ loss_name: str = "LogInCELoss",
+ ) -> None:
"""
:param cardinality: number of unique items in vocabulary (catalog).
The specified cardinality value must not take into account the padding value.
@@ -140,12 +141,13 @@ def __init__(
self.log_epsilon = log_epsilon
self.clamp_border = clamp_border
self.negative_labels_ignore_index = negative_labels_ignore_index
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -163,7 +165,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -174,7 +176,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, target_padding_mask)
**Note**: At forward pass, the whole catalog of items is used as negatives.
@@ -235,7 +238,12 @@ def forward(
-self.clamp_border,
self.clamp_border,
)
- return loss.mean()
+ loss = loss.mean()
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
class LogInCESampled(LogInCEBase):
@@ -260,6 +268,7 @@ def __init__(
log_epsilon: float = 1e-6,
clamp_border: float = 100.0,
negative_labels_ignore_index: int = -100,
+ loss_name: str = "LogInCESampledLoss",
):
"""
:param log_epsilon: correction to avoid zero in the logarithm during loss calculating.
@@ -276,12 +285,13 @@ def __init__(
self.log_epsilon = log_epsilon
self.clamp_border = clamp_border
self.negative_labels_ignore_index = negative_labels_ignore_index
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -299,7 +309,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -310,7 +320,8 @@ def forward(
negative_labels: torch.LongTensor,
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, negative_labels, target_padding_mask)
@@ -370,4 +381,9 @@ def forward(
-self.clamp_border,
self.clamp_border,
)
- return loss.mean()
+ loss = loss.mean()
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
diff --git a/replay/nn/loss/logout_ce.py b/replay/nn/loss/logout_ce.py
index d0d5bba32..9299af070 100644
--- a/replay/nn/loss/logout_ce.py
+++ b/replay/nn/loss/logout_ce.py
@@ -1,10 +1,8 @@
-from typing import Callable, Optional
-
import torch
from replay.data.nn import TensorMap
-from .base import mask_negative_logits
+from .base import LogitsCallback, LossOutput, mask_negative_logits
class LogOutCE(torch.nn.Module):
@@ -28,6 +26,7 @@ def __init__(
self,
cardinality: int,
negative_labels_ignore_index: int = -100,
+ loss_name: str = "LogOutCELoss",
**kwargs,
):
"""
@@ -46,12 +45,13 @@ def __init__(
self.cardinality = cardinality
self.negative_labels_ignore_index = negative_labels_ignore_index
self._loss = torch.nn.CrossEntropyLoss(**kwargs)
- self._logits_callback = None
+ self._logits_callback: LogitsCallback | None = None
+ self.loss_name: str = loss_name
@property
def logits_callback(
self,
- ) -> Callable[[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
+ ) -> LogitsCallback:
"""
Property for calling a function for the logits computation.\n
@@ -69,7 +69,7 @@ def logits_callback(
return self._logits_callback
@logits_callback.setter
- def logits_callback(self, func: Optional[Callable]) -> None:
+ def logits_callback(self, func: LogitsCallback) -> None:
self._logits_callback = func
def forward(
@@ -80,7 +80,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, positive_labels, target_padding_mask)
**Note**: At forward pass, the whole catalog of items is used as negatives.
@@ -144,7 +145,11 @@ def forward(
target = torch.zeros(logits.size(0), dtype=torch.long, device=positive_labels.device)
# [masked_batch_size] - loss for all recommendation points
loss = self._loss(logits, target)
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
class LogOutCEWeighted(LogOutCE):
@@ -174,6 +179,7 @@ def __init__(
cardinality: int,
feature_name: str,
negative_labels_ignore_index: int = -100,
+ loss_name: str = "LogOutCEWeightedLoss",
**kwargs,
):
"""
@@ -196,6 +202,7 @@ def __init__(
)
self.feature_name = feature_name
self._loss = torch.nn.CrossEntropyLoss(reduction="none", **kwargs)
+ self.loss_name: str = loss_name
def forward(
self,
@@ -205,7 +212,8 @@ def forward(
negative_labels: torch.LongTensor, # noqa: ARG002
padding_mask: torch.BoolTensor, # noqa: ARG002
target_padding_mask: torch.BoolTensor,
- ) -> torch.Tensor:
+ return_info: bool = False,
+ ) -> LossOutput:
"""
forward(model_embeddings, feature_tensors, positive_labels, target_padding_mask)
**Note**: At forward pass, the whole catalog of items is used as negatives.
@@ -227,4 +235,8 @@ def forward(
sample_weight = feature_tensors[self.feature_name]
sample_weight = sample_weight[target_padding_mask]
loss = (loss * sample_weight).mean()
- return loss
+
+ if return_info:
+ return (loss, {self.loss_name: loss.detach()})
+ else:
+ return (loss, None)
diff --git a/replay/nn/output.py b/replay/nn/output.py
index 868a6e2cd..e378f3f0d 100644
--- a/replay/nn/output.py
+++ b/replay/nn/output.py
@@ -3,6 +3,8 @@
import torch
from typing_extensions import NotRequired
+from .loss import LossInfo
+
class TrainOutput(TypedDict):
"""
@@ -17,6 +19,7 @@ class TrainOutput(TypedDict):
"""
loss: torch.Tensor
+ info: NotRequired[LossInfo | None]
hidden_states: NotRequired[tuple[torch.Tensor, ...]]
diff --git a/replay/nn/sequential/sasrec/agg.py b/replay/nn/sequential/sasrec/agg.py
index bc2f42d5c..73e31556b 100644
--- a/replay/nn/sequential/sasrec/agg.py
+++ b/replay/nn/sequential/sasrec/agg.py
@@ -43,9 +43,9 @@ def forward(self, feature_tensors: TensorMap) -> torch.Tensor:
seqs: torch.Tensor = self.embedding_aggregator(feature_tensors)
assert seqs.dim() == 3
batch_size, seq_len, embedding_dim = seqs.size()
- assert (
- seq_len <= self.pe.num_embeddings
- ), f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}"
+ assert seq_len <= self.pe.num_embeddings, (
+ f"Sequence length = {seq_len} is greater then positional embedding num = {self.pe.num_embeddings}"
+ )
seqs *= embedding_dim**0.5
seqs += self.pe.weight[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1)
diff --git a/replay/nn/sequential/sasrec/model.py b/replay/nn/sequential/sasrec/model.py
index ed2539dd2..ea4ff2d9e 100644
--- a/replay/nn/sequential/sasrec/model.py
+++ b/replay/nn/sequential/sasrec/model.py
@@ -270,23 +270,26 @@ def forward_train(
positive_labels: torch.LongTensor,
negative_labels: torch.LongTensor,
target_padding_mask: torch.BoolTensor,
+ return_info: bool = False,
) -> TrainOutput:
hidden_states: torch.Tensor = self.body(feature_tensors, padding_mask)
assert hidden_states.dim() == 3
- loss: torch.Tensor = self.loss(
+ loss, info = self.loss(
model_embeddings=hidden_states,
feature_tensors=feature_tensors,
positive_labels=positive_labels,
negative_labels=negative_labels,
padding_mask=padding_mask,
target_padding_mask=target_padding_mask,
+ return_info=return_info,
)
- return {
- "loss": loss,
- "hidden_states": (hidden_states,),
- }
+ return TrainOutput(
+ loss=loss,
+ info=info,
+ hidden_states=(hidden_states,),
+ )
def forward_inference(
self,
@@ -313,6 +316,7 @@ def forward(
positive_labels: Optional[torch.LongTensor] = None,
negative_labels: Optional[torch.LongTensor] = None,
target_padding_mask: Optional[torch.BoolTensor] = None,
+ return_info: bool = False,
) -> Union[TrainOutput, InferenceOutput]:
"""
:param feature_tensors: a dictionary of tensors to generate embeddings.
@@ -358,6 +362,7 @@ def forward(
positive_labels=positive_labels,
negative_labels=negative_labels,
target_padding_mask=target_padding_mask,
+ return_info=return_info,
)
all(
diff --git a/replay/nn/sequential/twotower/model.py b/replay/nn/sequential/twotower/model.py
index 6016ae94f..eaa4874cf 100644
--- a/replay/nn/sequential/twotower/model.py
+++ b/replay/nn/sequential/twotower/model.py
@@ -542,6 +542,7 @@ def forward_train(
positive_labels: torch.LongTensor,
negative_labels: torch.LongTensor,
target_padding_mask: torch.BoolTensor,
+ return_info: bool = False,
) -> TrainOutput:
hidden_states = ()
query_hidden_states: torch.Tensor = self.body.query_tower(
@@ -559,17 +560,19 @@ def forward_train(
assert query_hidden_states.dim() == 3
hidden_states += (query_hidden_states,)
- loss: torch.Tensor = self.loss(
+ loss, info = self.loss(
model_embeddings=query_hidden_states,
feature_tensors=feature_tensors,
positive_labels=positive_labels,
negative_labels=negative_labels,
padding_mask=padding_mask,
target_padding_mask=target_padding_mask,
+ return_info=return_info,
)
return TrainOutput(
loss=loss,
+ info=info,
hidden_states=hidden_states,
)
@@ -612,6 +615,7 @@ def forward(
positive_labels: Optional[torch.LongTensor] = None,
negative_labels: Optional[torch.LongTensor] = None,
target_padding_mask: Optional[torch.BoolTensor] = None,
+ return_info: bool = False,
) -> Union[TrainOutput, InferenceOutput]:
"""
:param feature_tensors: a dictionary of tensors to generate embeddings.
@@ -657,6 +661,7 @@ def forward(
positive_labels=positive_labels,
negative_labels=negative_labels,
target_padding_mask=target_padding_mask,
+ return_info=return_info,
)
all(