diff --git a/thicket/tests/test_add_root_node.py b/thicket/tests/test_add_root_node.py new file mode 100644 index 00000000..e5b81747 --- /dev/null +++ b/thicket/tests/test_add_root_node.py @@ -0,0 +1,29 @@ +# Copyright 2022 Lawrence Livermore National Security, LLC and other +# Thicket Project Developers. See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +from hatchet.node import Node + + +def test_add_root_node(literal_thickets): + tk, _, _ = literal_thickets + + assert len(tk.graph) == 4 + + # Call add_root_node + tk.add_root_node({"name": "Test", "type": "function"}) + # Get node variable + test_node = tk.get_node("Test") + + # Check if node was inserted in all components + assert isinstance(test_node, Node) + assert test_node._hatchet_nid == 3 + assert test_node._depth == 0 + assert len(tk.graph) == 5 + assert len(tk.statsframe.graph) == 5 + assert test_node in tk.dataframe.index.get_level_values("node") + assert test_node in tk.statsframe.dataframe.index.get_level_values("node") + + assert tk.dataframe.loc[test_node, "name"].values[0] == "Test" + assert tk.statsframe.dataframe.loc[test_node, "name"] == "Test" diff --git a/thicket/tests/test_get_node.py b/thicket/tests/test_get_node.py new file mode 100644 index 00000000..06b0075a --- /dev/null +++ b/thicket/tests/test_get_node.py @@ -0,0 +1,20 @@ +# Copyright 2022 Lawrence Livermore National Security, LLC and other +# Thicket Project Developers. See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +import pytest + + +def test_get_node(literal_thickets): + tk, _, _ = literal_thickets + + with pytest.raises(KeyError): + tk.get_node("Foo") + + baz = tk.get_node("Baz") + + # Check node properties + assert baz.frame["name"] == "Baz" + assert baz.frame["type"] == "function" + assert baz._hatchet_nid == 0 diff --git a/thicket/thicket.py b/thicket/thicket.py index 8d0e2a33..afb30e9c 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -16,7 +16,9 @@ import pandas as pd import numpy as np from hatchet import GraphFrame +from hatchet.frame import Frame from hatchet.graph import Graph +from hatchet.node import Node from hatchet.query import QueryEngine from thicket.query import ( Query, @@ -1514,6 +1516,57 @@ def get_unique_metadata(self): return sorted_meta + def add_root_node(self, attrs): + """Add node at root level with given attributes. + + Arguments: + attrs (dict): attributes for the new node which will be used to initilize the + node.frame. + """ + + new_node = Node(frame_obj=Frame(attrs=attrs)) + + # graph and statsframe.graph + self.graph.roots.append(new_node) + + # Set hatchet nid and depth + self.graph.enumerate_traverse() + + # dataframe + idx_levels = self.dataframe.index.names + new_idx = [[new_node]] + [self.profile] + new_node_df = pd.DataFrame( + index=pd.MultiIndex.from_product(new_idx, names=idx_levels) + ) + new_node_df["name"] = attrs["name"] + self.dataframe = pd.concat([self.dataframe, new_node_df]) + + # statsframe.dataframe + self.statsframe.dataframe = helpers._new_statsframe_df(self.dataframe) + # Reapply stats operations after clearing statsframe dataframe + self.reapply_stats_operations() + + # Check Thicket state + validate_nodes(self) + + def get_node(self, name): + """Get a node object in the Thicket by its Node.frame['name']. If more than one + node has the same name, a list of nodes is returned. + + Arguments: + name (str): name of the node (Node.frame['name']). + + Returns: + (Node or list(Node)): Node object with the given name or list of Node objects + with the given name. + """ + node = [n for n in self.graph.traverse() if n.frame["name"] == name] + + if len(node) == 0: + raise KeyError(f'Node with name "{name}" not found.') + + return node[0] if len(node) == 1 else node + def _sync_profile_components(self, component): """Synchronize the Performance DataFrame, Metadata Dataframe, profile and profile mapping objects based on the component's index or a list of profiles.