Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 26 additions & 41 deletions modeling/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class EmbeddingsDataset:
def __init__(self, all_interactions_path, all_embeddings_path, train_parts=None):
def __init__(self, all_interactions_path, all_embeddings_path, parts=None):
self.all_interactions_path = all_interactions_path
self.all_embeddings_path = all_embeddings_path

Expand All @@ -21,15 +21,15 @@ def __init__(self, all_interactions_path, all_embeddings_path, train_parts=None)
logger.info(f"Loaded {len(self.all_interactions)} interactions")
logger.info(f"Loaded {len(all_embeddings)} embeddings")

if train_parts is not None:
train_interactions = self.get_interactions_by_part(train_parts[0], train_parts[1])
train_embeddings_dict = self._get_interactions_embeddings(train_interactions, all_embeddings)
if parts is not None:
interactions = self.get_interactions_by_part(parts[0], parts[1])
embeddings_dict = self._get_interactions_embeddings(interactions, all_embeddings)
else:
logger.info("TRAIN PARTS IS NONE (it is ok if it infer)")
train_embeddings_dict = self._get_interactions_embeddings(self.all_interactions, all_embeddings)
logger.info("PARTS IS NONE (loading all items)")
embeddings_dict = self._get_interactions_embeddings(self.all_interactions, all_embeddings)

self.item_ids = list(train_embeddings_dict.keys())
self.embeddings = list(train_embeddings_dict.values())
self.item_ids = list(embeddings_dict.keys())
self.embeddings = list(embeddings_dict.values())

def __getitem__(self, idx):
tensor_emb = self.embeddings[idx]
Expand Down Expand Up @@ -64,15 +64,13 @@ def __init__(
all_interactions_path,
all_embeddings_path,
train_parts=(0, 8),
val_parts=(8, 9),
test_parts=(9, 10),
eval_parts=(9, 10),
max_seq_len=20,
):
self.all_interactions_path = all_interactions_path
self.all_embeddings_path = all_embeddings_path
self.train_parts = train_parts
self.val_parts = val_parts
self.test_parts = test_parts
self.eval_parts = eval_parts
self.max_sequence_length = max_seq_len

self._validate_parts()
Expand All @@ -82,17 +80,11 @@ def __init__(
def _validate_parts(self):
if self.train_parts[0] > self.train_parts[1]:
raise ValueError("train_parts: start should be less than end")
if self.val_parts[0] > self.val_parts[1]:
raise ValueError("val_parts: start should be less than end")
if self.test_parts[0] > self.test_parts[1]:
raise ValueError("test_parts: should be less than end")

if self.train_parts[1] > self.val_parts[0] and self.train_parts[0] < self.val_parts[1]:
logger.warning("Train and Val intersect!")
if self.train_parts[1] > self.test_parts[0] and self.train_parts[0] < self.test_parts[1]:
if self.eval_parts[0] > self.eval_parts[1]:
raise ValueError("eval_parts: should be less than end")

if self.train_parts[1] > self.eval_parts[0] and self.train_parts[0] < self.eval_parts[1]:
logger.warning("Train and Test intersect!")
if self.val_parts[1] > self.test_parts[0] and self.val_parts[0] < self.test_parts[1]:
logger.warning("Val and Test intersect!")

def _load_data(self):
logger.info("Loading all interactions...")
Expand Down Expand Up @@ -124,14 +116,9 @@ def _load_data(self):
def _create_samples(self):
self.train_samples = self._create_train_samples()

self.val_samples = self._create_eval_samples(
eval_start_part=self.val_parts[0],
eval_end_part=self.val_parts[1],
)

self.test_samples = self._create_eval_samples(
eval_start_part=self.test_parts[0],
eval_end_part=self.test_parts[1],
self.eval_samples = self._create_eval_samples(
eval_start_part=self.eval_parts[0],
eval_end_part=self.eval_parts[1],
)

def _create_train_samples(self):
Expand Down Expand Up @@ -217,33 +204,31 @@ def __init__(
self,
all_interactions_path,
all_embeddings_path,
train_parts=(0, 17),
gap_parts=(17, 18),
val_parts=(18, 19),
test_parts=(19, 20),
train_parts=(0, 8),
gap_parts=(8, 9),
eval_parts=(9, 10),
max_seq_len=20,
):
self.gap_parts = gap_parts
super().__init__(
all_interactions_path,
all_embeddings_path,
train_parts=train_parts,
val_parts=val_parts,
test_parts=test_parts,
eval_parts=eval_parts,
max_seq_len=max_seq_len,
)

def _create_samples(self):
self.train_samples = self._create_train_samples()

self.val_samples = self._create_eval_samples(
eval_start_part=self.val_parts[0],
eval_end_part=self.val_parts[1],
eval_start_part=self.gap_parts[0],
eval_end_part=self.gap_parts[1],
)

self.test_samples = self._create_eval_samples(
eval_start_part=self.test_parts[0],
eval_end_part=self.test_parts[1],
self.eval_samples = self._create_eval_samples(
eval_start_part=self.eval_parts[0],
eval_end_part=self.eval_parts[1],
)

def _create_train_samples(self):
Expand Down
60 changes: 2 additions & 58 deletions modeling/training.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import datetime
from pathlib import Path

import torch
from loguru import logger
from torch.utils.tensorboard import SummaryWriter


class TensorboardLogger:
def __init__(self, experiment_name, logdir):
self._experiment_name = experiment_name
self.timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M")
self.logdir = Path(logdir)
self.logdir.mkdir(parents=True, exist_ok=True)

log_path = self.logdir / (f"{experiment_name}_{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M')}")
log_path = self.logdir / f"{self._experiment_name}_{self.timestamp}"
self.writer = SummaryWriter(log_dir=log_path)

def add_metrics(self, step, metrics):
Expand All @@ -25,59 +25,3 @@ def add_metrics(self, step, metrics):

def close(self):
self.writer.close()


class EarlyStopper:
def __init__(self, metric, patience, minimize=True, checkpoints_dir=None, experiment_name=None):
self.metric = metric
self.best_metric = None
self.minimize = minimize
self.patience = patience
self.wait = 0
self.checkpoints_dir = Path(checkpoints_dir)
self.experiment_name = experiment_name
self.best_model_file = None
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)

def check(self, current_metric, model):
if self.best_metric is None:
self.best_metric = current_metric
self.best_model_file = self._save_model(model, current_metric)
return False

improved = (self.minimize and current_metric < self.best_metric) or (
not self.minimize and current_metric > self.best_metric
)

if improved:
if self.best_model_file and self.best_model_file.exists():
self.best_model_file.unlink()
logger.info(f"Removed previous best model: {self.best_model_file.name}")

self.wait = 0
self.best_metric = current_metric
self.best_model_file = self._save_model(model, current_metric)
logger.info(
f"New best value for {self.metric}: {self.best_metric:.4f} (saved to {self.best_model_file.name})"
)
return False
else:
self.wait += 1
logger.info(f"Wait is increased to {self.wait}")

if self.wait >= self.patience:
logger.info(
f"Patience for {self.metric} is reached: "
f"couldn't beat value {self.best_metric:.4f} for {self.wait} calls"
)
return True
return False

def _save_model(self, model, metric):
rounded_metric = round(metric, 4)
filepath = self.checkpoints_dir / f"{self.experiment_name}_best_{rounded_metric}.pth"
torch.save(model.state_dict(), filepath)
return filepath

def get_best_model_path(self):
return self.best_model_file
4 changes: 2 additions & 2 deletions notebooks/AlignSemanticIDs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
"CODEBOOK_SIZE = 512\n",
"NUM_CODEBOOKS = 4\n",
"\n",
"OLD_TRAIN_SPLIT = \"0-8TR_8-9V_9-10T\"\n",
"OLD_TRAIN_SPLIT = \"0-8TR_9-10TE\"\n",
"OLD_TRAIN_PART = \"0-8TR\"\n",
"\n",
"NEW_TRAIN_SPLIT = \"0-9TR_8-9V_9-10T\"\n",
"NEW_TRAIN_SPLIT = \"0-9TR_9-10TE\"\n",
"NEW_TRAIN_PART = \"0-8TR\"\n",
"\n",
"\n",
Expand Down
9 changes: 2 additions & 7 deletions scripts/dense-retriever/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ paths:
dataset:
name: "vk-lsvd"
train_parts: [0, 8]
val_parts: [8, 9]
test_parts: [9, 10]
eval_parts: [9, 10]

finetune:
train_parts: [0, 8]
gap_parts: [8, 9]
val_parts: [9, 10]
test_parts: [9, 10]
eval_parts: [9, 10]

model:
max_seq_len: 20
Expand All @@ -37,8 +35,5 @@ training:
num_epochs: 150
train_batch_size: 256
valid_batch_size: 1024
metric: "validation/ndcg@20"
patience: 60
minimize_metric: false
lr: 1e-4
save_embeddings: true
39 changes: 13 additions & 26 deletions scripts/dense-retriever/finetune_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
sys.path.append("..")

from modeling.datasets import FinetuneDataset
from modeling.training import EarlyStopper, TensorboardLogger
from modeling.training import TensorboardLogger
from modeling.utils import collate, fix_random_seed, run_evaluation


Expand All @@ -37,17 +37,16 @@ def _transform(batch):
def generate_constants(cfg: DictConfig):
split_name = (
f"{cfg.dataset.train_parts[0]}-{cfg.dataset.train_parts[1]}TR_"
f"{cfg.dataset.val_parts[0]}-{cfg.dataset.val_parts[1]}V_"
f"{cfg.dataset.test_parts[0]}-{cfg.dataset.test_parts[1]}T"
f"{cfg.dataset.eval_parts[0]}-{cfg.dataset.eval_parts[1]}TE"
)
old_experiment_name = f"dense-retriever_{cfg.dataset.name}_{split_name}"

results_path = Path(cfg.paths.results_dir) / split_name / "dense-retriever"
interactions_path = Path(cfg.paths.data_dir) / "all_data_interactions_with_groups.parquet"
embeddings_path = Path(cfg.paths.data_dir) / "items_metadata_remapped.parquet"

experiment_name = f"{old_experiment_name}_finetuned_on_{cfg.finetune.gap_parts[0]}-{cfg.finetune.gap_parts[1]}"
previous_model_mask = f"{old_experiment_name}_best_*.pth"
experiment_name = f"finetuned_{old_experiment_name}_on_{cfg.finetune.gap_parts[0]}-{cfg.finetune.gap_parts[1]}"
previous_model_mask = f"{old_experiment_name}_*.pth"

return {
"SPLIT_NAME": split_name,
Expand Down Expand Up @@ -75,14 +74,13 @@ def finetune_sasrec(cfg: DictConfig):
all_embeddings_path=consts["EMBEDDINGS_PATH"],
train_parts=cfg.finetune.train_parts,
gap_parts=cfg.finetune.gap_parts,
val_parts=cfg.finetune.val_parts,
test_parts=cfg.finetune.test_parts,
eval_parts=cfg.finetune.eval_parts,
max_seq_len=cfg.model.max_seq_len,
)

train_dataset = SASRecTrainDataset(data.train_samples)
valid_dataset = SASRecEvalDataset(data.val_samples)
eval_dataset = SASRecEvalDataset(data.test_samples)
eval_dataset = SASRecEvalDataset(data.eval_samples)

train_dataloader = DataLoader(
dataset=train_dataset,
Expand Down Expand Up @@ -122,7 +120,7 @@ def finetune_sasrec(cfg: DictConfig):
).to(device)

model_files = list(Path(cfg.paths.checkpoints_dir).glob(consts["PREVIOUS_MODEL_MASK"]))
assert len(model_files) == 1, f"Expected exactly one model file, found {len(model_files)}"
assert len(model_files) >= 1, f"Expected at least one model file, found {len(model_files)}"
finetune_model_path = max(model_files, key=lambda p: p.stat().st_mtime)
logger.info(f"MODEL TO FINETUNE: {finetune_model_path}")
state_dict = torch.load(finetune_model_path)
Expand All @@ -147,14 +145,6 @@ def finetune_sasrec(cfg: DictConfig):

tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir)

early_stopper = EarlyStopper(
metric=cfg.training.metric,
patience=cfg.training.patience,
minimize=cfg.training.minimize_metric,
checkpoints_dir=cfg.paths.checkpoints_dir,
experiment_name=consts["EXPERIMENT_NAME"],
)

logger.debug("Everything is ready for finetuning process!")

for epoch in range(cfg.training.num_epochs):
Expand All @@ -177,20 +167,17 @@ def finetune_sasrec(cfg: DictConfig):
all_metrics = {**train_metrics, **validation_metrics, **eval_metrics}
tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics)

if early_stopper.check(all_metrics[cfg.training.metric], model):
logger.info("Early stopping triggered")
break

tensorboard_logger.close()

Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True)
last_model_path = (
Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth"
)
torch.save(model.state_dict(), last_model_path)
logger.info(f"Last model saved to: {last_model_path}")
logger.info("Finetuning completed successfully!")

best_model_file = early_stopper.get_best_model_path()
logger.info(f"Best model path is: {best_model_file}")

if cfg.training.save_embeddings and cfg.model.num_layers == 2:
state_dict = torch.load(best_model_file)
model.load_state_dict(state_dict)
torch.save(model._item_embeddings.weight.detach().cpu(), consts["EMBEDDINGS_FILE"])
logger.info(f"Embeddings saved to {consts['EMBEDDINGS_FILE']}!")

Expand Down
Loading