diff --git a/drevalpy/models/XGDP/__init__.py b/drevalpy/models/XGDP/__init__.py new file mode 100644 index 00000000..808afa23 --- /dev/null +++ b/drevalpy/models/XGDP/__init__.py @@ -0,0 +1,5 @@ +"""A GNN and CNN based drug response prediction model.""" + +from .xgdp import XGDP + +__all__ = ["XGDP"] diff --git a/drevalpy/models/XGDP/_models.py b/drevalpy/models/XGDP/_models.py new file mode 100644 index 00000000..4d13ee66 --- /dev/null +++ b/drevalpy/models/XGDP/_models.py @@ -0,0 +1,1866 @@ +"""Models for XGDP model.""" + +import torch +import torch.nn as nn +import torch.nn.functional as f +from torch.nn import Linear, ReLU, Sequential +from torch_geometric.nn import ( + FiLMConv, + GATConv, + GATv2Conv, + GCNConv, + GINConv, + GINEConv, + RGATConv, + RGCNConv, + SAGEConv, + global_add_pool, +) +from torch_geometric.nn import global_max_pool as gmp + +""" + DeepChem feature set: 78 + ECFP4: 192 + ECFP4 + DeepChem: 270 + ECFP6: 256 + ECFP6 + DeepChem: 334 +""" + + +class GCNNet(torch.nn.Module): + """Standard graph convolutions to capture structural SMILES information.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialization method for GCNNet. + + :param n_output: Number of output units (default: 1) + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # SMILES graph branch + self.n_output = n_output + self.conv1 = GCNConv(num_features_xd, num_features_xd) + self.conv2 = GCNConv(num_features_xd, num_features_xd * 2) + self.conv3 = GCNConv(num_features_xd * 2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + # cell line feature + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, self.n_output) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None, return_attention_weights=False): + """ + Forward pass of the GCNNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for GCN) + :param edge_weight: Optional edge weights + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + # get graph input + # edge_weight is only used for decoding + + if edge_feat is not None: + pass + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # edge_index = edge_index.long() + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + x = self.conv1(x, edge_index, edge_weight) + x = self.relu(x) + x = self.conv2(x, edge_index, edge_weight) + x = self.relu(x) + x = self.conv3(x, edge_index, edge_weight) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # print(x_cell_mut.shape) + + # add this line for CNV data, remove for gene expr data + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GATNet(torch.nn.Module): + """Uses attention to weigh importance of neighboring nodes.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GATNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ + # graph input feed-forward + # x, edge_index, batch = data.x, data.edge_index, data.batch + if edge_feat is not None: + pass + + # x = self.dropout(x) + # x = f.dropout(x, p=0.2, training=self.training) + x = f.elu(self.gcn1(x, edge_index)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2(x, edge_index, return_attention_weights=return_attention_weights) + else: + x = self.gcn2(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + # return out, x + return out + + +class GATv2Net(torch.nn.Module): + """More expressive attention mechanism that supports edge features.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATv2Net model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATv2Conv( + num_features_xd, num_features_xd, heads=25, dropout=dropout, edge_dim=7, add_self_loops=False + ) + self.gcn2 = GATv2Conv(num_features_xd * 25, output_dim, dropout=dropout, edge_dim=7, add_self_loops=False) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + # cell line feature + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + # self.fc1_xt = nn.Linear(2944, output_dim) + # self.fc1_xt = nn.Linear(4224, output_dim) + # self.fc1_xt = nn.Linear(61824, output_dim) + self.fc1_xt = nn.Linear(4096, output_dim) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GATv2Net model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + # print(edge_feat.shape) + if edge_feat is not None: + pass + + # x = f.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = f.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + x = self.dropout(x) + # x = f.dropout(x, p=0.2, training=self.training) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights + ) + else: + x = self.gcn2(x, edge_index, edge_attr=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class GATNetE(torch.nn.Module): + """A GAT variant explicitly incorporating edge attributes.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GATNetE model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for cell line CNN branch + :param embed_dim: Embedding dimension for optional embeddings + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = GATConv(num_features_xd, num_features_xd, heads=10, dropout=dropout, edge_dim=7) + self.gcn2 = GATConv(num_features_xd * 10, output_dim, dropout=dropout, edge_dim=7) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, output_dim) + else: + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GATNet_E model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :return: out and attention weights/ out + """ + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + if edge_feat is not None: + pass + + # x = f.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = f.elu(self.gcn1(x, edge_index, edge_attr=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_attr=edge_feat, return_attention_weights=return_attention_weights + ) + else: + x = self.gcn2(x, edge_index, edge_attr=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class SAGENet(torch.nn.Module): + """Focuses on sampling and aggregating local node neighborhoods.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the SAGENet model. + + :param n_output: Number of output units + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to enable cross-attention layers + """ + super().__init__() + self.use_attn = use_attn + # SMILES graph branch + + # GCNSAGE + self.n_output = n_output + self.conv1 = SAGEConv(num_features_xd, num_features_xd) + self.conv2 = SAGEConv(num_features_xd, num_features_xd * 2) + self.conv3 = SAGEConv(num_features_xd * 2, num_features_xd * 4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the SAGENet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for SAGEConv) + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + # get graph input + # x, edge_index, batch = data.x, data.edge_index, data.batch + + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # GCNSAGE + x = self.conv1(x, edge_index) + x = self.relu(x) + x = self.conv2(x, edge_index) + x = self.relu(x) + x = self.conv3(x, edge_index) + x = self.relu(x) + x = gmp(x, batch) + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GINNet(torch.nn.Module): + """Structurally powerful architecture for distinguishing complex graph patterns.""" + + def __init__( + self, + n_output=1, + num_features_xd=334, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GINNet model. + + :param n_output: Number of output units + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to enable cross-attention layers + """ + super().__init__() + self.use_attn = use_attn + + dim = 32 + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.n_output = n_output + # convolution layers + nn1 = Sequential(Linear(num_features_xd, dim), ReLU(), Linear(dim, dim)) + self.conv1 = GINConv(nn1) + self.bn1 = torch.nn.BatchNorm1d(dim) + + nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv2 = GINConv(nn2) + self.bn2 = torch.nn.BatchNorm1d(dim) + + nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv3 = GINConv(nn3) + self.bn3 = torch.nn.BatchNorm1d(dim) + + nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv4 = GINConv(nn4) + self.bn4 = torch.nn.BatchNorm1d(dim) + + nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv5 = GINConv(nn5) + self.bn5 = torch.nn.BatchNorm1d(dim) + + self.fc1_xd = Linear(dim, output_dim) + + # 1D convolution on protein sequence + self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) + + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GINNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features (unused for GINConv) + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # print(x) + # print(data.target) + x = f.relu(self.conv1(x, edge_index)) + x = self.bn1(x) + x = f.relu(self.conv2(x, edge_index)) + x = self.bn2(x) + x = f.relu(self.conv3(x, edge_index)) + x = self.bn3(x) + x = f.relu(self.conv4(x, edge_index)) + x = self.bn4(x) + x = f.relu(self.conv5(x, edge_index)) + x = self.bn5(x) + x = global_add_pool(x, batch) + x = f.relu(self.fc1_xd(x)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class GINENet(torch.nn.Module): + """Combines GIN's structural power with edge feature integration.""" + + def __init__( + self, + n_output=1, + num_features_xd=334, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the GINENet model. + + :param n_output: Number of output units + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to enable cross-attention layers + """ + super().__init__() + self.use_attn = use_attn + + dim = 32 + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.n_output = n_output + # convolution layers + nn1 = Sequential(Linear(num_features_xd, dim), ReLU(), Linear(dim, dim)) + self.conv1 = GINEConv(nn1, edge_dim=7) + self.bn1 = torch.nn.BatchNorm1d(dim) + + nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv2 = GINEConv(nn2, edge_dim=7) + self.bn2 = torch.nn.BatchNorm1d(dim) + + nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv3 = GINEConv(nn3, edge_dim=7) + self.bn3 = torch.nn.BatchNorm1d(dim) + + nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv4 = GINEConv(nn4, edge_dim=7) + self.bn4 = torch.nn.BatchNorm1d(dim) + + nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) + self.conv5 = GINEConv(nn5, edge_dim=7) + self.bn5 = torch.nn.BatchNorm1d(dim) + + self.fc1_xd = Linear(dim, output_dim) + + # 1D convolution on protein sequence + self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the GINENet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge features of the molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + if edge_feat is not None: + pass + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # print(x) + # print(data.target) + x = f.relu(self.conv1(x, edge_index, edge_attr=edge_feat)) + x = self.bn1(x) + x = f.relu(self.conv2(x, edge_index, edge_attr=edge_feat)) + x = self.bn2(x) + x = f.relu(self.conv3(x, edge_index, edge_attr=edge_feat)) + x = self.bn3(x) + x = f.relu(self.conv4(x, edge_index, edge_attr=edge_feat)) + x = self.bn4(x) + x = f.relu(self.conv5(x, edge_index, edge_attr=edge_feat)) + x = self.bn5(x) + x = global_add_pool(x, batch) + x = f.relu(self.fc1_xd(x)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class RGCNNet(torch.nn.Module): + """Uses relation-specific weights for multi-relational drug graphs.""" + + def __init__( + self, + n_output=1, + n_filters=32, + embed_dim=128, + num_features_xd=334, + num_features_xt=25, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the RGCNNet model. + + :param n_output: Number of output units + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param num_features_xd: Number of molecular graph node features + :param num_features_xt: Number of cell line features + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # SMILES graph branch + self.n_output = n_output + self.conv1 = RGCNConv(num_features_xd, num_features_xd, num_relations=4) + self.conv2 = RGCNConv(num_features_xd, num_features_xd * 2, num_relations=4) + self.conv3 = RGCNConv(num_features_xd * 2, num_features_xd * 4, num_relations=4) + self.fc_g1 = torch.nn.Linear(num_features_xd * 4, 1024) + self.fc_g2 = torch.nn.Linear(1024, output_dim) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout, batch_first=True) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout, batch_first=True) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, output_dim) + else: + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, self.n_output) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None, return_attention_weights=False): + """ + Forward pass of the RGCNNet model. + + :param x: Node feature matrix of the molecular graph + :param edge_index: Edge indices of the molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge type indices for relational graph convolution + :param edge_weight: Optional edge weights + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + # get graph input + # edge_weight is only used for decoding + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + if edge_feat is not None: + edge_feat = edge_feat.long().view(-1) + + # x, edge_index, batch = data.x, data.edge_index, data.batch + # edge_index = edge_index.long() + # edge_feat = edge_feat.long().squeeze() + + x = self.conv1(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = self.conv2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = self.conv3(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + + # flatten + x = self.relu(self.fc_g1(x)) + x = self.dropout(x) + x = self.fc_g2(x) + x = self.dropout(x) + + # get protein input + # target = data.target + # print(x_cell_mut.shape) + + # add this line for CNV data, remove for gene expr data + # x_cell_mut = x_cell_mut[:,None,:] + if x_cell_mut.dim() == 2: + x_cell_mut = x_cell_mut.unsqueeze(1) + + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + return out + + +class WIRGATNet(torch.nn.Module): + """Relational attention focused on interactions within the same relation.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the WIRGATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = RGATConv( + num_features_xd, + num_features_xd, + num_relations=4, + attention_mechanism="within-relation", + heads=10, + dropout=dropout, + ) + self.gcn2 = RGATConv( + num_features_xd * 10, output_dim, num_relations=4, attention_mechanism="within-relation", dropout=dropout + ) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the WIRGATNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + if edge_feat is not None: + pass + edge_feat = edge_feat.int().squeeze() + # print(edge_feat) + + # x = f.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = f.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights + ) + else: + x = self.gcn2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class ARGATNet(torch.nn.Module): + """Relational attention designed to process features across different relations.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the ARGATNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = RGATConv( + num_features_xd, + num_features_xd, + num_relations=4, + attention_mechanism="across-relation", + heads=10, + dropout=dropout, + ) + self.gcn2 = RGATConv( + num_features_xd * 10, output_dim, num_relations=4, attention_mechanism="across-relation", dropout=dropout + ) + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the ARGATNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: edge features of molecular graph + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response or (prediction, attention weights) + """ + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + if edge_feat is not None: + pass + edge_feat = edge_feat.int().squeeze() + + # x = f.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + x = f.elu(self.gcn1(x, edge_index, edge_type=edge_feat)) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + if return_attention_weights: + x, attn_weights = self.gcn2( + x, edge_index, edge_type=edge_feat, return_attention_weights=return_attention_weights + ) + else: + x = self.gcn2(x, edge_index, edge_type=edge_feat) + x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + if return_attention_weights: + return out, attn_weights + else: + return out + + +class FiLMNet(torch.nn.Module): + """Adaptively modulates features based on specific graph relations.""" + + def __init__( + self, + num_features_xd=334, + n_output=1, + num_features_xt=25, + n_filters=32, + embed_dim=128, + output_dim=128, + dropout=0.5, + use_attn=False, + ): + """ + Initialize the FiLMNet model. + + :param num_features_xd: Number of molecular graph node features + :param n_output: Number of output units + :param num_features_xt: Number of cell line features + :param n_filters: Number of convolution filters for the cell line CNN branch + :param embed_dim: Embedding dimension (unused but kept for API consistency) + :param output_dim: Dimensionality of the latent representation + :param dropout: Dropout probability + :param use_attn: Whether to use cross‑attention between drug and cell line features + """ + super().__init__() + self.use_attn = use_attn + + # graph layers + self.gcn1 = FiLMConv(num_features_xd, num_features_xd, num_relations=4, act=nn.LeakyReLU(), edge_dim=7) + self.gcn2 = FiLMConv(num_features_xd, output_dim, num_relations=4, act=nn.LeakyReLU(), edge_dim=7) + + self.fc_g1 = nn.Linear(output_dim, output_dim) + + # cell line feature + if num_features_xt < 50: + k = 3 + p = 2 + else: + k = 8 + p = 3 + + self.conv_xt_1 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=k) + self.pool_xt_1 = nn.MaxPool1d(p) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=n_filters * 2, kernel_size=k) + self.pool_xt_2 = nn.MaxPool1d(p) + self.conv_xt_3 = nn.Conv1d(in_channels=n_filters * 2, out_channels=n_filters * 4, kernel_size=k) + self.pool_xt_3 = nn.MaxPool1d(p) + + with torch.no_grad(): + dummy = torch.zeros(1, 1, num_features_xt) + conv_xt = self.pool_xt_1(self.conv_xt_1(dummy)) + conv_xt = self.pool_xt_2(self.conv_xt_2(conv_xt)) + conv_xt = self.pool_xt_3(self.conv_xt_3(conv_xt)) + flat_dim = conv_xt.shape[1] * conv_xt.shape[2] + + self.fc1_xt = nn.Linear(flat_dim, output_dim) + + if self.use_attn: + self.cross_attn1 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.cross_attn2 = nn.MultiheadAttention(output_dim, num_heads=8, dropout=dropout) + self.norm1 = nn.LayerNorm(output_dim) + self.norm2 = nn.LayerNorm(output_dim) + self.fc = nn.Linear(2 * output_dim, 128) + else: + # combined layers + self.fc1 = nn.Linear(2 * output_dim, 1024) + self.fc2 = nn.Linear(1024, 128) + self.out = nn.Linear(128, n_output) + + # activation and regularization + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, return_attention_weights=False): + """ + Forward pass of the FiLMNet model. + + :param x: feature matrix of molecular graph + :param edge_index: edges of molecular graph + :param batch: Batch vector assigning nodes to graphs + :param x_cell_mut: Cell line omics features + :param edge_feat: Edge type indices for FiLM modulation + :param return_attention_weights: Whether to return attention weights + :returns: Predicted drug response + """ + # graph input feed-forward + # x, edge_index, batch, edge_feat = data.x, data.edge_index, data.batch, data.edge_features + # print(data.x.shape) + if edge_feat is not None: + pass + edge_feat = edge_feat.int().squeeze() + + # x = f.dropout(x, p=0.2, training=self.training) + # x = self.dropout(x) + self.gcn1(x, edge_index, edge_type=edge_feat) + # x = f.dropout(x, p=0.2, training=self.training) + x = self.dropout(x) + x = self.gcn2(x, edge_index, edge_type=edge_feat) + # x = self.relu(x) + x = gmp(x, batch) # global max pooling + x = self.fc_g1(x) + x = self.relu(x) + + MIN_CNN_INPUT = 22 + if x_cell_mut.shape[-1] < MIN_CNN_INPUT: + pad = MIN_CNN_INPUT - x_cell_mut.shape[-1] + x_cell_mut = torch.nn.functional.pad(x_cell_mut, (0, pad)) + + # protein input feed-forward: + # target = data.target + # x_cell_mut = x_cell_mut[:,None,:] + # 1d conv layers + conv_xt = self.conv_xt_1(x_cell_mut) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_1(conv_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + conv_xt = f.relu(conv_xt) + conv_xt = self.pool_xt_3(conv_xt) + + # flatten + xt = conv_xt.view(-1, conv_xt.shape[1] * conv_xt.shape[2]) + xt = self.fc1_xt(xt) + + if self.use_attn: + xc1, _ = self.cross_attn1(x, xt, xt) + xc1 = xc1 + x + xc1 = self.norm1(xc1) + xc2, _ = self.cross_attn2(xt, x, x) + xc2 = xc2 + xt + xc2 = self.norm2(xc2) + xc = torch.cat((xc1, xc2), 1) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + else: + # concat + xc = torch.cat((x, xt), 1) + # add some dense layers + xc = self.fc1(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + xc = self.fc2(xc) + xc = self.relu(xc) + xc = self.dropout(xc) + + if self.use_attn: + pass + out = self.out(xc) + out = nn.Sigmoid()(out) + + return out diff --git a/drevalpy/models/XGDP/hyperparameters.yaml b/drevalpy/models/XGDP/hyperparameters.yaml new file mode 100644 index 00000000..f1c797ca --- /dev/null +++ b/drevalpy/models/XGDP/hyperparameters.yaml @@ -0,0 +1,25 @@ +XGDP: + model_type: + - "GATNet" + - "GCNNet" + - "GATv2Net" + - "GATNetE" + - "SAGENet" + - "GINNet" + - "GINENet" + - "RGCNNet" + - "WIRGATNet" + - "ARGATNet" + - "FiLMNet" + learning_rate: + - 0.001 + epochs: + - 100 + batch_size: + - 64 + - 128 + weight_decay: + - 0.0001 + dropout_rate: + - 0.2 + - 0.3 diff --git a/drevalpy/models/XGDP/xgdp.py b/drevalpy/models/XGDP/xgdp.py new file mode 100644 index 00000000..1c538310 --- /dev/null +++ b/drevalpy/models/XGDP/xgdp.py @@ -0,0 +1,444 @@ +"""Contains XGDP, a GNN and CNN based drug response prediction model. + +Drug discovery and mechanism prediction with explainable graph neural networks. + +Original authors: Wang, C., Kumar, G.A. & Rajapakse, J.C. (2025, 10.1038/s41598-024-83090-3) +Code adapted from their Github: https://github.com/SCSE-Biomedical-Computing-Group/XGDP/blob/main/utils_tcnn.py +""" + +from pathlib import Path +from typing import Any + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from torch.optim import Adam +from torch.utils.data import Dataset as PytorchDataset +from torch_geometric.loader import DataLoader + +import drevalpy.models.XGDP._models as Models + +from ...datasets.dataset import DrugResponseDataset, FeatureDataset +from ..drp_model import DRPModel +from ..lightning_metrics_mixin import RegressionMetricsMixin +from ..utils import load_and_select_gene_features + + +class _XGDPDataset(PytorchDataset): + """A PyTorch Dataset for XGDP.""" + + def __init__( + self, + response: np.ndarray, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_features: FeatureDataset, + drug_features: FeatureDataset, + ): + """Initialize the dataset. + + :param response: The drug response values. + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_features: A FeatureDataset object with cell line features. + :param drug_features: A FeatureDataset object with drug features. + """ + self.response = response + self.cell_line_ids = cell_line_ids + self.drug_ids = drug_ids + + # preconvert to tensors to avoid per item tensor creation + self.cell_features = { + cl_id: torch.tensor(features["gene_expression"], dtype=torch.float32).unsqueeze(0) + for cl_id, features in cell_line_features.features.items() + } + self.response_tensor = torch.tensor(self.response, dtype=torch.float32) + + self.drug_graphs = { + drug_id: feature_views["drug_graph"] for drug_id, feature_views in drug_features.features.items() + } + + def __len__(self): + return len(self.response) + + def __getitem__(self, idx): + cell_line_id = self.cell_line_ids[idx] + drug_id = self.drug_ids[idx] + + drug_graph = self.drug_graphs[drug_id] + cell_feat = self.cell_features[cell_line_id] + response = self.response_tensor[idx] + + return drug_graph, cell_feat, response + + +class XGDPModule(pl.LightningModule): + """The LightningModule for the XGDP model.""" + + def __init__( + self, + model: str, + num_node_features: int, + num_cell_features: int, + do_att: bool, + hidden_dim: int = 128, + dropout: float = 0.5, # changed to 0.5 as per there default settings for there models + learning_rate: float = 0.001, + ): + """Initialize the LightningModule. + + :param model: Name of the XGDP backbone to use (e.g., 'gat', 'gcn', 'gatv2'). + :param do_att: Whether to enable cross-attention between drug and cell features. + :param num_node_features: Number of features for each node in the drug graph. + :param num_cell_features: Number of features for the cell line. + :param hidden_dim: The hidden dimension size. + :param dropout: The dropout rate. + :param learning_rate: The learning rate. + :raises ValueError: If drug_input is not provided. + """ + super().__init__() + self.save_hyperparameters() + self + # model_name = model + model_name = model + try: + model_class = getattr(Models, model_name) + print(type(model_class)) + except AttributeError: + # Specifically catch the error if the string name doesn't exist in 'models' + raise ValueError(f"Model '{model_name}' not found in the list of available models.") + + self.return_attention_weights = ( + do_att if "GAT" in model_name else False + ) # because no need to return attention weight because no models explanation through XAI + + self.model = model_class( + n_output=1, + num_features_xd=num_node_features, # gnn number of node features (ECFP& + DeepChem: 334) + num_features_xt=25, # only used in GINNet in embedding, mutation/protein data??? + n_filters=32, # number of filters for cnn for gene expression + embed_dim=128, # only used in GINNet in embedding + output_dim=hidden_dim, # size latend sapce as output of gnn + cnn -> size of shared size after combination + dropout=dropout, + use_attn=do_att, # not in GINNet, SAGENet, add to models to have same input + ) + self.criterion = nn.MSELoss() + + # Initialize metrics storage for epoch-end R^2 and PCC computation + # self._init_metrics_storage() + + def forward(self, batch): + """Forward pass of the module. + + :param batch: The batch. + :return: The output of the model. + """ + drug_graph, cell_features, _ = batch + if self.return_attention_weights: + # print("x", drug_graph.x.shape, type(drug_graph.x)) + # print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + # print("edge_index", drug_graph.edge.shape,type(drug_graph.edge)) + + out, _ = self.model.forward( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), + return_attention_weights=True, + ) + return out + else: + out = self.model.forward( + x=drug_graph.x, + edge_index=drug_graph.edge_index, + batch=drug_graph.batch, + x_cell_mut=cell_features, + edge_feat=getattr(drug_graph, "edge_attr", None), + ) + return out + + def training_step(self, batch, batch_idx): + """A single training step. + + :param batch: The batch. + :param batch_idx: The batch index. + :return: The loss. + """ + drug_graph, cell_features, responses = batch + outputs = self(batch) + loss = self.criterion(outputs, responses) + self.log("train_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + # self._store_predictions(outputs, responses, is_training=True) + + return loss + + def validation_step(self, batch, batch_idx): + """A single validation step. + + :param batch: The batch. + :param batch_idx: The batch index. + """ + drug_graph, cell_features, responses = batch + outputs = self.forward(batch) + loss = self.criterion(outputs, responses) + self.log("val_loss", loss, on_step=False, on_epoch=True, batch_size=responses.size(0)) + + # Store predictions and targets for epoch-end metrics via mixin + # self._store_predictions(outputs, responses, is_training=False) + pass + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """A single prediction step. + + :param batch: The batch. + :param batch_idx: The batch index. + :param dataloader_idx: The dataloader index. + :return: The output of the model. + """ + drug_graph, cell_features, _ = batch + outputs = self.forward(batch) + outputs = outputs.squeeze(-1) # to flatten torch tensor so that it can be handeled by pandas + return outputs + + def configure_optimizers(self): + """Configure the optimizer. + + :return: The optimizer. + """ + return Adam(self.parameters(), lr=self.hparams.learning_rate) + + +class XGDP(DRPModel, RegressionMetricsMixin): + """XGDP model for ...""" + + cell_line_views = ["gene_expression"] + drug_views = ["drug_graph"] + + def __init__(self) -> None: + """Initialize the XGDP model.""" + super().__init__() + self.DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model: XGDPModule | None = None + self.hyperparameters: dict[str, Any] = {} + self.gene_expression_scaler: StandardScaler | None = None + self.gene_expression_normalizer: MinMaxScaler | None = None + + @classmethod + def get_model_name(cls) -> str: + """ + Get the model name. + + :returns: XGDP + """ + return "XGDP" + + def build_model(self, hyperparameters: dict[str, Any]) -> None: + """ + Builds the XGDP model with the specified hyperparameters. + + :param hyperparameters: TODO: ADD HYPERPARAMETERS + """ + self.hyperparameters = hyperparameters + # init in train + # self.model = XGDPPredictor(name_hyperparameter=hyperparameter["name_hyperparameter"]) + """ + # Log hyperparameters to wandb if enabled + self.log_hyperparameters(hyperparameters) + + self.hyperparameters = hyperparameters + # init model in train + """ + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the cell line features. + + :param data_path: Path to the gene expression and landmark genes + :param dataset_name: name of the dataset + :return: FeatureDataset containing the cell line gene expression features. + """ + return load_and_select_gene_features( + feature_type="gene_expression", + gene_list=None, + data_path=data_path, + dataset_name=dataset_name, + ) + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Loads the pre-computed drug graph data. + + :param data_path: Path to the data directory. + :param dataset_name: Name of the dataset. + :raises FileNotFoundError: If the drug graph directory is not found. + :raises ValueError: If no drug graphs are loaded. + :return: FeatureDataset containing the drug graphs. + """ + graph_path = Path(data_path) / dataset_name / "drug_graphs" + if not graph_path.exists(): + raise FileNotFoundError( + f"Drug graph directory not found at {graph_path}. " + f"Please run 'create_drug_graphs.py' for the {dataset_name} dataset." + ) + + drug_graphs = {} + for p_file in graph_path.glob("*.pt"): + drug_id = p_file.stem + drug_graphs[drug_id] = torch.load(p_file, weights_only=False) + + if not drug_graphs: + raise ValueError(f"No drug graphs loaded from {graph_path}. Check the directory and file contents.") + + feature_dict = {drug_id: {"drug_graph": graph} for drug_id, graph in drug_graphs.items()} + + return FeatureDataset(features=feature_dict) + + def _loader_kwargs(self) -> dict[str, Any]: + num_workers = int(self.hyperparameters.get("num_workers", 4)) + kw = { + "num_workers": num_workers, + "pin_memory": True, + } + if num_workers > 0: + kw["persistent_workers"] = True + kw["prefetch_factor"] = int(self.hyperparameters.get("prefetch_factor", 2)) + return kw + + def train( + self, + output: DrugResponseDataset, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + output_earlystopping: DrugResponseDataset | None = None, + **kwargs, + ): + """Train the model. + + :param output: The output dataset. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset. + :param output_earlystopping: The early stopping output dataset. + :param kwargs: Additional arguments. + :raises ValueError: If drug input is not provided. + """ + if drug_input is None: + raise ValueError("Drug input is required for XGDP") + + # Determine feature sizes + num_node_features = next(iter(drug_input.features.values()))["drug_graph"].num_node_features + num_cell_features = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + + self.model = XGDPModule( + model=self.hyperparameters.get("model_type", "GATv2Net"), + num_node_features=num_node_features, + num_cell_features=num_cell_features, + hidden_dim=self.hyperparameters.get("hidden_dim", 128), + dropout=self.hyperparameters.get("dropout", 0.5), + learning_rate=self.hyperparameters.get("learning_rate", 0.001), + do_att=self.hyperparameters.get("do_att", True), + ) + + train_dataset = _XGDPDataset( + response=output.response, + cell_line_ids=output.cell_line_ids, + drug_ids=output.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + train_loader = DataLoader( + train_dataset, + batch_size=self.hyperparameters.get("batch_size", 64), + shuffle=True, + **self._loader_kwargs(), + ) + + val_loader = None + if output_earlystopping is not None and len(output_earlystopping) > 0: + val_dataset = _XGDPDataset( + response=output_earlystopping.response, + cell_line_ids=output_earlystopping.cell_line_ids, + drug_ids=output_earlystopping.drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + val_loader = DataLoader( + val_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + # Set up wandb logger if project is provided + loggers = [] + if self.wandb_project is not None: + from pytorch_lightning.loggers import WandbLogger + + logger = WandbLogger(project=self.wandb_project, log_model=False) + loggers.append(logger) + + trainer = pl.Trainer( + # max_epochs=self.hyperparameters.get("epochs", 100), #changed to 10 fro testing + max_epochs=1, + accelerator="auto", + devices="auto", + callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)] if val_loader else None, + logger=loggers if loggers else True, # Use default logger if no wandb + enable_progress_bar=True, + log_every_n_steps=int(self.hyperparameters.get("log_every_n_steps", 50)), + precision=self.hyperparameters.get("precision", 32), + ) + trainer.fit(self.model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + def predict( + self, + cell_line_ids: np.ndarray, + drug_ids: np.ndarray, + cell_line_input: FeatureDataset, + drug_input: FeatureDataset | None = None, + ) -> np.ndarray: + """Predict drug response. + + :param cell_line_ids: The cell line IDs. + :param drug_ids: The drug IDs. + :param cell_line_input: The cell line input dataset. + :param drug_input: The drug input dataset. + :raises RuntimeError: If the model has not been trained yet. + :raises ValueError: If drug input is not provided. + :return: The predicted drug response. + """ + if len(drug_ids) == 0 or len(cell_line_ids) == 0: + print("XGDP predict: No drug or cell line IDs provided; returning empty array.") + return np.array([]) + if self.model is None: + raise RuntimeError("Model has not been trained yet.") + if drug_input is None: + raise ValueError("Drug input is required for XGDP") + + predict_dataset = _XGDPDataset( + response=np.zeros(len(cell_line_ids)), + cell_line_ids=cell_line_ids, + drug_ids=drug_ids, + cell_line_features=cell_line_input, + drug_features=drug_input, + ) + predict_loader = DataLoader( + predict_dataset, + batch_size=self.hyperparameters.get("batch_size", 32), + **self._loader_kwargs(), + ) + + trainer = pl.Trainer(accelerator="auto", devices="auto", enable_progress_bar=False) + predictions_list = trainer.predict(self.model, dataloaders=predict_loader) + + if not predictions_list: + print("XGDP predict: No predictions were made; returning empty array.") + return np.array([]) + + predictions_flat = [ + item for sublist in predictions_list for item in (sublist if isinstance(sublist, list) else [sublist]) + ] + + predictions = torch.cat(predictions_flat).view(-1).cpu().numpy() + return predictions diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 7a8dec71..8224c3e9 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -25,6 +25,7 @@ "DIPKModel", "DrugGNN", "PharmaFormerModel", + "XGDP", "KNNRegressor", "AdaBoostDecisionTree", "Lasso", @@ -60,6 +61,7 @@ from .SimpleNeuralNetwork.simple_neural_network import SimpleNeuralNetwork from .SRMF.srmf import SRMF from .SuperFELTR.superfeltr import SuperFELTR +from .XGDP.xgdp import XGDP # SINGLE_DRUG_MODEL_FACTORY is used in the pipeline! SINGLE_DRUG_MODEL_FACTORY: dict[str, type[DRPModel]] = { @@ -88,6 +90,7 @@ "DIPK": DIPKModel, "DrugGNN": DrugGNN, "PharmaFormer": PharmaFormerModel, + "XGDP": XGDP, "KNNRegressor": KNNRegressor, "AdaBoostDecisionTree": AdaBoostDecisionTree, "Lasso": LassoModel, diff --git a/tests/models/test_global_models.py b/tests/models/test_global_models.py index b643208a..78284e38 100644 --- a/tests/models/test_global_models.py +++ b/tests/models/test_global_models.py @@ -25,6 +25,7 @@ "SimpleNeuralNetwork[chemberta]", "MultiViewNeuralNetwork", "PharmaFormer", + "XGDP", ], ) def test_global_models( diff --git a/tests/test_drp_model.py b/tests/test_drp_model.py index 535f7915..a423e7ec 100644 --- a/tests/test_drp_model.py +++ b/tests/test_drp_model.py @@ -43,6 +43,7 @@ def test_factory() -> None: assert "MOLIR" in MODEL_FACTORY assert "SuperFELTR" in MODEL_FACTORY assert "DIPK" in MODEL_FACTORY + assert "XGDP" in MODEL_FACTORY def test_load_cl_ids_from_csv() -> None: