Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8ab3c2c
Training pipeline (#1)
Escape142 Nov 24, 2024
1f601e7
Внести правки в пайплайн (#2)
Escape142 Nov 25, 2024
be2aadf
add sce loss to bert4rec and sasrec
petrsokerin Nov 27, 2024
a87ce66
debug bert4rec
petrsokerin Nov 27, 2024
8a4a0cc
modify pipeline
petrsokerin Nov 27, 2024
40fc298
minor changes
petrsokerin Nov 27, 2024
1026546
debug seed logging
petrsokerin Nov 27, 2024
ddd674d
add configs
petrsokerin Nov 27, 2024
089062c
debug cold users/items
petrsokerin Nov 27, 2024
e2cf8ef
testing benchmark
petrsokerin Nov 27, 2024
08a3211
add notebook with different metrics
petrsokerin Nov 27, 2024
7011c28
update notebooks
petrsokerin Nov 27, 2024
abbf619
update test notebooks
petrsokerin Nov 28, 2024
9a02394
debug
petrsokerin Nov 29, 2024
e50db28
add params desciption
petrsokerin Nov 29, 2024
cd77a25
update configs
petrsokerin Nov 29, 2024
89a4aad
Sce dev (#3)
petrsokerin Nov 29, 2024
7c0bb0c
decrease_memory_consumption
Escape142 Dec 1, 2024
2dc715c
debug hotfix
petrsokerin Dec 2, 2024
d67fc08
Merge branch 'main' of https://github.com/On-Point-RND/RePlay-Acceler…
petrsokerin Dec 2, 2024
d3edd4e
decrease_memory_consumption
Escape142 Dec 2, 2024
a1480dc
Merge branch 'main' of https://github.com/On-Point-RND/RePlay-Acceler…
petrsokerin Dec 2, 2024
9895e19
Sce dev (#5)
petrsokerin Dec 2, 2024
c97f52e
add custom model name in logs and saving training time
petrsokerin Dec 2, 2024
a572878
Merge branch 'main' of https://github.com/On-Point-RND/RePlay-Acceler…
petrsokerin Dec 2, 2024
606aa4e
fix tensor schema
Escape142 Dec 2, 2024
f00e986
add simple lightning profiler
petrsokerin Dec 2, 2024
34c8e14
Merge branch 'sce_dev' of https://github.com/On-Point-RND/RePlay-Acce…
petrsokerin Dec 2, 2024
0a136f1
minor fix
petrsokerin Dec 2, 2024
5c4736f
pull changes with user_item_id
petrsokerin Dec 2, 2024
2fdb39e
fix model_save_name
petrsokerin Dec 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ airflow.yaml

# temporary
examples/tests
outputs

### d3rlpy logs
d3rlpy_logs/
d3rlpy_logs/

# datasets
replay_benchmarks/data
# logs and checkpoints
replay_benchmarks/artifacts
17 changes: 17 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FROM python:3.11-slim as builder
RUN apt-get update \
&& apt-get install -y --no-install-recommends apt-utils \
&& apt-get install libgomp1 build-essential pandoc -y \
&& apt-get install git -y --no-install-recommends

COPY --from=openjdk:11-jre-slim /usr/local/openjdk-11 /usr/local/openjdk-11
ENV JAVA_HOME /usr/local/openjdk-11
RUN update-alternatives --install /usr/bin/java java /usr/local/openjdk-11/bin/java 1

WORKDIR /root

RUN pip install --no-cache-dir --upgrade pip wheel poetry==1.5.1 poetry-dynamic-versioning \
&& python -m poetry config virtualenvs.create false
COPY . RePlay-Accelerated/
RUN cd RePlay-Accelerated && ./poetry_wrapper.sh install --all-extras
CMD ["/bin/bash"]
38 changes: 38 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Main module"""

import os
import logging
import warnings
import yaml

from replay_benchmarks.utils.conf import load_config, seed_everything
from replay_benchmarks import TrainRunner, InferRunner

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
warnings.filterwarnings("ignore")


def main() -> None:
config_dir = "./replay_benchmarks/configs"
base_config_path = os.path.join(config_dir, "config.yaml")
config = load_config(base_config_path, config_dir)
logging.info("Configuration:\n%s", yaml.dump(config))

seed_everything(config["env"]["SEED"])
logging.info(f"Fixing seed: {config['env']['SEED']}")

if config["mode"]["name"] == "train":
runner = TrainRunner(config)
elif config["mode"]["name"] == "infer":
runner = InferRunner(config)
else:
raise ValueError(f"Unsupported mode: {config['mode']}")

runner.run()


if __name__ == "__main__":
main()
78 changes: 77 additions & 1 deletion replay/models/nn/sequential/bert4rec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def __init__(
loss_sample_count: Optional[int] = None,
negative_sampling_strategy: str = "global_uniform",
negatives_sharing: bool = False,
n_buckets: int = 100,
bucket_size_x: int = 100,
bucket_size_y: int = 100,
mix_x: bool = False,
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
):
Expand Down Expand Up @@ -64,6 +68,14 @@ def __init__(
Default: ``global_uniform``.
:param negatives_sharing: Apply negative sharing in calculating sampled logits.
Default: ``False``.
:param n_buckets: Number of buckets for SCE loss.
Default: ``100``
:param bucket_size_x: Size of x buckets for SCE loss.
Default: ``100``
:param bucket_size_y: Size of y buckets for SCE loss.
Default: ``100``
:param mix_x: Mix states embeddings with random matrix for SCE loss.
Default: ``False``
:param optimizer_factory: Optimizer factory.
Default: ``FatOptimizerFactory``.
:param lr_scheduler_factory: Learning rate schedule factory.
Expand All @@ -90,6 +102,10 @@ def __init__(
self._lr_scheduler_factory = lr_scheduler_factory
self._loss = self._create_loss()
self._schema = tensor_schema
self._n_buckets = n_buckets
self._bucket_size_x = bucket_size_x
self._bucket_size_y = bucket_size_y
self._mix_x = mix_x
assert negative_sampling_strategy in {"global_uniform", "inbatch"}

item_count = tensor_schema.item_id_features.item().cardinality
Expand Down Expand Up @@ -207,6 +223,8 @@ def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor:
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
elif self._loss_type == "CE":
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
elif self._loss_type == "SCE":
loss_func = self._compute_loss_scalable_ce
else:
msg = f"Not supported loss type: {self._loss_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -325,6 +343,64 @@ def _compute_loss_ce_sampled(
labels_flat = torch.zeros(positive_logits.size(0), dtype=torch.long, device=padding_mask.device)
loss = self._loss(logits, labels_flat)
return loss

def _compute_loss_scalable_ce(
self,
feature_tensors: TensorMap,
positive_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
tokens_mask: torch.BoolTensor,
) -> torch.Tensor:

labels_mask = (~padding_mask) + tokens_mask
masked_tokens = ~labels_mask

pad_token = feature_tensors[self._schema.item_id_feature_name].view(-1)[~padding_mask.view(-1)][0]
emb = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)
hd = torch.tensor(emb.shape[-1])

x = emb.view(-1, hd)
y = positive_labels.view(-1)
w = self.get_all_embeddings()["item_embedding"]

correct_class_logits_ = (x * torch.index_select(w, dim=0, index=y)).sum(dim=1) # (bs,)

with torch.no_grad():
if self._mix_x:
omega = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(x.shape[0], self._n_buckets, device=x.device)
buckets = omega.T @ x
del omega
else:
buckets = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(self._n_buckets, hd, device=x.device) # (n_b, hd)

with torch.no_grad():
x_bucket = buckets @ x.T # (n_b, hd) x (hd, b) -> (n_b, b)
x_bucket[:, ~padding_mask.view(-1)] = float('-inf')
_, top_x_bucket = torch.topk(x_bucket, dim=1, k=self._bucket_size_x) # (n_b, bs_x)
del x_bucket

y_bucket = buckets @ w.T # (n_b, hd) x (hd, n_cl) -> (n_b, n_cl)

y_bucket[:, pad_token] = float('-inf')
_, top_y_bucket = torch.topk(y_bucket, dim=1, k=self._bucket_size_y) # (n_b, bs_y)
del y_bucket

x_bucket = torch.gather(x, 0, top_x_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_x, hd) # (n_b, bs_x, hd)
y_bucket = torch.gather(w, 0, top_y_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_y, hd) # (n_b, bs_y, hd)

wrong_class_logits = (x_bucket @ y_bucket.transpose(-1, -2)) # (n_b, bs_x, bs_y)
mask = torch.index_select(y, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] == top_y_bucket[:, None, :] # (n_b, bs_x, bs_y)
wrong_class_logits = wrong_class_logits.masked_fill(mask, float('-inf')) # (n_b, bs_x, bs_y)
correct_class_logits = torch.index_select(correct_class_logits_, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] # (n_b, bs_x, 1)
logits = torch.cat((wrong_class_logits, correct_class_logits), dim=2) # (n_b, bs_x, bs_y + 1)

loss_ = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), (logits.shape[-1] - 1) * torch.ones(logits.shape[0] * logits.shape[1], dtype=torch.int64, device=logits.device), reduction='none') # (n_b * bs_x,)
loss = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
loss.scatter_reduce_(0, top_x_bucket.view(-1), loss_, reduce='amax', include_self=False)
loss = loss[(loss != 0) & (masked_tokens).view(-1)]
loss = torch.mean(loss)

return loss

def _get_sampled_logits(
self,
Expand Down Expand Up @@ -412,7 +488,7 @@ def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntrop
if self._loss_type == "BCE":
return torch.nn.BCEWithLogitsLoss(reduction="sum")

if self._loss_type == "CE":
if self._loss_type == "CE" or self._loss_type == "SCE":
return torch.nn.CrossEntropyLoss()

msg = "Not supported loss_type"
Expand Down
75 changes: 74 additions & 1 deletion replay/models/nn/sequential/sasrec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def __init__(
loss_sample_count: Optional[int] = None,
negative_sampling_strategy: str = "global_uniform",
negatives_sharing: bool = False,
n_buckets: int = 100,
bucket_size_x: int = 100,
bucket_size_y: int = 100,
mix_x: bool = False,
optimizer_factory: OptimizerFactory = FatOptimizerFactory(),
lr_scheduler_factory: Optional[LRSchedulerFactory] = None,
):
Expand Down Expand Up @@ -62,6 +66,14 @@ def __init__(
Default: ``global_uniform``.
:param negatives_sharing: Apply negative sharing in calculating sampled logits.
Default: ``False``.
:param n_buckets: Number of buckets for SCE loss.
Default: ``100``
:param bucket_size_x: Size of x buckets for SCE loss.
Default: ``100``
:param bucket_size_y: Size of y buckets for SCE loss.
Default: ``100``
:param mix_x: Mix states embeddings with random matrix for SCE loss.
Default: ``False``
:param optimizer_factory: Optimizer factory.
Default: ``FatOptimizerFactory``.
:param lr_scheduler_factory: Learning rate schedule factory.
Expand All @@ -87,6 +99,10 @@ def __init__(
self._lr_scheduler_factory = lr_scheduler_factory
self._loss = self._create_loss()
self._schema = tensor_schema
self._n_buckets = n_buckets
self._bucket_size_x = bucket_size_x
self._bucket_size_y = bucket_size_y
self._mix_x = mix_x
assert negative_sampling_strategy in {"global_uniform", "inbatch"}

item_count = tensor_schema.item_id_features.item().cardinality
Expand Down Expand Up @@ -190,6 +206,8 @@ def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor:
loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled
elif self._loss_type == "CE":
loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled
elif self._loss_type == "SCE":
loss_func = self._compute_loss_scalable_ce
else:
msg = f"Not supported loss type: {self._loss_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -306,6 +324,61 @@ def _compute_loss_ce_sampled(
labels_flat = torch.zeros(positive_logits.size(0), dtype=torch.long, device=padding_mask.device)
loss = self._loss(logits, labels_flat)
return loss

def _compute_loss_scalable_ce(
self,
feature_tensors: TensorMap,
positive_labels: torch.LongTensor,
padding_mask: torch.BoolTensor,
target_padding_mask: torch.BoolTensor,
) -> torch.Tensor:

pad_token = feature_tensors[self._schema.item_id_feature_name].view(-1)[~padding_mask.view(-1)][0]
emb = self._model.forward_step(feature_tensors, padding_mask)[target_padding_mask]
hd = torch.tensor(emb.shape[-1])

x = emb.view(-1, hd)
y = positive_labels[target_padding_mask].view(-1)
w = self.get_all_embeddings()["item_embedding"]

correct_class_logits_ = (x * torch.index_select(w, dim=0, index=y)).sum(dim=1) # (bs,)

with torch.no_grad():
if self._mix_x:
omega = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(x.shape[0], self._n_buckets, device=x.device)
buckets = omega.T @ x
del omega
else:
buckets = 1/torch.sqrt(torch.sqrt(hd)) * torch.randn(self._n_buckets, hd, device=x.device) # (n_b, hd)

with torch.no_grad():
x_bucket = buckets @ x.T # (n_b, hd) x (hd, b) -> (n_b, b)
x_bucket[:, ~padding_mask[target_padding_mask].view(-1)] = float('-inf')
_, top_x_bucket = torch.topk(x_bucket, dim=1, k=self._bucket_size_x) # (n_b, bs_x)
del x_bucket

y_bucket = buckets @ w.T # (n_b, hd) x (hd, n_cl) -> (n_b, n_cl)

y_bucket[:, pad_token] = float('-inf')
_, top_y_bucket = torch.topk(y_bucket, dim=1, k=self._bucket_size_y) # (n_b, bs_y)
del y_bucket

x_bucket = torch.gather(x, 0, top_x_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_x, hd) # (n_b, bs_x, hd)
y_bucket = torch.gather(w, 0, top_y_bucket.view(-1, 1).expand(-1, hd)).view(self._n_buckets, self._bucket_size_y, hd) # (n_b, bs_y, hd)

wrong_class_logits = (x_bucket @ y_bucket.transpose(-1, -2)) # (n_b, bs_x, bs_y)
mask = torch.index_select(y, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] == top_y_bucket[:, None, :] # (n_b, bs_x, bs_y)
wrong_class_logits = wrong_class_logits.masked_fill(mask, float('-inf')) # (n_b, bs_x, bs_y)
correct_class_logits = torch.index_select(correct_class_logits_, dim=0, index=top_x_bucket.view(-1)).view(self._n_buckets, self._bucket_size_x)[:, :, None] # (n_b, bs_x, 1)
logits = torch.cat((wrong_class_logits, correct_class_logits), dim=2) # (n_b, bs_x, bs_y + 1)

loss_ = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), (logits.shape[-1] - 1) * torch.ones(logits.shape[0] * logits.shape[1], dtype=torch.int64, device=logits.device), reduction='none') # (n_b * bs_x,)
loss = torch.zeros(x.shape[0], device=x.device, dtype=x.dtype)
loss.scatter_reduce_(0, top_x_bucket.view(-1), loss_, reduce='amax', include_self=False)
loss = loss[loss != 0]
loss = torch.mean(loss)

return loss

def _get_sampled_logits(
self,
Expand Down Expand Up @@ -391,7 +464,7 @@ def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntrop
if self._loss_type == "BCE":
return torch.nn.BCEWithLogitsLoss(reduction="sum")

if self._loss_type == "CE":
if self._loss_type == "CE" or self._loss_type == "SCE":
return torch.nn.CrossEntropyLoss()

msg = "Not supported loss_type"
Expand Down
Loading