diff --git a/drevalpy/models/PaccMann/__init__.py b/drevalpy/models/PaccMann/__init__.py new file mode 100644 index 00000000..524763c2 --- /dev/null +++ b/drevalpy/models/PaccMann/__init__.py @@ -0,0 +1,5 @@ +"""Module for the Paccmann model.""" + +from .paccmann import PaccMann + +__all__ = ["PaccMann"] diff --git a/drevalpy/models/PaccMann/hyperparameters.yaml b/drevalpy/models/PaccMann/hyperparameters.yaml new file mode 100644 index 00000000..ee4ca625 --- /dev/null +++ b/drevalpy/models/PaccMann/hyperparameters.yaml @@ -0,0 +1,51 @@ +PaccMann: + epochs: + - 3 + batch_size: + - 64 + learning_rate: + - 0.001 + weight_decay: + - 0.0 + + smiles_embedding_size: + - 8 + + filters: + - [16, 16, 16] + + molecule_heads: + - [2, 2, 2, 2] + + gene_heads: + - [2, 2, 2, 2] + + smiles_padding_length: + - 128 + + dropout: + - 0.5 + + batch_norm: + - true + + activation_fn: + - relu + + loss_fn: + - mse + + smiles_attention_size: + - 64 + + gene_attention_size: + - 1 + + molecule_temperature: + - 1.0 + + gene_temperature: + - 1.0 + + stacked_dense_hidden_sizes: + - [512, 256] diff --git a/drevalpy/models/PaccMann/paccmann.py b/drevalpy/models/PaccMann/paccmann.py new file mode 100644 index 00000000..dcad9014 --- /dev/null +++ b/drevalpy/models/PaccMann/paccmann.py @@ -0,0 +1,420 @@ +"""PaccMann model.""" + +from __future__ import annotations + +import json +import os +from typing import Any + +import joblib +import numpy as np +import torch +from sklearn.preprocessing import StandardScaler +from torch.utils.data import DataLoader, TensorDataset + +from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from drevalpy.models.drp_model import DRPModel +from drevalpy.models.utils import load_and_select_gene_features + +from .paccmann_v2 import PaccMannV2 + + +class PaccMann(DRPModel): + """PaccMann model for drug response prediction. + + This DrEval wrapper combines cell line gene expression features and tokenized SMILES representations of drugs + and uses the PaccMannV2 neural network to predict drug response values. + + This wrapper: + - loads gene expression features for cell lines + - loads SMILES strings for drugs + - tokenizes SMILES into padded integer sequences + - scales gene expression on training data only + - trains a PaccMannV2 PyTorch model + """ + + early_stopping = True + is_single_drug_model = False + + cell_line_views = ["gene_expression"] + drug_views = ["smiles"] + + def __init__(self) -> None: + """Initialize the PaccMann model wrapper. + + Initialized attributes: + model: stores the PaccMann neural network + hyperparameters: stores the passed hyperparameters + device: CPU or GPU device + gene_expression_scaler: scaler fitted on training gene expression + smiles_to_idx: SMILES vocabulary + padding_idx: index used for padding tokens + unk_idx: index for unknown tokens + smiles_padding_length: sequence length used for padding + number_of_genes: number of gene features + """ + super().__init__() + self.model: PaccMannV2 | None = None + self.hyperparameters: dict[str, Any] | None = None + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.gene_expression_scaler = StandardScaler() + + self.smiles_to_idx: dict[str, int] = { + "": 0, + "": 1, + } + self.padding_idx = 0 + self.unk_idx = 1 + + self.smiles_padding_length: int | None = None + self.number_of_genes: int | None = None + + @classmethod + def get_model_name(cls) -> str: + """Return the model name. + + :return: model name + """ + return "PaccMann" + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load gene expression features. + + :param data_path: path to the data directory + :param dataset_name: name of the dataset + :return: FeatureDataset containing gene expression features + """ + return load_and_select_gene_features( + feature_type="gene_expression", + data_path=data_path, + dataset_name=dataset_name, + gene_list="gene_list_paccmann_network_prop_reduced", + ) + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load raw SMILES features. + + :param data_path: path to the data directory + :param dataset_name: name of the dataset + :return: FeatureDataset containing SMILES features + """ + return FeatureDataset.from_csv( + path_to_csv=f"{data_path}/{dataset_name}/drug_smiles.csv", + id_column="pubchem_id", + view_name="smiles", + drop_columns=["drug_name", "cactvs_fingerprint", "fingerprint"], + ) + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """Store hyperparameters for later model initialization. + + The actual PaccMannV2 network is initialized in train(), + because the number of genes depends on the loaded training data. + + :param hyperparameters: dictionary containing model hyperparameters + """ + self.hyperparameters = hyperparameters + + def _normalize_smiles_array(self, smiles_raw: np.ndarray) -> list[str]: + """Convert SMILES output from FeatureDataset into a list of strings. + + :param smiles_raw: raw SMILES array + :return: list of SMILES strings + """ + smiles_raw = np.asarray(smiles_raw, dtype=object) + + if smiles_raw.ndim == 2 and smiles_raw.shape[1] == 1: + smiles_raw = smiles_raw[:, 0] + + smiles_list = [] + for smile in smiles_raw: + if smile is None: + smiles_list.append("") + else: + smiles_list.append(str(smile)) + return smiles_list + + def _build_smiles_vocab(self, smiles_list: list[str]) -> None: + """Build a character-level vocabulary from training SMILES strings. + + :param smiles_list: list of SMILES strings + """ + for smile in smiles_list: # Build vocabulary: "C", "O", "=" ... -> {"C": 2, "O": 3, "=": 4} + for char in smile: + if char not in self.smiles_to_idx: + self.smiles_to_idx[char] = len(self.smiles_to_idx) + + def _encode_smiles(self, smiles_list: list[str]) -> np.ndarray: + """Encode SMILES strings as padded integer sequences. + + :param smiles_list: list of SMILES strings + :return: encoded SMILES array + :raises ValueError: if smiles_padding_length is not set + """ + if self.smiles_padding_length is None: + raise ValueError("smiles_padding_length is not set.") + + encoded = np.full( + (len(smiles_list), self.smiles_padding_length), + fill_value=self.padding_idx, + dtype=np.int64, + ) + + for i, smile in enumerate(smiles_list): + token_ids = [self.smiles_to_idx.get(char, self.unk_idx) for char in smile] # "CCO" -> [2, 2, 3] + token_ids = token_ids[: self.smiles_padding_length] + encoded[i, : len(token_ids)] = token_ids # Padding: [2,2,2] -> [2,2,3,0,0,0,...] + + return encoded + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + model_checkpoint_dir: str | None = None, + ) -> None: + """Train the PaccMann model on gene DrEval data. + + Procedure: + - get gene expression data for all the cell lines + - get raw SMILES for the drugs + - scale gene expression features + - build a SMILES vocabulary + - encode and pad the SMILES strings + - initialize the PaccMann network + - convert both inputs to tensors + - train the network + + :param output: training dataset containing response values, cell line ids, and drug ids + :param cell_line_input: FeatureDataset containing cell line features + :param drug_input: FeatureDataset containing drug features + :param output_earlystopping: optional early stopping dataset + :param model_checkpoint_dir: optional directory to save a model checkpoint + :raises ValueError: if drug_input is None + :raises ValueError: if the model has not been built yet + """ + if drug_input is None: + raise ValueError("drug_input (SMILES) is required for PaccMann.") + + if self.hyperparameters is None: + raise ValueError("Model has not been built yet. Call build_model first.") + + # Retrieve gene expression features for the training cell lines + gex = cell_line_input.get_feature_matrix("gene_expression", output.cell_line_ids) + + # Retrieve raw SMILES features for the corresponding drugs + smiles_raw = drug_input.get_feature_matrix("smiles", output.drug_ids) + + # Target vector containing the drug response values + y = output.response + + # Convert to numpy arrays with explicit dtypes + gex = np.asarray(gex, dtype=np.float32) + y = np.asarray(y, dtype=np.float32) + + # Convert SMILES to a list of strings + smiles = self._normalize_smiles_array(smiles_raw) + + # Scale gene expression on training data only + gex = self.gene_expression_scaler.fit_transform(gex).astype(np.float32) + + # Build SMILES vocabulary from training data only + self.smiles_to_idx = { + "": 0, + "": 1, + } + self._build_smiles_vocab(smiles) + + # Determine SMILES padding length + if "smiles_padding_length" in self.hyperparameters: + self.smiles_padding_length = int(self.hyperparameters["smiles_padding_length"]) + else: + self.smiles_padding_length = max(len(smile) for smile in smiles) + + # Encode and pad SMILES strings + smiles_encoded = self._encode_smiles(smiles) + + # Copy hyperparameters and adapt the number of genes to the training data + model_params = dict(self.hyperparameters) + model_params["number_of_genes"] = gex.shape[1] + model_params["smiles_padding_length"] = self.smiles_padding_length + model_params["smiles_vocabulary_size"] = len(self.smiles_to_idx) + self.number_of_genes = gex.shape[1] + + # Build the PaccMann neural network + self.model = PaccMannV2(model_params).to(self.device) + + # Convert all inputs to tensors + smiles_tensor = torch.tensor(smiles_encoded, dtype=torch.long) + gex_tensor = torch.tensor(gex, dtype=torch.float32) + y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1) + + # Create PyTorch dataset and dataloader + dataset = TensorDataset(smiles_tensor, gex_tensor, y_tensor) + train_loader = DataLoader( + dataset, + batch_size=model_params.get("batch_size", 64), + shuffle=True, + ) + + # Initialize optimizer + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=model_params.get("learning_rate", 1e-3), + weight_decay=model_params.get("weight_decay", 0.0), + ) + + epochs = model_params.get("epochs", 20) + + # Train the model + for _ in range(epochs): + self.model.train() + for batch_smiles, batch_gex, batch_y in train_loader: + batch_smiles = batch_smiles.to(self.device) + batch_gex = batch_gex.to(self.device) + batch_y = batch_y.to(self.device) + + optimizer.zero_grad() + + predictions, _ = self.model(batch_smiles, batch_gex) + loss = self.model.loss(predictions, batch_y) + + loss.backward() + optimizer.step() + + # Optional: save trained model checkpoint + if self.model is not None and model_checkpoint_dir is not None: + self.model.save(f"{model_checkpoint_dir}/paccmann.pt") + + def predict( + self, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + ) -> np.ndarray: + """Predict drug response values. + + Procedure: + - load appropriate cell lines features and drug features + - scale gene expression features + - encode and pad SMILES strings + - convert inputs to tensors + - run the trained PaccMann model in evaluation mode + - return predicted drug response values + + :param cell_line_ids: array of cell line identifiers + :param drug_ids: array of drug identifiers + :param cell_line_input: FeatureDataset containing cell line features + :param drug_input: FeatureDataset containing drug features + :return: predicted drug response values + :raises ValueError: if drug_input is None + :raises ValueError: if the model has not been trained yet + """ + if drug_input is None: + raise ValueError("drug_input (SMILES) is required for PaccMann.") + + if self.model is None: + raise ValueError("Model has not been trained yet.") + + # Retrieve gene expression features + gex = cell_line_input.get_feature_matrix("gene_expression", cell_line_ids) + + # Retrieve raw SMILES features + smiles_raw = drug_input.get_feature_matrix("smiles", drug_ids) + + # Convert gene expression to numpy array + gex = np.asarray(gex, dtype=np.float32) + + # Convert SMILES to a list of strings + smiles = self._normalize_smiles_array(smiles_raw) + + # Apply the fitted gene expression scaler + gex = self.gene_expression_scaler.transform(gex).astype(np.float32) + + # Encode and pad SMILES strings using the training vocabulary + smiles_encoded = self._encode_smiles(smiles) + + # Convert inputs to tensors + smiles_tensor = torch.tensor(smiles_encoded, dtype=torch.long, device=self.device) + gex_tensor = torch.tensor(gex, dtype=torch.float32, device=self.device) + + # Predict drug response values + self.model.eval() + with torch.no_grad(): + predictions, _ = self.model(smiles_tensor, gex_tensor) + + return predictions.cpu().numpy().reshape(-1) + + def save(self, path: str) -> None: + """Save the trained PaccMann wrapper. + + Saved files: + - model.pt: trained model weights + - config.json: model hyperparameters + - scaler.pkl: fitted gene expression scaler + - vocab.json: SMILES vocabulary + - meta.json: additional metadata needed for loading + + :param path: directory where the model should be saved + :raises ValueError: if no model is available + """ + os.makedirs(path, exist_ok=True) + + if self.model is None: + raise ValueError("No model to save.") + + torch.save(self.model.state_dict(), f"{path}/model.pt") + + with open(f"{path}/config.json", "w") as f: + json.dump(self.hyperparameters, f) + + joblib.dump(self.gene_expression_scaler, f"{path}/scaler.pkl") + + with open(f"{path}/vocab.json", "w") as f: + json.dump(self.smiles_to_idx, f) + + with open(f"{path}/meta.json", "w") as f: + json.dump( + { + "padding_length": self.smiles_padding_length, + "num_genes": self.number_of_genes, + }, + f, + ) + + @classmethod + def load(cls, path: str) -> PaccMann: + """Load a trained PaccMann wrapper. + + :param path: directory containing the saved model files + :return: loaded PaccMann instance + """ + instance = cls() + + with open(f"{path}/config.json") as f: + instance.hyperparameters = json.load(f) + + instance.gene_expression_scaler = joblib.load(f"{path}/scaler.pkl") + + with open(f"{path}/vocab.json") as f: + instance.smiles_to_idx = json.load(f) + + with open(f"{path}/meta.json") as f: + meta = json.load(f) + instance.smiles_padding_length = meta["padding_length"] + instance.number_of_genes = meta["num_genes"] + + params = dict(instance.hyperparameters) + params["smiles_padding_length"] = instance.smiles_padding_length + params["smiles_vocabulary_size"] = len(instance.smiles_to_idx) + params["number_of_genes"] = instance.number_of_genes + + instance.model = PaccMannV2(params).to(instance.device) + instance.model.load_state_dict(torch.load(f"{path}/model.pt", map_location=instance.device)) # noqa: S614 + instance.model.eval() + + return instance diff --git a/drevalpy/models/PaccMann/paccmann_v2.py b/drevalpy/models/PaccMann/paccmann_v2.py new file mode 100644 index 00000000..07a7c77e --- /dev/null +++ b/drevalpy/models/PaccMann/paccmann_v2.py @@ -0,0 +1,385 @@ +"""Contains the PaccMannV2 model for drug response prediction. + +The model is based on the PaccMann framework for predicting anticancer drug sensitivity +from SMILES strings and gene expression data. + +Original PaccMann repository: +https://github.com/PaccMann/paccmann_predictor +""" + +import logging +import sys +from collections import OrderedDict + +import pytoda +import torch +import torch.nn as nn +from pytoda.smiles.transforms import AugmentTensor + +from .utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY +from .utils.interpret import monte_carlo_dropout, test_time_augmentation +from .utils.layers import ContextAttentionLayer, convolutional_layer, dense_layer +from .utils.utils import get_device, get_log_molar + +# setup logging +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class PaccMannV2(nn.Module): + """PaccMannV2 model for drug response prediction. + + Based on the MCA model in Molecular Pharmaceutics: + https://pubs.acs.org/doi/10.1021/acs.molpharmaceut.9b00520. + + Main idea: + - SMILES strings are embedded and processed with convolutional layers + - gene expression is used as biological context + - context attention connects gene and drug information + - the combined representation is passed through dense layers + - output is a predicted drug sensitivity value + """ + + def __init__(self, params, *args, **kwargs): + """Initialize the PaccMannV2 model. + + :param params: A dictionary containing the parameter to built the dense encoder. + :param args: additional positional arguments passed to nn.Module + :param kwargs: additional keyword arguments passed to nn.Module + + Items in params: + - smiles_padding_length (int): Padding length for SMILES. + - smiles_embedding_size (int): dimension of tokens' embedding. + - smiles_vocabulary_size (int): size of the tokens vocabulary. + - activation_fn (string, optional): Activation function used in all layers + for specification in ACTIVATION_FN_FACTORY. Defaults to 'relu'. + - batch_norm (bool, optional): Whether batch normalization is applied. + Defaults to True. + - dropout (float, optional): Dropout probability in all except parametric layer. Defaults to 0.5. + - filters (list[int], optional): Numbers of filters to learn per + SMILES convolutional layer. Defaults to [64, 64, 64]. + - kernel_sizes (list[list[int]], optional): Sizes of kernels per SMILES convolutional layer. + Defaults to + [[3, params['smiles_embedding_size']], + [5, params['smiles_embedding_size']], + [11, params['smiles_embedding_size']]] + - molecule_heads (list[int], optional): Amount of attentive molecule_heads + per SMILES embedding. Should have len(filters)+1. + Defaults to [4, 4, 4, 4]. + - stacked_dense_hidden_sizes (list[int], optional): Sizes of the + hidden dense layers. Defaults to [1024, 512]. + - smiles_attention_size (int, optional): size of the attentive layer + for the smiles sequence. Defaults to 64. + + :raises ValueError: if the attention head settings or convolution settings do not match + """ + super().__init__(*args, **kwargs) + + # model parameter + self.device = get_device() + self.params = params + + # select loss function + self.loss_fn = LOSS_FN_FACTORY[params.get("loss_fn", "mse")] + + # scaling information + self.min_max_scaling = True if params.get("drug_sensitivity_processing_parameters", {}) != {} else False + if self.min_max_scaling: + self.IC50_max = params["drug_sensitivity_processing_parameters"]["parameters"]["max"] # yapf: disable + self.IC50_min = params["drug_sensitivity_processing_parameters"]["parameters"]["min"] # yapf: disable + + # input sizes + self.smiles_padding_length = params["smiles_padding_length"] + self.number_of_genes = params.get("number_of_genes", 2128) + + # attention settings + self.smiles_attention_size = params.get("smiles_attention_size", 64) + self.gene_attention_size = params.get("gene_attention_size", 1) + + self.molecule_temperature = params.get("molecule_temperature", 1.0) + self.gene_temperature = params.get("gene_temperature", 1.0) + + # model architecture (hyperparameter) + self.molecule_heads = params.get("molecule_heads", [4, 4, 4, 4]) + self.gene_heads = params.get("gene_heads", [2, 2, 2, 2]) + + if len(self.gene_heads) != len(self.molecule_heads): + raise ValueError("Length of gene and molecule_heads do not match.") + + self.filters = params.get("filters", [64, 64, 64]) + + # size of dense input + self.hidden_sizes = [ + self.molecule_heads[0] * params["smiles_embedding_size"] + + sum([h * f for h, f in zip(self.molecule_heads[1:], self.filters)]) + + sum(self.gene_heads) * self.number_of_genes + ] + params.get("stacked_dense_hidden_sizes", [1024, 512]) + + # general NN settings + self.dropout = params.get("dropout", 0.5) + self.temperature = params.get("temperature", 1.0) + self.act_fn = ACTIVATION_FN_FACTORY[params.get("activation_fn", "relu")] + + # Default convolution kernel size + self.kernel_sizes = params.get( + "kernel_sizes", + [ + [3, params["smiles_embedding_size"]], + [5, params["smiles_embedding_size"]], + [11, params["smiles_embedding_size"]], + ], + ) + if len(self.filters) != len(self.kernel_sizes): + raise ValueError("Length of filter and kernel size lists do not match.") + + if len(self.filters) + 1 != len(self.molecule_heads): + raise ValueError("Length of filter and multihead lists do not match") + + # Build the model + self.smiles_embedding = nn.Embedding( + self.params["smiles_vocabulary_size"], + self.params["smiles_embedding_size"], + scale_grad_by_freq=params.get("embed_scale_grad", False), + ) + + # Convolution layers over embedded SMILES + self.convolutional_layers = nn.Sequential( + OrderedDict( + [ + ( + f"convolutional_{index}", + convolutional_layer( + num_kernel, + kernel_size, + act_fn=self.act_fn, + batch_norm=params.get("batch_norm", False), + dropout=self.dropout, + ).to(self.device), + ) + for index, (num_kernel, kernel_size) in enumerate(zip(self.filters, self.kernel_sizes)) + ] + ) + ) + + # Hidden size of each SMILES representation stage: raw embedding + outputs of conv layer + smiles_hidden_sizes = [params["smiles_embedding_size"]] + self.filters + + # Attention layers: gene context -> SMILES (focus on relevant molecule parts) + self.molecule_attention_layers = nn.Sequential( + OrderedDict( + [ + ( + f"molecule_attention_{layer}_head_{head}", + ContextAttentionLayer( + reference_hidden_size=smiles_hidden_sizes[layer], + reference_sequence_length=self.smiles_padding_length, + context_hidden_size=1, + context_sequence_length=self.number_of_genes, + attention_size=self.smiles_attention_size, + individual_nonlinearity=params.get("context_nonlinearity", nn.Sequential()), + temperature=self.molecule_temperature, + ), + ) + for layer in range(len(self.molecule_heads)) + for head in range(self.molecule_heads[layer]) + ] + ) + ) # yapf: disable + + # Attention layers: SMILES -> gene expression (focus on relevant genes) + self.gene_attention_layers = nn.Sequential( + OrderedDict( + [ + ( + f"gene_attention_{layer}_head_{head}", + ContextAttentionLayer( + reference_hidden_size=1, + reference_sequence_length=self.number_of_genes, + context_hidden_size=smiles_hidden_sizes[layer], + context_sequence_length=self.smiles_padding_length, + attention_size=self.gene_attention_size, + individual_nonlinearity=params.get("context_nonlinearity", nn.Sequential()), + temperature=self.gene_temperature, + ), + ) + for layer in range(len(self.molecule_heads)) + for head in range(self.gene_heads[layer]) + ] + ) + ) # yapf: disable + + # Batch normalization for the concatenated attention output + # Only applied if params['batch_norm'] = True + self.batch_norm = nn.BatchNorm1d(self.hidden_sizes[0]) + + # Dense layers after attention + self.dense_layers = nn.Sequential( + OrderedDict( + [ + ( + f"dense_{ind}", + dense_layer( + self.hidden_sizes[ind], + self.hidden_sizes[ind + 1], + act_fn=self.act_fn, + dropout=self.dropout, + batch_norm=params.get("batch_norm", True), + ).to(self.device), + ) + for ind in range(len(self.hidden_sizes) - 1) + ] + ) + ) + + # Final output layer + self.final_dense = ( + nn.Linear(self.hidden_sizes[-1], 1) + if not params.get("final_activation", False) + else nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(self.hidden_sizes[-1], 1)), + ("sigmoidal", ACTIVATION_FN_FACTORY["sigmoid"]), + ] + ) + ) + ) + + def forward(self, smiles, gep, confidence=False): + """Forward pass through the PaccMannV2. + + :param smiles: tokenized SMILES tensor of shape [bs, smiles_padding_length] + :param gep: gene expression tensor of shape [bs, number_of_genes] + :param confidence: whether confidence estimation should be performed + :return: + - predictions: tensor of shape [batch_size, 1] + - prediction_dict: dictionary with predictions and optional attention/confidence outputs + """ + # reshape gene input + gep = torch.unsqueeze(gep, dim=-1) + embedded_smiles = self.smiles_embedding(smiles.to(dtype=torch.int64)) + + # SMILES Convolutions. Unsqueeze has shape bs x 1 x T x H. + encoded_smiles = [embedded_smiles] + [ + self.convolutional_layers[ind](torch.unsqueeze(embedded_smiles, 1)).permute(0, 2, 1) + for ind in range(len(self.convolutional_layers)) + ] + + # Molecule context attention + encodings, smiles_alphas, gene_alphas = [], [], [] + for layer in range(len(self.molecule_heads)): + for head in range(self.molecule_heads[layer]): + + ind = self.molecule_heads[0] * layer + head + e, a = self.molecule_attention_layers[ind](encoded_smiles[layer], gep) + encodings.append(e) + smiles_alphas.append(a) + + # Gene context attention + for layer in range(len(self.gene_heads)): + for head in range(self.gene_heads[layer]): + ind = self.gene_heads[0] * layer + head + + e, a = self.gene_attention_layers[ind](gep, encoded_smiles[layer], average_seq=False) + encodings.append(e) + gene_alphas.append(a) + + # concat features + encodings = torch.cat(encodings, dim=1) + + # Apply batch normalization if specified + inputs = self.batch_norm(encodings) if self.params.get("batch_norm", False) else encodings + # NOTE: stacking dense layers as a bottleneck + for dl in self.dense_layers: + inputs = dl(inputs) + + # prediction + predictions = self.final_dense(inputs) + prediction_dict = {} + + if not self.training: + # The below is to ease postprocessing + smiles_attention = torch.cat([torch.unsqueeze(p, -1) for p in smiles_alphas], dim=-1) + gene_attention = torch.cat([torch.unsqueeze(p, -1) for p in gene_alphas], dim=-1) + prediction_dict.update( + { + "gene_attention": gene_attention, + "smiles_attention": smiles_attention, + "IC50": predictions, + "log_micromolar_IC50": ( + get_log_molar(predictions, ic50_max=self.IC50_max, ic50_min=self.IC50_min) + if self.min_max_scaling + else predictions + ), + } + ) # yapf: disable + + if confidence: + augmenter = AugmentTensor(self.smiles_language) + epi_conf, epi_pred = monte_carlo_dropout(self, regime="tensors", tensors=(smiles, gep), repetitions=5) + ale_conf, ale_pred = test_time_augmentation( + self, + regime="tensors", + tensors=(smiles, gep), + repetitions=5, + augmenter=augmenter, + tensors_to_augment=0, + ) + + prediction_dict.update( + { + "epistemic_confidence": epi_conf, + "epistemic_predictions": epi_pred, + "aleatoric_confidence": ale_conf, + "aleatoric_predictions": ale_pred, + } + ) # yapf: disable + + elif confidence: + logger.info("Using confidence in training mode is not supported.") + + return predictions, prediction_dict + + def loss(self, yhat, y): + """Compute the loss between predictions and targets. + + :param yhat: predicted values + :param y: true target values + :return: loss value + """ + return self.loss_fn(yhat, y) + + def _associate_language(self, smiles_language): + """Bind a SMILES language object to the model. + + Is only used inside the confidence estimation. + + :param smiles_language: pytoda SMILESLanguage object + :raises TypeError: if the passed object is not a valid SMILESLanguage + """ + if not isinstance(smiles_language, pytoda.smiles.smiles_language.SMILESLanguage): + raise TypeError( + "Please insert a smiles language (object of type " + "pytoda.smiles.smiles_language.SMILESLanguage). Given was " + f"{type(smiles_language)}" + ) + self.smiles_language = smiles_language + + def load(self, path, *args, **kwargs): + """Load model from path. + + :param path: path to the saved model file + :param args: additional positioinal arguments passed to torch.load + :param kwargs: additional keyword arguments passed to torch.load + """ + weights = torch.load(path, *args, **kwargs) # noqa: S614 + self.load_state_dict(weights) + + def save(self, path, *args, **kwargs): + """Save model to path. + + :param path: path where the model should be saved + :param args: additional positional arguments passed to torch.save + :param kwargs: additional keyword arguments passed to torch.save + """ + torch.save(self.state_dict(), path, *args, **kwargs) diff --git a/drevalpy/models/PaccMann/pytoda/__init__.py b/drevalpy/models/PaccMann/pytoda/__init__.py new file mode 100644 index 00000000..0371be41 --- /dev/null +++ b/drevalpy/models/PaccMann/pytoda/__init__.py @@ -0,0 +1,5 @@ +"""Module for pytoda functionality used in the PaccMann model.""" + +from . import smiles + +__all__ = ["smiles"] diff --git a/drevalpy/models/PaccMann/pytoda/smiles/__init__.py b/drevalpy/models/PaccMann/pytoda/smiles/__init__.py new file mode 100644 index 00000000..fa32bb71 --- /dev/null +++ b/drevalpy/models/PaccMann/pytoda/smiles/__init__.py @@ -0,0 +1,5 @@ +"""Module for SMILES handling in the PaccMann model.""" + +from .smiles_language import SMILESLanguage + +__all__ = ["SMILESLanguage"] diff --git a/drevalpy/models/PaccMann/pytoda/smiles/smiles_language.py b/drevalpy/models/PaccMann/pytoda/smiles/smiles_language.py new file mode 100644 index 00000000..873bbc8d --- /dev/null +++ b/drevalpy/models/PaccMann/pytoda/smiles/smiles_language.py @@ -0,0 +1,7 @@ +"""Module for SMILESLanguage used in the PaccMann model.""" + + +class SMILESLanguage: + """SMILES language representation.""" + + pass diff --git a/drevalpy/models/PaccMann/pytoda/smiles/transforms/__init__.py b/drevalpy/models/PaccMann/pytoda/smiles/transforms/__init__.py new file mode 100644 index 00000000..4d737236 --- /dev/null +++ b/drevalpy/models/PaccMann/pytoda/smiles/transforms/__init__.py @@ -0,0 +1,5 @@ +"""Module for SMILES transformations in the PaccMann model.""" + +from .augment import AugmentTensor + +__all__ = ["AugmentTensor"] diff --git a/drevalpy/models/PaccMann/pytoda/smiles/transforms/augment.py b/drevalpy/models/PaccMann/pytoda/smiles/transforms/augment.py new file mode 100644 index 00000000..76403aac --- /dev/null +++ b/drevalpy/models/PaccMann/pytoda/smiles/transforms/augment.py @@ -0,0 +1,20 @@ +"""Module for tensor transformations in the PaccMann model.""" + + +class AugmentTensor: + """Tensor transformation class.""" + + def __init__(self, smiles_language): + """Initialize the transformation. + + :param smiles_language: SMILES language object + """ + self.smiles_language = smiles_language + + def __call__(self, tensor): + """Apply the transformation. + + :param tensor: input tensor + :return: unchanged tensor + """ + return tensor diff --git a/drevalpy/models/PaccMann/utils/__init__.py b/drevalpy/models/PaccMann/utils/__init__.py new file mode 100644 index 00000000..2e0a46f6 --- /dev/null +++ b/drevalpy/models/PaccMann/utils/__init__.py @@ -0,0 +1 @@ +"""Utility modules for PaccMann.""" diff --git a/drevalpy/models/PaccMann/utils/hyperparams.py b/drevalpy/models/PaccMann/utils/hyperparams.py new file mode 100644 index 00000000..8711e8f0 --- /dev/null +++ b/drevalpy/models/PaccMann/utils/hyperparams.py @@ -0,0 +1,42 @@ +"""Customizable model hyperparameters.""" + +import torch.nn as nn +import torch.optim as optim + +from drevalpy.models.PaccMann.utils.loss_functions import ( + correlation_coefficient_loss, + mse_cc_loss, +) + +# LSTM(10, 20, 2) -> input has 10 features, 20 hidden size and 2 layers. +# NOTE: Make sure to set batch_first=True. Optionally set bidirectional=True +RNN_CELL_FACTORY = {"lstm": nn.LSTM, "gru": nn.GRU} + +LOSS_FN_FACTORY = { + "mse": nn.MSELoss(), + "l1": nn.L1Loss(), + "mse_and_pearson": mse_cc_loss, + "pearson": correlation_coefficient_loss, + "binary_cross_entropy": nn.BCELoss(), +} + +ACTIVATION_FN_FACTORY = { + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "selu": nn.SELU(), + "tanh": nn.Tanh(), + "lrelu": nn.LeakyReLU(), + "elu": nn.ELU(), +} +OPTIMIZER_FACTORY = { + "adam": optim.Adam, + "adadelta": optim.Adadelta, + "adagrad": optim.Adagrad, + "gd": optim.SGD, + "sparseadam": optim.SparseAdam, + "adamax": optim.Adamax, + "asgd": optim.ASGD, + "lbfgs": optim.LBFGS, + "rmsprop": optim.RMSprop, + "rprop": optim.Rprop, +} diff --git a/drevalpy/models/PaccMann/utils/interpret.py b/drevalpy/models/PaccMann/utils/interpret.py new file mode 100644 index 00000000..fc60812e --- /dev/null +++ b/drevalpy/models/PaccMann/utils/interpret.py @@ -0,0 +1,217 @@ +"""Utility functions for uncertainty estimation in PaccMann models.""" + +import torch +from torch import Tensor, nn + +from .utils import get_device + +# We use standard deviation to measure uncertainty since entropy is not +# defined for continuous variables and differential entropy is not ideal. +# In case all predictions are identical, std is 0. If 50% are 0 and 50% are +# one, it is maximal, i.e. 0.5. +MAX_STD = 0.5 +MIN_STD = 0.0 + +DEVICE = get_device() + + +def map_to_device(inputs: tuple[Tensor, ...]) -> tuple[Tensor, ...]: + """Move all input tensors to the configured device. + + :param inputs: Tuple of input tensors + :return: Tuple of tensors on the target device + """ + return tuple(x.to(DEVICE) for x in inputs) + + +def monte_carlo_dropout(model, regime="loader", loader=None, tensors=None, repetitions=20): # noqa C901 + """Attempts to approximate epistemic uncertainty through MC dropout. + + Performs Monte Carlo dropout for a given model and returns a list of + sample-wise confidence estimates. + This method can be used in two regimes, either by passing a dataloader + or by passing a tensor with the raw input to the model. + + :param model: Torch model to evaluate + :param regime: Either 'loader' or 'tensors' + :param loader: The dataset to be tested + The loader is expected to return a tuple with the last item + being the labels and all others the model inputs. + Is only used if 'regime'=='loader' + :param tensors: The input tensor(s) for the model + Can either be a single tensor or a tuple of tensors (in the right order) + :param repetitions: Amount of forward passes for each sample + + :return: Tuple (confidences, predictions) where confidences contain the inverse + normalized standard deviation of the MC dropout estimates. + :raises ValueError: If regime is invalid or tensor has an invalid type. + :raises AttributeError: If the loader does not use sequential sampling. + """ + if regime != "loader" and regime != "tensors": + raise ValueError("Choose regime from {'loader', 'tensors'}") + + # Activate dropout layers while keeping other rest in eval mode. + def enable_dropout(m): + if isinstance(m, nn.Dropout): + m.train() + + model.eval() + model.apply(enable_dropout) + + if regime == "loader": + + # Error handling + if not isinstance(loader.sampler, torch.utils.data.sampler.SequentialSampler): + raise AttributeError( + "Data loader does not use sequential sampling. Consider set" + "ting shuffle=False when instantiating the data loader." + ) + + # Run over all batches in the loader + + def call_fn(): + preds = [] + for inputs in loader: + # inputs is a tuple with the last element being the labels + # outs can be a n-tuple returned by the model + outs = model(*map_to_device(inputs[:-1])) + preds.append(outs[0].detach().cpu() if isinstance(outs, tuple) else outs.detach().cpu()) + + return torch.cat(preds) + + elif regime == "tensors": + + if not isinstance(tensors, tuple) and not isinstance(tensors, torch.Tensor): + raise ValueError("Tensor needs to either tuple or torch.Tensor") + + inputs = tensors if isinstance(tensors, tuple) else (tensors,) + + def call_fn(): + outs = model(*map_to_device(inputs)) + return outs[0] if isinstance(outs, tuple) else outs + + with torch.no_grad(): + predictions = [torch.unsqueeze(call_fn(), -1) for _ in range(repetitions)] + predictions = torch.cat(predictions, dim=-1) + + # Scale confidences to [0, 1] + confidences = -1 * ((predictions.std(dim=-1) - MIN_STD) / (MAX_STD - MIN_STD)) + 1 + + model.eval() + + return confidences, torch.mean(predictions, -1) + + +def test_time_augmentation( # noqa: C901 + model, + regime="loader", + loader=None, + tensors=None, + repetitions=20, + augmenter=None, + tensors_to_augment=None, +): + """Attempts to measure aleatoric uncertainty through augmentation during test time. + + It returns a list of sample-wise confidence estimates. + + This method can be used in two regimes, either by passing a dataloader + or by passing a tensor with the raw input to the model. + + :param model: The torch network to be investigated. + :param regime: Either 'loader' or 'tensors' + :param loader: The dataset to be tested + The loader is expected to return a tuple with the last item + being the labels and all others the model inputs. The loader should + natively perform data augmentation. + Is only used if 'regime'=='loader'. + :param tensors: The input tensor(s) for the model + Can either be a single tensor or a tuple of tensors (in the + right order) + :param repetitions: Amount of forward passes for each sample + :param augmenter: This can either be function that performs the augmentation, + e.g. an object of type + pytoda.smiles.AugmentTensor (if `tensors` represents a SMILES + tensor). Alternatively, it can also be a list of augmenters with + the same length like tensors_to_augment. + Only used if regime=='tensors'. + :param tensors_to_augment: This can either be an integer + pointing to the tensor to be augmented. E.g. tensors_to_augment = 0 + augments the first tensor in tensors. Can also be a list of the + same length as augmenter (if several augmentations should be + performed on several tensors simultaneously). + Only used if regime=='tensors'. + + :return: Tuple (confidences, predictions) where confidences contains + inverse normalized standard deviations and predictions contains mean + predictions across repetitions. + :raises ValueError: If regime is invalid, tensor inputs are invalid, + augmentation indices are invalid or the number of augmenters does + not match the number of tensors to augment. + :raises AttributeError: If the loader does not use sequential sampling. + """ + if regime != "loader" and regime != "tensors": + raise ValueError("Choose regime from {'loader', 'tensors'}") + + model.eval() + + if regime == "loader": + + # Error handling + if not isinstance(loader.sampler, torch.utils.data.sampler.SequentialSampler): + raise AttributeError( + "Data loader does not use sequential sampling. Consider set" + "ting shuffle=False when instantiating the data loader." + ) + + # Run over all batches in the loader + + def call_fn(): + preds = [] + for inputs in loader: + # inputs is a tuple with the last element being the labels + # outs can be a n-tuple returned by the model + outs = model(*map_to_device(inputs[:-1])) + preds.append(outs[0] if isinstance(outs, tuple) else outs) + + return torch.cat(preds) + + elif regime == "tensors": + + if not isinstance(tensors, tuple) and not isinstance(tensors, torch.Tensor): + raise ValueError("Tensor needs to either tuple or torch.Tensor") + if not isinstance(tensors_to_augment, list) and not isinstance(tensors_to_augment, int): + raise ValueError("tensors_to_augment needs to be list or int") + + # Convert input to common formats (tuples and lists) + tensors_to_augment = [tensors_to_augment] if isinstance(tensors_to_augment, int) else tensors_to_augment + inputs = tensors if isinstance(tensors, tuple) else (tensors,) + aug_fns = augmenter if isinstance(augmenter, tuple) else (augmenter,) + + # Error handling + if not len(aug_fns) == len(tensors_to_augment): + raise ValueError("Provide one augmenter for each tensor you want to augment.") + if max(tensors_to_augment) > len(inputs): + raise ValueError( + "tensors_to_augment should be indexes to the tensors used for " + f"augmentation. {max(tensors_to_augment)} is larger than " + f"length of inputs ({len(inputs)})." + ) + + def call_fn(): + # Perform augmentation on all designated functions + augmented_inputs = [ + (aug_fns[tensors_to_augment.index(ind)](tensor) if ind in tensors_to_augment else tensor) + for ind, tensor in enumerate(inputs) + ] + outs = model(*map_to_device(augmented_inputs)) + return outs[0] if isinstance(outs, tuple) else outs + + with torch.no_grad(): + predictions = [torch.unsqueeze(call_fn(), -1) for _ in range(repetitions)] + predictions = torch.cat(predictions, dim=-1) + + # Scale confidences to [0, 1] + confidences = -1 * ((predictions.std(dim=-1) - MIN_STD) / (MAX_STD - MIN_STD)) + 1 + + return torch.clamp(confidences, min=0), torch.mean(predictions, -1) diff --git a/drevalpy/models/PaccMann/utils/layers.py b/drevalpy/models/PaccMann/utils/layers.py new file mode 100644 index 00000000..bd37a3d0 --- /dev/null +++ b/drevalpy/models/PaccMann/utils/layers.py @@ -0,0 +1,297 @@ +"""Custom layers implementation.""" + +from collections import OrderedDict + +import torch +import torch.nn as nn + +from .utils import Squeeze, Temperature, Unsqueeze, get_device + +DEVICE = get_device() + + +def dense_layer( + input_size, + hidden_size, + act_fn=None, + batch_norm=False, + dropout=0.0, +): + """Build a dense layer block. + + :param input_size: Input feature size + :param hidden_size: Output feature size + :param act_fn: Activation module + :param batch_norm: whether batch normalization is applied + :param dropout: Dropout probability + :return: Sequential dense layer block + """ + if act_fn is None: + act_fn = nn.ReLU() + + return nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(input_size, hidden_size)), + ( + "batch_norm", + nn.BatchNorm1d(hidden_size) if batch_norm else nn.Identity(), + ), + ("act_fn", act_fn), + ("dropout", nn.Dropout(p=dropout)), + ] + ) + ) + + +def dense_attention_layer(number_of_features: int, temperature: float = 1.0, dropout=0.0) -> nn.Sequential: + """Attention mechanism layer for dense inputs. + + :param number_of_features: size of the feature dimension + :param temperature: softmax temperature parameter + :param dropout: Dropout probability + :return: sequential attention layer + """ + return nn.Sequential( + OrderedDict( + [ + ("dense", nn.Linear(number_of_features, number_of_features)), + ("dropout", nn.Dropout(p=dropout)), + ("temperature", Temperature(temperature)), + ("softmax", nn.Softmax(dim=-1)), + ] + ) + ) + + +def convolutional_layer( + num_kernel, + kernel_size, + act_fn=None, + batch_norm=False, + dropout=0.0, + input_channels=1, +): + """Convolutional layer. + + :param num_kernel: number of convolution kernels + :param kernel_size: size of the convolution kernels + :param act_fn: activation module + :param batch_norm: whether batch normalization is applied + :param dropout: dropout probability + :param input_channels: number of input channels + :return: sequential convolutional layer block + """ + if act_fn is None: + act_fn = nn.ReLU() + + return nn.Sequential( + OrderedDict( + [ + ( + "convolve", + torch.nn.Conv2d( + input_channels, # channel_in + num_kernel, # channel_out + kernel_size, # kernel_size + padding=[kernel_size[0] // 2, 0], # pad for valid conv. + ), + ), + ("squeeze", Squeeze()), + ("act_fn", act_fn), + ("dropout", nn.Dropout(p=dropout)), + ( + "batch_norm", + nn.BatchNorm1d(num_kernel) if batch_norm else nn.Identity(), + ), + ] + ) + ) + + +class ContextAttentionLayer(nn.Module): + """Context attention layer used in the PaccMann architecture. + + It implements context attention as described in the PaccMann paper and + supports an optional hidden size in the context representation. + """ + + def __init__( + self, + reference_hidden_size: int, + reference_sequence_length: int, + context_hidden_size: int, + context_sequence_length: int = 1, + attention_size: int = 16, + individual_nonlinearity=None, + temperature: float = 1.0, + ): + """Initialize the context attention layer. + + :param reference_hidden_size: hidden size of the reference input + :param reference_sequence_length: sequence length of the reference input + :param context_hidden_size: hidden size or feature count of the context + :param context_sequence_length: sequence length of the context + :param attention_size: size of the attention space + :param individual_nonlinearity: optional activation module applied to each projection + :param temperature: temperature used for the softmax + """ + super().__init__() + + if individual_nonlinearity is None: + individual_nonlinearity = nn.Sequential() + + self.reference_sequence_length = reference_sequence_length + self.reference_hidden_size = reference_hidden_size + self.context_sequence_length = context_sequence_length + self.context_hidden_size = context_hidden_size + self.attention_size = attention_size + self.individual_nonlinearity = individual_nonlinearity + self.temperature = temperature + + # Project the reference into the attention space + self.reference_projection = nn.Sequential( + OrderedDict( + [ + ( + "projection", + nn.Linear(reference_hidden_size, attention_size), + ), + ("act_fn", individual_nonlinearity), + ] + ) + ) # yapf: disable + + # Project the context into the attention space + self.context_projection = nn.Sequential( + OrderedDict( + [ + ( + "projection", + nn.Linear(context_hidden_size, attention_size), + ), + ("act_fn", individual_nonlinearity), + ] + ) + ) # yapf: disable + + # Optionally reduce the hidden size in context + if context_sequence_length > 1: + self.context_hidden_projection = nn.Sequential( + OrderedDict( + [ + ( + "projection", + nn.Linear( + context_sequence_length, + reference_sequence_length, + ), + ), + ("act_fn", individual_nonlinearity), + ] + ) + ) # yapf: disable + else: + self.context_hidden_projection = nn.Sequential() + + self.alpha_projection = nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(attention_size, 1, bias=False)), + ("squeeze", Squeeze()), + ("temperature", Temperature(self.temperature)), + ("softmax", nn.Softmax(dim=1)), + ] + ) + ) + + def forward( + self, + reference: torch.Tensor, + context: torch.Tensor, + average_seq: bool = True, + ): + """Forward pass through a context attention layer. + + :param reference: reference tensor of shape 'bs x ref_seq_length x ref_hidden_size' + :param context: context tensor of shape 'bs x context_seq_length x context_hidden_size' + :param average_seq: whether to average over the sequence length + :return: Tuple (output, attention_weights) + :raises ValueError: If reference or context is not 3-dimensional. + """ + if len(reference.shape) != 3: + raise ValueError("Reference tensor needs to be 3D") + + if len(context.shape) != 3: + raise ValueError("Context tensor needs to be 3D") + + reference_attention = self.reference_projection(reference) + context_attention = self.context_hidden_projection(self.context_projection(context).permute(0, 2, 1)).permute( + 0, 2, 1 + ) + alphas = self.alpha_projection(torch.tanh(reference_attention + context_attention)) + + output = reference * torch.unsqueeze(alphas, -1) + output = torch.sum(output, 1) if average_seq else torch.squeeze(output) + + return output, alphas + + +def gene_projection(num_genes, attention_size, ind_nonlin=None): + """Build the gene projection layer. + + :param num_genes: number of gene features + :param attention_size: size of the attention space + :param ind_nonlin: optional activation module + :return: sequential projection module + """ + if ind_nonlin is None: + ind_nonlin = nn.Sequential() + + return nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(num_genes, attention_size)), + ("act_fn", ind_nonlin), + ("expand", Unsqueeze(1)), + ] + ) + ).to(DEVICE) + + +def smiles_projection(smiles_hidden_size, attention_size, ind_nonlin=None): + """Build the SMILES projection layer. + + :param smiles_hidden_size: size of the SMILES hidden representation + :param attention_size: size of the attention space + :param ind_nonlin: optional activation module + :return: sequential projection module + """ + if ind_nonlin is None: + ind_nonlin = nn.Sequential() + + return nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(smiles_hidden_size, attention_size)), + ("act_fn", ind_nonlin), + ] + ) + ).to(DEVICE) + + +def alpha_projection(attention_size): + """Build the alpha projection layer. + + :param attention_size: size of the attention space + :return: sequential alpha projection module + """ + return nn.Sequential( + OrderedDict( + [ + ("projection", nn.Linear(attention_size, 1, bias=False)), + ("squeeze", Squeeze()), + ("softmax", nn.Softmax(dim=1)), + ] + ) + ).to(DEVICE) diff --git a/drevalpy/models/PaccMann/utils/loss_functions.py b/drevalpy/models/PaccMann/utils/loss_functions.py new file mode 100644 index 00000000..45a7e35d --- /dev/null +++ b/drevalpy/models/PaccMann/utils/loss_functions.py @@ -0,0 +1,61 @@ +"""Loss function definitions for PaccMann.""" + +import torch +import torch.nn as nn + + +def pearsonr(x, y): + """Compute Pearson correlation. + + :param x: 1D vector + :param y: 1D vector of the same size as x + :return: Pearson correlation coefficient + :raises TypeError: if inputs are not torch.Tensors + :raises ValueError: if inputs are not 1D, have different lengths, have length < 2 or are constant + """ + if not isinstance(x, torch.Tensor) or not isinstance(y, torch.Tensor): + raise TypeError("Function expects torch Tensors.") + + if len(x.shape) > 1 or len(y.shape) > 1: + raise ValueError("x and y must be 1D Tensors.") + + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + + if len(x) < 2: + raise ValueError("x and y must have length at least 2.") + + # If an input is constant, the correlation coefficient is not defined. + if bool((x == x[0]).all()) or bool((y == y[0]).all()): + raise ValueError("Constant input, r is not defined.") + + mx = x - torch.mean(x) + my = y - torch.mean(y) + cost = torch.sum(mx * my) / (torch.sqrt(torch.sum(mx**2)) * torch.sqrt(torch.sum(my**2))) + return torch.clamp(cost, min=-1.0, max=1.0) + + +def correlation_coefficient_loss(labels, predictions): + """Compute loss based on Pearson correlation. + + :param labels: reference values + :param predictions: predicted values + :return: Loss value defined as 1 - r(labels, predictions)^2 + """ + return 1 - pearsonr(labels, predictions) ** 2 + + +def mse_cc_loss(labels, predictions): + """Compute loss based on MSE and Pearson correlation. + + The main assumption is that MSE lies in [0,1] range, i.e.: range is + comparable with Pearson correlation-based loss. + + :param labels: reference values + :param predictions: predicted values + :return: Loss defined as mse(labels, predictions) + 1 - r(labels, predictions)^2 + """ + mse_loss_fn = nn.MSELoss() + mse_loss = mse_loss_fn(predictions, labels) + cc_loss = correlation_coefficient_loss(labels, predictions) + return mse_loss + cc_loss diff --git a/drevalpy/models/PaccMann/utils/utils.py b/drevalpy/models/PaccMann/utils/utils.py new file mode 100644 index 00000000..d551252e --- /dev/null +++ b/drevalpy/models/PaccMann/utils/utils.py @@ -0,0 +1,106 @@ +"""Utility functions.""" + +import torch +import torch.nn as nn + + +def get_device(): + """Return the active torch device. + + :return: torch.device("cuda") if cuda is available otherwise torch.device("cpu") + """ + return torch.device("cuda" if cuda() else "cpu") + + +def cuda(): + """Check whether cuda is available. + + :return: True if cuda is available otherwise False. + """ + return torch.cuda.is_available() + + +def to_np(x): + """Convert a tensor to a NumPy array. + + :param x: Input tensor + :return: Tensor converted to a NumPy array on the CPU + """ + return x.data.cpu().numpy() + + +def attention_list_to_matrix(coding_tuple, dim=2): + """Convert a list of attention outputs to attention matrices. + + :param coding_tuple: iterable of (outputs, att_weights) tuples coming from the attention function + :param dim: The dimension along which expansion takes place to concatenate the attention weights. + Defaults to 2. + :return: Tuple (raw_coeff, coeff) where 'raw_coeff' contains all + attention weights concatenated along 'dim' and 'coeff' contains + the averaged attention weights. + """ + raw_coeff = torch.cat([torch.unsqueeze(tpl[1], 2) for tpl in coding_tuple], dim=dim) + return raw_coeff, torch.mean(raw_coeff, dim=dim) + + +def get_log_molar(y, ic50_max=None, ic50_min=None): + """Converts PaccMann predictions from [0,1] to log(micromolar) range. + + :param y: predicted values in the normalized range + :param ic50_max: maximum IC50 value used for scaling + :param ic50_min: minimum IC50 value used for scaling + :return: predictions transformed to the log-micromolar range + """ + return y * (ic50_max - ic50_min) + ic50_min + + +class Squeeze(nn.Module): + """Squeeze wrapper for nn.Sequential.""" + + def forward(self, data): + """Squeeze the last dimension of the input tensor. + + :param data: input tensor + :return: squeezed tensor + """ + return torch.squeeze(data, -1) + + +class Unsqueeze(nn.Module): + """Unsqueeze wrapper for nn.Sequential.""" + + def __init__(self, dim): + """Initialize the unsqueeze wrapper. + + :param dim: dimension at which to insert the new axis + """ + super().__init__() + self.dim = dim + + def forward(self, data): + """Unsqueeze the input tensor at the configured dimension. + + :param data: input tensor + :return: tensor with added dimension + """ + return torch.unsqueeze(data, self.dim) + + +class Temperature(nn.Module): + """Temperature wrapper for nn.Sequential.""" + + def __init__(self, temperature): + """Initialize the temperature wrapper. + + :param temperature: Temperature value used for scaling. + """ + super().__init__() + self.temperature = temperature + + def forward(self, data): + """Scale the input tensor by the temperature value. + + :param data: input tensor + :return: scaled tensor + """ + return data / self.temperature diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 036d4df3..b0a43902 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -28,6 +28,7 @@ "KNNRegressor", "AdaBoostDecisionTree", "Lasso", + "PaccMann", ] from .baselines.multi_view_random_forest import MultiViewRandomForest @@ -53,6 +54,7 @@ from .drp_model import DRPModel from .DrugGNN import DrugGNN from .MOLIR.molir import MOLIR +from .PaccMann.paccmann import PaccMann from .PharmaFormer.pharmaformer import PharmaFormerModel from .SimpleNeuralNetwork.multi_view_neural_network import MultiViewNeuralNetwork from .SimpleNeuralNetwork.simple_neural_network import SimpleNeuralNetwork @@ -89,6 +91,7 @@ "KNNRegressor": KNNRegressor, "AdaBoostDecisionTree": AdaBoostDecisionTree, "Lasso": LassoModel, + "PaccMann": PaccMann, } # MODEL_FACTORY is used in the pipeline! diff --git a/tests/models/test_global_models.py b/tests/models/test_global_models.py index b643208a..95cdb81a 100644 --- a/tests/models/test_global_models.py +++ b/tests/models/test_global_models.py @@ -25,6 +25,7 @@ "SimpleNeuralNetwork[chemberta]", "MultiViewNeuralNetwork", "PharmaFormer", + "PaccMann", ], ) def test_global_models(