Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ def get_config_objects(self, typ: Type, prev_config: FwConfigNormalized) -> tupl
return prev_config.service_objects, self.normalized_config.service_objects
return prev_config.users, self.normalized_config.users

def get_id(self, typ: Type, uid: str, before_update: bool = False) -> int | None:
def get_id(self, typ: Type, uid: str, before_update: bool = False) -> int:
if typ == Type.NETWORK_OBJECT:
return self.uid2id_mapper.get_network_object_id(uid, before_update)
if typ == Type.SERVICE_OBJECT:
return self.uid2id_mapper.get_service_object_id(uid, before_update)
return self.uid2id_mapper.get_user_id(uid, before_update)

def get_local_id(self, typ: Type, uid: str, before_update: bool = False) -> int | None:
def get_local_id(self, typ: Type, uid: str, before_update: bool = False) -> int:
if typ == Type.NETWORK_OBJECT:
return self.uid2id_mapper.get_network_object_id(uid, before_update, local_only=True)
if typ == Type.SERVICE_OBJECT:
Expand Down Expand Up @@ -496,14 +496,14 @@ def get_members(self, typ: Type, refs: str | None) -> list[str]:
)
return refs.split(fwo_const.LIST_DELIMITER) if refs else []

def get_flats(self, typ: Type, uid: str) -> list[str]:
def get_flats(self, typ: Type, uid: str) -> set[str]:
if typ == Type.NETWORK_OBJECT:
return self.group_flats_mapper.get_network_object_flats([uid])
if typ == Type.SERVICE_OBJECT:
return self.group_flats_mapper.get_service_object_flats([uid])
return self.group_flats_mapper.get_user_flats([uid])

def get_prev_flats(self, typ: Type, uid: str) -> list[str]:
def get_prev_flats(self, typ: Type, uid: str) -> set[str]:
if typ == Type.NETWORK_OBJECT:
return self.prev_group_flats_mapper.get_network_object_flats([uid])
if typ == Type.SERVICE_OBJECT:
Expand Down Expand Up @@ -648,16 +648,13 @@ def add_group_memberships(self, prev_config: FwConfigNormalized, obj_type: Type)
continue
member_uids = self.get_members(obj_type, self.get_refs(obj_type, current_config_objects[uid]))
prev_member_uids = [] # all members need to be added if group added or changed
prev_flat_member_uids = []
prev_flat_member_uids: set[str] = set()
if uid in prev_config_objects and current_config_objects[uid] == prev_config_objects[uid]:
# group not changed -> check for changes in members
prev_member_uids = self.get_members(obj_type, self.get_refs(obj_type, prev_config_objects[uid]))
prev_flat_member_uids = self.get_prev_flats(obj_type, uid)

group_id = self.get_id(obj_type, uid)
if group_id is None:
FWOLogger.error(f"failed to add group memberships: no id found for group uid '{uid}'")
continue

self.collect_group_members(
group_id,
Expand Down Expand Up @@ -691,13 +688,13 @@ def collect_flat_group_members(
group_id: int,
current_config_objects: dict[str, Any],
new_group_member_flats: list[dict[str, Any]],
flat_member_uids: list[str],
flat_member_uids: set[str],
obj_type: Type,
prefix: str,
prev_flat_member_uids: list[str],
prev_flat_member_uids: set[str],
prev_config_objects: dict[str, Any],
):
for flat_member_uid in flat_member_uids:
for flat_member_uid in sorted(flat_member_uids): # deterministic order for better debugging and testing
if (
flat_member_uid in prev_flat_member_uids
and prev_config_objects[flat_member_uid] == current_config_objects[flat_member_uid]
Expand All @@ -724,10 +721,14 @@ def collect_group_members(
prev_member_uids: list[str],
prev_config_objects: dict[str, Any],
):
added_member_ids: set[int] = set()
for member_uid in member_uids:
if member_uid in prev_member_uids and prev_config_objects[member_uid] == current_config_objects[member_uid]:
continue # member was not added or changed
member_id = self.get_id(obj_type, member_uid)
if member_id in added_member_ids:
continue # avoid duplicate entries for same member and group (e.g. if same member is contained twice)
added_member_ids.add(member_id)
new_group_members.append(
{
f"{prefix}_id": group_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,15 @@ def get_rule_refs(
tos.append((dst_ref, user_ref))
svcs = rule.rule_svc_refs.split(fwo_const.LIST_DELIMITER)
if is_prev:
nwobj_resolveds = self.prev_group_flats_mapper.get_network_object_flats([ref[0] for ref in froms + tos])
svc_resolveds = self.prev_group_flats_mapper.get_service_object_flats(svcs)
user_resolveds = self.prev_group_flats_mapper.get_user_flats(users)
nwobj_resolveds = sorted(
self.prev_group_flats_mapper.get_network_object_flats([ref[0] for ref in froms + tos])
)
svc_resolveds = sorted(self.prev_group_flats_mapper.get_service_object_flats(svcs))
user_resolveds = sorted(self.prev_group_flats_mapper.get_user_flats(users))
else:
nwobj_resolveds = self.group_flats_mapper.get_network_object_flats([ref[0] for ref in froms + tos])
svc_resolveds = self.group_flats_mapper.get_service_object_flats(svcs)
user_resolveds = self.group_flats_mapper.get_user_flats(users)
nwobj_resolveds = sorted(self.group_flats_mapper.get_network_object_flats([ref[0] for ref in froms + tos]))
svc_resolveds = sorted(self.group_flats_mapper.get_service_object_flats(svcs))
user_resolveds = sorted(self.group_flats_mapper.get_user_flats(users))
from_zones = rule.rule_src_zone.split(fwo_const.LIST_DELIMITER) if rule.rule_src_zone else []
to_zones = rule.rule_dst_zone.split(fwo_const.LIST_DELIMITER) if rule.rule_dst_zone else []
times = rule.rule_time.split(fwo_const.LIST_DELIMITER) if rule.rule_time else []
Expand Down
24 changes: 12 additions & 12 deletions roles/importer/files/importer/services/group_flats_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init_config(
self.service_object_flats = {}
self.user_flats = {}

def get_network_object_flats(self, uids: list[str]) -> list[str]:
def get_network_object_flats(self, uids: list[str]) -> set[str]:
"""
Flatten the network object UIDs to all members, including group objects, and the top-level group object itself.
Does not check if the given objects are group objects or not.
Expand All @@ -59,18 +59,18 @@ def get_network_object_flats(self, uids: list[str]) -> list[str]:
uids (list[str]): The list of network object UIDs to flatten.

Returns:
list[str]: The flattened network object UIDs.
set[str]: The flattened network object UIDs.

"""
if self.normalized_config is None:
self.log_error(f"{CONFIG_NOT_SET_MESSAGE} - networks")
return []
return set()
all_members: set[str] = set()
for uid in uids:
members = self.flat_nwobj_members_recursive(uid)
if members is not None:
all_members.update(members)
return list(all_members)
return all_members

def flat_nwobj_members_recursive(self, group_uid: str, recursion_level: int = 0) -> set[str] | None:
if recursion_level > MAX_RECURSION_LEVEL:
Expand Down Expand Up @@ -102,7 +102,7 @@ def get_nwobj(self, group_uid: str) -> NetworkObject | None:
nwobj = self.global_normalized_config.network_objects.get(group_uid, None)
return nwobj

def get_service_object_flats(self, uids: list[str]) -> list[str]:
def get_service_object_flats(self, uids: list[str]) -> set[str]:
"""
Flatten the service object UIDs to all members, including group objects, and the top-level group object itself.
Does not check if the given objects are group objects or not.
Expand All @@ -111,18 +111,18 @@ def get_service_object_flats(self, uids: list[str]) -> list[str]:
uids (list[str]): The list of service object UIDs to flatten.

Returns:
list[str]: The flattened service object UIDs.
set[str]: The flattened service object UIDs.

"""
if self.normalized_config is None:
self.log_error(f"{CONFIG_NOT_SET_MESSAGE} - services")
return []
return set()
all_members: set[str] = set()
for uid in uids:
members = self.flat_svcobj_members_recursive(uid)
if members is not None:
all_members.update(members)
return list(all_members)
return all_members

def flat_svcobj_members_recursive(self, group_uid: str, recursion_level: int = 0) -> set[str] | None:
if recursion_level > MAX_RECURSION_LEVEL:
Expand Down Expand Up @@ -154,7 +154,7 @@ def get_svcobj(self, group_uid: str) -> ServiceObject | None:
svcobj = self.global_normalized_config.service_objects.get(group_uid, None)
return svcobj

def get_user_flats(self, uids: list[str]) -> list[str]:
def get_user_flats(self, uids: list[str]) -> set[str]:
"""
Flatten the user UIDs to all members, including groups, and the top-level group itself.
Does not check if the given users are groups or not.
Expand All @@ -163,18 +163,18 @@ def get_user_flats(self, uids: list[str]) -> list[str]:
uids (list[str]): The list of user UIDs to flatten.

Returns:
list[str]: The flattened user UIDs.
set[str]: The flattened user UIDs.

"""
if self.normalized_config is None:
self.log_error(f"{CONFIG_NOT_SET_MESSAGE} - users")
return []
return set()
all_members: set[str] = set()
for uid in uids:
members = self.flat_user_members_recursive(uid)
if members is not None:
all_members.update(members)
return list(all_members)
return all_members

def flat_user_members_recursive(self, group_uid: str, recursion_level: int = 0) -> set[str] | None:
if recursion_level > MAX_RECURSION_LEVEL:
Expand Down
Loading
Loading