diff --git a/examples/config/graphormer_pretraining.yaml b/examples/config/graphormer_pretraining.yaml new file mode 100644 index 0000000..d740d34 --- /dev/null +++ b/examples/config/graphormer_pretraining.yaml @@ -0,0 +1,56 @@ +callbacks: + patience: 100 + tol: 0 +data: + baseMVA: 100 + learn_mask: false + mask_dim: 6 + mask_ratio: 0.5 + mask_type: rnd + mask_value: -1.0 + networks: + # - Texas2k_case1_2016summerpeak + - case24_ieee_rts + - case118_ieee + - case300_ieee + normalization: baseMVAnorm + scenarios: + - 5000 + - 5000 + - 5000 + test_ratio: 0.1 + val_ratio: 0.1 + workers: 4 + add_graphormer_encoding: true + max_node_num: 300 # necessary for Graphormer + max_hops: 6 # for the edge encoding, should match + edge_type: multi_hop # singlehop +model: + attention_head: 8 + dropout: 0.1 + edge_dim: 2 + hidden_size: 123 + input_dim: 9 + num_layers: 14 + output_dim: 6 + pe_dim: 20 + type: Graphormer #GPSTransformer # +optimizer: + beta1: 0.9 + beta2: 0.999 + learning_rate: 0.0001 + lr_decay: 0.7 + lr_patience: 10 +seed: 0 +training: + batch_size: 8 + epochs: 500 + loss_weights: + - 0.01 + - 0.99 + losses: + - MaskedMSE + - PBE + accelerator: auto + devices: auto + strategy: auto diff --git a/gridfm_graphkit/datasets/powergrid_datamodule.py b/gridfm_graphkit/datasets/powergrid_datamodule.py index c18c360..ff796a8 100644 --- a/gridfm_graphkit/datasets/powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/powergrid_datamodule.py @@ -128,6 +128,7 @@ def setup(self, stage: str): pe_dim=self.args.model.pe_dim, mask_dim=self.args.data.mask_dim, transform=get_transform(args=self.args), + args=self.args.data ) self.datasets.append(dataset) diff --git a/gridfm_graphkit/datasets/powergrid_dataset.py b/gridfm_graphkit/datasets/powergrid_dataset.py index 026d9a8..309800a 100644 --- a/gridfm_graphkit/datasets/powergrid_dataset.py +++ b/gridfm_graphkit/datasets/powergrid_dataset.py @@ -2,6 +2,7 @@ from gridfm_graphkit.datasets.transforms import ( AddEdgeWeights, AddNormalizedRandomWalkPE, + AddGraphormerEncodings ) import os.path as osp @@ -43,6 +44,7 @@ def __init__( transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, + args: Optional = {}, ): self.norm_method = norm_method self.node_normalizer = node_normalizer @@ -51,6 +53,13 @@ def __init__( self.mask_dim = mask_dim self.length = None + if ("add_graphormer_encoding" in args) and args.add_graphormer_encoding: + self.add_graphormer_encoding = args.add_graphormer_encoding + self.max_node_num = args.max_node_num + self.max_hops = args.max_hops + else: + self.add_graphormer_encoding = False + super().__init__(root, transform, pre_transform, pre_filter) # Load normalization stats if available @@ -171,6 +180,7 @@ def process(self): attr_name="pe", ) graph_data = pe_transform(graph_data) + torch.save( graph_data, osp.join( @@ -194,6 +204,11 @@ def len(self): self.length = len(files) return self.length + def __cat_dim__(self, key, value, *args, **kwargs): + if key in ['attn_bias', 'spatial_pos', 'in_degree', 'edge_input']: + return None + return super().__cat_dim__(key, value, *args, **kwargs) + def get(self, idx): file_name = osp.join( self.processed_dir, @@ -204,6 +219,14 @@ def get(self, idx): data = torch.load(file_name, weights_only=False) if self.transform: data = self.transform(data) + + if self.add_graphormer_encoding: + gr_transform = AddGraphormerEncodings( + self.max_node_num, + self.max_hops + ) + data = gr_transform(data) + return data def change_transform(self, new_transform): diff --git a/gridfm_graphkit/datasets/transforms.py b/gridfm_graphkit/datasets/transforms.py index fb770d3..832087b 100644 --- a/gridfm_graphkit/datasets/transforms.py +++ b/gridfm_graphkit/datasets/transforms.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from torch_geometric.transforms import BaseTransform -from typing import Optional +from typing import Optional, Any import torch_geometric.typing from torch_geometric.data import Data from torch_geometric.utils import ( @@ -15,6 +15,12 @@ to_torch_csr_tensor, ) +import numpy as np +import os +import pyximport +pyximport.install(setup_args={'include_dirs': np.get_include()}) +import gridfm_graphkit.models.algos as algos + class AddNormalizedRandomWalkPE(BaseTransform): r"""Adds the random walk positional encoding from the @@ -84,6 +90,184 @@ def get_pe(out: Tensor) -> Tensor: return data +def add_node_attr(data: Data, + value: Any, + attr_name: str + ) -> Data: + if attr_name is None: + raise ValueError("Expected attr_name to be not None") + else: + data[attr_name] = value + + return data + +def convert_to_single_emb(x, offset=512): + """ + The edge feature embedding range is set to 512, with the futher assumption + that the input range is from -512 to 512. This may need to change in the future. + """ + x = torch.clamp(x, -512, 512) + x = ( 512*(x+512)/1024 ).long() + + feature_num = x.size(1) if len(x.size()) > 1 else 1 + feature_offset = 1 + \ + torch.arange( + 0, + (feature_num) * offset, + offset, + dtype=torch.long + ) + + x = x + feature_offset + + return x + +def get_edge_encoding(edge_attr, N, edge_index, max_dist, path): + if len(edge_attr.size()) == 1: + edge_attr = edge_attr[:, None] + attn_edge_type = torch.zeros([N, N, edge_attr.size(-1)], dtype=torch.long) + attn_edge_type[edge_index[0, :], edge_index[1, :] + ] = convert_to_single_emb(edge_attr.long()) + 1 + if os.name == 'nt': + edge_input = algos.gen_edge_input( + max_dist, + path, + attn_edge_type.numpy(), + localtype=np.int32 + ) + else: + edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy()) + + return attn_edge_type, torch.from_numpy(edge_input).long() + +def preprocess_item(data, max_hops): + """ + Calculation of the Graphormer attention bias, and positional/structural + variables. From a Data-like object the shortest paths in number of hops + between nodes are calculated, being cut off at max_hops. In addition to the + centrality (assume undirected graphs) and attention bias, these are the + inputs to the model structural and positional encodings. + """ + edge_index = data.edge_index + edge_attr = data.edge_attr + N = data.num_nodes + edge_adj = torch.sparse_coo_tensor( + edge_index, + torch.ones(edge_index.shape[1]).to(data.x.device), + [N, N] + ) + + adj = edge_adj.to_dense().to(torch.int16) + + # get shortest paths in number of hops (shortest_path_result) and intermediate nodes + # for those shortest paths (path) + if os.name == 'nt': + shortest_path_result, path = algos.floyd_warshall( + adj.numpy().astype(np.int32), + max_hops, + localtype=np.int32 + ) + else: + shortest_path_result, path = algos.floyd_warshall( + adj.numpy().astype(np.int32), + max_hops + ) + + spatial_pos = torch.from_numpy((shortest_path_result)).long().to(data.x.device) + attn_bias = torch.zeros([N, N], dtype=torch.float).to(data.x.device) + + if edge_attr is not None: + attn_edge_type, edge_input = get_edge_encoding(edge_attr, N, edge_index, max_hops, path) + else: + edge_input = None + attn_edge_type = None + + in_degree = adj.long().sum(dim=1).view(-1) + out_degree = adj.long().sum(dim=0).view(-1) + return attn_bias, spatial_pos, in_degree, out_degree, attn_edge_type, edge_input + +def pad_1d_unsqueeze(x, padlen): + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros([padlen], dtype=x.dtype) + new_x[:xlen] = x + x = new_x + return x.unsqueeze(0) + +def pad_attn_bias_unsqueeze(x, padlen): + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros( + [padlen, padlen], dtype=x.dtype).fill_(float('-inf')) + new_x[:xlen, :xlen] = x + new_x[xlen:, :xlen] = 0 + x = new_x + return x.unsqueeze(0) + +def pad_edge_bias_unsqueeze(x, padlen): + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros( + (padlen, padlen) + x.size()[2:], dtype=x.dtype).fill_(int(0)) + new_x[:xlen, :xlen] = x + new_x[xlen:, :xlen] = 0 + x = new_x + return x.unsqueeze(0) + +def pad_spatial_pos_unsqueeze(x, padlen): + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) + new_x[:xlen, :xlen] = x + x = new_x + return x.unsqueeze(0) + + +class AddGraphormerEncodings(BaseTransform): + """Adds a positional encoding (node centrallity) to the given graph, as + well as the attention and edge biases, as described in: Do transformers + really perform badly for graph representation?, C. Ying et al., 2021. + + Args: + max_node_num (int): The number of nodes in the largest graph considered. + max_hops (int): The maximum path length between nodes to consider for + the edge encodings. + """ + + def __init__( + self, + max_node_num: int, + max_hops: int, + ) -> None: + self.max_node_num = max_node_num + self.max_hops = max_hops + + def forward(self, data: Data) -> Data: + if data.edge_index is None: + raise ValueError("Expected data.edge_index to be not None") + + N = data.num_nodes + if N is None: + raise ValueError("Expected data.num_nodes to be not None") + + attn_bias, spatial_pos, in_degree, out_degree, attn_edge_type, edge_input = \ + preprocess_item(data, self.max_hops) + + attn_bias = pad_attn_bias_unsqueeze(attn_bias, self.max_node_num) + spatial_pos = pad_spatial_pos_unsqueeze(spatial_pos, self.max_node_num) + in_degree = pad_1d_unsqueeze(in_degree, self.max_node_num).squeeze() + edge_input = pad_edge_bias_unsqueeze(edge_input, self.max_node_num) + attn_edge_type = pad_edge_bias_unsqueeze(attn_edge_type, self.max_node_num) + + data = add_node_attr(data, attn_bias, attr_name='attn_bias') + data = add_node_attr(data, spatial_pos, attr_name='spatial_pos') + data = add_node_attr(data, in_degree, attr_name='in_degree') + data = add_node_attr(data, edge_input, attr_name='edge_input') + data = add_node_attr(data, attn_edge_type, attr_name='attn_edge_type') + + return data + + class AddEdgeWeights(BaseTransform): """ Computes and adds edge weight as the magnitude of complex admittance. diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index de355d3..ce5432e 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,5 @@ from gridfm_graphkit.models.gps_transformer import GPSTransformer from gridfm_graphkit.models.gnn_transformer import GNN_TransformerConv +from gridfm_graphkit.models.graphormer import Graphormer -__all__ = ["GPSTransformer", "GNN_TransformerConv"] +__all__ = ["GPSTransformer", "GNN_TransformerConv", "Graphormer"] diff --git a/gridfm_graphkit/models/algos.pyx b/gridfm_graphkit/models/algos.pyx new file mode 100644 index 0000000..6701740 --- /dev/null +++ b/gridfm_graphkit/models/algos.pyx @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import cython +from cython.parallel cimport prange, parallel +cimport numpy +import numpy + +def floyd_warshall(adjacency_matrix, max_hops, localtype=long): + + (nrows, ncols) = adjacency_matrix.shape + assert nrows == ncols + cdef unsigned int n = nrows + cdef unsigned int max_hops_copy = max_hops + + adj_mat_copy = adjacency_matrix.astype(localtype, order='C', casting='safe', copy=True) + assert adj_mat_copy.flags['C_CONTIGUOUS'] + cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy + cdef numpy.ndarray[long, ndim=2, mode='c'] path = numpy.zeros([n, n], dtype=localtype) + + cdef unsigned int i, j, k + cdef long M_ij, M_ik, cost_ikkj + cdef long* M_ptr = &M[0,0] + cdef long* M_i_ptr + cdef long* M_k_ptr + + # set unreachable nodes distance to 510 + for i in range(n): + for j in range(n): + if i == j: + M[i][j] = 0 + elif M[i][j] == 0: + M[i][j] = 510 + + # floyed algo + for k in range(n): + M_k_ptr = M_ptr + n*k + for i in range(n): + M_i_ptr = M_ptr + n*i + M_ik = M_i_ptr[k] + for j in range(n): + cost_ikkj = M_ik + M_k_ptr[j] + M_ij = M_i_ptr[j] + if cost_ikkj > max_hops_copy: + continue + if M_ij > cost_ikkj: + M_i_ptr[j] = cost_ikkj + path[i][j] = k + + # set unreachable path to 510 + for i in range(n): + for j in range(n): + if M[i][j] >= 510: + path[i][j] = 510 + M[i][j] = 510 + + return M, path + + +def get_all_edges(path, i, j): + cdef unsigned int k = path[i][j] + if k == 0: + return [] + else: + return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) + + +def gen_edge_input(max_dist, path, edge_feat, localtype=long): + + (nrows, ncols) = path.shape + assert nrows == ncols + cdef unsigned int n = nrows + cdef unsigned int max_dist_copy = max_dist + + path_copy = path.astype(long, order='C', casting='safe', copy=True) + edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) + assert path_copy.flags['C_CONTIGUOUS'] + assert edge_feat_copy.flags['C_CONTIGUOUS'] + + cdef numpy.ndarray[long, ndim=4, mode='c'] edge_fea_all = -1 * numpy.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=localtype) + cdef unsigned int i, j, k, num_path, cur + + for i in range(n): + for j in range(n): + if i == j: + continue + if path_copy[i][j] == 510: + continue + path = [i] + get_all_edges(path_copy, i, j) + [j] + num_path = len(path) - 1 + for k in range(num_path): + edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] + + return edge_fea_all diff --git a/gridfm_graphkit/models/gnn_transformer.py b/gridfm_graphkit/models/gnn_transformer.py index 9e1ab23..7747f59 100644 --- a/gridfm_graphkit/models/gnn_transformer.py +++ b/gridfm_graphkit/models/gnn_transformer.py @@ -93,4 +93,5 @@ def forward(self, x, pe, edge_index, edge_attr, batch): x = nn.LeakyReLU()(x) x = self.mlps(x) + return x diff --git a/gridfm_graphkit/models/gps_transformer.py b/gridfm_graphkit/models/gps_transformer.py index cc8b648..e99188a 100644 --- a/gridfm_graphkit/models/gps_transformer.py +++ b/gridfm_graphkit/models/gps_transformer.py @@ -105,7 +105,7 @@ def __init__(self, args): requires_grad=False, ) - def forward(self, x, pe, edge_index, edge_attr, batch): + def forward(self, x, pe, edge_index, edge_attr, batch, data=None): """ Forward pass for the GPSTransformer. @@ -115,6 +115,7 @@ def forward(self, x, pe, edge_index, edge_attr, batch): edge_index (Tensor): Edge indices for graph convolution. edge_attr (Tensor): Edge feature tensor. batch (Tensor): Batch vector assigning nodes to graphs. + data (Data): Pytorch Geometric Data/Batch object. Returns: output (Tensor): Output node features of shape [num_nodes, output_dim]. diff --git a/gridfm_graphkit/models/graphormer.py b/gridfm_graphkit/models/graphormer.py new file mode 100644 index 0000000..b36610d --- /dev/null +++ b/gridfm_graphkit/models/graphormer.py @@ -0,0 +1,321 @@ + +from gridfm_graphkit.io.registries import MODELS_REGISTRY +import torch +import torch.nn as nn + +from torch_geometric.utils import to_dense_batch + + + +@MODELS_REGISTRY.register("Graphormer") +class Graphormer(nn.Module): + """ + A Graph Transformer model based on the Graphormer architecture. + + This model directly modifies the attention between nodes based on + its graph encodings. This requires padding the input nodes and propogating + the associated mask as needed. + + Args: + args (NestedNamespace): Parameters + + Attributes: + n_node_features (int): Dimension of input node features. From ``args.model.input_dim``. + hidden_dim (int): Hidden dimension size for all layers. From ``args.model.hidden_size``. + output_dim (int): Dimension of the output node features. From ``args.model.output_dim``. + n_encoder_layers (int): Number of transformer blocks. From ``args.model.num_layers``. + num_heads (int): Number of attention heads. From ``args.model.attention_head``. Defaults to 1. + n_edge_features (int): Dimension of edge features. From ``args.model.edge_dim``. + dropout (float, optional): Dropout rate in attention blocks. From ``args.model.dropout``. Defaults to 0.0. + mask_dim (int, optional): Dimension of the mask vector. From ``args.data.mask_dim``. Defaults to 6. + mask_value (float, optional): Initial value for learnable mask parameters. From ``args.data.mask_value``. Defaults to -1.0. + learn_mask (bool, optional): Whether to learn mask values as parameters. From ``args.data.learn_mask``. Defaults to False. + edge_type (string, optional): Type of edge to consider multi_hop or not. From ``args.data.edge_type``. Defaults to multi_hop. + multi_hop_max_dist (int, optional): Maximum number of hops to consider at edges. From ``args.data.multi_hop_max_dist``. Defaults to 20. + max_node_num (int, optional): Maximum number of node in the input graphs. From ``args.data.max_node_num``. Defaults to 24. + """ + def __init__(self, args): + super().__init__() + + self.n_node_features = args.model.input_dim + self.hidden_dim = args.model.hidden_size + self.output_dim = args.model.output_dim + self.n_encoder_layers = args.model.num_layers + self.num_heads = args.model.attention_head + self.n_edge_features = args.model.edge_dim + self.dropout = getattr(args.model, "dropout", 0.0) + self.mask_dim = getattr(args.data, "mask_dim", 6) + self.mask_value = getattr(args.data, "mask_value", -1.0) + self.learn_mask = getattr(args.data, "learn_mask", False) + self.edge_type = getattr(args.model, "edge_type", "multi_hop") + self.multi_hop_max_dist = getattr(args.data, "max_hops", 20) + self.max_node_num = getattr(args.data, "max_node_num", 24) + + if self.learn_mask: + self.mask_value = nn.Parameter( + torch.randn(self.mask_dim) + self.mask_value, + requires_grad=True, + ) + else: + self.mask_value = nn.Parameter( + torch.zeros(self.mask_dim) + self.mask_value, + requires_grad=False, + ) + + # model layers + self.input_proj = nn.Linear(self.n_node_features, self.hidden_dim) + self.input_dropout = nn.Dropout(self.dropout) + encoders = [ + EncoderLayer( + self.hidden_dim, + self.hidden_dim, + self.dropout, + self.num_heads + ) + for _ in range(self.n_encoder_layers) + ] + self.encoder_layers = nn.ModuleList(encoders) + self.encoder_final_ln = nn.LayerNorm(self.hidden_dim) + + self.decoder_layers = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.LayerNorm(self.hidden_dim), + nn.Linear(self.hidden_dim, self.output_dim) + ) + + # for positional embeddings + self.spatial_pos_encoder = nn.Embedding( + 512, self.num_heads, padding_idx=0) + self.in_degree_encoder = nn.Embedding( + 512, self.hidden_dim, padding_idx=0) + self.out_degree_encoder = nn.Embedding( + 512, self.hidden_dim, padding_idx=0) + if self.n_edge_features is not None: + self.edge_encoder = nn.Embedding( + 512 * self.n_edge_features + 1, self.num_heads, padding_idx=0) + if self.edge_type == 'multi_hop': + self.edge_dis_encoder = nn.Embedding( + 128 * self.num_heads * self.num_heads, 1) + + + def compute_pos_embeddings(self, data, x): + """ + Calculate Graphormer positional encodings, and attention biases + + Args: + data (Data): Pytorch geometric Data/Batch object + x (Tensor): The node feature tensor from data + + Returns: + graph_node_feature (Tensor): data.x with positional encoding appended. + graph_attn_bias (Tensor): attention bais terms. + """ + attn_bias, spatial_pos = data.attn_bias, data.spatial_pos + in_degree, out_degree = data.in_degree, data.in_degree + + # graph_attn_bias + graph_attn_bias = attn_bias.clone() + graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( + 1, self.num_heads, 1, 1) # [n_graph, n_head, n_node, n_node] + + # spatial pos + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) + + graph_attn_bias = graph_attn_bias + spatial_pos_bias + + if (data.edge_input is not None) and (self.edge_type is not None): + edge_input, attn_edge_type = data.edge_input, data.attn_edge_type + # edge feature + if self.edge_type == 'multi_hop': + n_graph, n_node = edge_input.size()[:2] + spatial_pos_ = spatial_pos.clone() + spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1 + # set 1 to 1, x > 1 to x - 1 + spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) + if self.multi_hop_max_dist > 0: + spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist) + edge_input = edge_input[:, :, :, :self.multi_hop_max_dist, :] + # [n_graph, n_node, n_node, max_dist, n_head] + edge_input = self.edge_encoder(edge_input+1).mean(-2) + max_dist = edge_input.size(-2) + edge_input_flat = edge_input.permute( + 3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads) + edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape( + -1, self.num_heads, self.num_heads)[:max_dist, :, :]) + edge_input = edge_input_flat.reshape( + max_dist, n_graph, n_node, n_node, self.num_heads).permute(1, 2, 3, 0, 4) + edge_input = (edge_input.sum(-2) / + (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2) + else: + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + edge_input = self.edge_encoder( + attn_edge_type).mean(-2).permute(0, 3, 1, 2) + + graph_attn_bias = graph_attn_bias + edge_input + + graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset + + node_feature = self.input_proj(x) + graph_node_feature = node_feature.flatten(0,1) + \ + self.in_degree_encoder(in_degree) + \ + self.out_degree_encoder(out_degree) + graph_node_feature = graph_node_feature.reshape(node_feature.size()) + + return graph_node_feature, graph_attn_bias + + + def encoder(self, graph_node_feature, graph_attn_bias, mask=None): + output = self.input_dropout(graph_node_feature) + for enc_layer in self.encoder_layers: + output = enc_layer(output, graph_attn_bias, mask=mask) + output[~mask] = self.encoder_final_ln(output[~mask]) + return output + + def decoder(self, x): + output = self.decoder_layers(x) + return output + + def forward(self, x, pe=None, edge_index=None, edge_attr=None, batch=None, data=None): + """ + Forward pass for Graphormer. + + Args: + x (Tensor): Input node features of shape [num_nodes, input_dim]. + pe (Tensor): Positional encoding of shape [num_nodes, pe_dim]. + edge_index (Tensor): Edge indices for graph convolution. + edge_attr (Tensor): Edge feature tensor. + batch (Tensor): Batch vector assigning nodes to graphs. + data (Data): Pytorch Geometric Data/Batch object. + + Returns: + output (Tensor): Output node features of shape [num_nodes, output_dim]. + """ + + x, valid_nodes = to_dense_batch(x, batch, max_num_nodes=self.max_node_num) + mask = ~valid_nodes + + graph_node_feature, graph_attn_bias = self.compute_pos_embeddings(data, x) + output = self.encoder(graph_node_feature, graph_attn_bias, mask=mask) + output = self.decoder(output[valid_nodes]) + + return output + + +class FeedForwardNetwork(nn.Module): + def __init__(self, hidden_size, ffn_size): + super(FeedForwardNetwork, self).__init__() + + self.layer1 = nn.Linear(hidden_size, ffn_size) + self.gelu = nn.GELU() + self.layer2 = nn.Linear(ffn_size, hidden_size) + + def forward(self, x): + x = self.layer1(x) + x = self.gelu(x) + x = self.layer2(x) + return x + + +class MultiHeadAttention(nn.Module): + """ + This is a slight modification of vanilla attention, to allow masking + of buffer nodes, and the addition of biasses to the attention mechanism. + """ + def __init__(self, hidden_size, attention_dropout_rate, num_heads): + super(MultiHeadAttention, self).__init__() + + self.num_heads = num_heads + + self.att_size = att_size = hidden_size // num_heads + self.scale = att_size ** -0.5 + + self.linear_q = nn.Linear(hidden_size, num_heads * att_size) + self.linear_k = nn.Linear(hidden_size, num_heads * att_size) + self.linear_v = nn.Linear(hidden_size, num_heads * att_size) + self.att_dropout = nn.Dropout(attention_dropout_rate) + + self.output_layer = nn.Linear(num_heads * att_size, hidden_size) + + def forward(self, q, k, v, attn_bias=None, mask=None): + + orig_q_size = q.size() + + d_k = self.att_size + d_v = self.att_size + batch_size = q.size(0) + + # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) + q = self.linear_q(q).view(batch_size, -1, self.num_heads, d_k) + k = self.linear_k(k).view(batch_size, -1, self.num_heads, d_k) + v = self.linear_v(v).view(batch_size, -1, self.num_heads, d_v) + + q = q.transpose(1, 2) # [b, h, q_len, d_k] + v = v.transpose(1, 2) # [b, h, v_len, d_v] + k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] + + # Scaled Dot-Product Attention. + # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V + q = q * self.scale + x = torch.matmul(q, k) # [b, h, q_len, k_len] + + if attn_bias is not None: + if mask is not None: + usm0 = mask.unsqueeze(1).unsqueeze(3) + usm1 = mask.unsqueeze(1).unsqueeze(2) + + attn_bias = attn_bias.masked_fill(usm0 == 1, 0.0) + attn_bias = attn_bias.masked_fill(usm1 == 1, 0.0) + x = x + attn_bias + + # mask the data before the softmax + if mask is not None: + usm0 = mask.unsqueeze(1).unsqueeze(2) + x = x.masked_fill(usm0 == 1, -1e9) + + x = torch.softmax(x, dim=3) + x = self.att_dropout(x) + x = x.matmul(v) # [b, h, q_len, attn] + + x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] + x = x.view(batch_size, -1, self.num_heads * d_v) + + x = self.output_layer(x) + + assert x.size() == orig_q_size + return x + + +class EncoderLayer(nn.Module): + def __init__(self, hidden_size, ffn_size, dropout_rate, num_heads): + super(EncoderLayer, self).__init__() + + self.self_attention_norm = nn.LayerNorm(hidden_size) + self.self_attention = MultiHeadAttention( + hidden_size, dropout_rate, num_heads) + self.self_attention_dropout = nn.Dropout(dropout_rate) + + self.ffn_norm = nn.LayerNorm(hidden_size) + self.ffn = FeedForwardNetwork(hidden_size, ffn_size) + self.ffn_dropout = nn.Dropout(dropout_rate) + + def forward(self, x, attn_bias=None, mask=None): + """ + It is assumed that the mask is 1 where values are to be ignored + and then 0 where there are valid data + """ + + y = x + y[~mask] = self.self_attention_norm(x[~mask]) + attn_bias = attn_bias.squeeze() + y = self.self_attention(y, y, y, attn_bias, mask) + y = self.self_attention_dropout(y) + x = x + torch.reshape(y, x.size()) + + y[~mask] = self.ffn_norm(x[~mask]) + y = self.ffn(y) + y = self.ffn_dropout(y) + x = x + y + + return x diff --git a/gridfm_graphkit/tasks/feature_reconstruction_task.py b/gridfm_graphkit/tasks/feature_reconstruction_task.py index cb6963b..902acd0 100644 --- a/gridfm_graphkit/tasks/feature_reconstruction_task.py +++ b/gridfm_graphkit/tasks/feature_reconstruction_task.py @@ -74,11 +74,11 @@ def __init__(self, args, node_normalizers, edge_normalizers): self.edge_normalizers = edge_normalizers self.save_hyperparameters() - def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): + def forward(self, x, pe, edge_index, edge_attr, batch, mask=None, data=None): if mask is not None: mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] - return self.model(x, pe, edge_index, edge_attr, batch) + return self.model(x, pe, edge_index, edge_attr, batch, data) @rank_zero_only def on_fit_start(self): @@ -117,6 +117,7 @@ def shared_step(self, batch): edge_attr=batch.edge_attr, batch=batch.batch, mask=batch.mask, + data=batch ) loss_dict = self.loss_fn( diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index f90953b..34664ee 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -176,6 +176,7 @@ def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None): loss_details = {} for i, loss_fn in enumerate(self.loss_functions): + loss_output = loss_fn( pred, target, diff --git a/pyproject.toml b/pyproject.toml index 51c8665..caafe70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "pyyaml", "lightning", "seaborn", + "cython==3.0.11" ] [project.optional-dependencies]