diff --git a/modeling/datasets.py b/modeling/datasets.py index 6621efc..487077e 100644 --- a/modeling/datasets.py +++ b/modeling/datasets.py @@ -4,7 +4,7 @@ class EmbeddingsDataset: - def __init__(self, all_interactions_path, all_embeddings_path, train_parts=None): + def __init__(self, all_interactions_path, all_embeddings_path, parts=None): self.all_interactions_path = all_interactions_path self.all_embeddings_path = all_embeddings_path @@ -21,15 +21,15 @@ def __init__(self, all_interactions_path, all_embeddings_path, train_parts=None) logger.info(f"Loaded {len(self.all_interactions)} interactions") logger.info(f"Loaded {len(all_embeddings)} embeddings") - if train_parts is not None: - train_interactions = self.get_interactions_by_part(train_parts[0], train_parts[1]) - train_embeddings_dict = self._get_interactions_embeddings(train_interactions, all_embeddings) + if parts is not None: + interactions = self.get_interactions_by_part(parts[0], parts[1]) + embeddings_dict = self._get_interactions_embeddings(interactions, all_embeddings) else: - logger.info("TRAIN PARTS IS NONE (it is ok if it infer)") - train_embeddings_dict = self._get_interactions_embeddings(self.all_interactions, all_embeddings) + logger.info("PARTS IS NONE (loading all items)") + embeddings_dict = self._get_interactions_embeddings(self.all_interactions, all_embeddings) - self.item_ids = list(train_embeddings_dict.keys()) - self.embeddings = list(train_embeddings_dict.values()) + self.item_ids = list(embeddings_dict.keys()) + self.embeddings = list(embeddings_dict.values()) def __getitem__(self, idx): tensor_emb = self.embeddings[idx] @@ -64,15 +64,13 @@ def __init__( all_interactions_path, all_embeddings_path, train_parts=(0, 8), - val_parts=(8, 9), - test_parts=(9, 10), + eval_parts=(9, 10), max_seq_len=20, ): self.all_interactions_path = all_interactions_path self.all_embeddings_path = all_embeddings_path self.train_parts = train_parts - self.val_parts = val_parts - self.test_parts = test_parts + self.eval_parts = eval_parts self.max_sequence_length = max_seq_len self._validate_parts() @@ -82,17 +80,11 @@ def __init__( def _validate_parts(self): if self.train_parts[0] > self.train_parts[1]: raise ValueError("train_parts: start should be less than end") - if self.val_parts[0] > self.val_parts[1]: - raise ValueError("val_parts: start should be less than end") - if self.test_parts[0] > self.test_parts[1]: - raise ValueError("test_parts: should be less than end") - - if self.train_parts[1] > self.val_parts[0] and self.train_parts[0] < self.val_parts[1]: - logger.warning("Train and Val intersect!") - if self.train_parts[1] > self.test_parts[0] and self.train_parts[0] < self.test_parts[1]: + if self.eval_parts[0] > self.eval_parts[1]: + raise ValueError("eval_parts: should be less than end") + + if self.train_parts[1] > self.eval_parts[0] and self.train_parts[0] < self.eval_parts[1]: logger.warning("Train and Test intersect!") - if self.val_parts[1] > self.test_parts[0] and self.val_parts[0] < self.test_parts[1]: - logger.warning("Val and Test intersect!") def _load_data(self): logger.info("Loading all interactions...") @@ -124,14 +116,9 @@ def _load_data(self): def _create_samples(self): self.train_samples = self._create_train_samples() - self.val_samples = self._create_eval_samples( - eval_start_part=self.val_parts[0], - eval_end_part=self.val_parts[1], - ) - - self.test_samples = self._create_eval_samples( - eval_start_part=self.test_parts[0], - eval_end_part=self.test_parts[1], + self.eval_samples = self._create_eval_samples( + eval_start_part=self.eval_parts[0], + eval_end_part=self.eval_parts[1], ) def _create_train_samples(self): @@ -217,10 +204,9 @@ def __init__( self, all_interactions_path, all_embeddings_path, - train_parts=(0, 17), - gap_parts=(17, 18), - val_parts=(18, 19), - test_parts=(19, 20), + train_parts=(0, 8), + gap_parts=(8, 9), + eval_parts=(9, 10), max_seq_len=20, ): self.gap_parts = gap_parts @@ -228,8 +214,7 @@ def __init__( all_interactions_path, all_embeddings_path, train_parts=train_parts, - val_parts=val_parts, - test_parts=test_parts, + eval_parts=eval_parts, max_seq_len=max_seq_len, ) @@ -237,13 +222,13 @@ def _create_samples(self): self.train_samples = self._create_train_samples() self.val_samples = self._create_eval_samples( - eval_start_part=self.val_parts[0], - eval_end_part=self.val_parts[1], + eval_start_part=self.gap_parts[0], + eval_end_part=self.gap_parts[1], ) - self.test_samples = self._create_eval_samples( - eval_start_part=self.test_parts[0], - eval_end_part=self.test_parts[1], + self.eval_samples = self._create_eval_samples( + eval_start_part=self.eval_parts[0], + eval_end_part=self.eval_parts[1], ) def _create_train_samples(self): diff --git a/modeling/training.py b/modeling/training.py index ed3bdc4..bbb2095 100644 --- a/modeling/training.py +++ b/modeling/training.py @@ -1,7 +1,6 @@ import datetime from pathlib import Path -import torch from loguru import logger from torch.utils.tensorboard import SummaryWriter @@ -9,10 +8,11 @@ class TensorboardLogger: def __init__(self, experiment_name, logdir): self._experiment_name = experiment_name + self.timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M") self.logdir = Path(logdir) self.logdir.mkdir(parents=True, exist_ok=True) - log_path = self.logdir / (f"{experiment_name}_{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M')}") + log_path = self.logdir / f"{self._experiment_name}_{self.timestamp}" self.writer = SummaryWriter(log_dir=log_path) def add_metrics(self, step, metrics): @@ -25,59 +25,3 @@ def add_metrics(self, step, metrics): def close(self): self.writer.close() - - -class EarlyStopper: - def __init__(self, metric, patience, minimize=True, checkpoints_dir=None, experiment_name=None): - self.metric = metric - self.best_metric = None - self.minimize = minimize - self.patience = patience - self.wait = 0 - self.checkpoints_dir = Path(checkpoints_dir) - self.experiment_name = experiment_name - self.best_model_file = None - self.checkpoints_dir.mkdir(parents=True, exist_ok=True) - - def check(self, current_metric, model): - if self.best_metric is None: - self.best_metric = current_metric - self.best_model_file = self._save_model(model, current_metric) - return False - - improved = (self.minimize and current_metric < self.best_metric) or ( - not self.minimize and current_metric > self.best_metric - ) - - if improved: - if self.best_model_file and self.best_model_file.exists(): - self.best_model_file.unlink() - logger.info(f"Removed previous best model: {self.best_model_file.name}") - - self.wait = 0 - self.best_metric = current_metric - self.best_model_file = self._save_model(model, current_metric) - logger.info( - f"New best value for {self.metric}: {self.best_metric:.4f} (saved to {self.best_model_file.name})" - ) - return False - else: - self.wait += 1 - logger.info(f"Wait is increased to {self.wait}") - - if self.wait >= self.patience: - logger.info( - f"Patience for {self.metric} is reached: " - f"couldn't beat value {self.best_metric:.4f} for {self.wait} calls" - ) - return True - return False - - def _save_model(self, model, metric): - rounded_metric = round(metric, 4) - filepath = self.checkpoints_dir / f"{self.experiment_name}_best_{rounded_metric}.pth" - torch.save(model.state_dict(), filepath) - return filepath - - def get_best_model_path(self): - return self.best_model_file diff --git a/notebooks/AlignSemanticIDs.ipynb b/notebooks/AlignSemanticIDs.ipynb index 5ddebd5..9203aaa 100644 --- a/notebooks/AlignSemanticIDs.ipynb +++ b/notebooks/AlignSemanticIDs.ipynb @@ -27,10 +27,10 @@ "CODEBOOK_SIZE = 512\n", "NUM_CODEBOOKS = 4\n", "\n", - "OLD_TRAIN_SPLIT = \"0-8TR_8-9V_9-10T\"\n", + "OLD_TRAIN_SPLIT = \"0-8TR_9-10TE\"\n", "OLD_TRAIN_PART = \"0-8TR\"\n", "\n", - "NEW_TRAIN_SPLIT = \"0-9TR_8-9V_9-10T\"\n", + "NEW_TRAIN_SPLIT = \"0-9TR_9-10TE\"\n", "NEW_TRAIN_PART = \"0-8TR\"\n", "\n", "\n", diff --git a/scripts/dense-retriever/configs/config.yaml b/scripts/dense-retriever/configs/config.yaml index 345cce9..2ae4a1c 100644 --- a/scripts/dense-retriever/configs/config.yaml +++ b/scripts/dense-retriever/configs/config.yaml @@ -12,14 +12,12 @@ paths: dataset: name: "vk-lsvd" train_parts: [0, 8] - val_parts: [8, 9] - test_parts: [9, 10] + eval_parts: [9, 10] finetune: train_parts: [0, 8] gap_parts: [8, 9] - val_parts: [9, 10] - test_parts: [9, 10] + eval_parts: [9, 10] model: max_seq_len: 20 @@ -37,8 +35,5 @@ training: num_epochs: 150 train_batch_size: 256 valid_batch_size: 1024 - metric: "validation/ndcg@20" - patience: 60 - minimize_metric: false lr: 1e-4 save_embeddings: true \ No newline at end of file diff --git a/scripts/dense-retriever/finetune_gap.py b/scripts/dense-retriever/finetune_gap.py index 40b5eae..275aa5e 100644 --- a/scripts/dense-retriever/finetune_gap.py +++ b/scripts/dense-retriever/finetune_gap.py @@ -15,7 +15,7 @@ sys.path.append("..") from modeling.datasets import FinetuneDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import collate, fix_random_seed, run_evaluation @@ -37,8 +37,7 @@ def _transform(batch): def generate_constants(cfg: DictConfig): split_name = ( f"{cfg.dataset.train_parts[0]}-{cfg.dataset.train_parts[1]}TR_" - f"{cfg.dataset.val_parts[0]}-{cfg.dataset.val_parts[1]}V_" - f"{cfg.dataset.test_parts[0]}-{cfg.dataset.test_parts[1]}T" + f"{cfg.dataset.eval_parts[0]}-{cfg.dataset.eval_parts[1]}TE" ) old_experiment_name = f"dense-retriever_{cfg.dataset.name}_{split_name}" @@ -46,8 +45,8 @@ def generate_constants(cfg: DictConfig): interactions_path = Path(cfg.paths.data_dir) / "all_data_interactions_with_groups.parquet" embeddings_path = Path(cfg.paths.data_dir) / "items_metadata_remapped.parquet" - experiment_name = f"{old_experiment_name}_finetuned_on_{cfg.finetune.gap_parts[0]}-{cfg.finetune.gap_parts[1]}" - previous_model_mask = f"{old_experiment_name}_best_*.pth" + experiment_name = f"finetuned_{old_experiment_name}_on_{cfg.finetune.gap_parts[0]}-{cfg.finetune.gap_parts[1]}" + previous_model_mask = f"{old_experiment_name}_*.pth" return { "SPLIT_NAME": split_name, @@ -75,14 +74,13 @@ def finetune_sasrec(cfg: DictConfig): all_embeddings_path=consts["EMBEDDINGS_PATH"], train_parts=cfg.finetune.train_parts, gap_parts=cfg.finetune.gap_parts, - val_parts=cfg.finetune.val_parts, - test_parts=cfg.finetune.test_parts, + eval_parts=cfg.finetune.eval_parts, max_seq_len=cfg.model.max_seq_len, ) train_dataset = SASRecTrainDataset(data.train_samples) valid_dataset = SASRecEvalDataset(data.val_samples) - eval_dataset = SASRecEvalDataset(data.test_samples) + eval_dataset = SASRecEvalDataset(data.eval_samples) train_dataloader = DataLoader( dataset=train_dataset, @@ -122,7 +120,7 @@ def finetune_sasrec(cfg: DictConfig): ).to(device) model_files = list(Path(cfg.paths.checkpoints_dir).glob(consts["PREVIOUS_MODEL_MASK"])) - assert len(model_files) == 1, f"Expected exactly one model file, found {len(model_files)}" + assert len(model_files) >= 1, f"Expected at least one model file, found {len(model_files)}" finetune_model_path = max(model_files, key=lambda p: p.stat().st_mtime) logger.info(f"MODEL TO FINETUNE: {finetune_model_path}") state_dict = torch.load(finetune_model_path) @@ -147,14 +145,6 @@ def finetune_sasrec(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for finetuning process!") for epoch in range(cfg.training.num_epochs): @@ -177,20 +167,17 @@ def finetune_sasrec(cfg: DictConfig): all_metrics = {**train_metrics, **validation_metrics, **eval_metrics} tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break - tensorboard_logger.close() + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") logger.info("Finetuning completed successfully!") - best_model_file = early_stopper.get_best_model_path() - logger.info(f"Best model path is: {best_model_file}") - if cfg.training.save_embeddings and cfg.model.num_layers == 2: - state_dict = torch.load(best_model_file) - model.load_state_dict(state_dict) torch.save(model._item_embeddings.weight.detach().cpu(), consts["EMBEDDINGS_FILE"]) logger.info(f"Embeddings saved to {consts['EMBEDDINGS_FILE']}!") diff --git a/scripts/dense-retriever/train.py b/scripts/dense-retriever/train.py index bc72c18..0d15a2d 100644 --- a/scripts/dense-retriever/train.py +++ b/scripts/dense-retriever/train.py @@ -15,7 +15,7 @@ sys.path.append("..") from modeling.datasets import SequentialDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import collate, fix_random_seed, run_evaluation @@ -37,8 +37,7 @@ def _transform(batch): def generate_constants(cfg: DictConfig): split_name = ( f"{cfg.dataset.train_parts[0]}-{cfg.dataset.train_parts[1]}TR_" - f"{cfg.dataset.val_parts[0]}-{cfg.dataset.val_parts[1]}V_" - f"{cfg.dataset.test_parts[0]}-{cfg.dataset.test_parts[1]}T" + f"{cfg.dataset.eval_parts[0]}-{cfg.dataset.eval_parts[1]}TE" ) results_path = Path(cfg.paths.results_dir) / split_name / "dense-retriever" @@ -69,14 +68,12 @@ def train_model(cfg: DictConfig): all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"], train_parts=cfg.dataset.train_parts, - val_parts=cfg.dataset.val_parts, - test_parts=cfg.dataset.test_parts, + eval_parts=cfg.dataset.eval_parts, max_seq_len=cfg.model.max_seq_len, ) train_dataset = SASRecTrainDataset(data.train_samples) - valid_dataset = SASRecEvalDataset(data.val_samples) - eval_dataset = SASRecEvalDataset(data.test_samples) + eval_dataset = SASRecEvalDataset(data.eval_samples) train_dataloader = DataLoader( dataset=train_dataset, @@ -86,14 +83,6 @@ def train_model(cfg: DictConfig): collate_fn=collate_fn(device), ) - valid_dataloader = DataLoader( - dataset=valid_dataset, - batch_size=cfg.training.valid_batch_size, - shuffle=False, - drop_last=False, - collate_fn=collate_fn(device), - ) - eval_dataloader = DataLoader( dataset=eval_dataset, batch_size=cfg.training.valid_batch_size, @@ -128,14 +117,6 @@ def train_model(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for training process!") for epoch in range(cfg.training.num_epochs): @@ -153,25 +134,21 @@ def train_model(cfg: DictConfig): losses.append(outputs["loss"].item()) train_metrics = {"train/loss": sum(losses) / len(losses)} - validation_metrics = run_evaluation(model, valid_dataloader, "validation/") eval_metrics = run_evaluation(model, eval_dataloader, "eval/") - all_metrics = {**train_metrics, **validation_metrics, **eval_metrics} + all_metrics = {**train_metrics, **eval_metrics} tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break - tensorboard_logger.close() + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") logger.info("Training completed successfully!") - best_model_file = early_stopper.get_best_model_path() - logger.info(f"Best model path is: {best_model_file}") - if cfg.training.save_embeddings and cfg.model.num_layers == 2: - state_dict = torch.load(best_model_file) - model.load_state_dict(state_dict) torch.save(model._item_embeddings.weight.detach().cpu(), consts["EMBEDDINGS_FILE"]) logger.info(f"Embeddings saved to {consts['EMBEDDINGS_FILE']}!") diff --git a/scripts/rqvae-collab/configs/config.yaml b/scripts/rqvae-collab/configs/config.yaml index 62b9e77..10dd059 100644 --- a/scripts/rqvae-collab/configs/config.yaml +++ b/scripts/rqvae-collab/configs/config.yaml @@ -14,8 +14,7 @@ dataset: train: rqvae_train_parts: [0, 8] - rqvae_val_parts: [8, 9] - rqvae_test_parts: [9, 10] + rqvae_eval_parts: [9, 10] inference: allowed_items_parts: [0, 8] @@ -32,7 +31,4 @@ training: device: "cuda:0" num_epochs: 100 batch_size: 4096 - metric: "validation/loss" - patience: 40 - minimize_metric: true lr: 1e-4 diff --git a/scripts/rqvae-collab/inference.py b/scripts/rqvae-collab/inference.py index bddea66..33543b3 100644 --- a/scripts/rqvae-collab/inference.py +++ b/scripts/rqvae-collab/inference.py @@ -15,8 +15,7 @@ def generate_constants(cfg: DictConfig): rqvae_split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) results_path = Path(cfg.paths.results_dir) / rqvae_split_name / "rqvae-collab" diff --git a/scripts/rqvae-collab/train.py b/scripts/rqvae-collab/train.py index cf0ddae..92c5cf9 100644 --- a/scripts/rqvae-collab/train.py +++ b/scripts/rqvae-collab/train.py @@ -16,7 +16,7 @@ sys.path.append("..") from modeling.datasets import EmbeddingsDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import fix_random_seed, run_evaluation @@ -63,8 +63,7 @@ def run_inference(model, dataloader, save_path): def generate_constants(cfg: DictConfig): split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) results_path = Path(cfg.paths.results_dir) / split_name / "rqvae-collab" @@ -97,22 +96,36 @@ def train_rqvae(cfg: DictConfig): device = cfg.training.device if torch.cuda.is_available() and cfg.training.device != "cpu" else "cpu" logger.info(f"Using device: {device}") - dataset = EmbeddingsDataset( + train_dataset = EmbeddingsDataset( all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"], - train_parts=cfg.train.rqvae_train_parts, + parts=cfg.train.rqvae_train_parts, + ) + + eval_dataset = EmbeddingsDataset( + all_interactions_path=consts["INTERACTIONS_PATH"], + all_embeddings_path=consts["EMBEDDINGS_PATH"], + parts=cfg.train.rqvae_eval_parts, ) train_dataloader = StatefulDataLoader( - dataset=dataset, + dataset=train_dataset, batch_size=cfg.training.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn(device), ) - valid_dataloader = StatefulDataLoader( - dataset, + eval_train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=cfg.training.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn(device), + ) + + eval_dataloader = StatefulDataLoader( + eval_dataset, batch_size=cfg.training.batch_size, shuffle=False, drop_last=False, @@ -131,7 +144,7 @@ def train_rqvae(cfg: DictConfig): cf_embeddings=cf_embeddings, ).to(device) - codebook_initialize(model, valid_dataloader) + codebook_initialize(model, eval_train_dataloader) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -143,14 +156,6 @@ def train_rqvae(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for training process!") last_max_collisions = 0 @@ -167,7 +172,7 @@ def train_rqvae(cfg: DictConfig): loss.backward() optimizer.step() - num_fixed, max_collisions = fix_dead_codebooks(model, valid_dataloader) + num_fixed, max_collisions = fix_dead_codebooks(model, eval_train_dataloader) last_max_collisions = max_collisions train_accumulators["train/loss"].append(outputs["loss"]) @@ -180,24 +185,18 @@ def train_rqvae(cfg: DictConfig): train_metrics = {key: sum(values) / len(values) for key, values in train_accumulators.items()} train_metrics["num_dead/max_collisitons_num"] = last_max_collisions - validation_metrics = run_evaluation( - model, valid_dataloader, "validation/", ["loss", "recon_loss", "rqvae_loss"] - ) - all_metrics = {**train_metrics, **validation_metrics} + eval_metrics = run_evaluation(model, eval_dataloader, "eval/", ["loss", "recon_loss", "rqvae_loss"]) + all_metrics = {**train_metrics, **eval_metrics} tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break - tensorboard_logger.close() - best_model_file = early_stopper.get_best_model_path() - assert best_model_file is not None - - logger.info(f"Loading best model from: {best_model_file}") - state_dict = torch.load(best_model_file) - model.load_state_dict(state_dict) + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") inference_dataset = EmbeddingsDataset( all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"] @@ -239,7 +238,7 @@ def train_rqvae(cfg: DictConfig): with open(consts["ALL_MAPPING_PATH"], "w") as f: json.dump(all_mapping, f, indent=2) - train_interactions = dataset.get_interactions_by_part( + train_interactions = train_dataset.get_interactions_by_part( cfg.train.rqvae_train_parts[0], cfg.train.rqvae_train_parts[1] ) diff --git a/scripts/rqvae-content/configs/config.yaml b/scripts/rqvae-content/configs/config.yaml index 62b9e77..10dd059 100644 --- a/scripts/rqvae-content/configs/config.yaml +++ b/scripts/rqvae-content/configs/config.yaml @@ -14,8 +14,7 @@ dataset: train: rqvae_train_parts: [0, 8] - rqvae_val_parts: [8, 9] - rqvae_test_parts: [9, 10] + rqvae_eval_parts: [9, 10] inference: allowed_items_parts: [0, 8] @@ -32,7 +31,4 @@ training: device: "cuda:0" num_epochs: 100 batch_size: 4096 - metric: "validation/loss" - patience: 40 - minimize_metric: true lr: 1e-4 diff --git a/scripts/rqvae-content/inference.py b/scripts/rqvae-content/inference.py index 6ce8887..02a3d7d 100644 --- a/scripts/rqvae-content/inference.py +++ b/scripts/rqvae-content/inference.py @@ -15,8 +15,7 @@ def generate_constants(cfg: DictConfig): rqvae_split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) results_path = Path(cfg.paths.results_dir) / rqvae_split_name / "rqvae-content" diff --git a/scripts/rqvae-content/train.py b/scripts/rqvae-content/train.py index 6a4067e..e867db4 100644 --- a/scripts/rqvae-content/train.py +++ b/scripts/rqvae-content/train.py @@ -16,7 +16,7 @@ sys.path.append("..") from modeling.datasets import EmbeddingsDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import fix_random_seed, run_evaluation @@ -63,8 +63,7 @@ def run_inference(model, dataloader, save_path): def generate_constants(cfg: DictConfig): split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) results_path = Path(cfg.paths.results_dir) / split_name / "rqvae-content" @@ -94,22 +93,36 @@ def train_rqvae(cfg: DictConfig): device = cfg.training.device if torch.cuda.is_available() and cfg.training.device != "cpu" else "cpu" logger.info(f"Using device: {device}") - dataset = EmbeddingsDataset( + train_dataset = EmbeddingsDataset( all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"], - train_parts=cfg.train.rqvae_train_parts, + parts=cfg.train.rqvae_train_parts, + ) + + eval_dataset = EmbeddingsDataset( + all_interactions_path=consts["INTERACTIONS_PATH"], + all_embeddings_path=consts["EMBEDDINGS_PATH"], + parts=cfg.train.rqvae_eval_parts, ) train_dataloader = StatefulDataLoader( - dataset=dataset, + dataset=train_dataset, batch_size=cfg.training.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn(device), ) - valid_dataloader = StatefulDataLoader( - dataset, + eval_train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=cfg.training.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn(device), + ) + + eval_dataloader = StatefulDataLoader( + eval_dataset, batch_size=cfg.training.batch_size, shuffle=False, drop_last=False, @@ -124,7 +137,7 @@ def train_rqvae(cfg: DictConfig): beta=cfg.model.beta, ).to(device) - codebook_initialize(model, valid_dataloader) + codebook_initialize(model, eval_train_dataloader) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -136,14 +149,6 @@ def train_rqvae(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for training process!") last_max_collisions = 0 @@ -160,7 +165,7 @@ def train_rqvae(cfg: DictConfig): loss.backward() optimizer.step() - num_fixed, max_collisions = fix_dead_codebooks(model, valid_dataloader) + num_fixed, max_collisions = fix_dead_codebooks(model, eval_train_dataloader) last_max_collisions = max_collisions train_accumulators["train/loss"].append(outputs["loss"]) @@ -173,24 +178,18 @@ def train_rqvae(cfg: DictConfig): train_metrics = {key: sum(values) / len(values) for key, values in train_accumulators.items()} train_metrics["num_dead/max_collisitons_num"] = last_max_collisions - validation_metrics = run_evaluation( - model, valid_dataloader, "validation/", ["loss", "recon_loss", "rqvae_loss"] - ) - all_metrics = {**train_metrics, **validation_metrics} + eval_metrics = run_evaluation(model, eval_dataloader, "eval/", ["loss", "recon_loss", "rqvae_loss"]) + all_metrics = {**train_metrics, **eval_metrics} tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break - tensorboard_logger.close() - best_model_file = early_stopper.get_best_model_path() - assert best_model_file is not None - - logger.info(f"Loading best model from: {best_model_file}") - state_dict = torch.load(best_model_file) - model.load_state_dict(state_dict) + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") inference_dataset = EmbeddingsDataset( all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"] @@ -232,7 +231,7 @@ def train_rqvae(cfg: DictConfig): with open(consts["ALL_MAPPING_PATH"], "w") as f: json.dump(all_mapping, f, indent=2) - train_interactions = dataset.get_interactions_by_part( + train_interactions = train_dataset.get_interactions_by_part( cfg.train.rqvae_train_parts[0], cfg.train.rqvae_train_parts[1] ) diff --git a/scripts/sid-retriever/configs/config.yaml b/scripts/sid-retriever/configs/config.yaml index 76a1596..a9d4915 100644 --- a/scripts/sid-retriever/configs/config.yaml +++ b/scripts/sid-retriever/configs/config.yaml @@ -16,23 +16,19 @@ dataset: train: sid_retriever_train_parts: [0, 8] - sid_retriever_val_parts: [8, 9] - sid_retriever_test_parts: [9, 10] + sid_retriever_eval_parts: [9, 10] allowed_items_parts: [0, 8] rqvae_train_parts: [0, 8] - rqvae_val_parts: [8, 9] - rqvae_test_parts: [9, 10] + rqvae_eval_parts: [9, 10] finetune: matching_method: "hungarian" # none, hungarian or greedy sid_retriever_train_parts: [0, 8] sid_retriever_gap_parts: [8, 9] - sid_retriever_val_parts: [8, 9] - sid_retriever_test_parts: [9, 10] + sid_retriever_eval_parts: [9, 10] allowed_items_parts: [0, 8] rqvae_train_parts: [0, 8] - rqvae_val_parts: [8, 9] - rqvae_test_parts: [9, 10] + rqvae_eval_parts: [9, 10] model: max_seq_len: 20 @@ -55,10 +51,7 @@ training: valid_batch_size: 256 num_epochs: 30 lr: 1e-4 - patience: 40 - minimize_metric: false - metric: "eval/ndcg@20" inference: use_finetune_model: True - test_parts: [9, 10] + eval_parts: [9, 10] diff --git a/scripts/sid-retriever/finetune_gap.py b/scripts/sid-retriever/finetune_gap.py index cef12f5..daf8bd4 100644 --- a/scripts/sid-retriever/finetune_gap.py +++ b/scripts/sid-retriever/finetune_gap.py @@ -15,7 +15,7 @@ sys.path.append("..") from modeling.datasets import FinetuneDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import collate, fix_random_seed, run_evaluation @@ -46,18 +46,18 @@ def generate_constants(cfg: DictConfig): pretrained_allowed_parts = cfg.train.allowed_items_parts or cfg.train.sid_retriever_train_parts pretrained_sid_retriever_split_name = ( f"{cfg.train.sid_retriever_train_parts[0]}-{cfg.train.sid_retriever_train_parts[1]}TR_" - f"{cfg.train.sid_retriever_val_parts[0]}-{cfg.train.sid_retriever_val_parts[1]}V_" - f"{cfg.train.sid_retriever_test_parts[0]}-{cfg.train.sid_retriever_test_parts[1]}T_" + f"{cfg.train.sid_retriever_eval_parts[0]}-{cfg.train.sid_retriever_eval_parts[1]}TE_" f"items-{pretrained_allowed_parts[0]}-{pretrained_allowed_parts[1]}" ) pretrained_rqvae_split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) pretrained_name = ( - f"sid-retriever_{cfg.dataset.name}_{pretrained_sid_retriever_split_name}_{pretrained_rqvae_split_name}" + f"sid-retriever-{cfg.dataset.rqvae.model_name}_" + f"{cfg.dataset.name}_{pretrained_sid_retriever_split_name}_" + f"{pretrained_rqvae_split_name}" ) finetune_allowed_parts = ( @@ -65,19 +65,17 @@ def generate_constants(cfg: DictConfig): ) finetune_sid_retriever_split_name = ( f"{cfg.finetune.sid_retriever_train_parts[0]}-{cfg.finetune.sid_retriever_train_parts[1]}TR_" - f"{cfg.finetune.sid_retriever_val_parts[0]}-{cfg.finetune.sid_retriever_val_parts[1]}V_" - f"{cfg.finetune.sid_retriever_test_parts[0]}-{cfg.finetune.sid_retriever_test_parts[1]}T_" + f"{cfg.finetune.sid_retriever_eval_parts[0]}-{cfg.finetune.sid_retriever_eval_parts[1]}TE_" f"items-{finetune_allowed_parts[0]}-{finetune_allowed_parts[1]}" ) finetune_rqvae_split_name = ( f"{cfg.finetune.rqvae_train_parts[0]}-{cfg.finetune.rqvae_train_parts[1]}TR_" - f"{cfg.finetune.rqvae_val_parts[0]}-{cfg.finetune.rqvae_val_parts[1]}V_" - f"{cfg.finetune.rqvae_test_parts[0]}-{cfg.finetune.rqvae_test_parts[1]}T" + f"{cfg.finetune.rqvae_eval_parts[0]}-{cfg.finetune.rqvae_eval_parts[1]}TE" ) assert cfg.finetune.matching_method in ["greedy", "hungarian", "none"] experiment_name = ( - f"{pretrained_name}_finetuned_" + f"finetuned_{pretrained_name}_on_" f"{cfg.finetune.sid_retriever_gap_parts[0]}-{cfg.finetune.sid_retriever_gap_parts[1]}G_" f"{cfg.finetune.matching_method}_" f"{finetune_sid_retriever_split_name}_" @@ -101,7 +99,7 @@ def generate_constants(cfg: DictConfig): f"{cfg.finetune.matching_method}_to_{pretrained_rqvae_split_name}_{train_part_mapping_path_name}" ) - pretrained_model_mask = f"{pretrained_name}_best_*.pth" + pretrained_model_mask = f"{pretrained_name}_*.pth" return { "EXPERIMENT_NAME": experiment_name, @@ -123,7 +121,7 @@ def train_tiger_finetune(cfg: DictConfig): logger.info(f"Using device: {device}") model_files = list(Path(cfg.paths.checkpoints_dir).glob(consts["PRETRAINED_MODEL_MASK"])) - assert len(model_files) == 1, f"Expected exactly one model file, found {len(model_files)}" + assert len(model_files) >= 1, f"Expected at least one model file, found {len(model_files)}" pretrained_model_path = max(model_files, key=lambda p: p.stat().st_mtime) logger.info(f"Loading pre-trained model from: {pretrained_model_path}") logger.info(f"Semantic IDs train mapping path: {consts['TRAIN_PART_SEMANTIC_MAPPING_PATH']}") @@ -140,8 +138,7 @@ def train_tiger_finetune(cfg: DictConfig): all_embeddings_path=consts["EMBEDDINGS_PATH"], train_parts=cfg.finetune.sid_retriever_train_parts, gap_parts=cfg.finetune.sid_retriever_gap_parts, - val_parts=cfg.finetune.sid_retriever_val_parts, - test_parts=cfg.finetune.sid_retriever_test_parts, + eval_parts=cfg.finetune.sid_retriever_eval_parts, max_seq_len=cfg.model.max_seq_len, ) @@ -149,8 +146,12 @@ def train_tiger_finetune(cfg: DictConfig): data.train_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash ) + valid_dataset = TigerEvalDataset( + data.val_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash + ) + eval_dataset = TigerEvalDataset( - data.test_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash + data.eval_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash ) train_dataloader = DataLoader( @@ -161,6 +162,14 @@ def train_tiger_finetune(cfg: DictConfig): collate_fn=create_collate_fn(cfg.model.num_codebooks, cfg.model.codebook_size, device, is_eval=False), ) + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=cfg.training.valid_batch_size, + shuffle=False, + drop_last=False, + collate_fn=create_collate_fn(cfg.model.num_codebooks, cfg.model.codebook_size, device, is_eval=True), + ) + eval_dataloader = DataLoader( dataset=eval_dataset, batch_size=cfg.training.valid_batch_size, @@ -205,14 +214,6 @@ def train_tiger_finetune(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for fine-tuning process!") for epoch in range(cfg.training.num_epochs): @@ -231,24 +232,27 @@ def train_tiger_finetune(cfg: DictConfig): all_metrics = {"train/loss": sum(losses) / len(losses)} - if (epoch + 1) % 2 == 0: - logger.info("Doing evaluation") + if (epoch + 1) % 4 == 0: + logger.info("Doing validation") + validation_metrics = run_evaluation(model, valid_dataloader, "validation/") + all_metrics.update(validation_metrics) + logger.info("Doing test evaluation") eval_metrics = run_evaluation(model, eval_dataloader, "eval/") all_metrics.update(eval_metrics) tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break else: tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) tensorboard_logger.close() - best_model_file = early_stopper.get_best_model_path() - logger.info(f"Best model path is: {best_model_file}") + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") logger.info("Fine-tuning completed successfully!") diff --git a/scripts/sid-retriever/inference.py b/scripts/sid-retriever/inference.py index 1850b9e..68d3a7d 100644 --- a/scripts/sid-retriever/inference.py +++ b/scripts/sid-retriever/inference.py @@ -43,18 +43,18 @@ def generate_constants(cfg: DictConfig): pretrained_allowed_parts = cfg.train.allowed_items_parts or cfg.train.sid_retriever_train_parts pretrained_sid_retriever_split_name = ( f"{cfg.train.sid_retriever_train_parts[0]}-{cfg.train.sid_retriever_train_parts[1]}TR_" - f"{cfg.train.sid_retriever_val_parts[0]}-{cfg.train.sid_retriever_val_parts[1]}V_" - f"{cfg.train.sid_retriever_test_parts[0]}-{cfg.train.sid_retriever_test_parts[1]}T_" + f"{cfg.train.sid_retriever_eval_parts[0]}-{cfg.train.sid_retriever_eval_parts[1]}TE_" f"items-{pretrained_allowed_parts[0]}-{pretrained_allowed_parts[1]}" ) pretrained_rqvae_split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) pretrained_name = ( - f"sid-retriever_{cfg.dataset.name}_{pretrained_sid_retriever_split_name}_{pretrained_rqvae_split_name}" + f"sid-retriever-{cfg.dataset.rqvae.model_name}_" + f"{cfg.dataset.name}_{pretrained_sid_retriever_split_name}_" + f"{pretrained_rqvae_split_name}" ) if cfg.inference.use_finetune_model: @@ -63,19 +63,17 @@ def generate_constants(cfg: DictConfig): ) finetune_sid_retriever_split_name = ( f"{cfg.finetune.sid_retriever_train_parts[0]}-{cfg.finetune.sid_retriever_train_parts[1]}TR_" - f"{cfg.finetune.sid_retriever_val_parts[0]}-{cfg.finetune.sid_retriever_val_parts[1]}V_" - f"{cfg.finetune.sid_retriever_test_parts[0]}-{cfg.finetune.sid_retriever_test_parts[1]}T_" + f"{cfg.finetune.sid_retriever_eval_parts[0]}-{cfg.finetune.sid_retriever_eval_parts[1]}TE_" f"items-{finetune_allowed_parts[0]}-{finetune_allowed_parts[1]}" ) finetune_rqvae_split_name = ( f"{cfg.finetune.rqvae_train_parts[0]}-{cfg.finetune.rqvae_train_parts[1]}TR_" - f"{cfg.finetune.rqvae_val_parts[0]}-{cfg.finetune.rqvae_val_parts[1]}V_" - f"{cfg.finetune.rqvae_test_parts[0]}-{cfg.finetune.rqvae_test_parts[1]}T" + f"{cfg.finetune.rqvae_eval_parts[0]}-{cfg.finetune.rqvae_eval_parts[1]}TE" ) assert cfg.finetune.matching_method in ["greedy", "hungarian", "none"] experiment_name = ( - f"{pretrained_name}_finetuned_" + f"finetuned_{pretrained_name}_on_" f"{cfg.finetune.sid_retriever_gap_parts[0]}-{cfg.finetune.sid_retriever_gap_parts[1]}G_" f"{cfg.finetune.matching_method}_" f"{finetune_sid_retriever_split_name}_" @@ -114,7 +112,7 @@ def generate_constants(cfg: DictConfig): "EMBEDDINGS_PATH": Path(cfg.paths.data_dir) / "items_metadata_remapped.parquet", "ALL_ITEMS_SEMANTIC_MAPPING_PATH": rqvae_results_path / all_items_mapping_path_name, "TRAIN_PART_SEMANTIC_MAPPING_PATH": rqvae_results_path / train_part_mapping_path_name, - "PRETRAINED_MODEL_MASK": f"{experiment_name}_best_*.pth", + "PRETRAINED_MODEL_MASK": f"{experiment_name}_*.pth", } @@ -128,10 +126,10 @@ def tiger_inference(cfg: DictConfig): logger.info(f"Using device: {device}") model_files = list(Path(cfg.paths.checkpoints_dir).glob(consts["PRETRAINED_MODEL_MASK"])) - assert len(model_files) == 1, f"Expected exactly one model file, found {len(model_files)}" + assert len(model_files) >= 1, f"Expected at least one model file, found {len(model_files)}" pretrained_model_path = max(model_files, key=lambda p: p.stat().st_mtime) logger.info(f"Loading pre-trained model from: {pretrained_model_path}") - logger.info(f"Eval parts interval: [{cfg.inference.test_parts[0]}, {cfg.inference.test_parts[1]})") + logger.info(f"Eval parts interval: [{cfg.inference.eval_parts[0]}, {cfg.inference.eval_parts[1]})") logger.info(f"Semantic IDs train mapping path: {consts['TRAIN_PART_SEMANTIC_MAPPING_PATH']}") with open(consts["ALL_ITEMS_SEMANTIC_MAPPING_PATH"]) as f: @@ -141,17 +139,16 @@ def tiger_inference(cfg: DictConfig): all_semantics_mapping_array = create_semantic_mapping_array(all_mappings, cfg.model.num_codebooks) - test_samples = SequentialDataset( + eval_samples = SequentialDataset( all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"], - train_parts=[0, cfg.inference.test_parts[0]], - val_parts=[0, cfg.inference.test_parts[0]], - test_parts=cfg.inference.test_parts, + train_parts=[0, cfg.inference.eval_parts[0]], + eval_parts=cfg.inference.eval_parts, max_seq_len=cfg.model.max_seq_len, - ).test_samples + ).eval_samples eval_dataset = TigerEvalDataset( - test_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash + eval_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash ) eval_dataloader = DataLoader( diff --git a/scripts/sid-retriever/train.py b/scripts/sid-retriever/train.py index 850211f..3ecce1a 100644 --- a/scripts/sid-retriever/train.py +++ b/scripts/sid-retriever/train.py @@ -15,7 +15,7 @@ sys.path.append("..") from modeling.datasets import SequentialDataset -from modeling.training import EarlyStopper, TensorboardLogger +from modeling.training import TensorboardLogger from modeling.utils import collate, fix_random_seed, run_evaluation @@ -46,18 +46,20 @@ def generate_constants(cfg: DictConfig): allowed_parts = cfg.train.allowed_items_parts or cfg.train.sid_retriever_train_parts sid_retriever_split_name = ( f"{cfg.train.sid_retriever_train_parts[0]}-{cfg.train.sid_retriever_train_parts[1]}TR_" - f"{cfg.train.sid_retriever_val_parts[0]}-{cfg.train.sid_retriever_val_parts[1]}V_" - f"{cfg.train.sid_retriever_test_parts[0]}-{cfg.train.sid_retriever_test_parts[1]}T_" + f"{cfg.train.sid_retriever_eval_parts[0]}-{cfg.train.sid_retriever_eval_parts[1]}TE_" f"items-{allowed_parts[0]}-{allowed_parts[1]}" ) rqvae_split_name = ( f"{cfg.train.rqvae_train_parts[0]}-{cfg.train.rqvae_train_parts[1]}TR_" - f"{cfg.train.rqvae_val_parts[0]}-{cfg.train.rqvae_val_parts[1]}V_" - f"{cfg.train.rqvae_test_parts[0]}-{cfg.train.rqvae_test_parts[1]}T" + f"{cfg.train.rqvae_eval_parts[0]}-{cfg.train.rqvae_eval_parts[1]}TE" ) - experiment_name = f"sid-retriever_{cfg.dataset.name}_{sid_retriever_split_name}_{rqvae_split_name}" + experiment_name = ( + f"sid-retriever-{cfg.dataset.rqvae.model_name}_" + f"{cfg.dataset.name}_{sid_retriever_split_name}_" + f"{rqvae_split_name}" + ) results_path = Path(cfg.paths.results_dir) / rqvae_split_name / f"rqvae-{cfg.dataset.rqvae.model_name}" @@ -91,8 +93,7 @@ def train_model(cfg: DictConfig): all_interactions_path=consts["INTERACTIONS_PATH"], all_embeddings_path=consts["EMBEDDINGS_PATH"], train_parts=cfg.train.sid_retriever_train_parts, - val_parts=cfg.train.sid_retriever_val_parts, - test_parts=cfg.train.sid_retriever_test_parts, + eval_parts=cfg.train.sid_retriever_eval_parts, max_seq_len=cfg.model.max_seq_len, ) @@ -101,7 +102,7 @@ def train_model(cfg: DictConfig): ) eval_dataset = TigerEvalDataset( - data.test_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash + data.eval_samples, all_semantics_mapping_array, cfg.model.num_codebooks, cfg.model.num_user_hash ) train_dataloader = DataLoader( @@ -153,14 +154,6 @@ def train_model(cfg: DictConfig): tensorboard_logger = TensorboardLogger(experiment_name=consts["EXPERIMENT_NAME"], logdir=cfg.paths.tensorboard_dir) - early_stopper = EarlyStopper( - metric=cfg.training.metric, - patience=cfg.training.patience, - minimize=cfg.training.minimize_metric, - checkpoints_dir=cfg.paths.checkpoints_dir, - experiment_name=consts["EXPERIMENT_NAME"], - ) - logger.debug("Everything is ready for training process!") for epoch in range(cfg.training.num_epochs): @@ -177,26 +170,26 @@ def train_model(cfg: DictConfig): losses.append(outputs["loss"].item()) - all_metrics = {"train/loss": sum(losses) / len(losses)} - - if (epoch + 1) % 2 == 0: - logger.info("Doing evaluation") + train_metrics = {"train/loss": sum(losses) / len(losses)} + all_metrics = train_metrics.copy() + if (epoch + 1) % 4 == 0: + logger.info("Doing test evaluation") eval_metrics = run_evaluation(model, eval_dataloader, "eval/") all_metrics.update(eval_metrics) tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) - - if early_stopper.check(all_metrics[cfg.training.metric], model): - logger.info("Early stopping triggered") - break else: tensorboard_logger.add_metrics((epoch + 1) * (batch_idx + 1), all_metrics) tensorboard_logger.close() - best_model_file = early_stopper.get_best_model_path() - logger.info(f"Best model path is: {best_model_file}") + Path(cfg.paths.checkpoints_dir).mkdir(parents=True, exist_ok=True) + last_model_path = ( + Path(cfg.paths.checkpoints_dir) / f"{consts['EXPERIMENT_NAME']}_{tensorboard_logger.timestamp}.pth" + ) + torch.save(model.state_dict(), last_model_path) + logger.info(f"Last model saved to: {last_model_path}") logger.info("Training completed successfully!")