Skip to content

Commit 148b3c2

Browse files
authored
[networkx] Add generic type for node and edge data (#15660)
1 parent f342822 commit 148b3c2

13 files changed

Lines changed: 191 additions & 151 deletions

File tree

stubs/networkx/networkx/algorithms/clique.pyi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from _typeshed import Incomplete
22
from collections.abc import Generator, Iterable, Iterator
33
from typing import overload
44

5-
from networkx.classes.graph import Graph, _Node
5+
from networkx.classes.graph import Graph, _EdgeData, _Node, _NodeData
66
from networkx.utils.backends import _dispatchable
77

88
__all__ = [
@@ -23,10 +23,15 @@ def find_cliques(G: Graph[_Node], nodes: Iterable[Incomplete] | None = None) ->
2323
@_dispatchable
2424
def find_cliques_recursive(G: Graph[_Node], nodes: Iterable[Incomplete] | None = None) -> Iterator[list[_Node]]: ...
2525
@_dispatchable
26-
def make_max_clique_graph(G: Graph[_Node], create_using: Graph[_Node] | None = None) -> Graph[_Node]: ...
26+
def make_max_clique_graph(
27+
G: Graph[_Node], create_using: Graph[_Node, _NodeData, _EdgeData] | None = None
28+
) -> Graph[_Node, _NodeData, _EdgeData]: ...
2729
@_dispatchable
2830
def make_clique_bipartite(
29-
G: Graph[_Node], fpos: bool | None = None, create_using: Graph[_Node] | None = None, name=None
31+
G: Graph[_Node, _NodeData, _EdgeData],
32+
fpos: bool | None = None,
33+
create_using: Graph[_Node, _NodeData, _EdgeData] | None = None,
34+
name=None,
3035
) -> Graph[_Node]: ...
3136
@overload
3237
def node_clique_number(

stubs/networkx/networkx/algorithms/dag.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from _typeshed import Incomplete
22
from collections.abc import Callable, Generator, Iterable
33

44
from networkx.classes.digraph import DiGraph
5-
from networkx.classes.graph import Graph, _Node
5+
from networkx.classes.graph import Graph, _EdgeData, _Node, _NodeData
66
from networkx.utils.backends import _dispatchable
77

88
__all__ = [
@@ -40,11 +40,13 @@ def all_topological_sorts(G: DiGraph[_Node]) -> Generator[list[_Node]]: ...
4040
@_dispatchable
4141
def is_aperiodic(G: DiGraph[_Node]) -> bool: ...
4242
@_dispatchable
43-
def transitive_closure(G: Graph[_Node], reflexive=False) -> Graph[_Node]: ...
43+
def transitive_closure(G: Graph[_Node, _NodeData, _EdgeData], reflexive=False) -> Graph[_Node, _NodeData, _EdgeData]: ...
4444
@_dispatchable
45-
def transitive_closure_dag(G: DiGraph[_Node], topo_order: Iterable[Incomplete] | None = None) -> DiGraph[_Node]: ...
45+
def transitive_closure_dag(
46+
G: DiGraph[_Node, _NodeData, _EdgeData], topo_order: Iterable[Incomplete] | None = None
47+
) -> DiGraph[_Node, _NodeData, _EdgeData]: ...
4648
@_dispatchable
47-
def transitive_reduction(G: DiGraph[_Node]) -> DiGraph[_Node]: ...
49+
def transitive_reduction(G: DiGraph[_Node, _NodeData, _EdgeData]) -> DiGraph[_Node, _NodeData, _EdgeData]: ...
4850
@_dispatchable
4951
def antichains(G: DiGraph[_Node], topo_order: Iterable[Incomplete] | None = None) -> Generator[list[_Node]]: ...
5052
@_dispatchable
@@ -57,4 +59,4 @@ def dag_longest_path(
5759
@_dispatchable
5860
def dag_longest_path_length(G: DiGraph[_Node], weight: str | None = "weight", default_weight: int | None = 1) -> int: ...
5961
@_dispatchable
60-
def dag_to_branching(G: DiGraph[_Node]) -> DiGraph[_Node]: ...
62+
def dag_to_branching(G: DiGraph[_Node, _NodeData, _EdgeData]) -> DiGraph[_Node, _NodeData, _EdgeData]: ...

stubs/networkx/networkx/algorithms/planarity.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from decimal import Decimal
44
from typing import NoReturn
55

66
from networkx.classes.digraph import DiGraph
7-
from networkx.classes.graph import Graph, _EdgePlus, _Node
7+
from networkx.classes.graph import Graph, _EdgeData, _EdgePlus, _Node, _NodeData
88
from networkx.utils.backends import _dispatchable
99

1010
__all__ = ["check_planarity", "is_planar", "PlanarEmbedding"]
@@ -14,9 +14,9 @@ def is_planar(G: Graph[_Node]) -> bool: ...
1414
@_dispatchable
1515
def check_planarity(G: Graph[_Node], counterexample: bool = False): ...
1616
@_dispatchable
17-
def get_counterexample(G: Graph[_Node]) -> Graph[_Node]: ...
17+
def get_counterexample(G: Graph[_Node, _NodeData, _EdgeData]) -> Graph[_Node, _NodeData, _EdgeData]: ...
1818
@_dispatchable
19-
def get_counterexample_recursive(G: Graph[_Node]) -> Graph[_Node]: ...
19+
def get_counterexample_recursive(G: Graph[_Node, _NodeData, _EdgeData]) -> Graph[_Node, _NodeData, _EdgeData]: ...
2020

2121
class Interval:
2222
low: Incomplete

stubs/networkx/networkx/algorithms/traversal/breadth_first_search.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator
22
from typing import Final, Literal
33

44
from networkx.classes.digraph import DiGraph
5-
from networkx.classes.graph import Graph, _Node
5+
from networkx.classes.graph import Graph, _EdgeData, _Node, _NodeData
66
from networkx.utils.backends import _dispatchable
77

88
__all__ = [
@@ -30,12 +30,12 @@ def bfs_edges(
3030
) -> Generator[tuple[_Node, _Node]]: ...
3131
@_dispatchable
3232
def bfs_tree(
33-
G: Graph[_Node],
33+
G: Graph[_Node, _NodeData, _EdgeData],
3434
source: _Node,
3535
reverse: bool | None = False,
3636
depth_limit: int | None = None,
3737
sort_neighbors: Callable[[Iterator[_Node]], Iterable[_Node]] | None = None,
38-
) -> DiGraph[_Node]: ...
38+
) -> DiGraph[_Node, _NodeData, _EdgeData]: ...
3939
@_dispatchable
4040
def bfs_predecessors(
4141
G: Graph[_Node],

stubs/networkx/networkx/algorithms/traversal/depth_first_search.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator
22
from typing import Literal
33

44
from networkx.classes.digraph import DiGraph
5-
from networkx.classes.graph import Graph, _Node
5+
from networkx.classes.graph import Graph, _EdgeData, _Node, _NodeData
66
from networkx.utils.backends import _dispatchable
77

88
__all__ = [
@@ -25,12 +25,12 @@ def dfs_edges(
2525
) -> Generator[tuple[_Node, _Node]]: ...
2626
@_dispatchable
2727
def dfs_tree(
28-
G: Graph[_Node],
28+
G: Graph[_Node, _NodeData, _EdgeData],
2929
source: _Node | None = None,
3030
depth_limit: int | None = None,
3131
*,
3232
sort_neighbors: Callable[[Iterator[_Node]], Iterable[_Node]] | None = None,
33-
) -> DiGraph[_Node]: ...
33+
) -> DiGraph[_Node, _NodeData, _EdgeData]: ...
3434
@_dispatchable
3535
def dfs_predecessors(
3636
G: Graph[_Node],

stubs/networkx/networkx/classes/digraph.pyi

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from typing import Any
44
from typing_extensions import Self
55

66
from networkx.classes.coreviews import AdjacencyView
7-
from networkx.classes.graph import Graph, _Node
7+
from networkx.classes.graph import Graph, _EdgeData, _Node, _NodeData
88
from networkx.classes.reportviews import (
99
DiDegreeView,
1010
InDegreeView,
@@ -21,7 +21,7 @@ __all__ = ["DiGraph"]
2121
# NOTE: Graph subclasses relationships are so complex
2222
# we're only overriding methods that differ in signature from the base classes
2323
# to use inheritance to our advantage and reduce complexity
24-
class DiGraph(Graph[_Node]):
24+
class DiGraph(Graph[_Node, _NodeData, _EdgeData]):
2525
@cached_property
2626
def succ(self) -> AdjacencyView[_Node, _Node, dict[str, Any]]: ...
2727
@cached_property
@@ -34,19 +34,19 @@ class DiGraph(Graph[_Node]):
3434

3535
def predecessors(self, n: _Node) -> Iterator[_Node]: ...
3636
@cached_property
37-
def edges(self) -> OutEdgeView[_Node]: ...
37+
def edges(self) -> OutEdgeView[_Node, _NodeData, _EdgeData]: ...
3838
@cached_property
39-
def out_edges(self) -> OutEdgeView[_Node]: ...
39+
def out_edges(self) -> OutEdgeView[_Node, _NodeData, _EdgeData]: ...
4040
@cached_property
4141
# Including subtypes' possible return types for LSP
42-
def in_edges(self) -> InEdgeView[_Node] | InMultiEdgeView[_Node]: ...
42+
def in_edges(self) -> InEdgeView[_Node, _NodeData, _EdgeData] | InMultiEdgeView[_Node, _NodeData, _EdgeData]: ...
4343
@cached_property
44-
def degree(self) -> DiDegreeView[_Node]: ...
44+
def degree(self) -> DiDegreeView[_Node, _NodeData, _EdgeData]: ...
4545
@cached_property
4646
# Including subtypes' possible return types for LSP
47-
def in_degree(self) -> InDegreeView[_Node] | InMultiDegreeView[_Node]: ...
47+
def in_degree(self) -> InDegreeView[_Node, _NodeData, _EdgeData] | InMultiDegreeView[_Node, _NodeData, _EdgeData]: ...
4848
@cached_property
4949
# Including subtypes' possible return types for LSP
50-
def out_degree(self) -> OutDegreeView[_Node] | OutMultiDegreeView[_Node]: ...
51-
def to_undirected(self, reciprocal: bool = False, as_view: bool = False) -> Graph[_Node]: ... # type: ignore[override] # Has an additional `reciprocal` keyword argument
50+
def out_degree(self) -> OutDegreeView[_Node, _NodeData, _EdgeData] | OutMultiDegreeView[_Node, _NodeData, _EdgeData]: ...
51+
def to_undirected(self, reciprocal: bool = False, as_view: bool = False) -> Graph[_Node, _NodeData, _EdgeData]: ... # type: ignore[override] # Has an additional `reciprocal` keyword argument
5252
def reverse(self, copy: bool = True) -> Self: ...

stubs/networkx/networkx/classes/function.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Literal, TypeVar, overload
55
from networkx import _dispatchable
66
from networkx.algorithms.planarity import PlanarEmbedding
77
from networkx.classes.digraph import DiGraph
8-
from networkx.classes.graph import Graph, _NBunch, _Node
8+
from networkx.classes.graph import Graph, _EdgeData, _NBunch, _Node, _NodeData
99
from networkx.classes.multigraph import MultiGraph
1010

1111
__all__ = [
@@ -73,7 +73,7 @@ def add_star(G_to_add_to, nodes_for_star, **attr) -> None: ...
7373
def add_path(G_to_add_to, nodes_for_path, **attr) -> None: ...
7474
def add_cycle(G_to_add_to, nodes_for_cycle, **attr) -> None: ...
7575
def subgraph(G: Graph[_Node], nbunch): ...
76-
def induced_subgraph(G: Graph[_Node], nbunch: _NBunch[_Node]) -> Graph[_Node]: ...
76+
def induced_subgraph(G: Graph[_Node, _NodeData, _EdgeData], nbunch: _NBunch[_Node]) -> Graph[_Node, _NodeData, _EdgeData]: ...
7777
def edge_subgraph(G: Graph[_Node], edges): ...
7878
def restricted_view(G: Graph[_Node], nodes, edges): ...
7979
def to_directed(graph): ...
@@ -146,8 +146,8 @@ def selfloop_edges(
146146
) -> Generator[tuple[_Node, _Node]]: ...
147147
@overload
148148
def selfloop_edges(
149-
G: Graph[_Node], data: Literal[True], keys: Literal[False] = False, default=None
150-
) -> Generator[tuple[_Node, _Node, dict[str, Incomplete]]]: ...
149+
G: Graph[_Node, _NodeData, _EdgeData], data: Literal[True], keys: Literal[False] = False, default=None
150+
) -> Generator[tuple[_Node, _Node, _EdgeData]]: ...
151151
@overload
152152
def selfloop_edges(
153153
G: Graph[_Node], data: str, keys: Literal[False] = False, default: _U | None = None
@@ -162,8 +162,8 @@ def selfloop_edges(
162162
) -> Generator[tuple[_Node, _Node, int]]: ...
163163
@overload
164164
def selfloop_edges(
165-
G: Graph[_Node], data: Literal[True], keys: Literal[True], default=None
166-
) -> Generator[tuple[_Node, _Node, int, dict[str, Incomplete]]]: ...
165+
G: Graph[_Node, _NodeData, _EdgeData], data: Literal[True], keys: Literal[True], default=None
166+
) -> Generator[tuple[_Node, _Node, int, _EdgeData]]: ...
167167
@overload
168168
def selfloop_edges(
169169
G: Graph[_Node], data: str, keys: Literal[True], default: _U | None = None

stubs/networkx/networkx/classes/graph.pyi

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,39 @@
1-
from collections.abc import Callable, Collection, Hashable, Iterable, Iterator, MutableMapping
1+
from collections.abc import Callable, Collection, Hashable, Iterable, Iterator, Mapping, MutableMapping
22
from decimal import Decimal
33
from functools import cached_property
4-
from typing import Any, ClassVar, TypeAlias, TypeVar, overload
4+
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, overload
55
from typing_extensions import Self
66

77
import numpy
88
from networkx.classes.coreviews import AdjacencyView, AtlasView
99
from networkx.classes.digraph import DiGraph
1010
from networkx.classes.reportviews import DegreeView, DiDegreeView, EdgeView, NodeView, OutEdgeView
1111

12+
_DataBound: TypeAlias = Mapping[str, Any]
13+
1214
_Node = TypeVar("_Node", bound=Hashable)
13-
_NodeWithData: TypeAlias = tuple[_Node, dict[str, Any]]
14-
_NodePlus: TypeAlias = _Node | _NodeWithData[_Node]
15+
_NodeData = TypeVar("_NodeData", bound=_DataBound, default=dict[str, Any])
16+
_EdgeData = TypeVar("_EdgeData", bound=_DataBound, default=dict[str, Any])
17+
18+
_NodeWithData: TypeAlias = tuple[_Node, _NodeData]
19+
_NodePlus: TypeAlias = _Node | _NodeWithData[_Node, _NodeData]
1520
_Edge: TypeAlias = tuple[_Node, _Node]
16-
_EdgeWithData: TypeAlias = tuple[_Node, _Node, dict[str, Any]]
17-
_EdgePlus: TypeAlias = _Edge[_Node] | _EdgeWithData[_Node]
21+
_EdgeWithData: TypeAlias = tuple[_Node, _Node, _EdgeData]
22+
_EdgePlus: TypeAlias = _Edge[_Node] | _EdgeWithData[_Node, _EdgeData]
1823
_MapFactory: TypeAlias = Callable[[], MutableMapping[str, Any]]
1924
_NBunch: TypeAlias = _Node | Iterable[_Node] | None
2025
_Data: TypeAlias = (
21-
Graph[_Node]
22-
| dict[_Node, dict[_Node, dict[str, Any]]]
26+
Graph[_Node, _NodeData, _EdgeData]
27+
| dict[_Node, dict[_Node, _NodeData]]
2328
| dict[_Node, Iterable[_Node]]
24-
| Iterable[_EdgePlus[_Node]]
29+
| Iterable[_EdgePlus[_Node, _EdgeData]]
2530
| numpy.ndarray[Any, Any]
2631
# | scipy.sparse.base.spmatrix
2732
)
2833

2934
__all__ = ["Graph"]
3035

31-
class Graph(Collection[_Node]):
36+
class Graph(Collection[_Node], Generic[_Node, _NodeData, _EdgeData]):
3237
__networkx_backend__: ClassVar[str]
3338
node_dict_factory: ClassVar[_MapFactory]
3439
node_attr_dict_factory: ClassVar[_MapFactory]
@@ -40,14 +45,16 @@ class Graph(Collection[_Node]):
4045
graph: dict[str, Any]
4146
__networkx_cache__: dict[str, Any]
4247

43-
def to_directed_class(self) -> type[DiGraph[_Node]]: ...
44-
def to_undirected_class(self) -> type[Graph[_Node]]: ...
48+
def to_directed_class(self) -> type[DiGraph[_Node, _NodeData, _EdgeData]]: ...
49+
def to_undirected_class(self) -> type[Graph[_Node, _NodeData, _EdgeData]]: ...
4550
# @_dispatchable adds `backend` argument, but this decorated is unsupported constructor type here
4651
# and __init__() ignores this argument
4752
def __new__(cls, *args, backend=None, **kwargs) -> Self: ...
48-
def __init__(self, incoming_graph_data: _Data[_Node] | None = None, **attr: Any) -> None: ... # attr: key=value pairs
53+
def __init__(
54+
self, incoming_graph_data: _Data[_Node, _NodeData, _EdgeData] | None = None, **attr: Any
55+
) -> None: ... # attr: key=value pairs
4956
@cached_property
50-
def adj(self) -> AdjacencyView[_Node, _Node, dict[str, Any]]: ...
57+
def adj(self) -> AdjacencyView[_Node, _Node, _EdgeData]: ...
5158
# This object is a read-only dict-like structure
5259
@property
5360
def name(self) -> str: ...
@@ -58,49 +65,53 @@ class Graph(Collection[_Node]):
5865
def __len__(self) -> int: ...
5966
def __getitem__(self, n: _Node) -> AtlasView[_Node, str, Any]: ...
6067
def add_node(self, node_for_adding: _Node, **attr: Any) -> None: ... # attr: Set or change node attributes using key=value
61-
def add_nodes_from(self, nodes_for_adding: Iterable[_NodePlus[_Node]], **attr: Any) -> None: ... # attr: key=value pairs
68+
def add_nodes_from(
69+
self, nodes_for_adding: Iterable[_NodePlus[_Node, _NodeData]], **attr: Any
70+
) -> None: ... # attr: key=value pairs
6271
def remove_node(self, n: _Node) -> None: ...
6372
def remove_nodes_from(self, nodes: Iterable[_Node]) -> None: ...
6473
@cached_property
65-
def nodes(self) -> NodeView[_Node]: ...
74+
def nodes(self) -> NodeView[_Node, _NodeData, _EdgeData]: ...
6675
def number_of_nodes(self) -> int: ...
6776
def order(self) -> int: ...
6877
def has_node(self, n: _Node) -> bool: ...
6978
# Including subtypes' possible return types for LSP
7079
def add_edge(self, u_of_edge: _Node, v_of_edge: _Node, **attr: Any) -> Hashable | None: ...
7180
# attr: Edge data (or labels or objects) can be assigned using keyword arguments
72-
def add_edges_from(self, ebunch_to_add: Iterable[_EdgePlus[_Node]], **attr: Any) -> None: ...
81+
def add_edges_from(self, ebunch_to_add: Iterable[_EdgePlus[_Node, _EdgeData]], **attr: Any) -> None: ...
7382
# attr: Edge data (or labels or objects) can be assigned using keyword arguments
7483
def add_weighted_edges_from(
7584
self, ebunch_to_add: Iterable[tuple[_Node, _Node, float | Decimal | None]], weight: str = "weight", **attr: Any
7685
) -> None: ...
7786
# attr: Edge attributes to add/update for all edges.
7887
def remove_edge(self, u: _Node, v: _Node) -> None: ...
79-
def remove_edges_from(self, ebunch: Iterable[_EdgePlus[_Node]]) -> None: ...
88+
def remove_edges_from(self, ebunch: Iterable[_EdgePlus[_Node, _EdgeData]]) -> None: ...
8089
@overload
81-
def update(self, edges: Graph[_Node], nodes: None = None) -> None: ...
90+
def update(self, edges: Graph[_Node, _NodeData, _EdgeData], nodes: None = None) -> None: ...
8291
@overload
8392
def update(
84-
self, edges: Graph[_Node] | Iterable[_EdgePlus[_Node]] | None = None, nodes: Iterable[_Node] | None = None
93+
self,
94+
edges: Graph[_Node, _NodeData, _EdgeData] | Iterable[_EdgePlus[_Node, _EdgeData]] | None = None,
95+
nodes: Iterable[_Node] | None = None,
8596
) -> None: ...
8697
def has_edge(self, u: _Node, v: _Node) -> bool: ...
8798
def neighbors(self, n: _Node) -> Iterator[_Node]: ...
8899
@cached_property
89100
# Including subtypes' possible return types for LSP
90-
def edges(self) -> EdgeView[_Node] | OutEdgeView[_Node]: ...
91-
def get_edge_data(self, u: _Node, v: _Node, default: Any = None) -> dict[str, Any]: ...
101+
def edges(self) -> EdgeView[_Node, _NodeData, _EdgeData] | OutEdgeView[_Node, _NodeData, _EdgeData]: ...
102+
def get_edge_data(self, u: _Node, v: _Node, default: Any = None) -> _EdgeData: ...
92103
# default: any Python object
93-
def adjacency(self) -> Iterator[tuple[_Node, dict[_Node, dict[str, Any]]]]: ...
104+
def adjacency(self) -> Iterator[tuple[_Node, dict[_Node, _EdgeData]]]: ...
94105
@cached_property
95106
# Including subtypes' possible return types for LSP
96-
def degree(self) -> DegreeView[_Node] | DiDegreeView[_Node]: ...
107+
def degree(self) -> DegreeView[_Node, _NodeData, _EdgeData] | DiDegreeView[_Node, _NodeData, _EdgeData]: ...
97108
def clear(self) -> None: ...
98109
def clear_edges(self) -> None: ...
99110
def is_multigraph(self) -> bool: ...
100111
def is_directed(self) -> bool: ...
101112
def copy(self, as_view: bool = False) -> Self: ...
102-
def to_directed(self, as_view: bool = False) -> DiGraph[_Node]: ...
103-
def to_undirected(self, as_view: bool = False) -> Graph[_Node]: ...
113+
def to_directed(self, as_view: bool = False) -> DiGraph[_Node, _NodeData, _EdgeData]: ...
114+
def to_undirected(self, as_view: bool = False) -> Graph[_Node, _NodeData, _EdgeData]: ...
104115
def subgraph(self, nodes: _NBunch[_Node]) -> Self: ...
105116
def edge_subgraph(self, edges: Iterable[_Edge[_Node]]) -> Self: ...
106117
@overload

0 commit comments

Comments
 (0)