From 99c18e1a68f5d8059f729f09e5a675c4bdd43028 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Mon, 6 May 2024 22:11:37 +0530 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=A7=B9=20Ruff=20lint=20fixed=20stgrap?= =?UTF-8?q?h=5Fbase.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ruff.toml | 1 + stgraph/graph/STGraphBase.py | 79 ----------------------- stgraph/graph/__init__.py | 2 +- stgraph/graph/dynamic/DynamicGraph.py | 23 ++++--- stgraph/graph/static/StaticGraph.py | 2 +- stgraph/graph/stgraph_base.py | 90 +++++++++++++++++++++++++++ 6 files changed, 108 insertions(+), 89 deletions(-) delete mode 100644 stgraph/graph/STGraphBase.py create mode 100644 stgraph/graph/stgraph_base.py diff --git a/ruff.toml b/ruff.toml index 80e82b09..58ba992b 100644 --- a/ruff.toml +++ b/ruff.toml @@ -21,6 +21,7 @@ ignore = [ "D211", "D212", "D213", + "PIE790", ] [lint.per-file-ignores] diff --git a/stgraph/graph/STGraphBase.py b/stgraph/graph/STGraphBase.py deleted file mode 100644 index cb9cb381..00000000 --- a/stgraph/graph/STGraphBase.py +++ /dev/null @@ -1,79 +0,0 @@ -from abc import ABC, abstractmethod - - -class STGraphBase(ABC): - r"""An abstract base class used to represent graphs in STGraph - - This abstract class outlines the interface for defining different types of graphs - used in STGraph. It provides the basic structure and methods for graph classes. - Subclasses should implement the abstract methods to provide specific graph functionality. - - Attributes - ---------- - - fwd_row_offset_ptr - Pointer to the forward graphs row offset array - - fwd_column_indices_ptr - Pointer to the forward graphs column indices array - - fwd_eids_ptr - Pointer to the forward graphs edge ID array - - fwd_node_ids_ptr - Pointer to the forward graphs node ID array - - bwd_row_offset_ptr - Pointer to the backward graphs row offset array - - bwd_column_indices_ptr - Pointer to the backward graphs column indices array - - bwd_eids_ptr - Pointer to the backward graphs edge ID array - - bwd_node_ids_ptr - Pointer to the backward graphs node ID array - - """ - - def __init__(self): - self._ndata = {} - - self._forward_graph = None - self._backward_graph = None - - self.fwd_row_offset_ptr = None - self.fwd_column_indices_ptr = None - self.fwd_eids_ptr = None - self.fwd_node_ids_ptr = None - - self.bwd_row_offset_ptr = None - self.bwd_column_indices_ptr = None - self.bwd_eids_ptr = None - self.bwd_node_ids_ptr = None - - @abstractmethod - def _get_graph_csr_ptrs(self): - pass - - @abstractmethod - def get_num_nodes(self): - pass - - @abstractmethod - def get_num_edges(self): - pass - - @abstractmethod - def get_ndata(self, field): - pass - - @abstractmethod - def set_ndata(self, field, val): - pass - - @property - @abstractmethod - def graph_type(self): - pass diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index 10d2b940..09c94606 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -1,4 +1,4 @@ -from stgraph.graph.STGraphBase import STGraphBase +from stgraph.graph.stgraph_base import STGraphBase from stgraph.graph.static.StaticGraph import StaticGraph diff --git a/stgraph/graph/dynamic/DynamicGraph.py b/stgraph/graph/dynamic/DynamicGraph.py index 030d4c8c..afd1af8b 100644 --- a/stgraph/graph/dynamic/DynamicGraph.py +++ b/stgraph/graph/dynamic/DynamicGraph.py @@ -1,13 +1,17 @@ -from stgraph.graph.STGraphBase import STGraphBase +from stgraph.graph.stgraph_base import STGraphBase from abc import abstractmethod import time + class DynamicGraph(STGraphBase): def __init__(self, edge_list, max_num_nodes): super().__init__() self.graph_updates = {} self.max_num_nodes = max_num_nodes - self.graph_attr = {str(t): (self.max_num_nodes, len(set(edge_list[t]))) for t in range(len(edge_list))} + self.graph_attr = { + str(t): (self.max_num_nodes, len(set(edge_list[t]))) + for t in range(len(edge_list)) + } # Indicates whether the graph is currently undergoing backprop self._is_backprop_state = False @@ -43,7 +47,7 @@ def _preprocess_graph_structure(self, edge_list): "add": additions, "delete": deletions, } - + def reset_graph(self): self._get_cached_graph("base") self.current_timestamp = 0 @@ -61,14 +65,14 @@ def get_graph(self, timestamp: int): raise Exception( "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" ) - + if self._get_cached_graph(timestamp - 1): self.current_timestamp = timestamp - 1 while self.current_timestamp < timestamp: self._update_graph_forward() self.current_timestamp += 1 - + self.get_fwd_graph_time += time.time() - t0 def get_backward_graph(self, timestamp: int): @@ -87,7 +91,7 @@ def get_backward_graph(self, timestamp: int): while self.current_timestamp > timestamp: self._update_graph_backward() self.current_timestamp -= 1 - + self.get_bwd_graph_time += time.time() - t0 def get_num_nodes(self): @@ -95,9 +99,12 @@ def get_num_nodes(self): def get_num_edges(self): return self.graph_attr[str(self.current_timestamp)][1] - + def get_ndata(self, field): - if str(self.current_timestamp) in self._ndata and field in self._ndata[str(self.current_timestamp)]: + if ( + str(self.current_timestamp) in self._ndata + and field in self._ndata[str(self.current_timestamp)] + ): return self._ndata[str(self.current_timestamp)][field] else: return None diff --git a/stgraph/graph/static/StaticGraph.py b/stgraph/graph/static/StaticGraph.py index a7862a04..5e5b1ce4 100644 --- a/stgraph/graph/static/StaticGraph.py +++ b/stgraph/graph/static/StaticGraph.py @@ -7,7 +7,7 @@ console = Console() -from stgraph.graph.STGraphBase import STGraphBase +from stgraph.graph.stgraph_base import STGraphBase from stgraph.graph.static.csr import CSR diff --git a/stgraph/graph/stgraph_base.py b/stgraph/graph/stgraph_base.py new file mode 100644 index 00000000..37b07ae2 --- /dev/null +++ b/stgraph/graph/stgraph_base.py @@ -0,0 +1,90 @@ +"""Represent graphs in STGraph using this abstract base class.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + + +class STGraphBase(ABC): + r"""Represent graphs in STGraph using this abstract base class.""" + + def __init__(self: STGraphBase) -> None: + r"""Represent graphs in STGraph using this abstract base class. + + This abstract class outlines the interface for defining different + types of graphs used in STGraph. It provides the basic structure + and methods for graph classes. Subclasses should implement the + abstract methods to provide specific graph functionality. + + Attributes + ---------- + fwd_row_offset_ptr + Pointer to the forward graphs row offset array + + fwd_column_indices_ptr + Pointer to the forward graphs column indices array + + fwd_eids_ptr + Pointer to the forward graphs edge ID array + + fwd_node_ids_ptr + Pointer to the forward graphs node ID array + + bwd_row_offset_ptr + Pointer to the backward graphs row offset array + + bwd_column_indices_ptr + Pointer to the backward graphs column indices array + + bwd_eids_ptr + Pointer to the backward graphs edge ID array + + bwd_node_ids_ptr + Pointer to the backward graphs node ID array + + """ + self._ndata = {} + + self._forward_graph = None + self._backward_graph = None + + self.fwd_row_offset_ptr = None + self.fwd_column_indices_ptr = None + self.fwd_eids_ptr = None + self.fwd_node_ids_ptr = None + + self.bwd_row_offset_ptr = None + self.bwd_column_indices_ptr = None + self.bwd_eids_ptr = None + self.bwd_node_ids_ptr = None + + @abstractmethod + def _get_graph_csr_ptrs(self: STGraphBase) -> None: + r"""TODO:.""" + pass + + @abstractmethod + def get_num_nodes(self: STGraphBase) -> int: + r"""Return the number of nodes in the graph.""" + pass + + @abstractmethod + def get_num_edges(self: STGraphBase) -> int: + r"""Return the number of edges in the graph.""" + pass + + @abstractmethod + def get_ndata(self: STGraphBase, field: str) -> any: + r"""Return the graph metadata.""" + pass + + @abstractmethod + def set_ndata(self: STGraphBase, field: str, val: any) -> None: + r"""Set the graph metadata.""" + pass + + @property + @abstractmethod + def graph_type(self: STGraphBase) -> str: + r"""Return the graph type.""" + pass From 8b1c1345b4fe46025f70d5b6995160f42618abe8 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Tue, 7 May 2024 21:56:16 +0530 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=A7=B9=20Lint=20checked=20stgraph.gra?= =?UTF-8?q?ph.static=5Fgraph=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- benchmarking/gat/seastar/train.py | 2 +- benchmarking/gcn/seastar/train.py | 6 +- .../static-temporal-tgcn/seastar/train.py | 187 +++++++++++++----- stgraph/graph/__init__.py | 8 +- stgraph/graph/static/__init__.py | 1 + .../{StaticGraph.py => static_graph.py} | 75 +++---- .../v1_1_0/gcn_dataloaders/gcn/train.py | 2 +- .../temporal_tgcn_dataloaders/tgcn/train.py | 2 +- 8 files changed, 185 insertions(+), 98 deletions(-) rename stgraph/graph/static/{StaticGraph.py => static_graph.py} (65%) diff --git a/benchmarking/gat/seastar/train.py b/benchmarking/gat/seastar/train.py index 857addc0..bd98c88d 100644 --- a/benchmarking/gat/seastar/train.py +++ b/benchmarking/gat/seastar/train.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F import pynvml -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.static_graph import StaticGraph from stgraph.dataset.CoraDataLoader import CoraDataLoader from utils import EarlyStopping, accuracy import snoop diff --git a/benchmarking/gcn/seastar/train.py b/benchmarking/gcn/seastar/train.py index 0e5e9cb2..8c82203c 100644 --- a/benchmarking/gcn/seastar/train.py +++ b/benchmarking/gcn/seastar/train.py @@ -5,7 +5,7 @@ import pynvml import torch.nn as nn import torch.nn.functional as F -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.static_graph import StaticGraph from stgraph.dataset import CoraDataLoader from utils import to_default_device, accuracy, generate_test_mask, generate_train_mask from model import GCN @@ -47,9 +47,7 @@ def main(args): # A simple sanity check print("Measuerd Graph Size (pynvml): ", graph_mem, " B", flush=True) - print( - "Measuerd Graph Size (pynvml): ", (graph_mem) / (1024**2), " MB", flush=True - ) + print("Measuerd Graph Size (pynvml): ", (graph_mem) / (1024**2), " MB", flush=True) # normalization degs = torch.from_numpy(g.weighted_in_degrees()).type(torch.int32) diff --git a/benchmarking/static-temporal-tgcn/seastar/train.py b/benchmarking/static-temporal-tgcn/seastar/train.py index 1328b742..1b7d1950 100644 --- a/benchmarking/static-temporal-tgcn/seastar/train.py +++ b/benchmarking/static-temporal-tgcn/seastar/train.py @@ -9,7 +9,7 @@ import os from model import STGraphTGCN -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.static_graph import StaticGraph from stgraph.dataset.WindmillOutputDataLoader import WindmillOutputDataLoader from stgraph.dataset.WikiMathDataLoader import WikiMathDataLoader @@ -23,6 +23,7 @@ from rich import inspect + def main(args): if torch.cuda.is_available(): @@ -32,20 +33,63 @@ def main(args): quit() # Dummy object to account for CUDA context object - Graph = StaticGraph([(0,0)], [1], 1) - + Graph = StaticGraph([(0, 0)], [1], 1) + if args.dataset == "wiki": - dataloader = WikiMathDataLoader('static-temporal', 'wikivital_mathematics', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = WikiMathDataLoader( + "static-temporal", + "wikivital_mathematics", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "windmill": - dataloader = WindmillOutputDataLoader('static-temporal', 'windmill_output', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = WindmillOutputDataLoader( + "static-temporal", + "windmill_output", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "hungarycp": - dataloader = HungaryCPDataLoader('static-temporal', 'HungaryCP', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = HungaryCPDataLoader( + "static-temporal", + "HungaryCP", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "pedalme": - dataloader = PedalMeDataLoader('static-temporal', 'pedalme', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = PedalMeDataLoader( + "static-temporal", + "pedalme", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "metrla": - dataloader = METRLADataLoader('static-temporal', 'METRLA', args.feat_size, args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = METRLADataLoader( + "static-temporal", + "METRLA", + args.feat_size, + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "monte": - dataloader = MontevideoBusDataLoader('static-temporal', 'montevideobus', args.feat_size, args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = MontevideoBusDataLoader( + "static-temporal", + "montevideobus", + args.feat_size, + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) else: print("😔 Unrecognized dataset") quit() @@ -53,19 +97,23 @@ def main(args): edge_list = dataloader.get_edges() edge_weight_list = dataloader.get_edge_weights() targets = dataloader.get_all_targets() - + pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used G = StaticGraph(edge_list, edge_weight_list, dataloader.num_nodes) graph_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem - edge_weight = to_default_device(torch.unsqueeze(torch.FloatTensor(edge_weight_list), 1)) + edge_weight = to_default_device( + torch.unsqueeze(torch.FloatTensor(edge_weight_list), 1) + ) targets = to_default_device(torch.FloatTensor(np.array(targets))) num_hidden_units = args.num_hidden num_outputs = 1 - model = to_default_device(STGraphTGCN(args.feat_size, num_hidden_units, num_outputs)) + model = to_default_device( + STGraphTGCN(args.feat_size, num_hidden_units, num_outputs) + ) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Logging Output @@ -78,23 +126,26 @@ def main(args): backprop_every = args.backprop_every if backprop_every == 0: backprop_every = total_timestamps - + if total_timestamps % backprop_every == 0: - num_iter = int(total_timestamps/backprop_every) + num_iter = int(total_timestamps / backprop_every) else: - num_iter = int(total_timestamps/backprop_every) + 1 + num_iter = int(total_timestamps / backprop_every) + 1 # metrics dur = [] max_gpu = [] - table = BenchmarkTable(f"(STGraph Static-Temporal) TGCN on {dataloader.name} dataset", ["Epoch", "Time(s)", "MSE", "Used GPU Memory (Max MB)"]) - + table = BenchmarkTable( + f"(STGraph Static-Temporal) TGCN on {dataloader.name} dataset", + ["Epoch", "Time(s)", "MSE", "Used GPU Memory (Max MB)"], + ) + # normalization degs = torch.from_numpy(G.in_degrees()).type(torch.int32) norm = torch.pow(degs, -0.5) norm[torch.isinf(norm)] = 0 norm = to_default_device(norm) - G.set_ndata('norm', norm.unsqueeze(1)) + G.set_ndata("norm", norm.unsqueeze(1)) # train print("Training...\n") @@ -103,7 +154,7 @@ def main(args): torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats(0) model.train() - + t0 = time.time() gpu_mem_arr = [] cost_arr = [] @@ -112,20 +163,24 @@ def main(args): optimizer.zero_grad() cost = 0 hidden_state = None - y_hat = torch.randn((dataloader.num_nodes, args.feat_size), device=get_default_device()) + y_hat = torch.randn( + (dataloader.num_nodes, args.feat_size), device=get_default_device() + ) for k in range(backprop_every): t = index * backprop_every + k if t >= total_timestamps: break - y_out, y_hat, hidden_state = model(G, y_hat, edge_weight, hidden_state) - cost = cost + torch.mean((y_out-targets[t])**2) - + y_out, y_hat, hidden_state = model( + G, y_hat, edge_weight, hidden_state + ) + cost = cost + torch.mean((y_out - targets[t]) ** 2) + if cost == 0: break - - cost = cost / (backprop_every+1) + + cost = cost / (backprop_every + 1) cost.backward() optimizer.step() torch.cuda.synchronize() @@ -140,56 +195,82 @@ def main(args): dur.append(run_time_this_epoch) max_gpu.append(max(gpu_mem_arr)) - table.add_row([epoch, "{:.5f}".format(run_time_this_epoch), "{:.4f}".format(sum(cost_arr)/len(cost_arr)), "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2)))]) + table.add_row( + [ + epoch, + "{:.5f}".format(run_time_this_epoch), + "{:.4f}".format(sum(cost_arr) / len(cost_arr)), + "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2))), + ] + ) table.display() - print('Average Time taken: {:6f}'.format(np.mean(dur))) + print("Average Time taken: {:6f}".format(np.mean(dur))) return np.mean(dur), (max(max_gpu) * 1.0 / (1024**2)) - + except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): table.add_row(["OOM", "OOM", "OOM", "OOM"]) table.display() else: print("😔 Something went wrong") return "OOM", "OOM" + def write_results(args, time_taken, max_gpu): cutoff = "whole" if args.cutoff_time < sys.maxsize: cutoff = str(args.cutoff_time) file_name = f"stgraph_{args.dataset}_T{cutoff}_B{args.backprop_every}_H{args.num_hidden}_F{args.feat_size}" - df_data = pd.DataFrame([{'Filename': file_name, 'Time Taken (s)': time_taken, 'Max GPU Usage (MB)': max_gpu}]) - - if os.path.exists('../../results/static-temporal.csv'): - df = pd.read_csv('../../results/static-temporal.csv') + df_data = pd.DataFrame( + [ + { + "Filename": file_name, + "Time Taken (s)": time_taken, + "Max GPU Usage (MB)": max_gpu, + } + ] + ) + + if os.path.exists("../../results/static-temporal.csv"): + df = pd.read_csv("../../results/static-temporal.csv") df = pd.concat([df, df_data]) else: df = df_data - - df.to_csv('../../results/static-temporal.csv', sep=',', index=False, encoding='utf-8') + + df.to_csv( + "../../results/static-temporal.csv", sep=",", index=False, encoding="utf-8" + ) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='STGraph Static TGCN') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="STGraph Static TGCN") snoop.install(enabled=False) - parser.add_argument("--dataset", type=str, default="wiki", - help="Name of the Dataset (wiki, windmill, hungary_cp, pedalme, metrla, monte)") - parser.add_argument("--backprop-every", type=int, default=0, - help="Feature size of nodes") - parser.add_argument("--feat-size", type=int, default=8, - help="Feature size of nodes") - parser.add_argument("--num-hidden", type=int, default=100, - help="Number of hidden units") - parser.add_argument("--lr", type=float, default=1e-2, - help="learning rate") - parser.add_argument("--cutoff-time", type=int, default=sys.maxsize, - help="learning rate") - parser.add_argument("--num-epochs", type=int, default=1, - help="number of training epochs") + parser.add_argument( + "--dataset", + type=str, + default="wiki", + help="Name of the Dataset (wiki, windmill, hungary_cp, pedalme, metrla, monte)", + ) + parser.add_argument( + "--backprop-every", type=int, default=0, help="Feature size of nodes" + ) + parser.add_argument( + "--feat-size", type=int, default=8, help="Feature size of nodes" + ) + parser.add_argument( + "--num-hidden", type=int, default=100, help="Number of hidden units" + ) + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument( + "--cutoff-time", type=int, default=sys.maxsize, help="learning rate" + ) + parser.add_argument( + "--num-epochs", type=int, default=1, help="number of training epochs" + ) args = parser.parse_args() - + print(args) time_taken, max_gpu = main(args) - write_results(args, time_taken, max_gpu) \ No newline at end of file + write_results(args, time_taken, max_gpu) diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index 09c94606..1b57f826 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -1,8 +1,8 @@ -from stgraph.graph.stgraph_base import STGraphBase - -from stgraph.graph.static.StaticGraph import StaticGraph +"""Graph representation modules provided by STGraph.""" from stgraph.graph.dynamic.DynamicGraph import DynamicGraph -from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph from stgraph.graph.dynamic.gpma.GPMAGraph import GPMAGraph from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph +from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph +from stgraph.graph.static.static_graph import StaticGraph +from stgraph.graph.stgraph_base import STGraphBase diff --git a/stgraph/graph/static/__init__.py b/stgraph/graph/static/__init__.py index e69de29b..923160b9 100644 --- a/stgraph/graph/static/__init__.py +++ b/stgraph/graph/static/__init__.py @@ -0,0 +1 @@ +"""Static Graph respresentation modules provided by STGraph.""" diff --git a/stgraph/graph/static/StaticGraph.py b/stgraph/graph/static/static_graph.py similarity index 65% rename from stgraph/graph/static/StaticGraph.py rename to stgraph/graph/static/static_graph.py index 5e5b1ce4..fc20d5b0 100644 --- a/stgraph/graph/static/StaticGraph.py +++ b/stgraph/graph/static/static_graph.py @@ -1,28 +1,27 @@ -from abc import ABC, abstractmethod +"""Represent Static graphs in STGraph.""" + +from __future__ import annotations + import copy import numpy as np - from rich.console import Console -console = Console() - +from stgraph.graph.static.csr import CSR from stgraph.graph.stgraph_base import STGraphBase - -from stgraph.graph.static.csr import CSR +console = Console() class StaticGraph(STGraphBase): - r"""An abstract base class used to represent static graphs in STGraph. + r"""Represent Static graphs in STGraph. This abstract class outlines the interface for defining a static graphs used in STGraph. As of now the static graph is implemented using the Compressed Sparse Row (CSR) format. - Example + Example: ------- - .. code-block:: python from stgraph.graph import StaticGraph @@ -38,15 +37,23 @@ class StaticGraph(STGraphBase): """ - def __init__(self, edge_list, edge_weights, num_nodes): - """An abstract base class used to represent static graphs in STGraph.""" + def __init__( + self: StaticGraph, + edge_list: list, + edge_weights: list, + num_nodes: int, + ) -> None: + r"""Represent Static graphs in STGraph.""" super().__init__() self._num_nodes = num_nodes self._num_edges = len(set(edge_list)) self._prepare_edge_lst_fwd(edge_list) self._forward_graph = CSR( - self.fwd_edge_list, edge_weights, self._num_nodes, is_edge_reverse=True + self.fwd_edge_list, + edge_weights, + self._num_nodes, + is_edge_reverse=True, ) self._prepare_edge_lst_bwd(self.fwd_edge_list) @@ -54,8 +61,8 @@ def __init__(self, edge_list, edge_weights, num_nodes): self._get_graph_csr_ptrs() - def _prepare_edge_lst_fwd(self, edge_list): - r"""TODO:""" + def _prepare_edge_lst_fwd(self: STGraphBase, edge_list: list) -> None: + r"""TODO:.""" edge_list_for_t = edge_list edge_list_for_t.sort(key=lambda x: (x[1], x[0])) edge_list_for_t = [ @@ -64,14 +71,14 @@ def _prepare_edge_lst_fwd(self, edge_list): ] self.fwd_edge_list = edge_list_for_t - def _prepare_edge_lst_bwd(self, edge_list): - r"""TODO:""" + def _prepare_edge_lst_bwd(self: STGraphBase, edge_list: list) -> None: + r"""TODO:.""" edge_list_for_t = copy.deepcopy(edge_list) edge_list_for_t.sort() self.bwd_edge_list = edge_list_for_t - def _get_graph_csr_ptrs(self): - r"""TODO:""" + def _get_graph_csr_ptrs(self: STGraphBase) -> None: + r"""TODO:.""" self.fwd_row_offset_ptr = self._forward_graph.row_offset_ptr self.fwd_column_indices_ptr = self._forward_graph.column_indices_ptr self.fwd_eids_ptr = self._forward_graph.eids_ptr @@ -82,37 +89,37 @@ def _get_graph_csr_ptrs(self): self.bwd_eids_ptr = self._backward_graph.eids_ptr self.bwd_node_ids_ptr = self._backward_graph.node_ids_ptr - def get_num_nodes(self): + def get_num_nodes(self: STGraphBase) -> int: r"""Return the number of nodes in the static graph.""" return self._num_nodes - def get_num_edges(self): + def get_num_edges(self: STGraphBase) -> int: r"""Return the number of edges in the static graph.""" return self._num_edges - def get_ndata(self, field): - r"""Returns the graph metadata.""" + def get_ndata(self: STGraphBase, field: any) -> any: + r"""Return the graph metadata.""" if field in self._ndata: return self._ndata[field] - else: - return None - def set_ndata(self, field, val): - r"""Sets the graph metadata.""" + return None + + def set_ndata(self: STGraphBase, field: str, val: any) -> None: + r"""Set the graph metadata.""" self._ndata[field] = val - def graph_type(self): - r"""Returns the graph type.""" + def graph_type(self: STGraphBase) -> str: + r"""Return the graph type.""" return "csr_unsorted" - def in_degrees(self): - r"""Returns the graph inwards node degree array.""" + def in_degrees(self: STGraphBase) -> np.ndarray: + r"""Return the graph inwards node degree array.""" return np.array(self._forward_graph.out_degrees, dtype="int32") - def out_degrees(self): - r"""Returns the graph outwards node degree array.""" + def out_degrees(self: STGraphBase) -> np.ndarray: + r"""Return the graph outwards node degree array.""" return np.array(self._forward_graph.in_degrees, dtype="int32") - def weighted_in_degrees(self): - r"""TODO:""" + def weighted_in_degrees(self: STGraphBase) -> np.ndarray: + r"""TODO:.""" return np.array(self._forward_graph.weighted_out_degrees, dtype="int32") diff --git a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py index 782d2f76..ee129dbc 100644 --- a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py +++ b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py @@ -13,7 +13,7 @@ from .utils import accuracy, generate_test_mask, generate_train_mask, to_default_device from stgraph.dataset import CoraDataLoader -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.static_graph import StaticGraph from stgraph.benchmark_tools.table import BenchmarkTable diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py index d73d4bd4..2ad4832e 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py @@ -10,7 +10,7 @@ import traceback from .model import STGraphTGCN -from stgraph.graph.static.StaticGraph import StaticGraph +from stgraph.graph.static.static_graph import StaticGraph from stgraph.dataset import WindmillOutputDataLoader from stgraph.dataset import WikiMathDataLoader From 73713c6ba935d202e690a1d7a6c84b9a5053a9ea Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Tue, 7 May 2024 22:32:58 +0530 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=A7=B9=20Lint=20checked=20stgraph.gra?= =?UTF-8?q?ph.dynamic=5Fgraph=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- stgraph/compiler/executor.py | 234 +++++++++++------- stgraph/graph/__init__.py | 2 +- .../{DynamicGraph.py => dynamic_graph.py} | 90 +++++-- stgraph/graph/dynamic/gpma/GPMAGraph.py | 16 +- stgraph/graph/dynamic/naive/NaiveGraph.py | 16 +- stgraph/graph/dynamic/pcsr/PCSRGraph.py | 73 +++--- 6 files changed, 270 insertions(+), 161 deletions(-) rename stgraph/graph/dynamic/{DynamicGraph.py => dynamic_graph.py} (64%) diff --git a/stgraph/compiler/executor.py b/stgraph/compiler/executor.py index 799fd600..eab80c26 100644 --- a/stgraph/compiler/executor.py +++ b/stgraph/compiler/executor.py @@ -1,36 +1,37 @@ from .utils import is_const_scalar, ParallelMode import snoop from collections import deque -from ..graph.dynamic.DynamicGraph import DynamicGraph +from ..graph.dynamic.dynamic_graph import DynamicGraph from stgraph.compiler.debugging.stgraph_logger import print_log import torch + class Stack: def __init__(self, val=None): self.content = deque() if val is not None: self.content.append(val) - + def push(self, val): self.content.append(val) - + def pop(self): self.content.pop() - + def top(self): return self.content[-1] - + def print(self): for elem in self.content: print(elem) - + class ExeState(object): def __init__(self): # contains tensors for all previous execution of nb_compute self.tensor_map_stack = Stack() - + # contains timestamps of graphs that were forward propagated self.graph_timestamp_stack = Stack() @@ -43,11 +44,11 @@ def __init__(self): self.dep_map = {} self.executed_bunit = set() - + def reset(self, input_map, f_merged_units, bunits): # print("ENTERING RESET") self.dep_map = {} - #for mu in f_merged_units: + # for mu in f_merged_units: # if mu.compiled(): # for ret in mu.union_of_rets(): # self.dep_map[ret.id] = 0 @@ -71,28 +72,29 @@ def reset(self, input_map, f_merged_units, bunits): self.bwd_common_tensor_list.append(arg.id) # print("End of initializing bwd_common_tensor_list\n") - #print('dependency map', self.dep_map) + # print('dependency map', self.dep_map) self.num_bunits = len(bunits) # deletes all tensors that were previously stored here (verified) - self.current_tensor_map = {key: val for key,val in input_map.items()} + self.current_tensor_map = {key: val for key, val in input_map.items()} self.executed_bunit.clear() - + def track_executed_bu(self, bu): self.executed_bunit.add(bu) - + def is_executed_bu(self, bu): return bu in self.executed_bunit - + def all_bu_executed(self): return len(self.executed_bunit) == self.num_bunits - + def track_tensor(self, key, val): self.current_tensor_map[key] = val if key in self.bunit_arg_ids: self.bwd_common_tensor_list.append(key) - + def clear_current_tensor_state(self): self.current_tensor_map = {} + # def clear_cache(self): # rmv_list = [] # for k in self.tensor_map: @@ -102,25 +104,26 @@ def clear_current_tensor_state(self): # for k in rmv_list: # self.tensor_map.pop(k) + class MergedUnit(object): def __init__(self, units): self.units = units self._joint_inputs = None self._joint_args = None - self._joint_rets = None + self._joint_rets = None self._kernel_args = None self._union_of_rets = None - + def append(self, unit): self.units.append(unit) return self - + def last(self): return self.units[-1] - + def compiled(self): return self.units[-1].compiled - + def joint_inputs(self): if not self._joint_inputs: var_set = set() @@ -140,15 +143,15 @@ def joint_rets(self): var_set = var_set - u._args self._joint_rets = [var for var in var_set] return self._joint_rets - + def joint_args(self): if not self._joint_args: var_set = set() for u in self.units: var_set = var_set.union(u._args) - self._joint_args= [var for var in var_set] + self.joint_rets() + self._joint_args = [var for var in var_set] + self.joint_rets() return self._joint_args - + def union_of_rets(self): if not self._union_of_rets: var_set = set() @@ -156,7 +159,7 @@ def union_of_rets(self): var_set = var_set.union(u._rets) self._union_of_rets = var_set return self._union_of_rets - + def kernel_arg_list(self): if not self._kernel_args: args = self.joint_args() @@ -169,22 +172,27 @@ def kernel_arg_list(self): kernel_arg.append(i) self._kernel_args.append(kernel_arg) return self._kernel_args - + def __str__(self): return str(self.units) - + def __repr__(self): return self.__str__() - - def __iter__(self): + + def __iter__(self): for unit in self.units: yield unit + class Executor(object): - def __init__(self, graph, forward_exec_units, backward_exec_units, compiled_module, rets): + def __init__( + self, graph, forward_exec_units, backward_exec_units, compiled_module, rets + ): self.forward_exec_units = self.merge_units(forward_exec_units) self.bulist = backward_exec_units - self.var2bu = self.construct_backward_mappping(self.forward_exec_units,backward_exec_units) + self.var2bu = self.construct_backward_mappping( + self.forward_exec_units, backward_exec_units + ) self._rets = rets self.ts = ExeState() self.new_zeros = None @@ -199,7 +207,7 @@ def __init__(self, graph, forward_exec_units, backward_exec_units, compiled_modu for u in self.bulist: if u.compiled: u.prepare_compiled_kernel(graph, compiled_module) - + def construct_backward_mappping(self, funits, bunits): ret = {} for mu in funits: @@ -210,68 +218,66 @@ def construct_backward_mappping(self, funits, bunits): if arg._grad in bu.unit_rets(): ret[arg] = bu return ret - + def merge_units(self, exec_units): print_log("[green bold]Executor[/green bold]: Start merging units") - - assert len(exec_units) > 0, 'Error: empty exec units' + + assert len(exec_units) > 0, "Error: empty exec units" grouped_unit = [MergedUnit([exec_units[0]])] for i in range(1, len(exec_units)): if exec_units[i].compiled == grouped_unit[-1].last().compiled: grouped_unit[-1].append(exec_units[i]) else: grouped_unit.append(MergedUnit([exec_units[i]])) - + print_log("[green bold]Executor[/green bold]: Units merging completed") return grouped_unit - + def restart(self, input_map, graph=None): # print("ENTERING RESTART") self.ts.reset(input_map, self.forward_exec_units, self.bulist) if graph != None: - + # TODO: REMOVE # TODO: getting graph of current timestamp, probably better to move # this outside the compiler - # current_timestamp = self.ts.tensor_map_stack.len() + # current_timestamp = self.ts.tensor_map_stack.len() # self.graph.get_forward_graph_for_timestamp(current_timestamp) - + for mu in self.forward_exec_units: for u in mu: if u.compiled: # TODO: (Joel) Feel like this is going to be problematic for dynamic graphs u.reset_graph_info(graph) - + # NOTE: COMMENTED OUT NOW SINCE THIS IS HANDLED IN BACKWARD_CB # for u in self.bulist: # if u.compiled: # u.reset_graph_info(graph) - + self.num_nodes = graph.get_num_nodes() self.num_edges = graph.get_num_edges() - def set_raw_ptr_cb(self, cb): self.raw_ptr = cb def set_new_zeros_cb(self, cb): self.new_zeros = cb - + def execute(self, FuncWrapper): - ''' Execute forward pass''' - for i,unit in enumerate(self.forward_exec_units): + """Execute forward pass""" + for i, unit in enumerate(self.forward_exec_units): if unit.last().compiled: self.execute_compiled(i, FuncWrapper) else: self.execute_prog(unit) ret = tuple([self.ts.current_tensor_map[ret.id] for ret in self._rets]) - + # TODO: Will need to uncomment this one line # self.ts.clear_cache() - # bytes_list = [v.numel() *4 for k,v in self.ts.tensor_map.items()] - #print('after forward', self.ts.tensor_map.keys(), ' bytes ', bytes_list, sum(bytes_list)) + # print('after forward', self.ts.tensor_map.keys(), ' bytes ', bytes_list, sum(bytes_list)) # Old position # self.ts.tensor_map_stack.push(self.ts.current_tensor_map) @@ -282,22 +288,36 @@ def execute(self, FuncWrapper): # print("Index: {}".format(index)) # print(self.ts.tensor_map_stack.content[index]) - return ret - + return ret + def create_tensor_for_vars(self, var_list): - ret_tensors = {var.id : self.new_zeros(size=[self.num_edges if var.is_edgevar() else self.num_nodes] + list(var.var_shape), - dtype=var.var_dtype, - device=var.device, - requires_grad=False) for var in var_list if var.id not in self.ts.current_tensor_map} - + ret_tensors = { + var.id: self.new_zeros( + size=[self.num_edges if var.is_edgevar() else self.num_nodes] + + list(var.var_shape), + dtype=var.var_dtype, + device=var.device, + requires_grad=False, + ) + for var in var_list + if var.id not in self.ts.current_tensor_map + } + for key, val in ret_tensors.items(): - self.ts.track_tensor(key,val) + self.ts.track_tensor(key, val) def create_tensor_for_grad_vars(self, var_list, tensor_map): - ret_tensors = {var.id : self.new_zeros(size=[self.num_edges if var.is_edgevar() else self.num_nodes] + list(var.var_shape), - dtype=var.var_dtype, - device=var.device, - requires_grad=False) for var in var_list if var.id not in tensor_map} + ret_tensors = { + var.id: self.new_zeros( + size=[self.num_edges if var.is_edgevar() else self.num_nodes] + + list(var.var_shape), + dtype=var.var_dtype, + device=var.device, + requires_grad=False, + ) + for var in var_list + if var.id not in tensor_map + } tensor_map = {**tensor_map, **ret_tensors} return tensor_map @@ -308,86 +328,118 @@ def execute_unit(self, unit, tensor_list): def execute_compiled(self, uid, FuncWrapper): units = self.forward_exec_units[uid] args = units.joint_args() - rets = units.joint_rets() + rets = units.joint_rets() for unit in units: self.create_tensor_for_vars(unit.unit_rets()) - + kernel_arg_list = units.kernel_arg_list() - ret_tensors = FuncWrapper.apply(self, uid, kernel_arg_list, rets, *[self.ts.current_tensor_map[var.id] for var in args]) + ret_tensors = FuncWrapper.apply( + self, + uid, + kernel_arg_list, + rets, + *[self.ts.current_tensor_map[var.id] for var in args], + ) # Only the return values returned by the function will have grad_fn set properly. # Therefore we need to replace the tensors in self.tensor_map with the return values - for i,ret in enumerate(rets): + for i, ret in enumerate(rets): self.ts.track_tensor(ret.id, ret_tensors[i]) - + def forward_cb(self, uid, kernel_args, rets, tensor_list): - '''FuncWrapper will call this function in forward pass''' + """FuncWrapper will call this function in forward pass""" units = self.forward_exec_units[uid] - for i,unit in enumerate(units): + for i, unit in enumerate(units): self.execute_unit(unit, [tensor_list[tidx] for tidx in kernel_args[i]]) - - self.ts.tensor_map_stack.push({key: self.ts.current_tensor_map[key] for key in self.ts.bwd_common_tensor_list}) + + self.ts.tensor_map_stack.push( + { + key: self.ts.current_tensor_map[key] + for key in self.ts.bwd_common_tensor_list + } + ) # self.ts.tensor_map_stack.push(self.ts.current_tensor_map) - - if isinstance(self.graph,DynamicGraph): + + if isinstance(self.graph, DynamicGraph): self.ts.graph_timestamp_stack.push(self.graph.current_timestamp) - + ret = tuple([self.ts.current_tensor_map[ret.id] for ret in rets]) self.ts.clear_current_tensor_state() return ret def backward_cb(self, kid, grad_list): - '''FuncWrapper will call this function in backward pass''' + """FuncWrapper will call this function in backward pass""" # print("BACKWARD CALLED") - # which backward kernel to call? un-executed kernel that has all dependency satisfied. + # which backward kernel to call? un-executed kernel that has all dependency satisfied. # We need to get the grad_map in order to properly set the variables in compiled kernels. funits = self.forward_exec_units[kid] args = funits.joint_args() rets = funits.joint_rets() inputs = funits.joint_inputs() - ret_grads = [ret._grad for ret in rets] # ret_grads corresponds vars in grad_list + ret_grads = [ + ret._grad for ret in rets + ] # ret_grads corresponds vars in grad_list tensor_map = self.ts.tensor_map_stack.top() - - if isinstance(self.graph,DynamicGraph): + + if isinstance(self.graph, DynamicGraph): current_timestamp = self.ts.graph_timestamp_stack.top() self.graph.get_backward_graph(current_timestamp) - for i,grad in enumerate(ret_grads): + for i, grad in enumerate(ret_grads): # We track the ret_grads as its value is fixed to grad_list tensor_map[grad.id] = grad_list[i] - arg_grads = [arg._grad if arg in inputs and arg.requires_grad else None for arg in args] # arg_grads corresponds to the grads of funit.unit_args + arg_grads = [ + arg._grad if arg in inputs and arg.requires_grad else None for arg in args + ] # arg_grads corresponds to the grads of funit.unit_args for bu in self.bulist: if bu.compiled: # if self.ts.is_executed_bu(bu): # continue - + # NOTE: Added to fix gpma bu.reset_graph_info(self.graph) - - tensor_map = self.create_tensor_for_grad_vars(bu.unit_rets(),tensor_map) + + tensor_map = self.create_tensor_for_grad_vars( + bu.unit_rets(), tensor_map + ) self.execute_unit(bu, [tensor_map[arg.id] for arg in bu.kernel_args()]) - + # self.ts.track_executed_bu(bu) else: # The backward pass of some forward unit may be splitted into compiled and uncompiled parts self.execute_prog([bu]) - ret = tuple([tensor_map[grad.id] if grad != None else None for grad in arg_grads] + [None for grad in ret_grads]) - + ret = tuple( + [tensor_map[grad.id] if grad != None else None for grad in arg_grads] + + [None for grad in ret_grads] + ) + del tensor_map self.ts.tensor_map_stack.pop() - - if isinstance(self.graph,DynamicGraph): + + if isinstance(self.graph, DynamicGraph): self.ts.graph_timestamp_stack.pop() - + return ret def execute_prog(self, units): current_tensor_map = self.ts.current_tensor_map self.ts.clear_current_tensor_state() - for unit in units: + for unit in units: for stmt in unit.program: - self.ts.track_tensor(stmt.ret.id, stmt.execute([current_tensor_map[arg.id] if not is_const_scalar(arg) else arg for arg in stmt.args])) \ No newline at end of file + self.ts.track_tensor( + stmt.ret.id, + stmt.execute( + [ + ( + current_tensor_map[arg.id] + if not is_const_scalar(arg) + else arg + ) + for arg in stmt.args + ] + ), + ) diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index 1b57f826..c6cbb4d5 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -1,6 +1,6 @@ """Graph representation modules provided by STGraph.""" -from stgraph.graph.dynamic.DynamicGraph import DynamicGraph +from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.gpma.GPMAGraph import GPMAGraph from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph diff --git a/stgraph/graph/dynamic/DynamicGraph.py b/stgraph/graph/dynamic/dynamic_graph.py similarity index 64% rename from stgraph/graph/dynamic/DynamicGraph.py rename to stgraph/graph/dynamic/dynamic_graph.py index afd1af8b..7427db92 100644 --- a/stgraph/graph/dynamic/DynamicGraph.py +++ b/stgraph/graph/dynamic/dynamic_graph.py @@ -1,10 +1,37 @@ -from stgraph.graph.stgraph_base import STGraphBase -from abc import abstractmethod +"""Represent Dynamic Graphs in STGraph.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import time +from abc import abstractmethod + +from stgraph.graph.stgraph_base import STGraphBase class DynamicGraph(STGraphBase): - def __init__(self, edge_list, max_num_nodes): + r"""Represent Dynamic Graphs in STGraph. + + This abstract class outlines the interface for defining a dynamic graph + used in STGraph. As of now the dynamic graph is implemented using the + following graph representation format: + + 1. Compressed Sparse Row (CSR) + 2. Packed Compressed Sparse Row (PCSR) + 3. GPMA + + """ + + def __init__( + self: DynamicGraph, + edge_list: list, + max_num_nodes: int, + ) -> None: + r"""Represent Dynamic Graphs in STGraph.""" super().__init__() self.graph_updates = {} self.max_num_nodes = max_num_nodes @@ -24,7 +51,8 @@ def __init__(self, edge_list, max_num_nodes): self._preprocess_graph_structure(edge_list) - def _preprocess_graph_structure(self, edge_list): + def _preprocess_graph_structure(self: DynamicGraph, edge_list: list) -> None: + r"""TODO:.""" edge_dict = {} for i in range(len(edge_list)): edge_set = set() @@ -48,7 +76,8 @@ def _preprocess_graph_structure(self, edge_list): "delete": deletions, } - def reset_graph(self): + def reset_graph(self: DynamicGraph) -> None: + r"""TODO:.""" self._get_cached_graph("base") self.current_timestamp = 0 @@ -56,14 +85,15 @@ def reset_graph(self): self.get_bwd_graph_time = 0 self.move_to_gpu_time = 0 - def get_graph(self, timestamp: int): + def get_graph(self: DynamicGraph, timestamp: int) -> None: + r"""TODO:.""" t0 = time.time() self._is_backprop_state = False if timestamp < self.current_timestamp: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_forward()", ) if self._get_cached_graph(timestamp - 1): @@ -75,7 +105,8 @@ def get_graph(self, timestamp: int): self.get_fwd_graph_time += time.time() - t0 - def get_backward_graph(self, timestamp: int): + def get_backward_graph(self: DynamicGraph, timestamp: int) -> None: + r"""TODO:.""" t0 = time.time() if not self._is_backprop_state: @@ -84,8 +115,8 @@ def get_backward_graph(self, timestamp: int): self._init_reverse_graph() if timestamp > self.current_timestamp: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_backward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_backward()", ) while self.current_timestamp > timestamp: @@ -94,51 +125,62 @@ def get_backward_graph(self, timestamp: int): self.get_bwd_graph_time += time.time() - t0 - def get_num_nodes(self): + def get_num_nodes(self: DynamicGraph) -> int: + r"""TODO:.""" return self.graph_attr[str(self.current_timestamp)][0] - def get_num_edges(self): + def get_num_edges(self: DynamicGraph) -> int: + r"""TODO:.""" return self.graph_attr[str(self.current_timestamp)][1] - def get_ndata(self, field): + def get_ndata(self: DynamicGraph, field: str) -> any: + r"""TODO:.""" if ( str(self.current_timestamp) in self._ndata and field in self._ndata[str(self.current_timestamp)] ): return self._ndata[str(self.current_timestamp)][field] - else: - return None - def set_ndata(self, field, val): + return None + + def set_ndata(self: DynamicGraph, field: str, val: any) -> None: + r"""TODO:.""" if str(self.current_timestamp) in self._ndata: self._ndata[str(self.current_timestamp)][field] = val else: self._ndata[str(self.current_timestamp)] = {field: val} @abstractmethod - def in_degrees(self): + def in_degrees(self: DynamicGraph) -> np.ndarray: + r"""TODO:.""" pass @abstractmethod - def out_degrees(self): + def out_degrees(self: DynamicGraph) -> np.ndarray: + r"""TODO:.""" pass @abstractmethod - def _cache_graph(self): + def _cache_graph(self: DynamicGraph) -> None: + r"""TODO:.""" pass @abstractmethod - def _get_cached_graph(self, timestamp): + def _get_cached_graph(self: DynamicGraph, timestamp: str | int) -> None: + r"""TODO:.""" pass @abstractmethod - def _update_graph_forward(self): + def _update_graph_forward(self: DynamicGraph) -> None: + r"""TODO:.""" pass @abstractmethod - def _init_reverse_graph(self): + def _init_reverse_graph(self: DynamicGraph) -> None: + r"""TODO:.""" pass @abstractmethod - def _update_graph_backward(self): + def _update_graph_backward(self: DynamicGraph) -> None: + r"""TODO:.""" pass diff --git a/stgraph/graph/dynamic/gpma/GPMAGraph.py b/stgraph/graph/dynamic/gpma/GPMAGraph.py index 69ee4c7e..a04ac6b8 100644 --- a/stgraph/graph/dynamic/gpma/GPMAGraph.py +++ b/stgraph/graph/dynamic/gpma/GPMAGraph.py @@ -4,7 +4,7 @@ import numpy as np from rich import inspect -from stgraph.graph.dynamic.DynamicGraph import DynamicGraph +from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.gpma.gpma import ( GPMA, init_gpma, @@ -15,7 +15,7 @@ free_backward_csr, get_csr_ptrs, get_out_degrees, - get_in_degrees + get_in_degrees, ) @@ -41,7 +41,9 @@ def graph_type(self): return "gpma" def _cache_graph(self): - self.graph_cache[str(self.current_timestamp)] = copy.deepcopy(self._forward_graph) + self.graph_cache[str(self.current_timestamp)] = copy.deepcopy( + self._forward_graph + ) def _get_cached_graph(self, timestamp): if timestamp == "base": @@ -56,10 +58,10 @@ def _get_cached_graph(self, timestamp): return True else: return False - + def in_degrees(self): return np.array(get_out_degrees(self._forward_graph), dtype="int32") - + def out_degrees(self): return np.array(get_in_degrees(self._forward_graph), dtype="int32") @@ -103,9 +105,7 @@ def _update_graph_backward(self): # Freeing resources from previous CSR free_backward_csr(self._forward_graph) - edge_update_t( - self._forward_graph, self.current_timestamp, revert_update=True - ) + edge_update_t(self._forward_graph, self.current_timestamp, revert_update=True) label_edges(self._forward_graph) build_backward_csr(self._forward_graph) self._get_graph_csr_ptrs() diff --git a/stgraph/graph/dynamic/naive/NaiveGraph.py b/stgraph/graph/dynamic/naive/NaiveGraph.py index 2190515b..f70a8d03 100644 --- a/stgraph/graph/dynamic/naive/NaiveGraph.py +++ b/stgraph/graph/dynamic/naive/NaiveGraph.py @@ -2,7 +2,7 @@ import numpy as np from rich import inspect -from stgraph.graph.dynamic.DynamicGraph import DynamicGraph +from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.static.csr import CSR from collections import deque import time @@ -59,7 +59,7 @@ def _prepare_edge_lst_bwd(self, edge_list): def graph_type(self): return "csr" - + def _cache_graph(self): pass @@ -67,11 +67,15 @@ def _get_cached_graph(self, timestamp): return False def in_degrees(self): - return np.array(self._forward_graph[self.current_timestamp].out_degrees, dtype="int32") - + return np.array( + self._forward_graph[self.current_timestamp].out_degrees, dtype="int32" + ) + def out_degrees(self): - return np.array(self._forward_graph[self.current_timestamp].in_degrees, dtype="int32") - + return np.array( + self._forward_graph[self.current_timestamp].in_degrees, dtype="int32" + ) + # def weighted_in_degrees(self): # return np.array(self._forward_graph[self.current_timestamp].weighted_out_degrees, dtype="float32") diff --git a/stgraph/graph/dynamic/pcsr/PCSRGraph.py b/stgraph/graph/dynamic/pcsr/PCSRGraph.py index 7c39aa73..9c9423ef 100644 --- a/stgraph/graph/dynamic/pcsr/PCSRGraph.py +++ b/stgraph/graph/dynamic/pcsr/PCSRGraph.py @@ -3,38 +3,43 @@ from rich import inspect import time -from stgraph.graph.dynamic.DynamicGraph import DynamicGraph +from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.pcsr.pcsr import PCSR + class PCSRGraph(DynamicGraph): def __init__(self, edge_list, max_num_nodes): super().__init__(edge_list, max_num_nodes) - + # Get the maximum number of edges self._get_max_num_edges() - + self._forward_graph = PCSR(self.max_num_nodes, self.max_num_edges) - self._forward_graph.edge_update_list(self.graph_updates["0"]["add"], is_reverse_edge=True) - self._forward_graph.label_edges() + self._forward_graph.edge_update_list( + self.graph_updates["0"]["add"], is_reverse_edge=True + ) + self._forward_graph.label_edges() self._forward_graph.build_csr() self._get_graph_csr_ptrs() self.graph_cache = {} self.graph_cache["base"] = copy.deepcopy(self._forward_graph) - + def _get_max_num_edges(self): updates = self.graph_updates edge_set = set() for i in range(len(updates)): - for j in range(len(updates[str(i)]["add"])): - edge_set.add(updates[str(i)]["add"][j]) + for j in range(len(updates[str(i)]["add"])): + edge_set.add(updates[str(i)]["add"][j]) self.max_num_edges = len(edge_set) - + def graph_type(self): return "pcsr" - + def _cache_graph(self): - self.graph_cache[str(self.current_timestamp)] = copy.deepcopy(self._forward_graph) + self.graph_cache[str(self.current_timestamp)] = copy.deepcopy( + self._forward_graph + ) def _get_cached_graph(self, timestamp): if timestamp == "base": @@ -49,14 +54,14 @@ def _get_cached_graph(self, timestamp): return True else: return False - + def in_degrees(self): - return np.array(self._forward_graph.out_degrees, dtype='int32') - + return np.array(self._forward_graph.out_degrees, dtype="int32") + def out_degrees(self): - return np.array(self._forward_graph.in_degrees, dtype='int32') - - def _get_graph_csr_ptrs(self): + return np.array(self._forward_graph.in_degrees, dtype="int32") + + def _get_graph_csr_ptrs(self): csr_ptrs = self._forward_graph.get_csr_ptrs() if self._is_backprop_state: self.bwd_row_offset_ptr = csr_ptrs[0] @@ -68,40 +73,46 @@ def _get_graph_csr_ptrs(self): self.fwd_column_indices_ptr = csr_ptrs[1] self.fwd_eids_ptr = csr_ptrs[2] self.fwd_node_ids_ptr = csr_ptrs[3] - + def _update_graph_forward(self): - ''' Updates the current base graph to the next timestamp - ''' + """Updates the current base graph to the next timestamp""" if str(self.current_timestamp + 1) not in self.graph_updates: - raise Exception("⏰ Invalid timestamp during STGraphBase.update_graph_forward()") - + raise Exception( + "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" + ) + graph_additions = self.graph_updates[str(self.current_timestamp + 1)]["add"] graph_deletions = self.graph_updates[str(self.current_timestamp + 1)]["delete"] self._forward_graph.edge_update_list(graph_additions, is_reverse_edge=True) - self._forward_graph.edge_update_list(graph_deletions, is_delete=True, is_reverse_edge=True) - self._forward_graph.label_edges() + self._forward_graph.edge_update_list( + graph_deletions, is_delete=True, is_reverse_edge=True + ) + self._forward_graph.label_edges() move_to_gpu_time = self._forward_graph.build_csr() self.move_to_gpu_time += move_to_gpu_time self._get_graph_csr_ptrs() - + def _init_reverse_graph(self): - ''' Generates the reverse of the base graph''' + """Generates the reverse of the base graph""" move_to_gpu_time = self._forward_graph.build_reverse_csr() self.move_to_gpu_time += move_to_gpu_time self._get_graph_csr_ptrs() def _update_graph_backward(self): if self.current_timestamp < 0: - raise Exception("⏰ Invalid timestamp during STGraphBase.update_graph_backward()") - + raise Exception( + "⏰ Invalid timestamp during STGraphBase.update_graph_backward()" + ) + graph_additions = self.graph_updates[str(self.current_timestamp)]["delete"] graph_deletions = self.graph_updates[str(self.current_timestamp)]["add"] - self._forward_graph.edge_update_list(graph_additions, is_reverse_edge=True) - self._forward_graph.edge_update_list(graph_deletions, is_delete=True, is_reverse_edge=True) + self._forward_graph.edge_update_list(graph_additions, is_reverse_edge=True) + self._forward_graph.edge_update_list( + graph_deletions, is_delete=True, is_reverse_edge=True + ) self._forward_graph.label_edges() move_to_gpu_time = self._forward_graph.build_reverse_csr() self.move_to_gpu_time += move_to_gpu_time self._get_graph_csr_ptrs() - \ No newline at end of file From 66822239c5804b1c5d162d91c727f9e9a6ed5868 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Sat, 11 May 2024 16:31:40 +0530 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=A7=B9=20Lint=20fixed=20stgraph.graph?= =?UTF-8?q?.gpma=20module?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dynamic-temporal-tgcn/seastar/train.py | 224 +++++++++++++----- stgraph/graph/__init__.py | 2 +- stgraph/graph/dynamic/dynamic_graph.py | 2 +- stgraph/graph/dynamic/gpma/__init__.py | 2 +- .../gpma/{GPMAGraph.py => gpma_graph.py} | 107 ++++++--- 5 files changed, 241 insertions(+), 96 deletions(-) rename stgraph/graph/dynamic/gpma/{GPMAGraph.py => gpma_graph.py} (57%) diff --git a/benchmarking/dynamic-temporal-tgcn/seastar/train.py b/benchmarking/dynamic-temporal-tgcn/seastar/train.py index 6afd2b8e..6ed43065 100644 --- a/benchmarking/dynamic-temporal-tgcn/seastar/train.py +++ b/benchmarking/dynamic-temporal-tgcn/seastar/train.py @@ -9,12 +9,13 @@ import os from stgraph.dataset.LinkPredDataLoader import LinkPredDataLoader from stgraph.benchmark_tools.table import BenchmarkTable -from stgraph.graph.dynamic.gpma.GPMAGraph import GPMAGraph +from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph from model import STGraphTGCN from utils import to_default_device, get_default_device + def main(args): if torch.cuda.is_available(): @@ -22,42 +23,102 @@ def main(args): else: print("😔 CUDA is not available") quit() - + # dummy object to account for initial CUDA context object Graph = None if args.type == "naive": - Graph = NaiveGraph([[(0,0)]],1) + Graph = NaiveGraph([[(0, 0)]], 1) elif args.type == "pcsr": - Graph = PCSRGraph([[(0,1)]],2) # PCSRGraph([[(0,1)]],2) + Graph = PCSRGraph([[(0, 1)]], 2) # PCSRGraph([[(0,1)]],2) elif args.type == "gpma": - Graph = GPMAGraph([[(0,0)]],1) - + Graph = GPMAGraph([[(0, 0)]], 1) + if args.dataset == "math": - dataloader = LinkPredDataLoader('dynamic-temporal', f'sx-mathoverflow-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"sx-mathoverflow-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "wikitalk": - dataloader = LinkPredDataLoader('dynamic-temporal', f'wiki-talk-temporal-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"wiki-talk-temporal-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "askubuntu": - dataloader = LinkPredDataLoader('dynamic-temporal', f'sx-askubuntu-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"sx-askubuntu-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "superuser": - dataloader = LinkPredDataLoader('dynamic-temporal', f'sx-superuser-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"sx-superuser-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "stackoverflow": - dataloader = LinkPredDataLoader('dynamic-temporal', f'sx-stackoverflow-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"sx-stackoverflow-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "reddit_title": - dataloader = LinkPredDataLoader('dynamic-temporal', f'reddit-title-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"reddit-title-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "reddit_body": - dataloader = LinkPredDataLoader('dynamic-temporal', f'reddit-body-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"reddit-body-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "email": - dataloader = LinkPredDataLoader('dynamic-temporal', f'email-eu-core-temporal-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"email-eu-core-temporal-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) elif args.dataset == "bitcoin_otc": - dataloader = LinkPredDataLoader('dynamic-temporal', f'bitcoin-otc-data-{args.slide_size}', args.cutoff_time, verbose=True, for_stgraph=True) + dataloader = LinkPredDataLoader( + "dynamic-temporal", + f"bitcoin-otc-data-{args.slide_size}", + args.cutoff_time, + verbose=True, + for_stgraph=True, + ) else: print("😔 Unrecognized dataset") quit() edge_lists = dataloader.get_edges() pos_neg_edges_lists, pos_neg_targets_lists = dataloader.get_pos_neg_edges() - pos_neg_edges_lists = [to_default_device(torch.from_numpy(pos_neg_edges)) for pos_neg_edges in pos_neg_edges_lists] - pos_neg_targets_lists = [to_default_device(torch.from_numpy(pos_neg_targets).type(torch.float32)) for pos_neg_targets in pos_neg_targets_lists] + pos_neg_edges_lists = [ + to_default_device(torch.from_numpy(pos_neg_edges)) + for pos_neg_edges in pos_neg_edges_lists + ] + pos_neg_targets_lists = [ + to_default_device(torch.from_numpy(pos_neg_targets).type(torch.float32)) + for pos_neg_targets in pos_neg_targets_lists + ] pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) @@ -89,16 +150,27 @@ def main(args): backprop_every = args.backprop_every if backprop_every == 0: backprop_every = total_timestamps - + if total_timestamps % backprop_every == 0: - num_iter = int(total_timestamps/backprop_every) + num_iter = int(total_timestamps / backprop_every) else: - num_iter = int(total_timestamps/backprop_every) + 1 + num_iter = int(total_timestamps / backprop_every) + 1 # metrics dur = [] max_gpu = [] - table = BenchmarkTable(f"(STGraph Dynamic-Temporal) TGCN on {dataloader.name} dataset", ["Epoch", "Time(s)", "MSE", "Used GPU Memory (Max MB)", "Build FWD Graph Time(s)", "Build BWD Graph Time(s)", "Move to GPU Time(s)"]) + table = BenchmarkTable( + f"(STGraph Dynamic-Temporal) TGCN on {dataloader.name} dataset", + [ + "Epoch", + "Time(s)", + "MSE", + "Used GPU Memory (Max MB)", + "Build FWD Graph Time(s)", + "Build BWD Graph Time(s)", + "Move to GPU Time(s)", + ], + ) try: # train @@ -107,7 +179,7 @@ def main(args): torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats(0) model.train() - + t0 = time.time() gpu_mem_arr = [] cost_arr = [] @@ -117,7 +189,10 @@ def main(args): optimizer.zero_grad() cost = 0 hidden_state = None - y_hat = torch.randn((dataloader.max_num_nodes, args.feat_size), device=get_default_device()) + y_hat = torch.randn( + (dataloader.max_num_nodes, args.feat_size), + device=get_default_device(), + ) G.get_graph(index * backprop_every) for k in range(backprop_every): t = index * backprop_every + k @@ -128,7 +203,10 @@ def main(args): initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used G.get_graph(t) - graph_mem_delta = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem + graph_mem_delta = ( + pynvml.nvmlDeviceGetMemoryInfo(handle).used + - initial_used_gpu_mem + ) graph_mem = graph_mem + graph_mem_delta if G.get_ndata("norm") is None: @@ -141,11 +219,11 @@ def main(args): y_hat, hidden_state = model(G, y_hat, None, hidden_state) out = model.decode(y_hat, pos_neg_edges_lists[t]).view(-1) cost = cost + criterion(out, pos_neg_targets_lists[t]) - + if cost == 0: break - - cost = cost / (backprop_every+1) + + cost = cost / (backprop_every + 1) cost.backward() optimizer.step() torch.cuda.synchronize() @@ -159,61 +237,87 @@ def main(args): dur.append(run_time_this_epoch) max_gpu.append(max(gpu_mem_arr)) - table.add_row([epoch, "{:.5f}".format(run_time_this_epoch), "{:.4f}".format(sum(cost_arr)/len(cost_arr)), "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2))), "{:.5f}".format(G.get_fwd_graph_time), "{:.5f}".format(G.get_bwd_graph_time), "{:.5f}".format(G.move_to_gpu_time)]) + table.add_row( + [ + epoch, + "{:.5f}".format(run_time_this_epoch), + "{:.4f}".format(sum(cost_arr) / len(cost_arr)), + "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2))), + "{:.5f}".format(G.get_fwd_graph_time), + "{:.5f}".format(G.get_bwd_graph_time), + "{:.5f}".format(G.move_to_gpu_time), + ] + ) table.display() - print('Average Time taken: {:6f}'.format(np.mean(dur))) + print("Average Time taken: {:6f}".format(np.mean(dur))) return np.mean(dur), (max(max_gpu) * 1.0 / (1024**2)) except RuntimeError as e: - if 'out of memory' in str(e): - table.add_row(["OOM", "OOM", "OOM", "OOM", "OOM", "OOM", "OOM"]) + if "out of memory" in str(e): + table.add_row(["OOM", "OOM", "OOM", "OOM", "OOM", "OOM", "OOM"]) table.display() else: print("😔 Something went wrong") return "OOM", "OOM" + def write_results(args, time_taken, max_gpu): cutoff = "whole" if args.cutoff_time < sys.maxsize: cutoff = str(args.cutoff_time) file_name = f"stgraph_{args.type}_{args.dataset}_T{cutoff}_S{args.slide_size}_B{args.backprop_every}_H{args.num_hidden}_F{args.feat_size}" - df_data = pd.DataFrame([{'Filename': file_name, 'Time Taken (s)': time_taken, 'Max GPU Usage (MB)': max_gpu}]) - - if os.path.exists('../../results/dynamic-temporal.csv'): - df = pd.read_csv('../../results/dynamic-temporal.csv') + df_data = pd.DataFrame( + [ + { + "Filename": file_name, + "Time Taken (s)": time_taken, + "Max GPU Usage (MB)": max_gpu, + } + ] + ) + + if os.path.exists("../../results/dynamic-temporal.csv"): + df = pd.read_csv("../../results/dynamic-temporal.csv") df = pd.concat([df, df_data]) else: df = df_data - - df.to_csv('../../results/dynamic-temporal.csv', sep=',', index=False, encoding='utf-8') -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='STGraph Static TGCN') + df.to_csv( + "../../results/dynamic-temporal.csv", sep=",", index=False, encoding="utf-8" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="STGraph Static TGCN") snoop.install(enabled=False) - parser.add_argument("--dataset", type=str, default="math", - help="Name of the Dataset (math, wikitalk, askubuntu, superuser, stackoverflow, email, bitcoin_otc, reddit_title, reddit_body)") - parser.add_argument("--slide-size", type=str, default="1.0", - help="Slide Size") - parser.add_argument("--type", type=str, default="naive", - help="STGraph Type") - parser.add_argument("--backprop-every", type=int, default=0, - help="Feature size of nodes") - parser.add_argument("--feat-size", type=int, default=8, - help="Feature size of nodes") - parser.add_argument("--num-hidden", type=int, default=100, - help="Number of hidden units") - parser.add_argument("--lr", type=float, default=1e-2, - help="learning rate") - parser.add_argument("--cutoff-time", type=int, default=sys.maxsize, - help="cutoff time") - parser.add_argument("--num-epochs", type=int, default=1, - help="number of training epochs") + parser.add_argument( + "--dataset", + type=str, + default="math", + help="Name of the Dataset (math, wikitalk, askubuntu, superuser, stackoverflow, email, bitcoin_otc, reddit_title, reddit_body)", + ) + parser.add_argument("--slide-size", type=str, default="1.0", help="Slide Size") + parser.add_argument("--type", type=str, default="naive", help="STGraph Type") + parser.add_argument( + "--backprop-every", type=int, default=0, help="Feature size of nodes" + ) + parser.add_argument( + "--feat-size", type=int, default=8, help="Feature size of nodes" + ) + parser.add_argument( + "--num-hidden", type=int, default=100, help="Number of hidden units" + ) + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument( + "--cutoff-time", type=int, default=sys.maxsize, help="cutoff time" + ) + parser.add_argument( + "--num-epochs", type=int, default=1, help="number of training epochs" + ) args = parser.parse_args() - + print(args) time_taken, max_gpu = main(args) write_results(args, time_taken, max_gpu) - - \ No newline at end of file diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index c6cbb4d5..82986501 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -1,7 +1,7 @@ """Graph representation modules provided by STGraph.""" from stgraph.graph.dynamic.dynamic_graph import DynamicGraph -from stgraph.graph.dynamic.gpma.GPMAGraph import GPMAGraph +from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph from stgraph.graph.static.static_graph import StaticGraph diff --git a/stgraph/graph/dynamic/dynamic_graph.py b/stgraph/graph/dynamic/dynamic_graph.py index 7427db92..0c33110c 100644 --- a/stgraph/graph/dynamic/dynamic_graph.py +++ b/stgraph/graph/dynamic/dynamic_graph.py @@ -166,7 +166,7 @@ def _cache_graph(self: DynamicGraph) -> None: pass @abstractmethod - def _get_cached_graph(self: DynamicGraph, timestamp: str | int) -> None: + def _get_cached_graph(self: DynamicGraph, timestamp: str | int) -> bool: r"""TODO:.""" pass diff --git a/stgraph/graph/dynamic/gpma/__init__.py b/stgraph/graph/dynamic/gpma/__init__.py index 741218bf..f23e4442 100644 --- a/stgraph/graph/dynamic/gpma/__init__.py +++ b/stgraph/graph/dynamic/gpma/__init__.py @@ -1 +1 @@ -'''GPMA data structure to represent Dynamic Graphs in GPU''' \ No newline at end of file +"""Represent Dynamic Graphs using GPMA in STGraph.""" diff --git a/stgraph/graph/dynamic/gpma/GPMAGraph.py b/stgraph/graph/dynamic/gpma/gpma_graph.py similarity index 57% rename from stgraph/graph/dynamic/gpma/GPMAGraph.py rename to stgraph/graph/dynamic/gpma/gpma_graph.py index a04ac6b8..abaaa7fc 100644 --- a/stgraph/graph/dynamic/gpma/GPMAGraph.py +++ b/stgraph/graph/dynamic/gpma/gpma_graph.py @@ -1,26 +1,60 @@ +"""Represent Dynamic Graphs using GPMA in STGraph.""" + +from __future__ import annotations + import copy -import time import numpy as np -from rich import inspect from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.gpma.gpma import ( GPMA, - init_gpma, - init_graph_updates, - edge_update_t, - label_edges, build_backward_csr, + edge_update_t, free_backward_csr, get_csr_ptrs, - get_out_degrees, get_in_degrees, + get_out_degrees, + init_gpma, + init_graph_updates, + label_edges, ) class GPMAGraph(DynamicGraph): - def __init__(self, edge_list, max_num_nodes): + r"""Represent Dynamic Graphs using GPMA in STGraph. + + TODO: Add a paragraph explaining about GPMA in brief. + + Example: + -------- + .. code-block:: python + + from stgraph.graph import GPMAGraph + from stgraph.dataset import EnglandCovidDataLoader + + eng_covid = EnglandCovidDataLoader() + + G = GPMAGraph( + edge_list = eng_covid.get_edges(), + max_num_nodes = max(eng_covid.gdata["num_nodes"]), + ) + + Parameters + ---------- + edge_list : list + Edge list of the graph across all timestamps + max_num_nodes : int + Maximum number of nodes present in the graph across all timestamps + + Attributes + ---------- + TODO:. + + """ + + def __init__(self: GPMAGraph, edge_list: list, max_num_nodes: int) -> None: + r"""Represent Dynamic Graphs using GPMA in STGraph.""" super().__init__(edge_list, max_num_nodes) # forward graph for GPMA @@ -37,36 +71,41 @@ def __init__(self, edge_list, max_num_nodes): self.graph_cache = {} self.graph_cache["base"] = copy.deepcopy(self._forward_graph) - def graph_type(self): + def graph_type(self: GPMAGraph) -> str: + r"""Return the graph type.""" return "gpma" - def _cache_graph(self): + def _cache_graph(self: GPMAGraph) -> None: + r"""TODO:.""" self.graph_cache[str(self.current_timestamp)] = copy.deepcopy( - self._forward_graph + self._forward_graph, ) - def _get_cached_graph(self, timestamp): + def _get_cached_graph(self: GPMAGraph, timestamp: int | str) -> bool: + r"""TODO:.""" if timestamp == "base": self._forward_graph = copy.deepcopy(self.graph_cache["base"]) self._get_graph_csr_ptrs() return True - else: - if str(timestamp) in self.graph_cache: - self._forward_graph = self.graph_cache[str(timestamp)] - del self.graph_cache[str(timestamp)] - self._get_graph_csr_ptrs() - return True - else: - return False - - def in_degrees(self): + + if str(timestamp) in self.graph_cache: + self._forward_graph = self.graph_cache[str(timestamp)] + del self.graph_cache[str(timestamp)] + self._get_graph_csr_ptrs() + return True + + return False + + def in_degrees(self: GPMAGraph) -> np.ndarray: + r"""TODO:.""" return np.array(get_out_degrees(self._forward_graph), dtype="int32") - def out_degrees(self): + def out_degrees(self: GPMAGraph) -> np.ndarray: + r"""TODO:.""" return np.array(get_in_degrees(self._forward_graph), dtype="int32") - def _get_graph_csr_ptrs(self): - + def _get_graph_csr_ptrs(self: GPMAGraph) -> None: + r"""TODO:.""" forward_csr_ptrs = get_csr_ptrs(self._forward_graph) self.fwd_row_offset_ptr = forward_csr_ptrs[0] self.fwd_column_indices_ptr = forward_csr_ptrs[1] @@ -80,27 +119,29 @@ def _get_graph_csr_ptrs(self): self.bwd_eids_ptr = backward_csr_ptrs[2] self.bwd_node_ids_ptr = backward_csr_ptrs[3] - def _update_graph_forward(self): + def _update_graph_forward(self: GPMAGraph) -> None: + r"""TODO:.""" # if we went through the entire time-stamps if str(self.current_timestamp + 1) not in self.graph_updates: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_forward()", ) edge_update_t(self._forward_graph, self.current_timestamp + 1) label_edges(self._forward_graph) self._get_graph_csr_ptrs() - def _init_reverse_graph(self): - """Generates the reverse of the base graph""" + def _init_reverse_graph(self: GPMAGraph) -> None: + r"""Generate the reverse of the base graph.""" free_backward_csr(self._forward_graph) build_backward_csr(self._forward_graph) self._get_graph_csr_ptrs() - def _update_graph_backward(self): + def _update_graph_backward(self: GPMAGraph) -> None: + r"""TODO:.""" if self.current_timestamp < 0: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_backward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_backward()", ) # Freeing resources from previous CSR From 265e1f8fa503f96a6daab7c11241df516d2a6f69 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Sat, 11 May 2024 19:19:27 +0530 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=A7=B9=20Lint=20fixed=20stgraph.graph?= =?UTF-8?q?.naive?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dynamic-temporal-tgcn/seastar/train.py | 2 +- stgraph/graph/__init__.py | 2 +- stgraph/graph/dynamic/naive/__init__.py | 1 + .../naive/{NaiveGraph.py => naive_graph.py} | 92 +++++++++++++------ 4 files changed, 68 insertions(+), 29 deletions(-) rename stgraph/graph/dynamic/naive/{NaiveGraph.py => naive_graph.py} (60%) diff --git a/benchmarking/dynamic-temporal-tgcn/seastar/train.py b/benchmarking/dynamic-temporal-tgcn/seastar/train.py index 6ed43065..a6777463 100644 --- a/benchmarking/dynamic-temporal-tgcn/seastar/train.py +++ b/benchmarking/dynamic-temporal-tgcn/seastar/train.py @@ -11,7 +11,7 @@ from stgraph.benchmark_tools.table import BenchmarkTable from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph -from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph +from stgraph.graph.dynamic.naive.naive_graph import NaiveGraph from model import STGraphTGCN from utils import to_default_device, get_default_device diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index 82986501..30eef206 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -2,7 +2,7 @@ from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph -from stgraph.graph.dynamic.naive.NaiveGraph import NaiveGraph +from stgraph.graph.dynamic.naive.naive_graph import NaiveGraph from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph from stgraph.graph.static.static_graph import StaticGraph from stgraph.graph.stgraph_base import STGraphBase diff --git a/stgraph/graph/dynamic/naive/__init__.py b/stgraph/graph/dynamic/naive/__init__.py index e69de29b..ceeb7cd2 100644 --- a/stgraph/graph/dynamic/naive/__init__.py +++ b/stgraph/graph/dynamic/naive/__init__.py @@ -0,0 +1 @@ +"""Represent Dynamic Graphs using CSR in STGraph.""" diff --git a/stgraph/graph/dynamic/naive/NaiveGraph.py b/stgraph/graph/dynamic/naive/naive_graph.py similarity index 60% rename from stgraph/graph/dynamic/naive/NaiveGraph.py rename to stgraph/graph/dynamic/naive/naive_graph.py index f70a8d03..7fe1cdd0 100644 --- a/stgraph/graph/dynamic/naive/NaiveGraph.py +++ b/stgraph/graph/dynamic/naive/naive_graph.py @@ -1,17 +1,51 @@ +"""Represent Dynamic Graphs using CSR in STGraph.""" + +from __future__ import annotations + import copy + import numpy as np -from rich import inspect from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.static.csr import CSR -from collections import deque -import time class NaiveGraph(DynamicGraph): - def __init__(self, edge_list, max_num_nodes): + r"""Represent Dynamic Graphs using CSR in STGraph. + + TODO: Add a paragraph explaining about GPMA in brief. + + Example: + -------- + .. code-block:: python + + from stgraph.graph import NaiveGraph + from stgraph.dataset import EnglandCovidDataLoader + + eng_covid = EnglandCovidDataLoader() + + G = NaiveGraph( + edge_list = eng_covid.get_edges(), + max_num_nodes = max(eng_covid.gdata["num_nodes"]), + ) + + Parameters + ---------- + edge_list : list + Edge list of the graph across all timestamps + max_num_nodes : int + Maximum number of nodes present in the graph across all timestamps + + Attributes + ---------- + TODO:. + + """ + + def __init__(self: NaiveGraph, edge_list: list, max_num_nodes: int) -> None: + r"""Represent Dynamic Graphs using CSR in STGraph.""" super().__init__(edge_list, max_num_nodes) - # inspect(edge_list) + self._prepare_edge_lst_fwd(edge_list) self._prepare_edge_lst_bwd(self.fwd_edge_list) @@ -39,7 +73,8 @@ def __init__(self, edge_list, max_num_nodes): self._get_graph_csr_ptrs(0) - def _prepare_edge_lst_fwd(self, edge_list): + def _prepare_edge_lst_fwd(self: NaiveGraph, edge_list: list) -> None: + r"""TODO:.""" self.fwd_edge_list = [] for i in range(len(edge_list)): edge_list_for_t = edge_list[i] @@ -50,36 +85,38 @@ def _prepare_edge_lst_fwd(self, edge_list): ] self.fwd_edge_list.append(edge_list_for_t) - def _prepare_edge_lst_bwd(self, edge_list): + def _prepare_edge_lst_bwd(self: NaiveGraph, edge_list: list) -> None: + r"""TODO:.""" self.bwd_edge_list = [] for i in range(len(edge_list)): edge_list_for_t = copy.deepcopy(edge_list[i]) edge_list_for_t.sort() self.bwd_edge_list.append(edge_list_for_t) - def graph_type(self): + def graph_type(self: NaiveGraph) -> str: + r"""Return the graph type.""" return "csr" - def _cache_graph(self): + def _cache_graph(self: NaiveGraph) -> None: pass - def _get_cached_graph(self, timestamp): + def _get_cached_graph(self: NaiveGraph) -> None: return False - def in_degrees(self): + def in_degrees(self: NaiveGraph) -> np.ndarray: + r"""TODO:.""" return np.array( - self._forward_graph[self.current_timestamp].out_degrees, dtype="int32" + self._forward_graph[self.current_timestamp].out_degrees, dtype="int32", ) - def out_degrees(self): + def out_degrees(self: NaiveGraph) -> np.ndarray: + r"""TODO:.""" return np.array( - self._forward_graph[self.current_timestamp].in_degrees, dtype="int32" + self._forward_graph[self.current_timestamp].in_degrees, dtype="int32", ) - # def weighted_in_degrees(self): - # return np.array(self._forward_graph[self.current_timestamp].weighted_out_degrees, dtype="float32") - - def _get_graph_csr_ptrs(self, timestamp): + def _get_graph_csr_ptrs(self: NaiveGraph, timestamp: int) -> None: + r"""TODO:.""" if self._is_backprop_state: bwd_csr_ptrs = self._backward_graph[timestamp] self.bwd_row_offset_ptr = bwd_csr_ptrs.row_offset_ptr @@ -93,21 +130,22 @@ def _get_graph_csr_ptrs(self, timestamp): self.fwd_eids_ptr = fwd_csr_ptrs.eids_ptr self.fwd_node_ids_ptr = fwd_csr_ptrs.node_ids_ptr - def _update_graph_forward(self): - """Updates the current base graph to the next timestamp""" + def _update_graph_forward(self: NaiveGraph) -> None: + """Update the current base graph to the next timestamp.""" if str(self.current_timestamp + 1) not in self.graph_updates: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_forward()", ) self._get_graph_csr_ptrs(self.current_timestamp + 1) - def _init_reverse_graph(self): - """Generates the reverse of the base graph""" + def _init_reverse_graph(self: NaiveGraph) -> None: + """Generate the reverse of the base graph.""" self._get_graph_csr_ptrs(self.current_timestamp) - def _update_graph_backward(self): + def _update_graph_backward(self: NaiveGraph) -> None: + r"""TODO:.""" if self.current_timestamp < 0: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_backward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_backward()", ) self._get_graph_csr_ptrs(self.current_timestamp - 1) From d410eff18ecf600e4db250ce5156b658aa1210b6 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Sat, 11 May 2024 19:45:02 +0530 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=A7=B9=20Lint=20fixed=20stgraph.graph?= =?UTF-8?q?.pcsr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dynamic-temporal-tgcn/seastar/train.py | 2 +- stgraph/graph/__init__.py | 2 +- stgraph/graph/dynamic/pcsr/__init__.py | 1 + .../pcsr/{PCSRGraph.py => pcsr_graph.py} | 112 +++++++++++++----- 4 files changed, 83 insertions(+), 34 deletions(-) rename stgraph/graph/dynamic/pcsr/{PCSRGraph.py => pcsr_graph.py} (56%) diff --git a/benchmarking/dynamic-temporal-tgcn/seastar/train.py b/benchmarking/dynamic-temporal-tgcn/seastar/train.py index a6777463..c79fcc2e 100644 --- a/benchmarking/dynamic-temporal-tgcn/seastar/train.py +++ b/benchmarking/dynamic-temporal-tgcn/seastar/train.py @@ -10,7 +10,7 @@ from stgraph.dataset.LinkPredDataLoader import LinkPredDataLoader from stgraph.benchmark_tools.table import BenchmarkTable from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph -from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph +from stgraph.graph.dynamic.pcsr.pcsr_graph import PCSRGraph from stgraph.graph.dynamic.naive.naive_graph import NaiveGraph from model import STGraphTGCN from utils import to_default_device, get_default_device diff --git a/stgraph/graph/__init__.py b/stgraph/graph/__init__.py index 30eef206..1b4b46c5 100644 --- a/stgraph/graph/__init__.py +++ b/stgraph/graph/__init__.py @@ -3,6 +3,6 @@ from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.gpma.gpma_graph import GPMAGraph from stgraph.graph.dynamic.naive.naive_graph import NaiveGraph -from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph +from stgraph.graph.dynamic.pcsr.pcsr_graph import PCSRGraph from stgraph.graph.static.static_graph import StaticGraph from stgraph.graph.stgraph_base import STGraphBase diff --git a/stgraph/graph/dynamic/pcsr/__init__.py b/stgraph/graph/dynamic/pcsr/__init__.py index e69de29b..3aa80822 100644 --- a/stgraph/graph/dynamic/pcsr/__init__.py +++ b/stgraph/graph/dynamic/pcsr/__init__.py @@ -0,0 +1 @@ +"""Represent Dynamic Graphs using PCSR in STGraph.""" diff --git a/stgraph/graph/dynamic/pcsr/PCSRGraph.py b/stgraph/graph/dynamic/pcsr/pcsr_graph.py similarity index 56% rename from stgraph/graph/dynamic/pcsr/PCSRGraph.py rename to stgraph/graph/dynamic/pcsr/pcsr_graph.py index 9c9423ef..3f6facff 100644 --- a/stgraph/graph/dynamic/pcsr/PCSRGraph.py +++ b/stgraph/graph/dynamic/pcsr/pcsr_graph.py @@ -1,14 +1,49 @@ +"""Represent Dynamic Graphs using PCSR in STGraph.""" + +from __future__ import annotations + import copy + import numpy as np -from rich import inspect -import time from stgraph.graph.dynamic.dynamic_graph import DynamicGraph from stgraph.graph.dynamic.pcsr.pcsr import PCSR class PCSRGraph(DynamicGraph): - def __init__(self, edge_list, max_num_nodes): + r"""Represent Dynamic Graphs using PCSR in STGraph. + + TODO: Add a paragraph explaining about PCSR in brief. + + Example: + -------- + .. code-block:: python + + from stgraph.graph import PCSRGraph + from stgraph.dataset import EnglandCovidDataLoader + + eng_covid = EnglandCovidDataLoader() + + G = PCSRGraph( + edge_list = eng_covid.get_edges(), + max_num_nodes = max(eng_covid.gdata["num_nodes"]), + ) + + Parameters + ---------- + edge_list : list + Edge list of the graph across all timestamps + max_num_nodes : int + Maximum number of nodes present in the graph across all timestamps + + Attributes + ---------- + TODO:. + + """ + + def __init__(self: PCSRGraph, edge_list: list, max_num_nodes: int) -> None: + r"""Represent Dynamic Graphs using PCSR in STGraph.""" super().__init__(edge_list, max_num_nodes) # Get the maximum number of edges @@ -16,7 +51,8 @@ def __init__(self, edge_list, max_num_nodes): self._forward_graph = PCSR(self.max_num_nodes, self.max_num_edges) self._forward_graph.edge_update_list( - self.graph_updates["0"]["add"], is_reverse_edge=True + self.graph_updates["0"]["add"], + is_reverse_edge=True, ) self._forward_graph.label_edges() self._forward_graph.build_csr() @@ -25,7 +61,8 @@ def __init__(self, edge_list, max_num_nodes): self.graph_cache = {} self.graph_cache["base"] = copy.deepcopy(self._forward_graph) - def _get_max_num_edges(self): + def _get_max_num_edges(self: PCSRGraph) -> None: + r"""TODO:.""" updates = self.graph_updates edge_set = set() for i in range(len(updates)): @@ -33,35 +70,41 @@ def _get_max_num_edges(self): edge_set.add(updates[str(i)]["add"][j]) self.max_num_edges = len(edge_set) - def graph_type(self): + def graph_type(self: PCSRGraph) -> str: + r"""Return the graph type.""" return "pcsr" - def _cache_graph(self): + def _cache_graph(self: PCSRGraph) -> None: + r"""TODO:.""" self.graph_cache[str(self.current_timestamp)] = copy.deepcopy( - self._forward_graph + self._forward_graph, ) - def _get_cached_graph(self, timestamp): + def _get_cached_graph(self: PCSRGraph, timestamp: int | str) -> bool: + r"""TODO:.""" if timestamp == "base": self._forward_graph = copy.deepcopy(self.graph_cache["base"]) self._get_graph_csr_ptrs() return True - else: - if str(timestamp) in self.graph_cache: - self._forward_graph = self.graph_cache[str(timestamp)] - del self.graph_cache[str(timestamp)] - self._get_graph_csr_ptrs() - return True - else: - return False - - def in_degrees(self): + + if str(timestamp) in self.graph_cache: + self._forward_graph = self.graph_cache[str(timestamp)] + del self.graph_cache[str(timestamp)] + self._get_graph_csr_ptrs() + return True + + return False + + def in_degrees(self: PCSRGraph) -> np.ndarray: + r"""TODO:.""" return np.array(self._forward_graph.out_degrees, dtype="int32") - def out_degrees(self): + def out_degrees(self: PCSRGraph) -> np.ndarray: + r"""TODO:.""" return np.array(self._forward_graph.in_degrees, dtype="int32") - def _get_graph_csr_ptrs(self): + def _get_graph_csr_ptrs(self: PCSRGraph) -> None: + r"""TODO:.""" csr_ptrs = self._forward_graph.get_csr_ptrs() if self._is_backprop_state: self.bwd_row_offset_ptr = csr_ptrs[0] @@ -74,11 +117,11 @@ def _get_graph_csr_ptrs(self): self.fwd_eids_ptr = csr_ptrs[2] self.fwd_node_ids_ptr = csr_ptrs[3] - def _update_graph_forward(self): - """Updates the current base graph to the next timestamp""" + def _update_graph_forward(self: PCSRGraph) -> None: + r"""Update the current base graph to the next timestamp.""" if str(self.current_timestamp + 1) not in self.graph_updates: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_forward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_forward()", ) graph_additions = self.graph_updates[str(self.current_timestamp + 1)]["add"] @@ -86,23 +129,26 @@ def _update_graph_forward(self): self._forward_graph.edge_update_list(graph_additions, is_reverse_edge=True) self._forward_graph.edge_update_list( - graph_deletions, is_delete=True, is_reverse_edge=True + graph_deletions, + is_delete=True, + is_reverse_edge=True, ) self._forward_graph.label_edges() move_to_gpu_time = self._forward_graph.build_csr() self.move_to_gpu_time += move_to_gpu_time self._get_graph_csr_ptrs() - def _init_reverse_graph(self): - """Generates the reverse of the base graph""" + def _init_reverse_graph(self: PCSRGraph) -> None: + """Generate the reverse of the base graph.""" move_to_gpu_time = self._forward_graph.build_reverse_csr() self.move_to_gpu_time += move_to_gpu_time self._get_graph_csr_ptrs() - def _update_graph_backward(self): + def _update_graph_backward(self: PCSRGraph) -> None: + r"""TODO:.""" if self.current_timestamp < 0: - raise Exception( - "⏰ Invalid timestamp during STGraphBase.update_graph_backward()" + raise RuntimeError( + "⏰ Invalid timestamp during STGraphBase.update_graph_backward()", ) graph_additions = self.graph_updates[str(self.current_timestamp)]["delete"] @@ -110,7 +156,9 @@ def _update_graph_backward(self): self._forward_graph.edge_update_list(graph_additions, is_reverse_edge=True) self._forward_graph.edge_update_list( - graph_deletions, is_delete=True, is_reverse_edge=True + graph_deletions, + is_delete=True, + is_reverse_edge=True, ) self._forward_graph.label_edges() move_to_gpu_time = self._forward_graph.build_reverse_csr() From 033fba7c82c842059082e5b6962770c6210799d9 Mon Sep 17 00:00:00 2001 From: nithinmanoj10 Date: Sat, 11 May 2024 21:29:14 +0530 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=94=81=20Added=20stgraph.graph=20modu?= =?UTF-8?q?le=20check=20to=20ruff.yaml?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now onwards, the automated Ruff linting check executed using GitHub action will run on the stgraph.graph module. --- .github/workflows/ruff.yaml | 3 +++ stgraph/graph/dynamic/__init__.py | 1 + 2 files changed, 4 insertions(+) diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml index 94f35cdc..23e3c7eb 100644 --- a/.github/workflows/ruff.yaml +++ b/.github/workflows/ruff.yaml @@ -22,4 +22,7 @@ jobs: run: | cd stgraph/dataset/ ruff check . + cd ../../ + cd stgraph/graph + ruff check . cd ../../ \ No newline at end of file diff --git a/stgraph/graph/dynamic/__init__.py b/stgraph/graph/dynamic/__init__.py index e69de29b..fff87265 100644 --- a/stgraph/graph/dynamic/__init__.py +++ b/stgraph/graph/dynamic/__init__.py @@ -0,0 +1 @@ +"""Dynamic Graph respresentation modules provided by STGraph."""