From b02a9346dee4114627b7152ce43106bcc8eebf52 Mon Sep 17 00:00:00 2001 From: eleded Date: Wed, 8 Apr 2026 20:31:16 +0200 Subject: [PATCH 01/10] add XGDP base --- drevalpy/models/XGDP/__init__.py | 0 drevalpy/models/XGDP/hyperparameters.yaml | 0 drevalpy/models/XGDP/xgdp.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 drevalpy/models/XGDP/__init__.py create mode 100644 drevalpy/models/XGDP/hyperparameters.yaml create mode 100644 drevalpy/models/XGDP/xgdp.py diff --git a/drevalpy/models/XGDP/__init__.py b/drevalpy/models/XGDP/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/drevalpy/models/XGDP/hyperparameters.yaml b/drevalpy/models/XGDP/hyperparameters.yaml new file mode 100644 index 00000000..e69de29b diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py new file mode 100644 index 00000000..e69de29b From 90237b59c35e91dcbd9b23ff09a1617bf3a4e9e1 Mon Sep 17 00:00:00 2001 From: eleded Date: Mon, 13 Apr 2026 17:56:31 +0200 Subject: [PATCH 02/10] Update XGDP --- drevalpy/models/XGDP/__init__.py | 5 ++ drevalpy/models/XGDP/hyperparameters.yaml | 13 +++ drevalpy/models/XGDP/xgdp.py | 104 ++++++++++++++++++++++ 3 files changed, 122 insertions(+) diff --git a/drevalpy/models/XGDP/__init__.py b/drevalpy/models/XGDP/__init__.py index e69de29b..004aba4c 100644 --- a/drevalpy/models/XGDP/__init__.py +++ b/drevalpy/models/XGDP/__init__.py @@ -0,0 +1,5 @@ +"""A GNN based drug response prediction model.""" + +from .xgdp import XGDP + +__all__ = ["XGDP"] \ No newline at end of file diff --git a/drevalpy/models/XGDP/hyperparameters.yaml b/drevalpy/models/XGDP/hyperparameters.yaml index e69de29b..d97a5cd0 100644 --- a/drevalpy/models/XGDP/hyperparameters.yaml +++ b/drevalpy/models/XGDP/hyperparameters.yaml @@ -0,0 +1,13 @@ +XGDP: + learning_rate: + - 0.001 + epochs: + - 100 + batch_size: + - 64 + - 128 + weight_decay: + - 0.0001 + dropout_rate: + - 0.2 + - 0.3 diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index e69de29b..97c954f4 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -0,0 +1,104 @@ +"""XGDP model.""" + +import json +import os +import secrets +from pathlib import Path +from typing import Any, cast + +import numpy as np +import pytorch_lightning as pl +import pandas as pd +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.utils.data import Dataset as PytorchDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GCNConv, global_mean_pool + +from ...datasets.dataset import DrugResponseDataset, FeatureDataset +from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin +from ..utils import load_and_select_gene_features + +from .model_utils import XGDPPredictor + +class XGDP(DRPModel): + """XGDP model for ...""" + + cell_line_views = ["gene_expression"] + drug_views = ["drug_graph"] + early_stopping = True + + def __init__(self) -> None: + """Initialize the XGDP model.""" + super().__init__() + self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model: XGDPPredictor | None = None + self.hyperparameters: dict[str, Any] = {} + self.gene_expression_scaler: StandardScaler | None = None + self.gene_expression_normalizer: MinMaxScaler | None = None + + @classmethod + def get_model_name(cls) -> str: + """ + Get the model name. + + :returns: XGDP + """ + return "XGDP" + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """ + Builds the XGDP model with the specified hyperparameters. + + :param hyperparameters: TODO: ADD HYPERPARAMETERS + """ + + self.hyperparameters = hyperparameters + + self.model = XGDPPredictor( + name_hyperparameter = hyperparameter["name_hyperparameter"] + ) + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the cell line features. + + :param data_path: Path to the gene expression and landmark genes + :param dataset_name: name of the dataset + :return: FeatureDataset containing the cell line gene expression features. + """ + return load_and_select_gene_features( + feature_type="gene_expression", + gene_list="landmark_genes", + data_path=data_path, + dataset_name=dataset_name, + ) + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the pre-computed drug graph data. + + :param data_path: Path to the data directory. + :param dataset_name: Name of the dataset. + :raises FileNotFoundError: If the drug graph directory is not found. + :raises ValueError: If no drug graphs are loaded. + :return: FeatureDataset containing the drug graphs. + """ + graph_path = Path(data_path) / dataset_name / "drug_graphs" + if not graph_path.exists(): + raise FileNotFoundError( + f"Drug graph directory not found at {graph_path}. " + f"Please run 'create_drug_graphs.py' for the {dataset_name} dataset." + ) + + drug_graphs = {} + for p_file in graph_path.glob("*.pt"): + drug_id = p_file.stem + drug_graphs[drug_id] = torch.load(p_file, weights_only=False) + + if not drug_graphs: + raise ValueError(f"No drug graphs loaded from {graph_path}. Check the directory and file contents.") + + feature_dict = {drug_id: {"drug_graph": graph} for drug_id, graph in drug_graphs.items()} + + return FeatureDataset(features=feature_dict) From b0499337cf85d28528e0ca41471da52d81cad7d6 Mon Sep 17 00:00:00 2001 From: Gregor-git1 Date: Tue, 14 Apr 2026 04:20:44 +0200 Subject: [PATCH 03/10] started with class XGDP, added hyperparameters, models and utils --- drevalpy/models/XGDP/models.py | 1274 ++++++++++++++++++++++++++++++++ drevalpy/models/XGDP/utils.py | 0 drevalpy/models/XGDP/xgdp.py | 139 ++++ 3 files changed, 1413 insertions(+) create mode 100644 drevalpy/models/XGDP/models.py create mode 100644 drevalpy/models/XGDP/utils.py diff --git a/drevalpy/models/XGDP/models.py b/drevalpy/models/XGDP/models.py new file mode 100644 index 00000000..45ec80f7 --- /dev/null +++ b/drevalpy/models/XGDP/models.py @@ -0,0 +1,1274 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Sequential, Linear, ReLU +from torch_geometric.nn import GCNConv, GATConv, GATv2Conv, SAGEConv, GINEConv, GINConv, RGATConv, RGCNConv, FiLMConv +from torch_geometric.nn import global_max_pool as gmp +from torch_geometric.nn import global_add_pool + +''' + DeepChem feature set: 78 + ECFP4: 192 + ECFP4 + DeepChem: 270 + ECFP6: 256 + ECFP6 + DeepChem: 334 +''' + +''' +TODO: (already done) + 1. align all models' forward arguments with GATNet (make sure batch is the 3rd one due to gnn_explainer's implementation) + 2. remove x in the return (tuple cannot be accepted by gnn_explainer) + 3. change the output and input in utils_train.py correspondingly +''' + +# change num_features_xd into 78 for ordinary atom features (benchmark) + + +class GCNNet(torch.nn.Module): + def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5, use_attn=False): # qwe + + super(GCNNet, self).__init__() + self.use_attn = use_attn + + # SMILES graph branch + self.n_output = n_output + self.conv1 = GCNConv(num_features_xd, num_features_xd) + self.conv2 = GCNConv(num_features_xd, num_features_xd*2) + self.conv3 = GCNConv(num_features_xd*2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, self.n_output) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + # get graph input + # edge_weight is only used for decoding + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # edge_index = edge_index.long() + + x = self.conv1(x, edge_index, edge_weight) + x = self.relu(x) + x = self.conv2(x, edge_index, edge_weight) + x = self.relu(x) + x = self.conv3(x, edge_index, edge_weight) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # print(x_cell_mut.shape) + + # add this line for CNV data, remove for gene expr data + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GATNet(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(GATNet, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATConv(num_features_xd, num_features_xd, + heads=10, dropout=dropout) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + # graph input feed-forward + # x, edge_index, batch = data.x, data.edge_index, data.batch + # x = self.dropout(x) + # x = F.dropout(x, p=0.2, training=self.training) + x = F.elu(self.gcn1(x, edge_index)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + # return out, x + return out + + +class GATv2Net(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(GATv2Net, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATv2Conv(num_features_xd, num_features_xd, + heads=25, dropout=dropout, edge_dim=4, add_self_loops=False) + self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, + dropout=dropout, edge_dim=4, add_self_loops=False) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + # print(edge_feat.shape) + + # x = F.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = F.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + x = self.dropout(x) + # x = F.dropout(x, p=0.2, training=self.training) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index, edge_attr=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class GATNet_E(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(GATNet_E, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATConv(num_features_xd, num_features_xd, + heads=10, dropout=dropout, edge_dim=4) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, + dropout=dropout, edge_dim=4) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, output_dim) + else: + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + ''' + x: feature matrix of molecular graph + target: gene mutation data + edge_index: edges of molecular graph + batch + edge_feat: edge features of molecular graph + ''' + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + + # x = F.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = F.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index, edge_attr=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class SAGENet(torch.nn.Module): + def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5): # qwe + + super(SAGENet, self).__init__() + + # SMILES graph branch + + # GCNSAGE + self.n_output = n_output + self.conv1 = SAGEConv(num_features_xd, num_features_xd) + self.conv2 = SAGEConv(num_features_xd, num_features_xd*2) + self.conv3 = SAGEConv(num_features_xd*2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(61824, output_dim) + + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + # get graph input + # x, edge_index, batch = data.x, data.edge_index, data.batch + + # GCNSAGE + x = self.conv1(x, edge_index) + x = self.relu(x) + x = self.conv2(x, edge_index) + x = self.relu(x) + x = self.conv3(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GINNet(torch.nn.Module): + def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5): + + super(GINNet, self).__init__() + + dim = 32 + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.n_output = n_output + # convolution layers + nn1 = Sequential(Linear(num_features_xd, dim), + ReLU(), Linear(dim, dim)) + self.conv1 = GINConv(nn1) + self.bn1 = torch.nn.BatchNorm1d(dim) + + nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv2 = GINConv(nn2) + self.bn2 = torch.nn.BatchNorm1d(dim) + + nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv3 = GINConv(nn3) + self.bn3 = torch.nn.BatchNorm1d(dim) + + nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv4 = GINConv(nn4) + self.bn4 = torch.nn.BatchNorm1d(dim) + + nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv5 = GINConv(nn5) + self.bn5 = torch.nn.BatchNorm1d(dim) + + self.fc1_xd = Linear(dim, output_dim) + + # 1D convolution on protein sequence + self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) + self.conv_xt_1 = nn.Conv1d( + in_channels=1000, out_channels=n_filters, kernel_size=8) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(61824, output_dim) + + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + # x, edge_index, batch = data.x, data.edge_index, data.batch + # print(x) + # print(data.target) + x = F.relu(self.conv1(x, edge_index)) + x = self.bn1(x) + x = F.relu(self.conv2(x, edge_index)) + x = self.bn2(x) + x = F.relu(self.conv3(x, edge_index)) + x = self.bn3(x) + x = F.relu(self.conv4(x, edge_index)) + x = self.bn4(x) + x = F.relu(self.conv5(x, edge_index)) + x = self.bn5(x) + x = global_add_pool(x, batch) + x = F.relu(self.fc1_xd(x)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GINENet(torch.nn.Module): + def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5): + + super(GINENet, self).__init__() + + dim = 32 + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.n_output = n_output + # convolution layers + nn1 = Sequential(Linear(num_features_xd, dim), + ReLU(), Linear(dim, dim)) + self.conv1 = GINEConv(nn1, edge_dim=4) + self.bn1 = torch.nn.BatchNorm1d(dim) + + nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv2 = GINEConv(nn2, edge_dim=4) + self.bn2 = torch.nn.BatchNorm1d(dim) + + nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv3 = GINEConv(nn3, edge_dim=4) + self.bn3 = torch.nn.BatchNorm1d(dim) + + nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv4 = GINEConv(nn4, edge_dim=4) + self.bn4 = torch.nn.BatchNorm1d(dim) + + nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv5 = GINEConv(nn5, edge_dim=4) + self.bn5 = torch.nn.BatchNorm1d(dim) + + self.fc1_xd = Linear(dim, output_dim) + + # 1D convolution on protein sequence + self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) + self.conv_xt_1 = nn.Conv1d( + in_channels=1000, out_channels=n_filters, kernel_size=8) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(61824, output_dim) + + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + # x, edge_index, batch = data.x, data.edge_index, data.batch + # print(x) + # print(data.target) + x = F.relu(self.conv1(x, edge_index, edge_attr=edge_feat)) + x = self.bn1(x) + x = F.relu(self.conv2(x, edge_index, edge_attr=edge_feat)) + x = self.bn2(x) + x = F.relu(self.conv3(x, edge_index, edge_attr=edge_feat)) + x = self.bn3(x) + x = F.relu(self.conv4(x, edge_index, edge_attr=edge_feat)) + x = self.bn4(x) + x = F.relu(self.conv5(x, edge_index, edge_attr=edge_feat)) + x = self.bn5(x) + x = global_add_pool(x, batch) + x = F.relu(self.fc1_xd(x)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class RGCNNet(torch.nn.Module): + def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5, use_attn=False): # qwe + + super(RGCNNet, self).__init__() + self.use_attn = use_attn + + # SMILES graph branch + self.n_output = n_output + self.conv1 = RGCNConv( + num_features_xd, num_features_xd, num_relations=4) + self.conv2 = RGCNConv( + num_features_xd, num_features_xd*2, num_relations=4) + self.conv3 = RGCNConv( + num_features_xd*2, num_features_xd * 4, num_relations=4) + self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, output_dim) + else: + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, self.n_output) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + # get graph input + # edge_weight is only used for decoding + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # edge_index = edge_index.long() + edge_feat = edge_feat.long().squeeze() + + x = self.conv1(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = self.conv2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = self.conv3(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # print(x_cell_mut.shape) + + # add this line for CNV data, remove for gene expr data + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class WIRGATNet(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(WIRGATNet, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = RGATConv(num_features_xd, num_features_xd, num_relations=4, + attention_mechanism='within-relation', heads=10, dropout=dropout) + self.gcn2 = RGATConv(num_features_xd * 10, output_dim, num_relations=4, + attention_mechanism='within-relation', dropout=dropout) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + ''' + x: feature matrix of molecular graph + target: gene mutation data + edge_index: edges of molecular graph + batch + edge_feat: edge features of molecular graph + ''' + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + edge_feat = edge_feat.int().squeeze() + # print(edge_feat) + + # x = F.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = F.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class ARGATNet(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(ARGATNet, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = RGATConv(num_features_xd, num_features_xd, num_relations=4, + attention_mechanism='across-relation', heads=10, dropout=dropout) + self.gcn2 = RGATConv(num_features_xd * 10, output_dim, num_relations=4, + attention_mechanism='across-relation', dropout=dropout) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + ''' + x: feature matrix of molecular graph + target: gene mutation data + edge_index: edges of molecular graph + batch + edge_feat: edge features of molecular graph + ''' + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + edge_feat = edge_feat.int().squeeze() + + # x = F.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = F.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class FiLMNet(torch.nn.Module): + def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, + n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): + super(FiLMNet, self).__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = FiLMConv(num_features_xd, num_features_xd, num_relations=4, act=nn.LeakyReLU()) + self.gcn2 = FiLMConv(num_features_xd, output_dim, num_relations=4, act=nn.LeakyReLU()) + + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + self.conv_xt_1 = nn.Conv1d( + in_channels=1, out_channels=n_filters, kernel_size=8) + self.pool_xt_1 = nn.MaxPool1d(3) + self.conv_xt_2 = nn.Conv1d( + in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.pool_xt_2 = nn.MaxPool1d(3) + self.conv_xt_3 = nn.Conv1d( + in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.pool_xt_3 = nn.MaxPool1d(3) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2*output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + ''' + x: feature matrix of molecular graph + target: gene mutation data + edge_index: edges of molecular graph + batch + edge_feat: edge features of molecular graph + ''' + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + edge_feat = edge_feat.int().squeeze() + + # x = F.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + self.gcn1(x, edge_index, edge_type=edge_feat) + # x = F.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + x = self.gcn2(x, edge_index, edge_type=edge_feat) + # x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = F.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + out = self.out(xc) + out = nn.Sigmoid()(out) + + return out \ No newline at end of file diff --git a/drevalpy/models/XGDP/utils.py b/drevalpy/models/XGDP/utils.py new file mode 100644 index 00000000..e69de29b diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index e69de29b..406d9036 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -0,0 +1,139 @@ +import pandas as pd +import numpy as np +import sys +import os +import torch +import torch.nn as nn +import tqdm +import time + +import models +import utils as u + +class XGDP(DRPModel): + #to do + def __init__(self): + """Initialize the DrugGNN model.""" + pass + + @classmethod + def get_model_name(cls) -> str: + """Return the name of the model. + + :return: The name of the model. + """ + return "XGDP" + + @property + def cell_line_views(self) -> list[str]: + """Return the sources the model needs as input for describing the cell line. + + :return: The sources the model needs as input for describing the cell line. + """ + return ["gene_expression"] + + #to do + @property + def drug_views(self) -> list[str]: + """Return the sources the model needs as input for describing the drug. + + :return: The sources the model needs as input for describing the drug. + """ + return [""] + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """Build the model. + + :param hyperparameters: The hyperparameters. + """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + + self.hyperparameters = hyperparameters + # init model in train + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + **kwargs, + ): + """Train the model. + + :param output: The output dataset. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset. + :param output_earlystopping: The early stopping output dataset. + :param kwargs: Additional arguments. + :raises ValueError: If drug input is not provided. + """ + if drug_input is None: + raise ValueError("Drug input is required for XGDP") + + #step1 load data + + #step 2: preprocess data + + #step 4: init model + train parameters + #to do: get feature size from data? -> + with_attention = [] + model_name = self.hyperparameters["model"] + model_class = getattr(models,model_name ) + #is done after cv split + #self.model = model_class() #check hyper parameter for each model + if 'GAT' in model_name: + return_attention_weights = True + else: + return_attention_weights = False + n_epochs = self.hyperparameters["n_epochs"] + lr = self.hyperparameters["lr"] + train_batch_size = self.hyperparameters["train_batch_size"] + test_batch_size = self.hyperparameters["test_batch_size"] + val_batch_size = self.hyperparameters["val_batch_size"] + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + #step 3: build data loaders + split data + test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) + + #step 5: train loop + best_model_id = 0 + best_model = None + best_pearson_cv = 0 + ret_cv = [] + kf = KFold(n_splits=3) #creates splits indices for cross validation + #does the cv split + for i, (train_index, val_index) in enumerate(kf.split(cv_data)): + #splits data into val and train according to index form kf + train_data = Subset(data,train_index) + train_data = Subset(data,val_index) + train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) + val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) + + #creates model for each cv split + model = model_class().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + best_mse = 1000 + best_pearson = 0 + best_epoch = -1 + total_time = 0 + early_stop_tolerance = 30 + train_losses = [] + val_losses = [] + val_pearsons = [] + best_ret = [] + for epoch in tqdm(range(n_epochs)): + start_time = time.time() + train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) + G,P = predicting(model, device, val_loader) + ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] + train_losses.append(train_loss) + val_losses.append(ret[1]) + val_pearsons.append(ret[2]) + + self.model = best_model + #step 6: save mdoel (maybe not neccecary beacuse class) + + + + From d2a9a5a9bf1c16103ecb54cf6d28b0ddcaa9170f Mon Sep 17 00:00:00 2001 From: eleded Date: Fri, 17 Apr 2026 10:50:07 +0200 Subject: [PATCH 04/10] Cleanup --- drevalpy/models/XGDP/__init__.py | 4 +- drevalpy/models/XGDP/hyperparameters.yaml | 12 + drevalpy/models/XGDP/models.py | 879 +++++++++++++++------- drevalpy/models/XGDP/utils.py | 1 + drevalpy/models/XGDP/xgdp.py | 147 ++-- drevalpy/models/__init__.py | 3 + tests/models/test_global_models.py | 1 + tests/test_drp_model.py | 1 + 8 files changed, 688 insertions(+), 360 deletions(-) diff --git a/drevalpy/models/XGDP/__init__.py b/drevalpy/models/XGDP/__init__.py index 004aba4c..808afa23 100644 --- a/drevalpy/models/XGDP/__init__.py +++ b/drevalpy/models/XGDP/__init__.py @@ -1,5 +1,5 @@ -"""A GNN based drug response prediction model.""" +"""A GNN and CNN based drug response prediction model.""" from .xgdp import XGDP -__all__ = ["XGDP"] \ No newline at end of file +__all__ = ["XGDP"] diff --git a/drevalpy/models/XGDP/hyperparameters.yaml b/drevalpy/models/XGDP/hyperparameters.yaml index d97a5cd0..78a28841 100644 --- a/drevalpy/models/XGDP/hyperparameters.yaml +++ b/drevalpy/models/XGDP/hyperparameters.yaml @@ -1,4 +1,16 @@ XGDP: + model_type: + - "GATNet" + - "GCNNet" + - "GATv2Net" + - "GATNet_E" + - "SAGENet" + - "GINNet" + - "GINENet" + - "RGCNNet" + - "WIRGATNet" + - "ARGATNet" + - "FiLMNet" learning_rate: - 0.001 epochs: diff --git a/drevalpy/models/XGDP/models.py b/drevalpy/models/XGDP/models.py index 45ec80f7..3e047b1e 100644 --- a/drevalpy/models/XGDP/models.py +++ b/drevalpy/models/XGDP/models.py @@ -1,54 +1,77 @@ +"""Models for XGDP model.""" + import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn import Sequential, Linear, ReLU -from torch_geometric.nn import GCNConv, GATConv, GATv2Conv, SAGEConv, GINEConv, GINConv, RGATConv, RGCNConv, FiLMConv +import torch.nn.functional as f +from torch.nn import Linear, ReLU, Sequential +from torch_geometric.nn import ( + FiLMConv, + GATConv, + GATv2Conv, + GCNConv, + GINConv, + GINEConv, + RGATConv, + RGCNConv, + SAGEConv, + global_add_pool, +) from torch_geometric.nn import global_max_pool as gmp -from torch_geometric.nn import global_add_pool -''' +""" DeepChem feature set: 78 ECFP4: 192 ECFP4 + DeepChem: 270 ECFP6: 256 ECFP6 + DeepChem: 334 -''' - -''' -TODO: (already done) - 1. align all models' forward arguments with GATNet (make sure batch is the 3rd one due to gnn_explainer's implementation) - 2. remove x in the return (tuple cannot be accepted by gnn_explainer) - 3. change the output and input in utils_train.py correspondingly -''' - -# change num_features_xd into 78 for ordinary atom features (benchmark) +""" class GCNNet(torch.nn.Module): - def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5, use_attn=False): # qwe - - super(GCNNet, self).__init__() + """Standard graph convolutions to capture structural SMILES information.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialization method for GCNNet. + + :param n_output: Number of output units (default: 1) + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # SMILES graph branch self.n_output = n_output self.conv1 = GCNConv(num_features_xd, num_features_xd) - self.conv2 = GCNConv(num_features_xd, num_features_xd*2) - self.conv3 = GCNConv(num_features_xd*2, num_features_xd * 4) - self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.conv2 = GCNConv(num_features_xd, num_features_xd * 2) + self.conv3 = GCNConv(num_features_xd * 2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) self.fc_g2 = torch.nn.Linear(1024, output_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -61,13 +84,24 @@ def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) - else: - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc = nn.Linear(2 * output_dim, 128) + else: + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, self.n_output) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + """ + Forward pass of the GCNNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for GCN) + :param edge_weight: Optional edge weights + :returns: Predicted drug response + """ # get graph input # edge_weight is only used for decoding @@ -80,7 +114,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) x = self.relu(x) x = self.conv3(x, edge_index, edge_weight) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling # flatten x = self.relu(self.fc_g1(x)) @@ -97,13 +131,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -123,8 +157,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) xc = self.fc(xc) xc = self.relu(xc) xc = self.dropout(xc) - else: - # concat + else: + # concat xc = torch.cat((x, xt), 1) # add some dense layers xc = self.fc1(xc) @@ -139,25 +173,45 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) class GATNet(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(GATNet, self).__init__() + """Uses attention to weigh importance of neighboring nodes.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers - self.gcn1 = GATConv(num_features_xd, num_features_xd, - heads=10, dropout=dropout) + self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout) self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -170,9 +224,9 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, n_filter self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) + self.fc = nn.Linear(2 * output_dim, 128) else: - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -181,20 +235,30 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, n_filter self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GATNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ # graph input feed-forward # x, edge_index, batch = data.x, data.edge_index, data.batch # x = self.dropout(x) - # x = F.dropout(x, p=0.2, training=self.training) - x = F.elu(self.gcn1(x, edge_index)) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) + x = f.elu(self.gcn1(x, edge_index)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) if return_attention_weights: - x, attn_weights = self.gcn2( - x, edge_index, return_attention_weights=return_attention_weights) + x, attn_weights = self.gcn2(x, edge_index, return_attention_weights=return_attention_weights) else: x = self.gcn2(x, edge_index) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -203,13 +267,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -250,27 +314,47 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ class GATv2Net(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(GATv2Net, self).__init__() + """More expressive attention mechanism that supports edge features.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATv2Net model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers - self.gcn1 = GATv2Conv(num_features_xd, num_features_xd, - heads=25, dropout=dropout, edge_dim=4, add_self_loops=False) - self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, - dropout=dropout, edge_dim=4, add_self_loops=False) + self.gcn1 = GATv2Conv( + num_features_xd, num_features_xd, heads=25, dropout=dropout, edge_dim=4, add_self_loops=False + ) + self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, dropout=dropout, edge_dim=4, add_self_loops=False) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -283,9 +367,9 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) + self.fc = nn.Linear(2 * output_dim, 128) else: - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -294,23 +378,35 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GATv2Net model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) # print(edge_feat.shape) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) - x = F.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + x = f.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) x = self.dropout(x) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) if return_attention_weights: x, attn_weights = self.gcn2( - x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights) + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights + ) else: x = self.gcn2(x, edge_index, edge_attr=edge_feat) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -319,13 +415,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -365,27 +461,45 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ class GATNet_E(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(GATNet_E, self).__init__() + """A GAT variant explicitly incorporating edge attributes.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATNet_E model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers - self.gcn1 = GATConv(num_features_xd, num_features_xd, - heads=10, dropout=dropout, edge_dim=4) - self.gcn2 = GATConv(num_features_xd * 10, output_dim, - dropout=dropout, edge_dim=4) + self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout, edge_dim=4) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout, edge_dim=4) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -397,10 +511,10 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, output_dim) + self.fc = nn.Linear(2 * output_dim, output_dim) else: # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -409,29 +523,34 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): - ''' - x: feature matrix of molecular graph - target: gene mutation data - edge_index: edges of molecular graph - batch - edge_feat: edge features of molecular graph - ''' + """ + Forward pass of the GATNet_E model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :return: out and attention weights/ out + """ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) - x = F.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) - # x = F.dropout(x, p=0.2, training=self.training) + x = f.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) if return_attention_weights: x, attn_weights = self.gcn2( - x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights) + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights + ) else: x = self.gcn2(x, edge_index, edge_attr=edge_feat) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -440,13 +559,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -486,38 +605,56 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ class SAGENet(torch.nn.Module): - def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5): # qwe - - super(SAGENet, self).__init__() + """Focuses on sampling and aggregating local node neighborhoods.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + ): + """ + Initialize the SAGENet model. + + :param n_output: Number of output units + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + """ + super().__init__() # SMILES graph branch # GCNSAGE self.n_output = n_output self.conv1 = SAGEConv(num_features_xd, num_features_xd) - self.conv2 = SAGEConv(num_features_xd, num_features_xd*2) - self.conv3 = SAGEConv(num_features_xd*2, num_features_xd * 4) - self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.conv2 = SAGEConv(num_features_xd, num_features_xd * 2) + self.conv3 = SAGEConv(num_features_xd * 2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) self.fc_g2 = torch.nn.Linear(1024, output_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) self.fc1_xt = nn.Linear(61824, output_dim) # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -526,6 +663,16 @@ def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + """ + Forward pass of the SAGENet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for SAGEConv) + :returns: Predicted drug response + """ # get graph input # x, edge_index, batch = data.x, data.edge_index, data.batch @@ -547,13 +694,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): # target = data.target # x_cell_mut = x_cell_mut[:,None,:] conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -574,18 +721,37 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): class GINNet(torch.nn.Module): - def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5): - - super(GINNet, self).__init__() + """Structurally powerful architecture for distinguishing complex graph patterns.""" + + def __init__( + self, + n_output=1, + num_features_xd=334, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + ): + """ + Initialize the GINNet model. + + :param n_output: Number of output units + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + """ + super().__init__() dim = 32 self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() self.n_output = n_output # convolution layers - nn1 = Sequential(Linear(num_features_xd, dim), - ReLU(), Linear(dim, dim)) + nn1 = Sequential(Linear(num_features_xd, dim), ReLU(), Linear(dim, dim)) self.conv1 = GINConv(nn1) self.bn1 = torch.nn.BatchNorm1d(dim) @@ -609,25 +775,21 @@ def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d( - in_channels=1000, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) self.fc1_xt = nn.Linear(61824, output_dim) # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -636,22 +798,32 @@ def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + """ + Forward pass of the GINNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for GINConv) + :returns: Predicted drug response + """ # x, edge_index, batch = data.x, data.edge_index, data.batch # print(x) # print(data.target) - x = F.relu(self.conv1(x, edge_index)) + x = f.relu(self.conv1(x, edge_index)) x = self.bn1(x) - x = F.relu(self.conv2(x, edge_index)) + x = f.relu(self.conv2(x, edge_index)) x = self.bn2(x) - x = F.relu(self.conv3(x, edge_index)) + x = f.relu(self.conv3(x, edge_index)) x = self.bn3(x) - x = F.relu(self.conv4(x, edge_index)) + x = f.relu(self.conv4(x, edge_index)) x = self.bn4(x) - x = F.relu(self.conv5(x, edge_index)) + x = f.relu(self.conv5(x, edge_index)) x = self.bn5(x) x = global_add_pool(x, batch) - x = F.relu(self.fc1_xd(x)) - # x = F.dropout(x, p=0.2, training=self.training) + x = f.relu(self.fc1_xd(x)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) # protein input feed-forward: @@ -660,13 +832,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -688,18 +860,37 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): class GINENet(torch.nn.Module): - def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5): - - super(GINENet, self).__init__() + """Combines GIN's structural power with edge feature integration.""" + + def __init__( + self, + n_output=1, + num_features_xd=334, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + ): + """ + Initialize the GINENet model. + + :param n_output: Number of output units + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + """ + super().__init__() dim = 32 self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() self.n_output = n_output # convolution layers - nn1 = Sequential(Linear(num_features_xd, dim), - ReLU(), Linear(dim, dim)) + nn1 = Sequential(Linear(num_features_xd, dim), ReLU(), Linear(dim, dim)) self.conv1 = GINEConv(nn1, edge_dim=4) self.bn1 = torch.nn.BatchNorm1d(dim) @@ -723,25 +914,21 @@ def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d( - in_channels=1000, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) self.fc1_xt = nn.Linear(61824, output_dim) # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -750,22 +937,32 @@ def __init__(self, n_output=1, num_features_xd=334, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + """ + Forward pass of the GINENet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :returns: Predicted drug response + """ # x, edge_index, batch = data.x, data.edge_index, data.batch # print(x) # print(data.target) - x = F.relu(self.conv1(x, edge_index, edge_attr=edge_feat)) + x = f.relu(self.conv1(x, edge_index, edge_attr=edge_feat)) x = self.bn1(x) - x = F.relu(self.conv2(x, edge_index, edge_attr=edge_feat)) + x = f.relu(self.conv2(x, edge_index, edge_attr=edge_feat)) x = self.bn2(x) - x = F.relu(self.conv3(x, edge_index, edge_attr=edge_feat)) + x = f.relu(self.conv3(x, edge_index, edge_attr=edge_feat)) x = self.bn3(x) - x = F.relu(self.conv4(x, edge_index, edge_attr=edge_feat)) + x = f.relu(self.conv4(x, edge_index, edge_attr=edge_feat)) x = self.bn4(x) - x = F.relu(self.conv5(x, edge_index, edge_attr=edge_feat)) + x = f.relu(self.conv5(x, edge_index, edge_attr=edge_feat)) x = self.bn5(x) x = global_add_pool(x, batch) - x = F.relu(self.fc1_xd(x)) - # x = F.dropout(x, p=0.2, training=self.training) + x = f.relu(self.fc1_xd(x)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) # protein input feed-forward: @@ -774,13 +971,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -802,33 +999,50 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): class RGCNNet(torch.nn.Module): - def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, num_features_xt=25, output_dim=128, dropout=0.5, use_attn=False): # qwe - - super(RGCNNet, self).__init__() + """Uses relation-specific weights for multi-relational drug graphs.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the RGCNNet model. + + :param n_output: Number of output units + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # SMILES graph branch self.n_output = n_output - self.conv1 = RGCNConv( - num_features_xd, num_features_xd, num_relations=4) - self.conv2 = RGCNConv( - num_features_xd, num_features_xd*2, num_relations=4) - self.conv3 = RGCNConv( - num_features_xd*2, num_features_xd * 4, num_relations=4) - self.fc_g1 = torch.nn.Linear(num_features_xd*4, 1024) + self.conv1 = RGCNConv(num_features_xd, num_features_xd, num_relations=4) + self.conv2 = RGCNConv(num_features_xd, num_features_xd * 2, num_relations=4) + self.conv3 = RGCNConv(num_features_xd * 2, num_features_xd * 4, num_relations=4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) self.fc_g2 = torch.nn.Linear(1024, output_dim) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -839,14 +1053,25 @@ def __init__(self, n_output=1, n_filters=32, embed_dim=128, num_features_xd=334, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, output_dim) + self.fc = nn.Linear(2 * output_dim, output_dim) else: # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, self.n_output) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + """ + Forward pass of the RGCNNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge type indices for relational graph convolution + :param edge_weight: Optional edge weights + :returns: Predicted drug response + """ # get graph input # edge_weight is only used for decoding @@ -860,7 +1085,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) x = self.relu(x) x = self.conv3(x, edge_index, edge_type=edge_feat) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling # flatten x = self.relu(self.fc_g1(x)) @@ -877,13 +1102,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -919,27 +1144,54 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) class WIRGATNet(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(WIRGATNet, self).__init__() + """Relational attention focused on interactions within the same relation.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the WIRGATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers - self.gcn1 = RGATConv(num_features_xd, num_features_xd, num_relations=4, - attention_mechanism='within-relation', heads=10, dropout=dropout) - self.gcn2 = RGATConv(num_features_xd * 10, output_dim, num_relations=4, - attention_mechanism='within-relation', dropout=dropout) + self.gcn1 = RGATConv( + num_features_xd, + num_features_xd, + num_relations=4, + attention_mechanism="within-relation", + heads=10, + dropout=dropout, + ) + self.gcn2 = RGATConv( + num_features_xd * 10, output_dim, num_relations=4, attention_mechanism="within-relation", dropout=dropout + ) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -950,10 +1202,10 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) + self.fc = nn.Linear(2 * output_dim, 128) else: # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -962,31 +1214,36 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): - ''' - x: feature matrix of molecular graph - target: gene mutation data - edge_index: edges of molecular graph - batch - edge_feat: edge features of molecular graph - ''' + """ + Forward pass of the WIRGATNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) edge_feat = edge_feat.int().squeeze() # print(edge_feat) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) - x = F.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) - # x = F.dropout(x, p=0.2, training=self.training) + x = f.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) if return_attention_weights: x, attn_weights = self.gcn2( - x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights) + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights + ) else: x = self.gcn2(x, edge_index, edge_type=edge_feat) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -995,13 +1252,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -1041,27 +1298,54 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ class ARGATNet(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(ARGATNet, self).__init__() + """Relational attention designed to process features across different relations.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the ARGATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers - self.gcn1 = RGATConv(num_features_xd, num_features_xd, num_relations=4, - attention_mechanism='across-relation', heads=10, dropout=dropout) - self.gcn2 = RGATConv(num_features_xd * 10, output_dim, num_relations=4, - attention_mechanism='across-relation', dropout=dropout) + self.gcn1 = RGATConv( + num_features_xd, + num_features_xd, + num_relations=4, + attention_mechanism="across-relation", + heads=10, + dropout=dropout, + ) + self.gcn2 = RGATConv( + num_features_xd * 10, output_dim, num_relations=4, attention_mechanism="across-relation", dropout=dropout + ) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -1072,10 +1356,10 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) + self.fc = nn.Linear(2 * output_dim, 128) else: # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -1084,30 +1368,35 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): - ''' - x: feature matrix of molecular graph - target: gene mutation data - edge_index: edges of molecular graph - batch - edge_feat: edge features of molecular graph - ''' + """ + Forward pass of the ARGATNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) edge_feat = edge_feat.int().squeeze() - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) - x = F.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) - # x = F.dropout(x, p=0.2, training=self.training) + x = f.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) if return_attention_weights: x, attn_weights = self.gcn2( - x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights) + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights + ) else: x = self.gcn2(x, edge_index, edge_type=edge_feat) x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -1116,13 +1405,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -1162,9 +1451,32 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ class FiLMNet(torch.nn.Module): - def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, - n_filters=32, embed_dim=128, output_dim=128, dropout=0.5, use_attn=False): - super(FiLMNet, self).__init__() + """Adaptively modulates features based on specific graph relations.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the FiLMNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() self.use_attn = use_attn # graph layers @@ -1174,14 +1486,11 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d( - in_channels=1, out_channels=n_filters, kernel_size=8) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d( - in_channels=n_filters, out_channels=n_filters*2, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d( - in_channels=n_filters*2, out_channels=n_filters*4, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) self.pool_xt_3 = nn.MaxPool1d(3) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) @@ -1192,10 +1501,10 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) - self.fc = nn.Linear(2*output_dim, 128) + self.fc = nn.Linear(2 * output_dim, 128) else: # combined layers - self.fc1 = nn.Linear(2*output_dim, 1024) + self.fc1 = nn.Linear(2 * output_dim, 1024) self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, n_output) @@ -1204,26 +1513,30 @@ def __init__(self, num_features_xd=334, n_output=1, num_features_xt=25, self.dropout = nn.Dropout(0.5) def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): - ''' - x: feature matrix of molecular graph - target: gene mutation data - edge_index: edges of molecular graph - batch - edge_feat: edge features of molecular graph - ''' + """ + Forward pass of the FiLMNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge type indices for FiLM modulation + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) edge_feat = edge_feat.int().squeeze() - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) self.gcn1(x, edge_index, edge_type=edge_feat) - # x = F.dropout(x, p=0.2, training=self.training) + # x = f.dropout(x, p=0.2, training=self.training) x = self.dropout(x) x = self.gcn2(x, edge_index, edge_type=edge_feat) # x = self.relu(x) - x = gmp(x, batch) # global max pooling + x = gmp(x, batch) # global max pooling x = self.fc_g1(x) x = self.relu(x) @@ -1232,13 +1545,13 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) conv_xt = self.conv_xt_2(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_2(conv_xt) conv_xt = self.conv_xt_3(conv_xt) - conv_xt = F.relu(conv_xt) + conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_3(conv_xt) # flatten @@ -1271,4 +1584,4 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ out = self.out(xc) out = nn.Sigmoid()(out) - return out \ No newline at end of file + return out diff --git a/drevalpy/models/XGDP/utils.py b/drevalpy/models/XGDP/utils.py index e69de29b..c5fcd672 100644 --- a/drevalpy/models/XGDP/utils.py +++ b/drevalpy/models/XGDP/utils.py @@ -0,0 +1 @@ +"""Utility functions for the XGDP model.""" diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index ecca1336..3704518a 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -1,16 +1,28 @@ -"""XGDP model.""" +"""Contains XGDP, a GNN and CNN based drug response prediction model. + +Drug discovery and mechanism prediction with explainable graph neural networks. + +Original authors: Wang, C., Kumar, G.A. & Rajapakse, J.C. (2025, 10.1038/s41598-024-83090-3) +Code adapted from their Github: https://github.com/SCSE-Biomedical-Computing-Group/XGDP/blob/main/utils_tcnn.py +""" import json import os import secrets +import sys +import time from pathlib import Path -from typing import Any, cast +from typing import Any +import models import numpy as np -import pytorch_lightning as pl import pandas as pd +import pytorch_lightning as pl import torch import torch.nn as nn +import tqdm +import utils as u +from sklearn.preprocessing import MinMaxScaler, StandardScaler from torch.optim import Adam from torch.utils.data import Dataset as PytorchDataset from torch_geometric.loader import DataLoader @@ -20,25 +32,14 @@ from ..drp_model import DRPModel from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features +from .utils import XGDPPredictor, predict_model, train_epoch -from .model_utils import XGDPPredictor - -import pandas as pd -import numpy as np -import sys -import torch.nn as nn -import tqdm -import time - -import models -import utils as u -class XGDP(DRPModel): +class XGDP(DRPModel, RegressionMetricsMixin): """XGDP model for ...""" cell_line_views = ["gene_expression"] drug_views = ["drug_graph"] - early_stopping = True def __init__(self) -> None: """Initialize the XGDP model.""" @@ -57,7 +58,7 @@ def get_model_name(cls) -> str: :returns: XGDP """ return "XGDP" - + def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the XGDP model with the specified hyperparameters. @@ -66,17 +67,16 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: """ self.hyperparameters = hyperparameters + model_name = hyperparameters.get("model_type", "GATNet") - self.model = XGDPPredictor( - name_hyperparameter = hyperparameter["name_hyperparameter"] - ) - ''' + self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) + """ # Log hyperparameters to wandb if enabled self.log_hyperparameters(hyperparameters) self.hyperparameters = hyperparameters # init model in train - ''' + """ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """Loads the cell line features. @@ -91,7 +91,7 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD data_path=data_path, dataset_name=dataset_name, ) - + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """Loads the pre-computed drug graph data. @@ -119,7 +119,7 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase feature_dict = {drug_id: {"drug_graph": graph} for drug_id, graph in drug_graphs.items()} return FeatureDataset(features=feature_dict) - + def train( self, output: DrugResponseDataset, @@ -139,19 +139,20 @@ def train( """ if drug_input is None: raise ValueError("Drug input is required for XGDP") - - #step1 load data - #step 2: preprocess data + # step1 load data + train_loader = self._prepare_dataloader(output, cell_line_input, drug_input, is_train=True) - #step 4: init model + train parameters - #to do: get feature size from data? -> + # step 2: preprocess data + + # step 4: init model + train parameters + # to do: get feature size from data? -> with_attention = [] model_name = self.hyperparameters["model"] - model_class = getattr(models,model_name ) - #is done after cv split - #self.model = model_class() #check hyper parameter for each model - if 'GAT' in model_name: + model_class = getattr(models, model_name) + # is done after cv split + # self.model = model_class() #check hyper parameter for each model + if "GAT" in model_name: return_attention_weights = True else: return_attention_weights = False @@ -161,24 +162,24 @@ def train( test_batch_size = self.hyperparameters["test_batch_size"] val_batch_size = self.hyperparameters["val_batch_size"] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - #step 3: build data loaders + split data + # step 3: build data loaders + split data test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - - #step 5: train loop + + # step 5: train loop best_model_id = 0 best_model = None best_pearson_cv = 0 ret_cv = [] - kf = KFold(n_splits=3) #creates splits indices for cross validation - #does the cv split + kf = KFold(n_splits=3) # creates splits indices for cross validation + # does the cv split for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - #splits data into val and train according to index form kf - train_data = Subset(data,train_index) - train_data = Subset(data,val_index) + # splits data into val and train according to index form kf + train_data = Subset(data, train_index) + train_data = Subset(data, val_index) train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - #creates model for each cv split + # creates model for each cv split model = model_class().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) best_mse = 1000 @@ -192,19 +193,19 @@ def train( best_ret = [] for epoch in tqdm(range(n_epochs)): start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) - G,P = predicting(model, device, val_loader) - ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] + train_loss = train(model, device, train_loader, optimizer, epoch + 1, log_interval) + G, P = predicting(model, device, val_loader) + ret = [u.rmse(G, P), u.mse(G, P), u.pearson(G, P), u.spearman(G, P), coeffi_determ(G, P)] train_losses.append(train_loss) val_losses.append(ret[1]) val_pearsons.append(ret[2]) self.model = best_model - #step 6: save mdoel (maybe not neccecary beacuse class) + # step 6: save mdoel (maybe not neccecary beacuse class) class XGDP(DRPModel): - #to do + # to do def __init__(self): """Initialize the DrugGNN model.""" pass @@ -224,8 +225,8 @@ def cell_line_views(self) -> list[str]: :return: The sources the model needs as input for describing the cell line. """ return ["gene_expression"] - - #to do + + # to do @property def drug_views(self) -> list[str]: """Return the sources the model needs as input for describing the drug. @@ -264,19 +265,19 @@ def train( """ if drug_input is None: raise ValueError("Drug input is required for XGDP") - - #step1 load data - #step 2: preprocess data + # step1 load data + + # step 2: preprocess data - #step 4: init model + train parameters - #to do: get feature size from data? -> + # step 4: init model + train parameters + # to do: get feature size from data? -> with_attention = [] model_name = self.hyperparameters["model"] - model_class = getattr(models,model_name ) - #is done after cv split - #self.model = model_class() #check hyper parameter for each model - if 'GAT' in model_name: + model_class = getattr(models, model_name) + # is done after cv split + # self.model = model_class() #check hyper parameter for each model + if "GAT" in model_name: return_attention_weights = True else: return_attention_weights = False @@ -286,24 +287,24 @@ def train( test_batch_size = self.hyperparameters["test_batch_size"] val_batch_size = self.hyperparameters["val_batch_size"] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - #step 3: build data loaders + split data + # step 3: build data loaders + split data test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - - #step 5: train loop + + # step 5: train loop best_model_id = 0 best_model = None best_pearson_cv = 0 ret_cv = [] - kf = KFold(n_splits=3) #creates splits indices for cross validation - #does the cv split + kf = KFold(n_splits=3) # creates splits indices for cross validation + # does the cv split for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - #splits data into val and train according to index form kf - train_data = Subset(data,train_index) - train_data = Subset(data,val_index) + # splits data into val and train according to index form kf + train_data = Subset(data, train_index) + train_data = Subset(data, val_index) train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - #creates model for each cv split + # creates model for each cv split model = model_class().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) best_mse = 1000 @@ -317,16 +318,12 @@ def train( best_ret = [] for epoch in tqdm(range(n_epochs)): start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) - G,P = predicting(model, device, val_loader) - ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] + train_loss = train(model, device, train_loader, optimizer, epoch + 1, log_interval) + G, P = predicting(model, device, val_loader) + ret = [u.rmse(G, P), u.mse(G, P), u.pearson(G, P), u.spearman(G, P), coeffi_determ(G, P)] train_losses.append(train_loss) val_losses.append(ret[1]) val_pearsons.append(ret[2]) self.model = best_model - #step 6: save mdoel (maybe not neccecary beacuse class) - - - - + # step 6: save mdoel (maybe not neccecary beacuse class) diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 5ecf2e4f..b6f22e29 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -30,6 +30,7 @@ "DrugGNN", "ChemBERTaNeuralNetwork", "PharmaFormerModel", + "XGDP", ] from .baselines.multi_omics_random_forest import MultiOmicsRandomForest @@ -60,6 +61,7 @@ from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, SimpleNeuralNetwork from .SRMF.srmf import SRMF from .SuperFELTR.superfeltr import SuperFELTR +from .XGDP.xgdp import XGDP # SINGLE_DRUG_MODEL_FACTORY is used in the pipeline! SINGLE_DRUG_MODEL_FACTORY: dict[str, type[DRPModel]] = { @@ -93,6 +95,7 @@ "DrugGNN": DrugGNN, "ChemBERTaNeuralNetwork": ChemBERTaNeuralNetwork, "PharmaFormer": PharmaFormerModel, + "XGDP": XGDP, } # MODEL_FACTORY is used in the pipeline! diff --git a/tests/models/test_global_models.py b/tests/models/test_global_models.py index 69e20b2b..be9319bc 100644 --- a/tests/models/test_global_models.py +++ b/tests/models/test_global_models.py @@ -25,6 +25,7 @@ "SimpleNeuralNetwork", "MultiOmicsNeuralNetwork", "PharmaFormer", + "XGDP", ], ) def test_global_models( diff --git a/tests/test_drp_model.py b/tests/test_drp_model.py index 82c035f2..390947aa 100644 --- a/tests/test_drp_model.py +++ b/tests/test_drp_model.py @@ -43,6 +43,7 @@ def test_factory() -> None: assert "MOLIR" in MODEL_FACTORY assert "SuperFELTR" in MODEL_FACTORY assert "DIPK" in MODEL_FACTORY + assert XGDP in MODEL_FACTORY def test_load_cl_ids_from_csv() -> None: From f6193829e5bef78b5e60bf5fc3e2a2280a1e70ac Mon Sep 17 00:00:00 2001 From: Gregor-git1 Date: Mon, 20 Apr 2026 15:12:20 +0200 Subject: [PATCH 05/10] test1 --- drevalpy/models/XGDP/xgdp.py | 132 ++++++++++++++++------------------- 1 file changed, 61 insertions(+), 71 deletions(-) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index ecca1336..0ffd7036 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -11,6 +11,7 @@ import pandas as pd import torch import torch.nn as nn +import torch_geometric from torch.optim import Adam from torch.utils.data import Dataset as PytorchDataset from torch_geometric.loader import DataLoader @@ -163,41 +164,27 @@ def train( device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #step 3: build data loaders + split data test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - + train_data = Subset(data,train_index) + train_data = Subset(data,val_index) + train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) + val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) #step 5: train loop - best_model_id = 0 - best_model = None - best_pearson_cv = 0 - ret_cv = [] - kf = KFold(n_splits=3) #creates splits indices for cross validation - #does the cv split - for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - #splits data into val and train according to index form kf - train_data = Subset(data,train_index) - train_data = Subset(data,val_index) - train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) - val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - - #creates model for each cv split - model = model_class().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - best_mse = 1000 - best_pearson = 0 - best_epoch = -1 - total_time = 0 - early_stop_tolerance = 30 - train_losses = [] - val_losses = [] - val_pearsons = [] - best_ret = [] - for epoch in tqdm(range(n_epochs)): - start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) - G,P = predicting(model, device, val_loader) - ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] - train_losses.append(train_loss) - val_losses.append(ret[1]) - val_pearsons.append(ret[2]) + + #creates model for each cv split + self.model = model_class().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + train_losses = [] + val_losses = [] + val_pearsons = [] + best_ret = [] + for epoch in tqdm(range(n_epochs)): + start_time = time.time() + train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) + G,P = predicting(model, device, val_loader) + ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] + train_losses.append(train_loss) + val_losses.append(ret[1]) + val_pearsons.append(ret[2]) self.model = best_model #step 6: save mdoel (maybe not neccecary beacuse class) @@ -268,14 +255,18 @@ def train( #step1 load data #step 2: preprocess data + #to do + train_data = None + test_data = None + val_data = None #step 4: init model + train parameters #to do: get feature size from data? -> with_attention = [] model_name = self.hyperparameters["model"] model_class = getattr(models,model_name ) - #is done after cv split - #self.model = model_class() #check hyper parameter for each model + self.model = model_class().to(device) + if 'GAT' in model_name: return_attention_weights = True else: @@ -288,43 +279,42 @@ def train( device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #step 3: build data loaders + split data test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) + train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) + val_loader = DataLoader(val_data, batch_size=val_batch_size, shuffle=False) #step 5: train loop - best_model_id = 0 - best_model = None - best_pearson_cv = 0 - ret_cv = [] - kf = KFold(n_splits=3) #creates splits indices for cross validation - #does the cv split - for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - #splits data into val and train according to index form kf - train_data = Subset(data,train_index) - train_data = Subset(data,val_index) - train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) - val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - - #creates model for each cv split - model = model_class().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - best_mse = 1000 - best_pearson = 0 - best_epoch = -1 - total_time = 0 - early_stop_tolerance = 30 - train_losses = [] - val_losses = [] - val_pearsons = [] - best_ret = [] - for epoch in tqdm(range(n_epochs)): - start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) - G,P = predicting(model, device, val_loader) - ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),coeffi_determ(G,P)] - train_losses.append(train_loss) - val_losses.append(ret[1]) - val_pearsons.append(ret[2]) - - self.model = best_model + self.model.train() #set model to train + loss_fn = nn.MSELoss() + optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) + #metrics for checking how training is going + best_mse = 1000 + best_pearson = 0 + best_epoch = -1 + total_time = 0 + early_stop_tolerance = 30 + train_losses = [] + val_losses = [] + val_pearsons = [] + best_ret = [] + + for epoch in n_epochs: + #print what epoch is trained + start_time = time.time() + print(f"epoch : {epoch+1}/{n_epochs} ") + # + avg_loss = [] + for data in tqdm(train_loader): + data = data.to(device) + optimizer.zero_grad() + # output, _ = model(data) + x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index.long(), data.batch, data.edge_features + # output, _ = model(x, edge_index, x_cell_mut, batch_drug, edge_feat) + output = self.model(x, edge_index, batch_drug, x_cell_mut, edge_feat) + loss = loss_fn(output, data.y.view(-1, 1).float().to(device)) + loss.backward() + optimizer.step() + avg_loss.append(loss.item()) + train_loss = sum(avg_loss)/len(avg_loss) #step 6: save mdoel (maybe not neccecary beacuse class) From e75c7328df8d1d5e412f452afab8399c1e7661da Mon Sep 17 00:00:00 2001 From: Gregor-git1 Date: Tue, 21 Apr 2026 03:54:21 +0200 Subject: [PATCH 06/10] worked on predict and train, modified build_model --- drevalpy/models/XGDP/xgdp.py | 258 +++++++++++++---------------------- 1 file changed, 96 insertions(+), 162 deletions(-) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index 3704518a..07966d5f 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -23,6 +23,9 @@ import tqdm import utils as u from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sklearn.model_selection import KFold +from sklearn.model_selection import train_test_split +from torch.utils.data.dataset import Subset from torch.optim import Adam from torch.utils.data import Dataset as PytorchDataset from torch_geometric.loader import DataLoader @@ -68,8 +71,8 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: self.hyperparameters = hyperparameters model_name = hyperparameters.get("model_type", "GATNet") - - self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) + #init in train + #self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) """ # Log hyperparameters to wandb if enabled self.log_hyperparameters(hyperparameters) @@ -145,185 +148,116 @@ def train( # step 2: preprocess data + # step 3: build data loaders + split data + #to do create split indices + test_data = Subset(data, test_index) + train_data = Subset(data, train_index) + val_data = Subset(data, val_index) + test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) + train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) + val_loader = DataLoader(val_data, batch_size=val_batch_size, shuffle=False) + # step 4: init model + train parameters # to do: get feature size from data? -> + + #get model + init with_attention = [] model_name = self.hyperparameters["model"] model_class = getattr(models, model_name) - # is done after cv split - # self.model = model_class() #check hyper parameter for each model + self.model = model_class().to(device) #add feature size if "GAT" in model_name: - return_attention_weights = True + self.return_attention_weights = True else: - return_attention_weights = False + self.return_attention_weights = False + + #get model hyperparameters n_epochs = self.hyperparameters["n_epochs"] lr = self.hyperparameters["lr"] train_batch_size = self.hyperparameters["train_batch_size"] test_batch_size = self.hyperparameters["test_batch_size"] val_batch_size = self.hyperparameters["val_batch_size"] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # step 3: build data loaders + split data - test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - + optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) + do_save = self.hyperparameters["do_save"] + + #performance + log_interval = 20 + best_mse = 1000 + best_pearson = 0 + best_epoch = -1 + total_time = 0 + early_stop_tolerance = 30 + train_losses = [] + val_losses = [] + val_pearsons = [] + best_ret = [] + + # step 5: train loop - best_model_id = 0 - best_model = None - best_pearson_cv = 0 - ret_cv = [] - kf = KFold(n_splits=3) # creates splits indices for cross validation - # does the cv split - for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - # splits data into val and train according to index form kf - train_data = Subset(data, train_index) - train_data = Subset(data, val_index) - train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) - val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - - # creates model for each cv split - model = model_class().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - best_mse = 1000 - best_pearson = 0 - best_epoch = -1 - total_time = 0 - early_stop_tolerance = 30 - train_losses = [] - val_losses = [] - val_pearsons = [] - best_ret = [] - for epoch in tqdm(range(n_epochs)): - start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch + 1, log_interval) - G, P = predicting(model, device, val_loader) - ret = [u.rmse(G, P), u.mse(G, P), u.pearson(G, P), u.spearman(G, P), coeffi_determ(G, P)] - train_losses.append(train_loss) - val_losses.append(ret[1]) - val_pearsons.append(ret[2]) - - self.model = best_model + for epoch in tqdm(range(n_epochs)): + start_time = time.time() + train_loss = u.train(self.model, device, train_loader, optimizer, epoch+1, log_interval) + G,P = u.predicting(self.model, device, val_loader) + ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),u.coeffi_determ(G,P)] + + train_losses.append(train_loss) + val_losses.append(ret[1]) + val_pearsons.append(ret[2]) + + if ret[1] early_stop_tolerance: + print('early stop at epoch ', epoch) + break + # test with the model with best validation performance + if self.return_attention_weights: + G_test, P_test, attn_weights = u.predicting(self.model, device, test_loader, self.return_attention_weights) + else: + G_test, P_test = u.predicting(self.model, device, test_loader) + #rework + #result_file_name = 'result_' + save_name + '_' + dataset + '_' + '.csv' + #ret_test = [rmse(G_test,P_test),mse(G_test,P_test),pearson(G_test,P_test),spearman(G_test,P_test),coeffi_determ(G_test,P_test)] + #if do_save: + # best_model_file_name = 'model_' + save_name + '_' + dataset + '_best' + str(best_model_id) + '.model' + # torch.save(best_model.state_dict(), model_folder + best_model_file_name) + # if return_attention_weights: + # np.save(br_fol + '/Saliency/AttnWeight/' + model_st + '.npy', attn_weights) # step 6: save mdoel (maybe not neccecary beacuse class) - -class XGDP(DRPModel): - # to do - def __init__(self): - """Initialize the DrugGNN model.""" - pass - - @classmethod - def get_model_name(cls) -> str: - """Return the name of the model. - - :return: The name of the model. - """ - return "XGDP" - - @property - def cell_line_views(self) -> list[str]: - """Return the sources the model needs as input for describing the cell line. - - :return: The sources the model needs as input for describing the cell line. - """ - return ["gene_expression"] - - # to do - @property - def drug_views(self) -> list[str]: - """Return the sources the model needs as input for describing the drug. - - :return: The sources the model needs as input for describing the drug. - """ - return [""] - - def build_model(self, hyperparameters: dict[str, Any]) -> None: - """Build the model. - - :param hyperparameters: The hyperparameters. - """ - # Log hyperparameters to wandb if enabled - self.log_hyperparameters(hyperparameters) - - self.hyperparameters = hyperparameters - # init model in train - - def train( +def predict( self, - output: DrugResponseDataset, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, cell_line_input: FeatureDataset, drug_input: FeatureDataset | None = None, - output_earlystopping: DrugResponseDataset | None = None, - **kwargs, - ): - """Train the model. - - :param output: The output dataset. - :param cell_line_input: The cell line input dataset. - :param drug_input: The drug input dataset. - :param output_earlystopping: The early stopping output dataset. - :param kwargs: Additional arguments. - :raises ValueError: If drug input is not provided. + ) -> np.ndarray: """ - if drug_input is None: - raise ValueError("Drug input is required for XGDP") - - # step1 load data - - # step 2: preprocess data + Predicts the response for the given input. + + :param drug_ids: list of drug ids, also used for single drug models, there it is just an array containing the + same drug id + :param cell_line_ids: list of cell line ids + :param cell_line_input: input associated with the cell line, required for all models + :param drug_input: input associated with the drug, optional because single drug models do not use drug features + :returns: predicted response + """ + #step 1 load data + data = None + #(step 2 preprocess) - # step 4: init model + train parameters - # to do: get feature size from data? -> - with_attention = [] - model_name = self.hyperparameters["model"] - model_class = getattr(models, model_name) - # is done after cv split - # self.model = model_class() #check hyper parameter for each model - if "GAT" in model_name: - return_attention_weights = True - else: - return_attention_weights = False - n_epochs = self.hyperparameters["n_epochs"] - lr = self.hyperparameters["lr"] - train_batch_size = self.hyperparameters["train_batch_size"] - test_batch_size = self.hyperparameters["test_batch_size"] - val_batch_size = self.hyperparameters["val_batch_size"] + #step 3 dataloaders + + loader = DataLoader(data, batch_size=1, shuffle=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # step 3: build data loaders + split data - test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - - # step 5: train loop - best_model_id = 0 - best_model = None - best_pearson_cv = 0 - ret_cv = [] - kf = KFold(n_splits=3) # creates splits indices for cross validation - # does the cv split - for i, (train_index, val_index) in enumerate(kf.split(cv_data)): - # splits data into val and train according to index form kf - train_data = Subset(data, train_index) - train_data = Subset(data, val_index) - train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) - val_loader = DataLoader(test_data, batch_size=val_batch_size, shuffle=False) - - # creates model for each cv split - model = model_class().to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=lr) - best_mse = 1000 - best_pearson = 0 - best_epoch = -1 - total_time = 0 - early_stop_tolerance = 30 - train_losses = [] - val_losses = [] - val_pearsons = [] - best_ret = [] - for epoch in tqdm(range(n_epochs)): - start_time = time.time() - train_loss = train(model, device, train_loader, optimizer, epoch + 1, log_interval) - G, P = predicting(model, device, val_loader) - ret = [u.rmse(G, P), u.mse(G, P), u.pearson(G, P), u.spearman(G, P), coeffi_determ(G, P)] - train_losses.append(train_loss) - val_losses.append(ret[1]) - val_pearsons.append(ret[2]) - - self.model = best_model - # step 6: save mdoel (maybe not neccecary beacuse class) + #step 4 predict + ret_att = self.return_attention_weights + pred = u.predicting(model = self.model,device = device,loader = loader, return_attention_weights= ret_att) + #step 5 convert to dreval datatype + return \ No newline at end of file From 1dcf165dde076b58d9964df866e25e6d641d8072 Mon Sep 17 00:00:00 2001 From: eleded Date: Thu, 23 Apr 2026 21:37:37 +0200 Subject: [PATCH 07/10] refine train and predict, add classes --- drevalpy/models/XGDP/xgdp.py | 400 ++++++++++++++++++++++++----------- tests/test_drp_model.py | 2 +- 2 files changed, 273 insertions(+), 129 deletions(-) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index 07966d5f..d4198575 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -6,36 +6,176 @@ Code adapted from their Github: https://github.com/SCSE-Biomedical-Computing-Group/XGDP/blob/main/utils_tcnn.py """ -import json -import os -import secrets -import sys -import time from pathlib import Path from typing import Any -import models import numpy as np -import pandas as pd import pytorch_lightning as pl import torch import torch.nn as nn -import tqdm -import utils as u from sklearn.preprocessing import MinMaxScaler, StandardScaler -from sklearn.model_selection import KFold -from sklearn.model_selection import train_test_split -from torch.utils.data.dataset import Subset from torch.optim import Adam from torch.utils.data import Dataset as PytorchDataset from torch_geometric.loader import DataLoader -from torch_geometric.nn import GCNConv, global_mean_pool from ...datasets.dataset import DrugResponseDataset, FeatureDataset from ..drp_model import DRPModel from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features -from .utils import XGDPPredictor, predict_model, train_epoch +from .utils import XGDPPredictor + + +class _XGDPDataset(PytorchDataset): + """A PyTorch Dataset for XGDP.""" + + def __init__( + self, + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_features: FeatureDataset, + drug_features: FeatureDataset, + ): + """Initialize the dataset. + + :param response: The drug response values. + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_features: A FeatureDataset object with cell line features. + :param drug_features: A FeatureDataset object with drug features. + """ + self.response = response + self.cell_line_ids = cell_line_ids + self.drug_ids = drug_ids + + # preconvert to tensors to avoid per item tensor creation + self.cell_features = { + cl_id: torch.tensor(features["gene_expression"], dtype=torch.float32).unsqueeze(0) + for cl_id, features in cell_line_features.features.items() + } + self.response_tensor = torch.tensor(self.response, dtype=torch.float32) + + self.drug_graphs = { + drug_id: feature_views["drug_graph"] for drug_id, feature_views in drug_features.features.items() + } + + def __len__(self): + return len(self.response) + + def __getitem__(self, idx): + cell_line_id = self.cell_line_ids[idx] + drug_id = self.drug_ids[idx] + + drug_graph = self.drug_graphs[drug_id] + cell_feat = self.cell_features[cell_line_id] + response = self.response_tensor[idx] + + return drug_graph, cell_feat, response + + +class XGDPModule(pl.LightningModule): + """The LightningModule for the XGDP model.""" + + def __init__( + self, + num_node_features: int, + num_cell_features: int, + hidden_dim: int = 64, + dropout: float = 0.2, + learning_rate: float = 0.001, + ): + """Initialize the LightningModule. + + :param num_node_features: Number of features for each node in the drug graph. + :param num_cell_features: Number of features for the cell line. + :param hidden_dim: The hidden dimension size. + :param dropout: The dropout rate. + :param learning_rate: The learning rate. + """ + super().__init__() + self.save_hyperparameters() + self.model = XGDPPredictor( + num_node_features=self.hparams["num_node_features"], + num_cell_features=self.hparams["num_cell_features"], + hidden_dim=self.hparams["hidden_dim"], + dropout=self.hparams["dropout"], + model_type=self.hparams.get("model_type", "GATNet"), + ) + self.criterion = nn.MSELoss() + + # Initialize metrics storage for epoch-end R^2 and PCC computation + self._init_metrics_storage() + + def forward(self, batch): + """Forward pass of the module. + + :param batch: The batch. + :return: The output of the model. + """ + drug_graph, cell_features, _ = batch + return self.model( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), + ) + + def training_step(self, batch, batch_idx): + """A single training step. + + :param batch: The batch. + :param batch_idx: The batch index. + :return: The loss. + """ + drug_graph, cell_features, responses = batch + outputs = self.forward(batch) + loss = self.criterion(outputs, responses) + self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=True) + + return loss + + def validation_step(self, batch, batch_idx): + """A single validation step. + + :param batch: The batch. + :param batch_idx: The batch index. + """ + drug_graph, cell_features, responses = batch + outputs = self.model(drug_graph, cell_features) + loss = self.criterion(outputs, responses) + self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + self._store_predictions(outputs, responses, is_training=False) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """A single prediction step. + + :param batch: The batch. + :param batch_idx: The batch index. + :param dataloader_idx: The dataloader index. + :return: The output of the model. + """ + drug_graph, cell_features, _ = batch + outputs = self.model( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), + ) + return outputs + + def configure_optimizers(self): + """Configure the optimizer. + + :return: The optimizer. + """ + return Adam(self.parameters(), lr=self.hparams.learning_rate) class XGDP(DRPModel, RegressionMetricsMixin): @@ -68,11 +208,10 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: TODO: ADD HYPERPARAMETERS """ - self.hyperparameters = hyperparameters model_name = hyperparameters.get("model_type", "GATNet") - #init in train - #self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) + # init in train + # self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) """ # Log hyperparameters to wandb if enabled self.log_hyperparameters(hyperparameters) @@ -123,6 +262,17 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase return FeatureDataset(features=feature_dict) + def _loader_kwargs(self) -> dict[str, Any]: + num_workers = int(self.hyperparameters.get("num_workers", 4)) + kw = { + "num_workers": num_workers, + "pin_memory": True, + } + if num_workers > 0: + kw["persistent_workers"] = True + kw["prefetch_factor"] = int(self.hyperparameters.get("prefetch_factor", 2)) + return kw + def train( self, output: DrugResponseDataset, @@ -143,121 +293,115 @@ def train( if drug_input is None: raise ValueError("Drug input is required for XGDP") - # step1 load data - train_loader = self._prepare_dataloader(output, cell_line_input, drug_input, is_train=True) - - # step 2: preprocess data - - # step 3: build data loaders + split data - #to do create split indices - test_data = Subset(data, test_index) - train_data = Subset(data, train_index) - val_data = Subset(data, val_index) - test_loader = DataLoader(test_data, batch_size=test_batch_size, shuffle=False) - train_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=False) - val_loader = DataLoader(val_data, batch_size=val_batch_size, shuffle=False) - - # step 4: init model + train parameters - # to do: get feature size from data? -> - - #get model + init - with_attention = [] - model_name = self.hyperparameters["model"] - model_class = getattr(models, model_name) - self.model = model_class().to(device) #add feature size - if "GAT" in model_name: - self.return_attention_weights = True - else: - self.return_attention_weights = False - - #get model hyperparameters - n_epochs = self.hyperparameters["n_epochs"] - lr = self.hyperparameters["lr"] - train_batch_size = self.hyperparameters["train_batch_size"] - test_batch_size = self.hyperparameters["test_batch_size"] - val_batch_size = self.hyperparameters["val_batch_size"] - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) - do_save = self.hyperparameters["do_save"] - - #performance - log_interval = 20 - best_mse = 1000 - best_pearson = 0 - best_epoch = -1 - total_time = 0 - early_stop_tolerance = 30 - train_losses = [] - val_losses = [] - val_pearsons = [] - best_ret = [] - - - # step 5: train loop - for epoch in tqdm(range(n_epochs)): - start_time = time.time() - train_loss = u.train(self.model, device, train_loader, optimizer, epoch+1, log_interval) - G,P = u.predicting(self.model, device, val_loader) - ret = [u.rmse(G,P),u.mse(G,P),u.pearson(G,P),u.spearman(G,P),u.coeffi_determ(G,P)] - - train_losses.append(train_loss) - val_losses.append(ret[1]) - val_pearsons.append(ret[2]) - - if ret[1] early_stop_tolerance: - print('early stop at epoch ', epoch) - break - # test with the model with best validation performance - if self.return_attention_weights: - G_test, P_test, attn_weights = u.predicting(self.model, device, test_loader, self.return_attention_weights) - else: - G_test, P_test = u.predicting(self.model, device, test_loader) - #rework - #result_file_name = 'result_' + save_name + '_' + dataset + '_' + '.csv' - #ret_test = [rmse(G_test,P_test),mse(G_test,P_test),pearson(G_test,P_test),spearman(G_test,P_test),coeffi_determ(G_test,P_test)] - #if do_save: - # best_model_file_name = 'model_' + save_name + '_' + dataset + '_best' + str(best_model_id) + '.model' - # torch.save(best_model.state_dict(), model_folder + best_model_file_name) - # if return_attention_weights: - # np.save(br_fol + '/Saliency/AttnWeight/' + model_st + '.npy', attn_weights) - # step 6: save mdoel (maybe not neccecary beacuse class) - -def predict( + # Determine feature sizes + num_node_features = next(iter(drug_input.features.values()))["drug_graph"].num_node_features + num_cell_features = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + + self.model = XGDPModule( + num_node_features=num_node_features, + num_cell_features=num_cell_features, + hidden_dim=self.hyperparameters.get("hidden_dim", 64), + dropout=self.hyperparameters.get("dropout", 0.2), + learning_rate=self.hyperparameters.get("learning_rate", 0.001), + ) + + train_dataset = _XGDPDataset( + response=output.response, + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + train_loader = DataLoader( + train_dataset, + batch_size=self.hyperparameters.get("batch_size", 64), + shuffle=True, + **self._loader_kwargs(), + ) + + val_loader = None + if output_earlystopping is not None and len(output_earlystopping) > 0: + val_dataset = _XGDPDataset( + response=output_earlystopping.response, + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + val_loader = DataLoader( + val_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + # Set up wandb logger if project is provided + loggers = [] + if self.wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=self.wandb_project, log_model=False) + loggers.append(logger) + + trainer = pl.Trainer( + max_epochs=self.hyperparameters.get("epochs", 100), + accelerator="auto", + devices="auto", + callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, + logger=loggers if loggers else True, # Use default logger if no wandb + enable_progress_bar=True, + log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)), + precision=self.hyperparameters.get("precision", 32), + ) + trainer.fit(self.model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + def predict( self, cell_line_ids: np.ndarray, drug_ids: np.ndarray, cell_line_input: FeatureDataset, drug_input: FeatureDataset | None = None, ) -> np.ndarray: + """Predict drug response. + + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset. + :raises RuntimeError: If the model has not been trained yet. + :raises ValueError: If drug input is not provided. + :return: The predicted drug response. """ - Predicts the response for the given input. - - :param drug_ids: list of drug ids, also used for single drug models, there it is just an array containing the - same drug id - :param cell_line_ids: list of cell line ids - :param cell_line_input: input associated with the cell line, required for all models - :param drug_input: input associated with the drug, optional because single drug models do not use drug features - :returns: predicted response - """ - #step 1 load data - data = None - #(step 2 preprocess) - - #step 3 dataloaders + - loader = DataLoader(data, batch_size=1, shuffle=False) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - #step 4 predict - ret_att = self.return_attention_weights - pred = u.predicting(model = self.model,device = device,loader = loader, return_attention_weights= ret_att) - #step 5 convert to dreval datatype + return \ No newline at end of file + if len(drug_ids) == 0 or len(cell_line_ids) == 0: + print("XGDP predict: No drug or cell line IDs provided; returning empty array.") + return np.array([]) + if self.model is None: + raise RuntimeError("Model has not been trained yet.") + if drug_input is None: + raise ValueError("Drug input is required for XGDP") + + predict_dataset = _XGDPDataset( + response=np.zeros(len(cell_line_ids)), + cell_line_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + predict_loader = DataLoader( + predict_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + trainer = pl.Trainer(accelerator="auto", devices="auto", enable_progress_bar=False) + predictions_list = trainer.predict(self.model, dataloaders=predict_loader) + + if not predictions_list: + print("XGDP predict: No predictions were made; returning empty array.") + return np.array([]) + + predictions_flat = [ + item for sublist in predictions_list for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + + predictions = torch.cat(predictions_flat).cpu().numpy() + return predictions diff --git a/tests/test_drp_model.py b/tests/test_drp_model.py index 390947aa..0f4c3083 100644 --- a/tests/test_drp_model.py +++ b/tests/test_drp_model.py @@ -43,7 +43,7 @@ def test_factory() -> None: assert "MOLIR" in MODEL_FACTORY assert "SuperFELTR" in MODEL_FACTORY assert "DIPK" in MODEL_FACTORY - assert XGDP in MODEL_FACTORY + assert "XGDP" in MODEL_FACTORY def test_load_cl_ids_from_csv() -> None: From 673680716320dc1b9be74e95662d6d1ecdc10661 Mon Sep 17 00:00:00 2001 From: Gregor-git1 Date: Mon, 27 Apr 2026 18:23:18 +0200 Subject: [PATCH 08/10] test_1 --- .../models/XGDP/{models.py => _models.py} | 5 +- drevalpy/models/XGDP/utils.py | 82 ++++++++++++++++- drevalpy/models/XGDP/xgdp.py | 87 +++++++++++++------ 3 files changed, 146 insertions(+), 28 deletions(-) rename drevalpy/models/XGDP/{models.py => _models.py} (99%) diff --git a/drevalpy/models/XGDP/models.py b/drevalpy/models/XGDP/_models.py similarity index 99% rename from drevalpy/models/XGDP/models.py rename to drevalpy/models/XGDP/_models.py index 3e047b1e..0bc0cc24 100644 --- a/drevalpy/models/XGDP/models.py +++ b/drevalpy/models/XGDP/_models.py @@ -216,7 +216,8 @@ def __init__( # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) # self.fc1_xt = nn.Linear(61824, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + #self.fc1_xt = nn.Linear(4096, output_dim) + self.fc1_xt = nn.Linear(3584, output_dim) # combined layers if self.use_attn: @@ -266,6 +267,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # target = data.target # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) @@ -277,6 +279,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ conv_xt = self.pool_xt_3(conv_xt) # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) xt = self.fc1_xt(xt) diff --git a/drevalpy/models/XGDP/utils.py b/drevalpy/models/XGDP/utils.py index c5fcd672..3d198619 100644 --- a/drevalpy/models/XGDP/utils.py +++ b/drevalpy/models/XGDP/utils.py @@ -1 +1,81 @@ -"""Utility functions for the XGDP model.""" +import torch +import torch.nn as nn +from torch.utils.data.dataset import Subset +import time +from sklearn.model_selection import KFold +from tqdm import tqdm +import pandas as pd + +def rmse(y, f): + rmse = sqrt(((y - f)**2).mean(axis=0)) + return rmse + + +def mse(y, f): + mse = ((y - f)**2).mean(axis=0) + return mse + + +def pearson(y, f): + rp = np.corrcoef(y, f)[0, 1] + return rp + + +def spearman(y, f): + rs = stats.spearmanr(y, f)[0] + return rs + + +def coeffi_determ(y, f): + r2 = r2_score(y, f) + return r2 + +def predicting(model, device, loader, return_attention_weights = False): + model.eval() + total_preds = torch.Tensor() + total_labels = torch.Tensor() + print('Make prediction for {} samples...'.format(len(loader.dataset))) + with torch.no_grad(): + for data in loader: + data = data.to(device) + + # output, _ = model(data) + x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index.long(), data.batch, data.edge_features + if return_attention_weights: + # output, _, attn_weights = model(x, edge_index, x_cell_mut, batch_drug, edge_feat, return_attention_weights) + output, attn_weights = model(x, edge_index, batch_drug, x_cell_mut, edge_feat, return_attention_weights) + attn_weights = [attn_weight.cpu().numpy() for attn_weight in attn_weights] + # print(attn_weights) + attn_weights = np.array(attn_weights) + # print(attn_weights.shape) + else: + # output, _ = model(x, edge_index, x_cell_mut, batch_drug, edge_feat) + output = model(x, edge_index, batch_drug, x_cell_mut, edge_feat) + + total_preds = torch.cat((total_preds, output.cpu()), 0) + total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0) + torch.cuda.empty_cache() ## no grad + if return_attention_weights: + return total_labels.numpy().flatten(), total_preds.numpy().flatten(), attn_weights + else: + return total_labels.numpy().flatten(), total_preds.numpy().flatten() + +# training function at each epoch +def train(model, device, train_loader, optimizer, epoch, log_interval, return_attention_weights=False): + print('Training on {} samples...'.format(len(train_loader.dataset))) + model.train() + loss_fn = nn.MSELoss() + avg_loss = [] + for data in tqdm(train_loader): + data = data.to(device) + optimizer.zero_grad() + + x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index.long(), data.batch, data.edge_features + + output = model(x, edge_index, batch_drug, x_cell_mut, edge_feat) + + loss = loss_fn(output, data.y.view(-1, 1).float().to(device)) + loss.backward() + optimizer.step() + avg_loss.append(loss.item()) + return sum(avg_loss)/len(avg_loss) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index d4198575..61ab86dc 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -5,7 +5,7 @@ Original authors: Wang, C., Kumar, G.A. & Rajapakse, J.C. (2025, 10.1038/s41598-024-83090-3) Code adapted from their Github: https://github.com/SCSE-Biomedical-Computing-Group/XGDP/blob/main/utils_tcnn.py """ - +import drevalpy.models.XGDP._models as m from pathlib import Path from typing import Any @@ -22,7 +22,9 @@ from ..drp_model import DRPModel from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features -from .utils import XGDPPredictor + + + class _XGDPDataset(PytorchDataset): @@ -78,10 +80,12 @@ class XGDPModule(pl.LightningModule): def __init__( self, + model: str, num_node_features: int, num_cell_features: int, - hidden_dim: int = 64, - dropout: float = 0.2, + do_att: bool, + hidden_dim: int = 128, + dropout: float = 0.5, #changed to 0.5 as per there default settings for there models learning_rate: float = 0.001, ): """Initialize the LightningModule. @@ -94,17 +98,37 @@ def __init__( """ super().__init__() self.save_hyperparameters() - self.model = XGDPPredictor( - num_node_features=self.hparams["num_node_features"], - num_cell_features=self.hparams["num_cell_features"], - hidden_dim=self.hparams["hidden_dim"], - dropout=self.hparams["dropout"], - model_type=self.hparams.get("model_type", "GATNet"), + self + #model_name = model + model_name = "GATNet" + try: + model_class = getattr(m, model_name) + print(type(model_class)) + except AttributeError: + # Specifically catch the error if the string name doesn't exist in 'models' + raise ValueError(f"Model '{model_name}' not found in the list of available models.") + + if "GAT" in model_name: + self.return_attention_weights = False # because no need to return attention weight because no models explanation through XAI + else: + self.return_attention_weights = False + #print("----dimensions-------") + #print(hidden_dim) + self.model = model_class( + n_output = 1, + num_features_xd = num_node_features, #gnn number of node features (ECFP& + DeepChem: 334) + num_features_xt=25, #only used in GINNet in embedding, mutation/protein data??? + n_filters=32, #number of filters for cnn for gene expression + embed_dim=128, #only used in GINNet in embedding + output_dim=hidden_dim, #size of latend sapce as output of gnn and cnn -> determines size of shared feature size after combination + dropout= dropout, + use_attn=do_att #not in GINNet, SAGENet, add to models to have same input + ) self.criterion = nn.MSELoss() # Initialize metrics storage for epoch-end R^2 and PCC computation - self._init_metrics_storage() + #self._init_metrics_storage() def forward(self, batch): """Forward pass of the module. @@ -113,13 +137,27 @@ def forward(self, batch): :return: The output of the model. """ drug_graph, cell_features, _ = batch - return self.model( + if self.return_attention_weights: + #print("x", drug_graph.x.shape, type(drug_graph.x)) + #print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + #print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + + return self.model.forward( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), + return_attention_weights = True, + ) + else: + return self.model.forward( x=drug_graph.x, edge_index=drug_graph.edge_index, batch=drug_graph.batch, x_cell_mut=cell_features, edge_feat=getattr(drug_graph, "edge_attr", None), - ) + ) def training_step(self, batch, batch_idx): """A single training step. @@ -134,7 +172,7 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) # Store predictions and targets for epoch-end metrics via mixin - self._store_predictions(outputs, responses, is_training=True) + #self._store_predictions(outputs, responses, is_training=True) return loss @@ -145,7 +183,7 @@ def validation_step(self, batch, batch_idx): :param batch_idx: The batch index. """ drug_graph, cell_features, responses = batch - outputs = self.model(drug_graph, cell_features) + outputs = self.forward(batch) loss = self.criterion(outputs, responses) self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) @@ -161,13 +199,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): :return: The output of the model. """ drug_graph, cell_features, _ = batch - outputs = self.model( - x=drug_graph.x, - edge_index=drug_graph.edge_index, - batch=drug_graph.batch, - x_cell_mut=cell_features, - edge_feat=getattr(drug_graph, "edge_attr", None), - ) + outputs = self.forward(batch) return outputs def configure_optimizers(self): @@ -188,7 +220,7 @@ def __init__(self) -> None: """Initialize the XGDP model.""" super().__init__() self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model: XGDPPredictor | None = None + self.model: XGDPModule | None = None self.hyperparameters: dict[str, Any] = {} self.gene_expression_scaler: StandardScaler | None = None self.gene_expression_normalizer: MinMaxScaler | None = None @@ -300,9 +332,11 @@ def train( self.model = XGDPModule( num_node_features=num_node_features, num_cell_features=num_cell_features, - hidden_dim=self.hyperparameters.get("hidden_dim", 64), - dropout=self.hyperparameters.get("dropout", 0.2), + hidden_dim=self.hyperparameters.get("hidden_dim", 128), + dropout=self.hyperparameters.get("dropout", 0.5), learning_rate=self.hyperparameters.get("learning_rate", 0.001), + model = self.hyperparameters.get("model", "GATv2Net"), + do_att= self.hyperparameters.get("do_att", True), ) train_dataset = _XGDPDataset( @@ -343,7 +377,8 @@ def train( loggers.append(logger) trainer = pl.Trainer( - max_epochs=self.hyperparameters.get("epochs", 100), + #max_epochs=self.hyperparameters.get("epochs", 100), #changed to 10 fro testing + max_epochs=10, #changed to 10 fro testing accelerator="auto", devices="auto", callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, From d86b0ae10d45b0c632eef183df585ba87f15cd7e Mon Sep 17 00:00:00 2001 From: Gregor-git1 Date: Thu, 7 May 2026 12:19:00 +0200 Subject: [PATCH 09/10] still not working --- drevalpy/models/XGDP/xgdp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index 61ab86dc..89fc38b9 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -112,8 +112,7 @@ def __init__( self.return_attention_weights = False # because no need to return attention weight because no models explanation through XAI else: self.return_attention_weights = False - #print("----dimensions-------") - #print(hidden_dim) + self.model = model_class( n_output = 1, num_features_xd = num_node_features, #gnn number of node features (ECFP& + DeepChem: 334) @@ -167,7 +166,7 @@ def training_step(self, batch, batch_idx): :return: The loss. """ drug_graph, cell_features, responses = batch - outputs = self.forward(batch) + outputs = self(batch) loss = self.criterion(outputs, responses) self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) @@ -200,6 +199,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): """ drug_graph, cell_features, _ = batch outputs = self.forward(batch) + outputs = outputs.squeeze(-1) # to flatten torch tensor so that it can be handeled by pandas return outputs def configure_optimizers(self): @@ -378,7 +378,7 @@ def train( trainer = pl.Trainer( #max_epochs=self.hyperparameters.get("epochs", 100), #changed to 10 fro testing - max_epochs=10, #changed to 10 fro testing + max_epochs=1, #changed to 10 fro testing accelerator="auto", devices="auto", callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, From a24adf7bb349235aa6630d05f142221acf908e96 Mon Sep 17 00:00:00 2001 From: eleded Date: Tue, 19 May 2026 11:26:27 +0200 Subject: [PATCH 10/10] update models --- drevalpy/models/XGDP/_models.py | 516 ++++++++++++++++------ drevalpy/models/XGDP/hyperparameters.yaml | 2 +- drevalpy/models/XGDP/utils.py | 81 ---- drevalpy/models/XGDP/xgdp.py | 96 ++-- 4 files changed, 442 insertions(+), 253 deletions(-) delete mode 100644 drevalpy/models/XGDP/utils.py diff --git a/drevalpy/models/XGDP/_models.py b/drevalpy/models/XGDP/_models.py index 0bc0cc24..20aa4f0f 100644 --- a/drevalpy/models/XGDP/_models.py +++ b/drevalpy/models/XGDP/_models.py @@ -67,17 +67,34 @@ def __init__( self.dropout = nn.Dropout(dropout) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + # cell line feature + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) # self.fc1_xt = nn.Linear(61824, output_dim) self.fc1_xt = nn.Linear(4096, output_dim) + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + # combined layers if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -90,7 +107,7 @@ def __init__( self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, self.n_output) - def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None, return_attention_weights=False): """ Forward pass of the GCNNet model. @@ -105,9 +122,17 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) # get graph input # edge_weight is only used for decoding + if edge_feat is not None: + pass + # x, edge_index, batch = data.x, data.edge_index, data.batch # edge_index = edge_index.long() + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + x = self.conv1(x, edge_index, edge_weight) x = self.relu(x) x = self.conv2(x, edge_index, edge_weight) @@ -167,6 +192,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) return out @@ -206,18 +234,28 @@ def __init__( self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout) self.fc_g1 = nn.Linear(output_dim, output_dim) - # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - # self.fc1_xt = nn.Linear(61824, output_dim) - #self.fc1_xt = nn.Linear(4096, output_dim) - self.fc1_xt = nn.Linear(3584, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) # combined layers if self.use_attn: @@ -249,6 +287,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ """ # graph input feed-forward # x, edge_index, batch = data.x, data.edge_index, data.batch + if edge_feat is not None: + pass + # x = self.dropout(x) # x = f.dropout(x, p=0.2, training=self.training) x = f.elu(self.gcn1(x, edge_index)) @@ -267,7 +308,12 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # target = data.target # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers - + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + conv_xt = self.conv_xt_1(x_cell_mut) conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) @@ -279,7 +325,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ conv_xt = self.pool_xt_3(conv_xt) # flatten - + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) xt = self.fc1_xt(xt) @@ -306,6 +352,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) @@ -347,23 +396,39 @@ def __init__( # graph layers self.gcn1 = GATv2Conv( - num_features_xd, num_features_xd, heads=25, dropout=dropout, edge_dim=4, add_self_loops=False + num_features_xd, num_features_xd, heads=25, dropout=dropout, edge_dim=7, add_self_loops=False ) - self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, dropout=dropout, edge_dim=4, add_self_loops=False) + self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, dropout=dropout, edge_dim=7, add_self_loops=False) self.fc_g1 = nn.Linear(output_dim, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) # self.fc1_xt = nn.Linear(2944, output_dim) # self.fc1_xt = nn.Linear(4224, output_dim) # self.fc1_xt = nn.Linear(61824, output_dim) self.fc1_xt = nn.Linear(4096, output_dim) + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + # combined layers if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -396,6 +461,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) # print(edge_feat.shape) + if edge_feat is not None: + pass # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) @@ -417,6 +484,12 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # target = data.target # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + conv_xt = self.conv_xt_1(x_cell_mut) conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) @@ -454,6 +527,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) @@ -463,7 +539,7 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ return out -class GATNet_E(torch.nn.Module): +class GATNetE(torch.nn.Module): """A GAT variant explicitly incorporating edge attributes.""" def __init__( @@ -478,7 +554,7 @@ def __init__( use_attn=False, ): """ - Initialize the GATNet_E model. + Initialize the GATNetE model. :param num_features_xd: Number of molecular graph node features :param n_output: Number of output units @@ -493,21 +569,32 @@ def __init__( self.use_attn = use_attn # graph layers - self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout, edge_dim=4) - self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout, edge_dim=4) + self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout, edge_dim=7) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout, edge_dim=7) self.fc_g1 = nn.Linear(output_dim, output_dim) - # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - # self.fc1_xt = nn.Linear(61824, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -540,6 +627,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) + if edge_feat is not None: + pass # x = f.dropout(x, p=0.2, training=self.training) # x = self.dropout(x) @@ -561,6 +650,12 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # target = data.target # x_cell_mut = x_cell_mut[:,None,:] # 1d conv layers + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + conv_xt = self.conv_xt_1(x_cell_mut) conv_xt = f.relu(conv_xt) conv_xt = self.pool_xt_1(conv_xt) @@ -598,6 +693,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) @@ -619,6 +717,7 @@ def __init__( num_features_xt=25, output_dim=128, dropout=0.5, + use_attn=False, ): """ Initialize the SAGENet model. @@ -632,7 +731,7 @@ def __init__( :param dropout: Dropout probability """ super().__init__() - + self.use_attn = use_attn # SMILES graph branch # GCNSAGE @@ -645,16 +744,28 @@ def __init__( self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) - # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(61824, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) # combined layers self.fc1 = nn.Linear(2 * output_dim, 1024) @@ -665,7 +776,7 @@ def __init__( self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) - def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): """ Forward pass of the SAGENet model. @@ -679,6 +790,14 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): # get graph input # x, edge_index, batch = data.x, data.edge_index, data.batch + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # GCNSAGE x = self.conv1(x, edge_index) x = self.relu(x) @@ -718,6 +837,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) return out @@ -735,6 +857,7 @@ def __init__( embed_dim=128, output_dim=128, dropout=0.5, + use_attn=False, ): """ Initialize the GINNet model. @@ -748,6 +871,7 @@ def __init__( :param dropout: Dropout probability """ super().__init__() + self.use_attn = use_attn dim = 32 self.dropout = nn.Dropout(dropout) @@ -780,16 +904,28 @@ def __init__( self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(61824, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) # combined layers self.fc1 = nn.Linear(2 * output_dim, 1024) @@ -800,7 +936,7 @@ def __init__( self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) - def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): """ Forward pass of the GINNet model. @@ -811,6 +947,14 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): :param edge_feat: Edge features (unused for GINConv) :returns: Predicted drug response """ + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # x, edge_index, batch = data.x, data.edge_index, data.batch # print(x) # print(data.target) @@ -857,6 +1001,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) return out @@ -874,6 +1021,7 @@ def __init__( embed_dim=128, output_dim=128, dropout=0.5, + use_attn=False, ): """ Initialize the GINENet model. @@ -887,6 +1035,7 @@ def __init__( :param dropout: Dropout probability """ super().__init__() + self.use_attn = use_attn dim = 32 self.dropout = nn.Dropout(dropout) @@ -894,23 +1043,23 @@ def __init__( self.n_output = n_output # convolution layers nn1 = Sequential(Linear(num_features_xd, dim), ReLU(), Linear(dim, dim)) - self.conv1 = GINEConv(nn1, edge_dim=4) + self.conv1 = GINEConv(nn1, edge_dim=7) self.bn1 = torch.nn.BatchNorm1d(dim) nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) - self.conv2 = GINEConv(nn2, edge_dim=4) + self.conv2 = GINEConv(nn2, edge_dim=7) self.bn2 = torch.nn.BatchNorm1d(dim) nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) - self.conv3 = GINEConv(nn3, edge_dim=4) + self.conv3 = GINEConv(nn3, edge_dim=7) self.bn3 = torch.nn.BatchNorm1d(dim) nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) - self.conv4 = GINEConv(nn4, edge_dim=4) + self.conv4 = GINEConv(nn4, edge_dim=7) self.bn4 = torch.nn.BatchNorm1d(dim) nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) - self.conv5 = GINEConv(nn5, edge_dim=4) + self.conv5 = GINEConv(nn5, edge_dim=7) self.bn5 = torch.nn.BatchNorm1d(dim) self.fc1_xd = Linear(dim, output_dim) @@ -920,15 +1069,28 @@ def __init__( self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(61824, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) # combined layers self.fc1 = nn.Linear(2 * output_dim, 1024) @@ -939,7 +1101,7 @@ def __init__( self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) - def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): """ Forward pass of the GINENet model. @@ -950,6 +1112,14 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): :param edge_feat: Edge features of the molecular graph :returns: Predicted drug response """ + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # x, edge_index, batch = data.x, data.edge_index, data.batch # print(x) # print(data.target) @@ -996,6 +1166,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat): xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) return out @@ -1041,19 +1214,32 @@ def __init__( self.dropout = nn.Dropout(dropout) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) if self.use_attn: - self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) - self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout, batch_first=True) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout, batch_first=True) self.norm1 = nn.LayerNorm(output_dim) self.norm2 = nn.LayerNorm(output_dim) self.fc = nn.Linear(2 * output_dim, output_dim) @@ -1063,7 +1249,7 @@ def __init__( self.fc2 = nn.Linear(1024, 128) self.out = nn.Linear(128, self.n_output) - def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None): + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None, return_attention_weights=False): """ Forward pass of the RGCNNet model. @@ -1078,9 +1264,17 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) # get graph input # edge_weight is only used for decoding + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + if edge_feat is not None: + edge_feat = edge_feat.long().view(-1) + # x, edge_index, batch = data.x, data.edge_index, data.batch # edge_index = edge_index.long() - edge_feat = edge_feat.long().squeeze() + # edge_feat = edge_feat.long().squeeze() x = self.conv1(x, edge_index, edge_type=edge_feat) x = self.relu(x) @@ -1102,6 +1296,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) # add this line for CNV data, remove for gene expr data # x_cell_mut = x_cell_mut[:,None,:] + if x_cell_mut.dim() == 2: + x_cell_mut = x_cell_mut.unsqueeze(1) # 1d conv layers conv_xt = self.conv_xt_1(x_cell_mut) @@ -1141,6 +1337,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None) xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) return out @@ -1190,15 +1389,28 @@ def __init__( self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -1231,6 +1443,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) + if edge_feat is not None: + pass edge_feat = edge_feat.int().squeeze() # print(edge_feat) @@ -1250,6 +1464,11 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ x = self.fc_g1(x) x = self.relu(x) + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # protein input feed-forward: # target = data.target # x_cell_mut = x_cell_mut[:,None,:] @@ -1291,6 +1510,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) @@ -1344,15 +1566,28 @@ def __init__( self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -1385,6 +1620,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) + if edge_feat is not None: + pass edge_feat = edge_feat.int().squeeze() # x = f.dropout(x, p=0.2, training=self.training) @@ -1403,6 +1640,11 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ x = self.fc_g1(x) x = self.relu(x) + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # protein input feed-forward: # target = data.target # x_cell_mut = x_cell_mut[:,None,:] @@ -1444,6 +1686,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) @@ -1483,21 +1728,34 @@ def __init__( self.use_attn = use_attn # graph layers - self.gcn1 = FiLMConv(num_features_xd, num_features_xd, num_relations=4, act=nn.LeakyReLU()) - self.gcn2 = FiLMConv(num_features_xd, output_dim, num_relations=4, act=nn.LeakyReLU()) + self.gcn1 = FiLMConv(num_features_xd, num_features_xd, num_relations=4, act=nn.LeakyReLU(), edge_dim=7) + self.gcn2 = FiLMConv(num_features_xd, output_dim, num_relations=4, act=nn.LeakyReLU(), edge_dim=7) self.fc_g1 = nn.Linear(output_dim, output_dim) # cell line feature - self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=8) - self.pool_xt_1 = nn.MaxPool1d(3) - self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=8) - self.pool_xt_2 = nn.MaxPool1d(3) - self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=8) - self.pool_xt_3 = nn.MaxPool1d(3) - # self.fc1_xt = nn.Linear(2944, output_dim) - # self.fc1_xt = nn.Linear(4224, output_dim) - self.fc1_xt = nn.Linear(4096, output_dim) + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) if self.use_attn: self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) @@ -1530,6 +1788,8 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ # graph input feed-forward # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features # print(data.x.shape) + if edge_feat is not None: + pass edge_feat = edge_feat.int().squeeze() # x = f.dropout(x, p=0.2, training=self.training) @@ -1543,6 +1803,11 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ x = self.fc_g1(x) x = self.relu(x) + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + # protein input feed-forward: # target = data.target # x_cell_mut = x_cell_mut[:,None,:] @@ -1584,6 +1849,9 @@ def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_ xc = self.fc2(xc) xc = self.relu(xc) xc = self.dropout(xc) + + if self.use_attn: + pass out = self.out(xc) out = nn.Sigmoid()(out) diff --git a/drevalpy/models/XGDP/hyperparameters.yaml b/drevalpy/models/XGDP/hyperparameters.yaml index 78a28841..f1c797ca 100644 --- a/drevalpy/models/XGDP/hyperparameters.yaml +++ b/drevalpy/models/XGDP/hyperparameters.yaml @@ -3,7 +3,7 @@ XGDP: - "GATNet" - "GCNNet" - "GATv2Net" - - "GATNet_E" + - "GATNetE" - "SAGENet" - "GINNet" - "GINENet" diff --git a/drevalpy/models/XGDP/utils.py b/drevalpy/models/XGDP/utils.py deleted file mode 100644 index 3d198619..00000000 --- a/drevalpy/models/XGDP/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import torch.nn as nn -from torch.utils.data.dataset import Subset -import time -from sklearn.model_selection import KFold -from tqdm import tqdm -import pandas as pd - -def rmse(y, f): - rmse = sqrt(((y - f)**2).mean(axis=0)) - return rmse - - -def mse(y, f): - mse = ((y - f)**2).mean(axis=0) - return mse - - -def pearson(y, f): - rp = np.corrcoef(y, f)[0, 1] - return rp - - -def spearman(y, f): - rs = stats.spearmanr(y, f)[0] - return rs - - -def coeffi_determ(y, f): - r2 = r2_score(y, f) - return r2 - -def predicting(model, device, loader, return_attention_weights = False): - model.eval() - total_preds = torch.Tensor() - total_labels = torch.Tensor() - print('Make prediction for {} samples...'.format(len(loader.dataset))) - with torch.no_grad(): - for data in loader: - data = data.to(device) - - # output, _ = model(data) - x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index.long(), data.batch, data.edge_features - if return_attention_weights: - # output, _, attn_weights = model(x, edge_index, x_cell_mut, batch_drug, edge_feat, return_attention_weights) - output, attn_weights = model(x, edge_index, batch_drug, x_cell_mut, edge_feat, return_attention_weights) - attn_weights = [attn_weight.cpu().numpy() for attn_weight in attn_weights] - # print(attn_weights) - attn_weights = np.array(attn_weights) - # print(attn_weights.shape) - else: - # output, _ = model(x, edge_index, x_cell_mut, batch_drug, edge_feat) - output = model(x, edge_index, batch_drug, x_cell_mut, edge_feat) - - total_preds = torch.cat((total_preds, output.cpu()), 0) - total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0) - torch.cuda.empty_cache() ## no grad - if return_attention_weights: - return total_labels.numpy().flatten(), total_preds.numpy().flatten(), attn_weights - else: - return total_labels.numpy().flatten(), total_preds.numpy().flatten() - -# training function at each epoch -def train(model, device, train_loader, optimizer, epoch, log_interval, return_attention_weights=False): - print('Training on {} samples...'.format(len(train_loader.dataset))) - model.train() - loss_fn = nn.MSELoss() - avg_loss = [] - for data in tqdm(train_loader): - data = data.to(device) - optimizer.zero_grad() - - x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index.long(), data.batch, data.edge_features - - output = model(x, edge_index, batch_drug, x_cell_mut, edge_feat) - - loss = loss_fn(output, data.y.view(-1, 1).float().to(device)) - loss.backward() - optimizer.step() - avg_loss.append(loss.item()) - return sum(avg_loss)/len(avg_loss) diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py index 61ab86dc..391fa83d 100644 --- a/drevalpy/models/XGDP/xgdp.py +++ b/drevalpy/models/XGDP/xgdp.py @@ -5,7 +5,7 @@ Original authors: Wang, C., Kumar, G.A. & Rajapakse, J.C. (2025, 10.1038/s41598-024-83090-3) Code adapted from their Github: https://github.com/SCSE-Biomedical-Computing-Group/XGDP/blob/main/utils_tcnn.py """ -import drevalpy.models.XGDP._models as m + from pathlib import Path from typing import Any @@ -18,15 +18,14 @@ from torch.utils.data import Dataset as PytorchDataset from torch_geometric.loader import DataLoader +import drevalpy.models.XGDP._models as Models + from ...datasets.dataset import DrugResponseDataset, FeatureDataset from ..drp_model import DRPModel from ..lightning_metrics_mixin import RegressionMetricsMixin from ..utils import load_and_select_gene_features - - - class _XGDPDataset(PytorchDataset): """A PyTorch Dataset for XGDP.""" @@ -85,50 +84,51 @@ def __init__( num_cell_features: int, do_att: bool, hidden_dim: int = 128, - dropout: float = 0.5, #changed to 0.5 as per there default settings for there models + dropout: float = 0.5, # changed to 0.5 as per there default settings for there models learning_rate: float = 0.001, ): """Initialize the LightningModule. + :param model: Name of the XGDP backbone to use (e.g., 'gat', 'gcn', 'gatv2'). + :param do_att: Whether to enable cross-attention between drug and cell features. :param num_node_features: Number of features for each node in the drug graph. :param num_cell_features: Number of features for the cell line. :param hidden_dim: The hidden dimension size. :param dropout: The dropout rate. :param learning_rate: The learning rate. + :raises ValueError: If drug_input is not provided. """ super().__init__() self.save_hyperparameters() self - #model_name = model - model_name = "GATNet" + # model_name = model + model_name = model try: - model_class = getattr(m, model_name) + model_class = getattr(Models, model_name) print(type(model_class)) except AttributeError: # Specifically catch the error if the string name doesn't exist in 'models' raise ValueError(f"Model '{model_name}' not found in the list of available models.") - - if "GAT" in model_name: - self.return_attention_weights = False # because no need to return attention weight because no models explanation through XAI - else: - self.return_attention_weights = False - #print("----dimensions-------") - #print(hidden_dim) + + self.return_attention_weights = ( + do_att if "GAT" in model_name else False + ) # because no need to return attention weight because no models explanation through XAI + # print("----dimensions-------") + # print(hidden_dim) self.model = model_class( - n_output = 1, - num_features_xd = num_node_features, #gnn number of node features (ECFP& + DeepChem: 334) - num_features_xt=25, #only used in GINNet in embedding, mutation/protein data??? - n_filters=32, #number of filters for cnn for gene expression - embed_dim=128, #only used in GINNet in embedding - output_dim=hidden_dim, #size of latend sapce as output of gnn and cnn -> determines size of shared feature size after combination - dropout= dropout, - use_attn=do_att #not in GINNet, SAGENet, add to models to have same input - + n_output=1, + num_features_xd=num_node_features, # gnn number of node features (ECFP& + DeepChem: 334) + num_features_xt=25, # only used in GINNet in embedding, mutation/protein data??? + n_filters=32, # number of filters for cnn for gene expression + embed_dim=128, # only used in GINNet in embedding + output_dim=hidden_dim, # size latend sapce as output of gnn + cnn -> size of shared size after combination + dropout=dropout, + use_attn=do_att, # not in GINNet, SAGENet, add to models to have same input ) self.criterion = nn.MSELoss() # Initialize metrics storage for epoch-end R^2 and PCC computation - #self._init_metrics_storage() + # self._init_metrics_storage() def forward(self, batch): """Forward pass of the module. @@ -138,26 +138,28 @@ def forward(self, batch): """ drug_graph, cell_features, _ = batch if self.return_attention_weights: - #print("x", drug_graph.x.shape, type(drug_graph.x)) - #print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) - #print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + # print("x", drug_graph.x.shape, type(drug_graph.x)) + # print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + # print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) - return self.model.forward( + out, _ = self.model.forward( x=drug_graph.x, edge_index=drug_graph.edge_index, batch=drug_graph.batch, x_cell_mut=cell_features, edge_feat=getattr(drug_graph, "edge_attr", None), - return_attention_weights = True, + return_attention_weights=True, ) + return out else: - return self.model.forward( - x=drug_graph.x, - edge_index=drug_graph.edge_index, - batch=drug_graph.batch, - x_cell_mut=cell_features, - edge_feat=getattr(drug_graph, "edge_attr", None), + out = self.model.forward( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), ) + return out def training_step(self, batch, batch_idx): """A single training step. @@ -172,7 +174,7 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) # Store predictions and targets for epoch-end metrics via mixin - #self._store_predictions(outputs, responses, is_training=True) + # self._store_predictions(outputs, responses, is_training=True) return loss @@ -188,7 +190,8 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) # Store predictions and targets for epoch-end metrics via mixin - self._store_predictions(outputs, responses, is_training=False) + # self._store_predictions(outputs, responses, is_training=False) + pass def predict_step(self, batch, batch_idx, dataloader_idx=0): """A single prediction step. @@ -241,7 +244,6 @@ def build_model(self, hyperparameters: dict[str, Any]) -> None: :param hyperparameters: TODO: ADD HYPERPARAMETERS """ self.hyperparameters = hyperparameters - model_name = hyperparameters.get("model_type", "GATNet") # init in train # self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) """ @@ -261,7 +263,7 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD """ return load_and_select_gene_features( feature_type="gene_expression", - gene_list="landmark_genes", + gene_list=None, data_path=data_path, dataset_name=dataset_name, ) @@ -330,13 +332,13 @@ def train( num_cell_features = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] self.model = XGDPModule( + model=self.hyperparameters.get("model_type", "GATv2Net"), num_node_features=num_node_features, num_cell_features=num_cell_features, hidden_dim=self.hyperparameters.get("hidden_dim", 128), dropout=self.hyperparameters.get("dropout", 0.5), learning_rate=self.hyperparameters.get("learning_rate", 0.001), - model = self.hyperparameters.get("model", "GATv2Net"), - do_att= self.hyperparameters.get("do_att", True), + do_att=self.hyperparameters.get("do_att", True), ) train_dataset = _XGDPDataset( @@ -357,8 +359,8 @@ def train( if output_earlystopping is not None and len(output_earlystopping) > 0: val_dataset = _XGDPDataset( response=output_earlystopping.response, - cell_line_ids=output.cell_line_ids, - drug_ids=output.drug_ids, + cell_line_ids=output_earlystopping.cell_line_ids, + drug_ids=output_earlystopping.drug_ids, cell_line_features=cell_line_input, drug_features=drug_input, ) @@ -377,8 +379,8 @@ def train( loggers.append(logger) trainer = pl.Trainer( - #max_epochs=self.hyperparameters.get("epochs", 100), #changed to 10 fro testing - max_epochs=10, #changed to 10 fro testing + # max_epochs=self.hyperparameters.get("epochs", 100), #changed to 10 fro testing + max_epochs=10, # changed to 10 fro testing accelerator="auto", devices="auto", callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, @@ -438,5 +440,5 @@ def predict( item for sublist in predictions_list for item in (sublist if isinstance(sublist, list) else [sublist]) ] - predictions = torch.cat(predictions_flat).cpu().numpy() + predictions = torch.cat(predictions_flat).view(-1).cpu().numpy() return predictions