diff --git a/escnn/group/groups/cyclicgroup.py b/escnn/group/groups/cyclicgroup.py index 57afff79..10e54ac2 100644 --- a/escnn/group/groups/cyclicgroup.py +++ b/escnn/group/groups/cyclicgroup.py @@ -283,8 +283,8 @@ def _subgroup(self, id: int) -> Tuple[ # parent_mapping = lambda e, ratio=ratio: self.element(e._element * ratio) # child_mapping = lambda e, ratio=ratio, sg=sg: None if e._element % ratio != 0 else sg.element(int(e._element // ratio)) - parent_mapping = _build_parent_map(self, order) - child_mapping = _build_child_map(self, sg) + parent_mapping = ParentMapping(self, order) + child_mapping = ChildMapping(self, sg) return sg, parent_mapping, child_mapping @@ -411,12 +411,12 @@ def irrep(self, k: int) -> IrreducibleRepresentation: name = f"irrep_{k}" n = self.order() + + irrep = IrrepBuilder(k) + character = CharacterBuilder(k) if k == 0: # Trivial representation - - irrep = _build_irrep_cn(0) - character = _build_char_cn(0) supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R', supported_nonlinearities=supported_nonlinearities, @@ -425,8 +425,6 @@ def irrep(self, k: int) -> IrreducibleRepresentation: frequency=k) elif n % 2 == 0 and k == int(n/2): # 1 dimensional Irreducible representation (only for even order groups) - irrep = _build_irrep_cn(k) - character = _build_char_cn(k) supported_nonlinearities = ['norm', 'gated', 'concatenated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R', supported_nonlinearities=supported_nonlinearities, @@ -434,10 +432,6 @@ def irrep(self, k: int) -> IrreducibleRepresentation: frequency=k) else: # 2 dimensional Irreducible Representations - - irrep = _build_irrep_cn(k) - character = _build_char_cn(k) - supported_nonlinearities = ['norm', 'gated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2, 'C', supported_nonlinearities=supported_nonlinearities, @@ -555,6 +549,26 @@ def irrep(element: GroupElement, k:int =k) -> np.ndarray: return irrep +class IrrepBuilder: + def __init__(self, k: int): + self.k = k + + def __call__(self, element: GroupElement) -> np.ndarray: + k = self.k + + if k == 0: + return np.eye(1) + + n = element.group.order() + + if n % 2 == 0 and k == int(n / 2): + # 1 dimensional Irreducible representation (only for even order groups) + return np.array([[np.cos(k * element.to('radians'))]]) + else: + # 2 dimensional Irreducible Representations + return utils.psi(element.to('radians'), k=k) + + def _build_char_cn(k: int): def character(element: GroupElement, k=k) -> float: @@ -573,12 +587,40 @@ def character(element: GroupElement, k=k) -> float: return character +class CharacterBuilder: + def __init__(self, k: int): + self.k = k + + def __call__(self, element: GroupElement) -> float: + k = self.k + + if k == 0: + return 1. + + n = element.group.order() + + if n % 2 == 0 and k == int(n / 2): + # 1 dimensional Irreducible representation (only for even order groups) + return np.cos(k * element.to('radians')) + else: + # 2 dimensional Irreducible Representations + return 2*np.cos(k * element.to('radians')) + + def _build_parent_map(G: CyclicGroup, order: int): def parent_mapping(e: GroupElement, G: Group = G, order=order) -> GroupElement: return G.element(e.to('int') * G.order() // order) return parent_mapping +class ParentMapping: + def __init__(self, G: CyclicGroup, order: int): + self.G = G + self.order = order + + def __call__(self, e: GroupElement): + return self.G.element(e.to('int') * self.G.order() // self.order) + def _build_child_map(G: CyclicGroup, sg: CyclicGroup): assert G.order() % sg.order() == 0 @@ -595,3 +637,18 @@ def child_mapping(e: GroupElement, G=G, sg: Group = sg) -> GroupElement: return child_mapping +class ChildMapping: + def __init__(self, G: CyclicGroup, sg: CyclicGroup): + assert G.order() % sg.order() == 0 + self.G = G + self.sg = sg + + def __call__(self, e: GroupElement): + assert e.group == self.G + i = e.to('int') + ratio = self.G.order() // self.sg.order() + if i % ratio != 0: + return None + else: + return self.sg.element(i // ratio) + diff --git a/escnn/group/groups/dihedralgroup.py b/escnn/group/groups/dihedralgroup.py index 786a0d25..8644a494 100644 --- a/escnn/group/groups/dihedralgroup.py +++ b/escnn/group/groups/dihedralgroup.py @@ -328,24 +328,24 @@ def _subgroup(self, id: Tuple[int, int]) -> Tuple[ # parent_mapping = lambda e, axis=axis: self.element((e._element, axis * e._element)) # child_mapping = lambda e, axis=axis, sg=sg: None if e._element[1] != e._element[0] * axis else sg.element(e._element[0]) - parent_mapping = flip_to_dn(axis, self) - child_mapping = dn_to_flip(axis, sg) + parent_mapping = FlipToDnMapping(axis, self) + child_mapping = DnToFlipMapping(axis, sg) elif id[0] is None: # take the elements of the group generated by "r^ratio" sg = escnn.group.cyclic_group(order) # parent_mapping = lambda e, ratio=ratio: self.element((0, e._element * ratio)) # child_mapping = lambda e, ratio=ratio, sg=sg: None if (e._element[0] != 0 or e._element[1] % ratio > 0) else sg.element(int(e._element[1] / ratio)) - parent_mapping = cn_to_dn(self) - child_mapping = dn_to_cn(sg) + parent_mapping = CnToDnMapping(self) + child_mapping = DnToCnMapping(sg) else: # take the elements of the group generated by "r^ratio" and "r^axis f" sg = escnn.group.dihedral_group(order) # parent_mapping = lambda e, ratio=ratio, axis=axis: self.element((e._element[0], e._element[1] * ratio + e._element[0] * axis)) # child_mapping = lambda e, ratio=ratio, axis=axis, sg=sg: None if (e._element[1] - e._element[0] * axis) % ratio > 0 else sg.element((e._element[0], int((e._element[1] - e._element[0] * axis) / ratio))) - parent_mapping = dm_to_dn(axis, self) - child_mapping = dn_to_dm(axis, sg) + parent_mapping = DmToDnMapping(axis, self) + child_mapping = DnToDmMapping(axis, sg) return sg, parent_mapping, child_mapping @@ -557,8 +557,8 @@ def irrep(self, j: int, k: int) -> IrreducibleRepresentation: if id not in self._irreps: - irrep = _build_irrep_dn(j, k) - character = _build_char_dn(j, k) + irrep = IrrepBuilder(j, k) + character = CharacterBuilder(j, k) if j == 0: @@ -687,6 +687,70 @@ def character(element: GroupElement, j=j, k=k) -> float: return character +class IrrepBuilder: + def __init__(self, j: int, k: int): + self.j = j + self.k = k + + def __call__(self, element: GroupElement) -> np.ndarray: + j = self.j + k = self.k + + N = element.group.rotation_order + + if j == 0: + if k == 0: + return np.eye(1) + elif N % 2 == 0 and k == N // 2: + return np.array([[np.cos(k * element.to('radians')[1])]]) + else: + raise ValueError( + f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!") + else: + if k == 0: + # Trivial on Cyclic subgroup Representation + return np.array([[-1 if element.to('int')[0] else 1]]) + elif N % 2 == 0 and k == N / 2: + e = element.to('radians') + # 1 dimensional Irreducible representation (only for groups with an even number of rotations) + return np.array([[np.cos(k * e[1]) * (-1 if e[0] else 1)]]) + else: + # 2 dimensional Irreducible Representations + e = element.to('radians') + return utils.psichi(e[1], e[0], k=k) + + +class CharacterBuilder: + def __init__(self, j: int, k: int): + self.j = j + self.k = k + + def __call__(self, element: GroupElement): + N = element.group.rotation_order + j = self.j + k = self.k + + if j == 0: + if k == 0: + return 1. + elif N % 2 == 0 and k == N // 2: + return np.cos(k * element.to('radians')[1]) + else: + raise ValueError( + f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!") + else: + if k == 0: + # Trivial on Cyclic subgroup Representation + return -1 if element.to('int')[0] else 1 + elif N % 2 == 0 and k == N / 2: + e = element.to('radians') + # 1 dimensional Irreducible representation (only for groups with an even number of rotations) + return np.cos(k * e[1]) * (-1 if e[0] else 1) + else: + # 2 dimensional Irreducible Representations + e = element.to('radians') + return 0 if e[0] else (2 * np.cos(k * e[1])) + # Cyclic ############################### def dn_to_cn(cn: escnn.group.CyclicGroup): @@ -705,6 +769,22 @@ def _map(e: GroupElement, cn=cn): return _map +class DnToCnMapping: + def __init__(self, cn: escnn.group.CyclicGroup): + self.cn = cn + + def __call__(self, e: GroupElement): + assert isinstance(e.group, DihedralGroup) + + flip, rotation = e.to('int') + + ratio = e.group.rotation_order // self.cn.order() + + if flip == 0 and rotation % ratio == 0: + return self.cn.element(rotation // ratio, 'int') + else: + return None + def cn_to_dn(dn: DihedralGroup): @@ -719,6 +799,19 @@ def _map(e: GroupElement, dn=dn): return _map +class CnToDnMapping: + def __init__(self, dn: DihedralGroup): + self.dn = dn + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.CyclicGroup) + + ratio = self.dn.rotation_order // e.group.order() + + return self.dn.element( + (0, e.to('int') * ratio), 'int' + ) + # Flip wrt an axis ###################################### @@ -741,6 +834,25 @@ def _map(e: GroupElement, flip=flip, axis=axis): return _map +class DnToFlipMapping: + def __init__(self, axis: int, flip: escnn.group.CyclicGroup): + assert isinstance(flip, escnn.group.CyclicGroup) and flip.order() == 2 + self.axis = axis + self.flip = flip + + def _map(e: GroupElement): + assert isinstance(e.group, DihedralGroup) + + f, rot = e.to('int') + + if f == 0 and rot == 0: + return self.flip.identity + elif f == 1 and rot == self.axis: + return self.flip.element(1) + else: + return None + + def flip_to_dn(axis: int, dn: DihedralGroup): def _map(e: GroupElement, axis=axis, dn=dn): @@ -755,6 +867,21 @@ def _map(e: GroupElement, axis=axis, dn=dn): return _map +class FlipToDnMapping: + def __init__(self, axis: int, dn: DihedralGroup): + self.axis = axis + self.dn = dn + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.CyclicGroup) and e.group.order() == 2 + + f = e.to('int') + + if f == 0: + return self.dn.identity + else: + return self.dn.element((1, self.axis)) + # Dihedral Group ###################################### @@ -776,6 +903,24 @@ def _map(e: GroupElement, dm=dm, axis=axis): return _map +class DnToDmMapping: + def __init__(self, axis: int, dm: escnn.group.DihedralGroup): + assert isinstance(dm, escnn.group.DihedralGroup) + self.axis = axis + self.dm = dm + + def __call__(self, e: GroupElement): + assert isinstance(e.group, DihedralGroup) + + f, rot = e.to('int') + + ratio = e.group.rotation_order // self.dm.rotation_order + if (rot - f*self.axis) % ratio != 0: + return None + else: + return self.dm.element((f, (rot - f*self.axis) // ratio), 'int') + + def dm_to_dn(axis: int, dn: DihedralGroup): def _map(e: GroupElement, axis=axis, dn=dn): @@ -789,5 +934,18 @@ def _map(e: GroupElement, axis=axis, dn=dn): return _map +class DmToDnMapping: + def __init__(self, axis: int, dn: DihedralGroup): + self.axis = axis + self.dn = dn + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.DihedralGroup) + + f, rot = e.to('int') + + ratio = self.dn.rotation_order // e.group.rotation_order + + return self.dn.element((f, rot * ratio + f * self.axis), 'int') diff --git a/escnn/group/groups/o2group.py b/escnn/group/groups/o2group.py index bdf381c3..1564c2e9 100644 --- a/escnn/group/groups/o2group.py +++ b/escnn/group/groups/o2group.py @@ -373,16 +373,16 @@ def _subgroup(self, id: Tuple[float, int]) -> Tuple[ sg = escnn.group.so2_group(self._maximum_frequency) # parent_mapping = lambda e: self.element((0, e._element)) # child_mapping = lambda e, sg=sg: None if e._element[0] != 0 else sg.element(e._element[1]) - parent_mapping = so2_to_o2(self) - child_mapping = o2_to_so2(sg) + parent_mapping = SO2ToO2Mapping(self) + child_mapping = O2ToSO2Mapping(sg) elif id[0] is not None and id[1] == 1: # take the elements of the group generated by "2pi/k f" sg = escnn.group.cyclic_group(2) # parent_mapping = lambda e, axis=axis: self.element((e._element, axis * e._element)) # child_mapping = lambda e, axis=axis, sg=sg: None if not utils.cycle_isclose(e._element[1], axis * e._element[0], 2 * np.pi) else sg.element(e._element[0]) - parent_mapping = flip_to_o2(axis, self) - child_mapping = o2_to_flip(axis, sg) + parent_mapping = FlipToO2Mapping(axis, self) + child_mapping = O2ToFlipMapping(axis, sg) elif id[0] is None: # take the elements of the group generated by "2pi/order" @@ -390,8 +390,8 @@ def _subgroup(self, id: Tuple[float, int]) -> Tuple[ # parent_mapping = lambda e, order=order: self.element((0, e._element * 2. * np.pi / order)) # child_mapping = lambda e, order=order, sg=sg: None if (e._element[0] != 0 or not utils.cycle_isclose(e._element[1], 0., 2. * np.pi / order)) else \ # sg.element(int(round(e._element[1] * order / (2. * np.pi)))) - parent_mapping = so2_to_o2(self) - child_mapping = o2_to_so2(sg) + parent_mapping = SO2ToO2Mapping(self) + child_mapping = O2ToSO2Mapping(sg) elif id[0] is not None and id[1] > 1: # take the elements of the group generated by "2pi/order" and "2pi/k f" @@ -400,8 +400,8 @@ def _subgroup(self, id: Tuple[float, int]) -> Tuple[ # parent_mapping = lambda e, order=order, axis=axis: self.element((e._element[0], e._element[1] * 2. * np.pi / order + e._element[0] * axis)) # child_mapping = lambda e, order=order, axis=axis, sg=sg: None if not utils.cycle_isclose(e._element[1] - e._element[0] * axis, 0., 2. * np.pi / order) else \ # sg.element((e._element[0], int(round((e._element[1] - e._element[0] * axis) * order / (2. * np.pi))))) - parent_mapping = dn_to_o2(axis, self) - child_mapping = o2_to_dn(axis, sg) + parent_mapping = DnToO2Mapping(axis, self) + child_mapping = O2ToDnMapping(axis, sg) else: raise ValueError(f"id '{id}' not recognized") @@ -649,10 +649,10 @@ def irrep(self, j: int, k: int) -> IrreducibleRepresentation: name = f"irrep_{j},{k}" id = (j, k) + irrep = IrrepBuilder(j, k) + character = CharacterBuilder(j, k) + if id not in self._irreps: - irrep = _build_irrep_o2(j, k) - character = _build_char_o2(j, k) - if j == 0: if k == 0: # Trivial representation @@ -668,7 +668,6 @@ def irrep(self, j: int, k: int) -> IrreducibleRepresentation: raise ValueError(f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {self.name}!") elif k == 0: - # add Trivial on SO(2) subgroup Representation supported_nonlinearities = ['norm', 'gated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R', @@ -743,6 +742,33 @@ def irrep(element: GroupElement, j=j, k=k) -> np.ndarray: return irrep +class IrrepBuilder: + def __init__(self, j: int, k: int): + self.j = j + self.k = k + + def __call__(self, element: GroupElement) -> np.ndarray: + j = self.j + k = self.k + + assert j in [0, 1] + assert k >= 0 + + if j == 0: + if k == 0: + # Trivial representation + return np.eye(1) + else: + raise ValueError( + f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!") + elif k == 0: + # Trivial on SO(2) subgroup Representation + return np.array([[-1 if element.to('radians')[0] else 1]]) + else: + e = element.to('radians') + # 2 dimensional Irreducible Representations + return utils.psichi(e[1], e[0], k=k) + def _build_char_o2(j: int, k: int): @@ -768,6 +794,33 @@ def character(element: GroupElement, j=j, k=k) -> float: return character +class CharacterBuilder: + def __init__(self, j: int, k: int): + self.j = j + self.k = k + + def __call__(self, element: GroupElement) -> float: + j = self.j + k = self.k + assert j in [0, 1] + assert k >= 0 + + if j == 0: + if k == 0: + # Trivial representation + return 1. + else: + raise ValueError( + f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!") + elif k == 0: + # add Trivial on SO(2) subgroup Representation + return -1 if element.to('radians')[0] else 1 + else: + e = element.to('radians') + # 2 dimensional Irreducible Representations + return 0 if e[0] else (2 * np.cos(k * e[1])) + + # SO2 (and Cyclic) ############################### def o2_to_so2(so2: Union[escnn.group.SO2, escnn.group.CyclicGroup]): @@ -787,6 +840,22 @@ def _map(e: GroupElement, so2=so2): return _map +class O2ToSO2Mapping: + def __init__(self, so2: Union[escnn.group.SO2, escnn.group.CyclicGroup]): + self.so2 = so2 + + def __call__(self, e: GroupElement): + assert isinstance(e.group, O2) + flip, rotation = e.to('radians') + + if flip == 0: + try: + return self.so2.element(rotation, 'radians') + except ValueError: + return None + else: + return None + def so2_to_o2(o2: O2): @@ -799,6 +868,17 @@ def _map(e: GroupElement, o2=o2): return _map +class SO2ToO2Mapping: + def __init__(self, o2): + self.o2 = o2 + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.SO2) or isinstance(e.group, escnn.group.CyclicGroup) + return self.o2.element( + (0, e.to('radians')), 'radians' + ) + + # Flip wrt an axis ###################################### @@ -820,6 +900,25 @@ def _map(e: GroupElement, flip=flip, axis=axis): return _map +class O2ToFlipMapping: + def __init__(self, axis: float, flip: escnn.group.CyclicGroup): + assert isinstance(flip, escnn.group.CyclicGroup) and flip.order() == 2 + self.axis = axis + self.flip = flip + + def __call__(self, e: GroupElement): + assert isinstance(e.group, O2) + + f, rot = e.to('radians') + + if f == 0 and utils.cycle_isclose(rot, 0, 2*np.pi): + return self.flip.identity + elif f == 1 and utils.cycle_isclose(rot, self.axis, 2*np.pi): + return self.flip.element(1) + else: + return None + + def flip_to_o2(axis: float, o2: O2): def _map(e: GroupElement, axis=axis, o2=o2): @@ -835,6 +934,22 @@ def _map(e: GroupElement, axis=axis, o2=o2): return _map +class FlipToO2Mapping: + def __init__(self, axis: float, o2: O2): + self.axis = axis + self.o2 = o2 + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.CyclicGroup) and e.group.order() == 2 + + f = e.to('int') + + if f == 0: + return self.o2.identity + else: + return self.o2.element((1, self.axis)) + + # Dihedral Group ###################################### @@ -854,6 +969,23 @@ def _map(e: GroupElement, dn=dn, axis=axis): return _map +class O2ToDnMapping: + def __init__(self, axis: float, dn: escnn.group.DihedralGroup): + assert isinstance(dn, escnn.group.DihedralGroup) + self.axis = axis + self.dn = dn + + def __call__(self, e: GroupElement): + assert isinstance(e.group, O2) + + f, rot = e.to('radians') + + if utils.cycle_isclose(rot - f * self.axis, 0., 2. * np.pi / self.dn.rotation_order): + return self.dn.element((f, int(round((rot - f * self.axis) * self.dn.rotation_order / (2. * np.pi))))) + else: + return None + + def dn_to_o2(axis: float, o2: O2): def _map(e: GroupElement, axis=axis, o2=o2): @@ -866,4 +998,14 @@ def _map(e: GroupElement, axis=axis, o2=o2): return _map +class DnToO2Mapping: + def __init__(self, axis: float, o2: O2): + self.axis = axis + self.o2 = o2 + + def __call__(self, e: GroupElement): + assert isinstance(e.group, escnn.group.DihedralGroup) + f, rot = e.to('int') + + return self.o2.element((f, rot * 2. * np.pi / e.group.rotation_order + f * self.axis)) diff --git a/escnn/group/groups/so2group.py b/escnn/group/groups/so2group.py index a550b22e..c7fef9b7 100644 --- a/escnn/group/groups/so2group.py +++ b/escnn/group/groups/so2group.py @@ -288,10 +288,10 @@ def _subgroup(self, id: int) -> Tuple[ # take the elements of the group generated by "2pi/order" sg = escnn.group.cyclic_group(order) # parent_mapping = lambda e, order=order: self.element(e._element * 2 * np.pi / order) - parent_mapping = _build_parent_map(self, order) + parent_mapping = ParentMapping(self, order) # child_mapping = lambda e, order=order, sg=sg: None if divmod(e.g, 2.*np.pi/order)[1] > 1e-15 else sg.element(int(round(e.g * order / (2.*np.pi)))) # child_mapping = lambda e, order=order, sg=sg: None if not utils.cycle_isclose(e._element, 0., 2. * np.pi / order) else sg.element(int(round(e._element * order / (2. * np.pi)))) - child_mapping = _build_child_map(sg) + child_mapping = ChildMapping(sg) elif id == -1: sg = self parent_mapping = build_identity_map() @@ -439,8 +439,8 @@ def irrep(self, k: int) -> IrreducibleRepresentation: if id not in self._irreps: - irrep = _build_irrep_so2(k) - character = _build_char_so2(k) + irrep = IrrepBuilder(k) + character = CharacterBuilder(k) if k == 0: # Trivial representation @@ -452,7 +452,6 @@ def irrep(self, k: int) -> IrreducibleRepresentation: frequency=0 ) else: - # 2 dimensional Irreducible Representations supported_nonlinearities = ['norm', 'gated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2, 'C', @@ -544,6 +543,20 @@ def irrep(e: GroupElement, k: int = k): return irrep +class IrrepBuilder: + def __init__(self, k: int): + assert k >= 0 + self.k = k + + def __call__(self, e: GroupElement): + k = self.k + if k == 0: + # Trivial representation + return np.eye(1) + else: + # 2 dimensional Irreducible Representations + return utils.psi(e.to('radians'), k=k) + def _build_char_so2(k: int): assert k >= 0 @@ -558,6 +571,20 @@ def character(e: GroupElement, k: int = k) -> float: return character +class CharacterBuilder: + def __init__(self, k: int): + assert k >= 0 + self.k = k + + def __call__(self, e: GroupElement) -> float: + k = self.k + if k == 0: + # Trivial representation + return 1. + else: + # 2 dimensional Irreducible Representations + return 2 * np.cos(k * e.to('radians')) + def _build_parent_map(G: SO2, order: int): @@ -580,3 +607,22 @@ def child_mapping(e: GroupElement, sg: Group = sg) -> GroupElement: return child_mapping +class ParentMapping: + def __init__(self, G: SO2, order: int): + self.G = G + self.order = order + + def __call__(self, e: GroupElement): + return self.G.element(e.to('int') * 2 * np.pi / self.order) + + +class ChildMapping: + def __init__(self, sg: CyclicGroup): + self.sg = sg + + def __call__(self, e: GroupElement): + radians = e.to('radians') + if not utils.cycle_isclose(radians, 0., 2. * np.pi / self.sg.order()): + return None + else: + return self.sg.element(int(round(radians * self.sg.order() / (2. * np.pi)))) diff --git a/escnn/group/representation.py b/escnn/group/representation.py index 1b59be8f..f0cf982c 100644 --- a/escnn/group/representation.py +++ b/escnn/group/representation.py @@ -1041,26 +1041,26 @@ def direct_sum_factory(irreps: List[escnn.group.IrreducibleRepresentation], unique_irreps = list({irr.id: irr for irr in irreps}.items()) irreps_ids = [irr.id for irr in irreps] - def direct_sum(element: GroupElement, - irreps_ids=irreps_ids, change_of_basis=change_of_basis, - change_of_basis_inv=change_of_basis_inv, unique_irreps=unique_irreps): - reprs = {} - for n, irr in unique_irreps: - reprs[n] = irr(element) + # def direct_sum(element: GroupElement, + # irreps_ids=irreps_ids, change_of_basis=change_of_basis, + # change_of_basis_inv=change_of_basis_inv, unique_irreps=unique_irreps): + # reprs = {} + # for n, irr in unique_irreps: + # reprs[n] = irr(element) - blocks = [] - for irrep_id in irreps_ids: - repr = reprs[irrep_id] - blocks.append(repr) + # blocks = [] + # for irrep_id in irreps_ids: + # repr = reprs[irrep_id] + # blocks.append(repr) - P = sparse.block_diag(blocks, format='csc') + # P = sparse.block_diag(blocks, format='csc') - if change_of_basis is None: - return np.asarray(P.todense()) - else: - return change_of_basis @ P @ change_of_basis_inv + # if change_of_basis is None: + # return np.asarray(P.todense()) + # else: + # return change_of_basis @ P @ change_of_basis_inv - return direct_sum + return DirectSum(irreps_ids, change_of_basis, change_of_basis_inv, unique_irreps) def homomorphism_space(rho1: Representation, rho2: Representation) -> np.ndarray: @@ -1112,3 +1112,28 @@ def homomorphism_space(rho1: Representation, rho2: Representation) -> np.ndarray basis = np.einsum('Mm,kmn,nN->kMN', rho2.change_of_basis, basis, rho1.change_of_basis_inv) return basis + + +class DirectSum: + def __init__(self, irreps_ids, change_of_basis, change_of_basis_inv, unique_irreps): + self.irreps_ids = irreps_ids + self.change_of_basis = change_of_basis + self.change_of_basis_inv = change_of_basis_inv + self.unique_irreps = unique_irreps + + def __call__(self, element: GroupElement): + reprs = {} + for n, irr in self.unique_irreps: + reprs[n] = irr(element) + + blocks = [] + for irrep_id in self.irreps_ids: + repr = reprs[irrep_id] + blocks.append(repr) + + P = sparse.block_diag(blocks, format='csc') + + if self.change_of_basis is None: + return np.asarray(P.todense()) + else: + return self.change_of_basis @ P @ self.change_of_basis_inv diff --git a/escnn/gspaces/gspace.py b/escnn/gspaces/gspace.py index 55a61148..9364fc16 100644 --- a/escnn/gspaces/gspace.py +++ b/escnn/gspaces/gspace.py @@ -78,7 +78,7 @@ def __init__(self, fibergroup: escnn.group.Group, dimensionality: int, name: str # Store the computed intertwiners between irreps # - key = (filter size, sigma, rings) # - value = dictionary mapping (input_irrep, output_irrep) pairs to the corresponding basis - self._irreps_intertwiners_basis_memory = defaultdict(lambda: dict()) + self._irreps_intertwiners_basis_memory = defaultdict(dict) # Store the computed intertwiners between general representations # - key = (filter size, sigma, rings) diff --git a/test_pickling.py b/test_pickling.py new file mode 100644 index 00000000..b01b5a8e --- /dev/null +++ b/test_pickling.py @@ -0,0 +1,30 @@ +import pickle +import tempfile +import escnn +from escnn import nn, gspaces +import torch + + +def test_pickling(r2_act): + in_type = nn.FieldType(r2_act, 1 * [r2_act.trivial_repr]) + out_type = nn.FieldType(r2_act, 1 * [r2_act.regular_repr]) + + module = torch.nn.Sequential( + nn.R2Conv(in_type, out_type, 1), + ) + + with tempfile.TemporaryFile() as f: + pickle.dump(module, f) + + +if __name__ == "__main__": + r2_acts = [ + gspaces.trivialOnR2(), + gspaces.rot2dOnR2(N=4), + gspaces.flip2dOnR2(), + gspaces.flipRot2dOnR2(N=4), + ] + for r2_act in r2_acts: + print(str(r2_act)) + test_pickling(r2_act) +