Skip to content

Commit c39062b

Browse files
authored
Fix bugs and apply improvements across client library (#197)
- Fix PyatsException to inherit from VirlException instead of Exception - Fix ControllerNotFound to use __init__ instead of class attribute for message - Fix ControllerNotFound instantiation (raise ControllerNotFound()) - Fix sync_states staleness handling for nodes, interfaces, and links - Fix _import_link parameter order (iface_a_id before iface_b_id) - Fix incorrect log message in Interface._set_interface_property - Refactor Annotation validation using _VALID_KEYS frozensets - Refactor Version comparisons using _as_tuple helper - Refactor get_lab_list to pass params directly to session.get - Optimize Node.next_available_interface with direct slicing
1 parent 4d0c43f commit c39062b

7 files changed

Lines changed: 116 additions & 69 deletions

File tree

virl2_client/exceptions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class MethodNotActive(VirlException):
124124
pass
125125

126126

127-
class PyatsException(Exception):
127+
class PyatsException(VirlException):
128128
"""Base exception for pyATS integration errors."""
129129

130130
pass
@@ -145,7 +145,8 @@ class PyatsDeviceNotFound(PyatsException):
145145
class ControllerNotFound(VirlException):
146146
"""Raised when no CML controller node is found in the topology."""
147147

148-
message = "Controller not found"
148+
def __init__(self) -> None:
149+
super().__init__("Controller not found")
149150

150151

151152
class APIError(VirlException, httpx.HTTPStatusError):

virl2_client/models/annotation.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from __future__ import annotations
2222

2323
import logging
24-
from typing import TYPE_CHECKING, Any, Literal
24+
from typing import TYPE_CHECKING, Any, Literal, Union
2525

2626
from ..exceptions import InvalidProperty
2727
from ..utils import check_stale, get_url_from_template, locked
@@ -33,13 +33,13 @@
3333
from .lab import Lab
3434

3535
AnnotationTypeString = Literal["text", "line", "ellipse", "rectangle"]
36-
AnnotationType = (
37-
"Annotation"
38-
| "AnnotationRectangle"
39-
| "AnnotationEllipse"
40-
| "AnnotationLine"
41-
| "AnnotationText"
42-
)
36+
AnnotationType = Union[
37+
"Annotation",
38+
"AnnotationRectangle",
39+
"AnnotationEllipse",
40+
"AnnotationLine",
41+
"AnnotationText",
42+
]
4343

4444
_LOGGER = logging.getLogger(__name__)
4545

@@ -123,6 +123,20 @@ class Annotation:
123123
"annotation": "labs/{lab_id}/annotations/{annotation_id}",
124124
}
125125

126+
_VALID_KEYS: frozenset[str] = frozenset(
127+
{
128+
"id",
129+
"type",
130+
"border_color",
131+
"border_style",
132+
"color",
133+
"thickness",
134+
"x1",
135+
"y1",
136+
"z_index",
137+
}
138+
)
139+
126140
def __init__(
127141
self,
128142
lab: Lab,
@@ -387,12 +401,14 @@ def is_valid_property(
387401
:param _property: The property name to validate.
388402
:returns: True if the property is valid for the given type, False otherwise.
389403
"""
390-
try:
391-
assert annotation_type in _ANNOTATION_TYPES
392-
assert _property in ANNOTATION_PROPERTY_MAP
393-
except AssertionError:
404+
if (
405+
annotation_type not in _ANNOTATION_TYPES
406+
or _property not in ANNOTATION_PROPERTY_MAP
407+
):
394408
return False
395-
return ANNOTATION_MAP[annotation_type] & ANNOTATION_PROPERTY_MAP[_property] > 0
409+
return (
410+
ANNOTATION_MAP[annotation_type] & ANNOTATION_PROPERTY_MAP[_property]
411+
) > 0
396412

397413
@locked
398414
def as_dict(self) -> dict[str, Any]:
@@ -449,9 +465,8 @@ def _update(self, annotation_data: dict[str, Any], push_to_server: bool) -> None
449465
raise ValueError("Can't change annotation type.")
450466

451467
# make sure all properties we want to update are valid
452-
existing_keys = dir(self)
453468
for key in annotation_data:
454-
if key not in existing_keys:
469+
if key not in self._VALID_KEYS:
455470
raise InvalidProperty(f"Invalid annotation property: {key}")
456471

457472
if push_to_server:
@@ -493,6 +508,10 @@ class AnnotationRectangle(Annotation):
493508
Annotation class representing rectangle annotation.
494509
"""
495510

511+
_VALID_KEYS = Annotation._VALID_KEYS | frozenset(
512+
{"border_radius", "x2", "y2", "rotation"}
513+
)
514+
496515
def __init__(
497516
self,
498517
lab: Lab,
@@ -600,6 +619,8 @@ class AnnotationEllipse(Annotation):
600619
Annotation class representing ellipse annotation.
601620
"""
602621

622+
_VALID_KEYS = Annotation._VALID_KEYS | frozenset({"x2", "y2", "rotation"})
623+
603624
def __init__(
604625
self,
605626
lab: Lab,
@@ -687,6 +708,10 @@ class AnnotationLine(Annotation):
687708
Annotation class representing line annotation.
688709
"""
689710

711+
_VALID_KEYS = Annotation._VALID_KEYS | frozenset(
712+
{"x2", "y2", "line_start", "line_end"}
713+
)
714+
690715
def __init__(
691716
self,
692717
lab: Lab,
@@ -794,6 +819,20 @@ class AnnotationText(Annotation):
794819
Annotation class representing text annotation.
795820
"""
796821

822+
_VALID_KEYS = Annotation._VALID_KEYS | frozenset(
823+
{
824+
"x2",
825+
"y2",
826+
"rotation",
827+
"text_bold",
828+
"text_content",
829+
"text_font",
830+
"text_italic",
831+
"text_size",
832+
"text_unit",
833+
}
834+
)
835+
797836
def __init__(
798837
self,
799838
lab: Lab,

virl2_client/models/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def _set_interface_property(self, key: str, val: Any) -> None:
437437
:param key: The key of the property to set.
438438
:param val: The value to set.
439439
"""
440-
_LOGGER.debug(f"Setting node property {self} {key}: {val}")
440+
_LOGGER.debug(f"Setting interface property {self} {key}: {val}")
441441
self._set_interface_properties({key: val})
442442

443443
@check_stale

virl2_client/models/lab.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,8 +1357,18 @@ def sync_states(self) -> None:
13571357
"""
13581358
url = self._url_for("lab_element_state")
13591359
states: dict[str, dict[str, str]] = self._session.get(url).json()
1360+
1361+
nodes = self._nodes.copy()
13601362
for node_id, node_state in states["nodes"].items():
1361-
self._nodes[node_id]._state = node_state
1363+
try:
1364+
node = nodes.pop(node_id)
1365+
except KeyError:
1366+
pass
1367+
else:
1368+
node._state = node_state
1369+
for stale_node in nodes.values():
1370+
stale_node._stale = True
1371+
13621372
ifaces = self._interfaces.copy()
13631373
for interface_id, interface_state in states["interfaces"].items():
13641374
try:
@@ -1367,10 +1377,19 @@ def sync_states(self) -> None:
13671377
pass
13681378
else:
13691379
iface._state = interface_state
1370-
for stale_iface in ifaces:
1371-
ifaces[stale_iface]._stale = True
1380+
for stale_iface in ifaces.values():
1381+
stale_iface._stale = True
1382+
1383+
links = self._links.copy()
13721384
for link_id, link_state in states["links"].items():
1373-
self._links[link_id]._state = link_state
1385+
try:
1386+
link = links.pop(link_id)
1387+
except KeyError:
1388+
pass
1389+
else:
1390+
link._state = link_state
1391+
for stale_link in links.values():
1392+
stale_link._stale = True
13741393

13751394
self._last_sync_state_time = time.time()
13761395

@@ -1714,7 +1733,7 @@ def _handle_import_links(self, topology: dict) -> None:
17141733
iface_b_id = link["interface_b"]
17151734
label = link.get("label")
17161735

1717-
self._import_link(link_id, iface_b_id, iface_a_id, label)
1736+
self._import_link(link_id, iface_a_id, iface_b_id, label)
17181737

17191738
@locked
17201739
def _handle_import_annotations(self, topology: dict) -> None:
@@ -1749,16 +1768,16 @@ def _handle_import_annotations(self, topology: dict) -> None:
17491768
def _import_link(
17501769
self,
17511770
link_id: str,
1752-
iface_b_id: str,
17531771
iface_a_id: str,
1772+
iface_b_id: str,
17541773
label: str | None = None,
17551774
) -> Link:
17561775
"""
17571776
Import a link with the given parameters.
17581777
17591778
:param link_id: The ID of the link.
1760-
:param iface_b_id: The ID of the second interface.
17611779
:param iface_a_id: The ID of the first interface.
1780+
:param iface_b_id: The ID of the second interface.
17621781
:param label: The label of the link.
17631782
:returns: The imported Link object.
17641783
"""
@@ -2069,7 +2088,7 @@ def _add_links(self, topology: dict, new_links: Iterable[str]) -> None:
20692088
iface_a_id = link_data["interface_a"]
20702089
iface_b_id = link_data["interface_b"]
20712090
label = link_data.get("label")
2072-
link = self._import_link(link_id, iface_b_id, iface_a_id, label)
2091+
link = self._import_link(link_id, iface_a_id, iface_b_id, label)
20732092
_LOGGER.info(f"Added link {link}")
20742093

20752094
@locked

virl2_client/models/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def next_available_interface(self, index: int = 0) -> Interface | None:
300300
:returns: An available physical interface or None if all existing
301301
ones are connected.
302302
"""
303-
for _, iface in enumerate(self.interfaces(), index):
303+
for iface in self.interfaces()[index:]:
304304
if not iface.connected and iface.physical:
305305
return iface
306306
return None

virl2_client/models/system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def controller(self) -> ComputeHost:
101101
for compute_host in self._compute_hosts.values():
102102
if compute_host.is_connector:
103103
return compute_host
104-
raise ControllerNotFound
104+
raise ControllerNotFound()
105105

106106
@property
107107
def system_notices(self) -> dict[str, SystemNotice]:

virl2_client/virl2_client.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,43 +84,33 @@ def parse_version_str(version_str: str) -> tuple[str, str, str, str]:
8484
def __repr__(self) -> str:
8585
return self.version_str
8686

87-
def __eq__(self, other):
88-
return (
89-
isinstance(other, self.__class__)
90-
and self.major == other.major
91-
and self.minor == other.minor
92-
and self.patch == other.patch
93-
)
94-
95-
def __gt__(self, other):
96-
if isinstance(other, self.__class__):
97-
if self.major > other.major:
98-
return True
99-
elif self.major == other.major:
100-
if self.minor > other.minor:
101-
return True
102-
elif self.minor == other.minor:
103-
if self.patch > other.patch:
104-
return True
105-
return False
106-
107-
def __ge__(self, other):
108-
return self == other or self > other
109-
110-
def __lt__(self, other):
111-
if isinstance(other, self.__class__):
112-
if self.major < other.major:
113-
return True
114-
elif self.major == other.major:
115-
if self.minor < other.minor:
116-
return True
117-
elif self.minor == other.minor:
118-
if self.patch < other.patch:
119-
return True
120-
return False
121-
122-
def __le__(self, other):
123-
return self == other or self < other
87+
def _as_tuple(self) -> tuple[int, int, int]:
88+
return (self.major, self.minor, self.patch)
89+
90+
def __eq__(self, other: object) -> bool:
91+
if not isinstance(other, Version):
92+
return False
93+
return self._as_tuple() == other._as_tuple()
94+
95+
def __gt__(self, other: object) -> bool:
96+
if not isinstance(other, Version):
97+
return False
98+
return self._as_tuple() > other._as_tuple()
99+
100+
def __ge__(self, other: object) -> bool:
101+
if not isinstance(other, Version):
102+
return False
103+
return self._as_tuple() >= other._as_tuple()
104+
105+
def __lt__(self, other: object) -> bool:
106+
if not isinstance(other, Version):
107+
return False
108+
return self._as_tuple() < other._as_tuple()
109+
110+
def __le__(self, other: object) -> bool:
111+
if not isinstance(other, Version):
112+
return False
113+
return self._as_tuple() <= other._as_tuple()
124114

125115
def major_differs(self, other: Version) -> bool:
126116
return self.major != other.major
@@ -995,10 +985,8 @@ def get_lab_list(self, show_all: bool = False) -> list[str]:
995985
owned by the authenticated user (False).
996986
:returns: A list of lab IDs.
997987
"""
998-
url: dict[str, str | dict] = {"url": self._url_for("labs")}
999-
if show_all:
1000-
url["params"] = {"show_all": True}
1001-
return self._session.get(**url).json()
988+
params = {"show_all": True} if show_all else None
989+
return self._session.get(self._url_for("labs"), params=params).json()
1002990

1003991

1004992
def _prepare_url(url: str, allow_http: bool) -> tuple[str, str]:

0 commit comments

Comments
 (0)