Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions escnn/group/groups/cyclicgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def _subgroup(self, id: int) -> Tuple[

# parent_mapping = lambda e, ratio=ratio: self.element(e._element * ratio)
# child_mapping = lambda e, ratio=ratio, sg=sg: None if e._element % ratio != 0 else sg.element(int(e._element // ratio))
parent_mapping = _build_parent_map(self, order)
child_mapping = _build_child_map(self, sg)
parent_mapping = ParentMapping(self, order)
child_mapping = ChildMapping(self, sg)

return sg, parent_mapping, child_mapping

Expand Down Expand Up @@ -411,12 +411,12 @@ def irrep(self, k: int) -> IrreducibleRepresentation:
name = f"irrep_{k}"

n = self.order()

irrep = IrrepBuilder(k)
character = CharacterBuilder(k)

if k == 0:
# Trivial representation

irrep = _build_irrep_cn(0)
character = _build_char_cn(0)
supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R',
supported_nonlinearities=supported_nonlinearities,
Expand All @@ -425,19 +425,13 @@ def irrep(self, k: int) -> IrreducibleRepresentation:
frequency=k)
elif n % 2 == 0 and k == int(n/2):
# 1 dimensional Irreducible representation (only for even order groups)
irrep = _build_irrep_cn(k)
character = _build_char_cn(k)
supported_nonlinearities = ['norm', 'gated', 'concatenated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R',
supported_nonlinearities=supported_nonlinearities,
character=character,
frequency=k)
else:
# 2 dimensional Irreducible Representations

irrep = _build_irrep_cn(k)
character = _build_char_cn(k)

supported_nonlinearities = ['norm', 'gated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2, 'C',
supported_nonlinearities=supported_nonlinearities,
Expand Down Expand Up @@ -555,6 +549,26 @@ def irrep(element: GroupElement, k:int =k) -> np.ndarray:
return irrep


class IrrepBuilder:
def __init__(self, k: int):
self.k = k

def __call__(self, element: GroupElement) -> np.ndarray:
k = self.k

if k == 0:
return np.eye(1)

n = element.group.order()

if n % 2 == 0 and k == int(n / 2):
# 1 dimensional Irreducible representation (only for even order groups)
return np.array([[np.cos(k * element.to('radians'))]])
else:
# 2 dimensional Irreducible Representations
return utils.psi(element.to('radians'), k=k)


def _build_char_cn(k: int):

def character(element: GroupElement, k=k) -> float:
Expand All @@ -573,12 +587,40 @@ def character(element: GroupElement, k=k) -> float:
return character


class CharacterBuilder:
def __init__(self, k: int):
self.k = k

def __call__(self, element: GroupElement) -> float:
k = self.k

if k == 0:
return 1.

n = element.group.order()

if n % 2 == 0 and k == int(n / 2):
# 1 dimensional Irreducible representation (only for even order groups)
return np.cos(k * element.to('radians'))
else:
# 2 dimensional Irreducible Representations
return 2*np.cos(k * element.to('radians'))


def _build_parent_map(G: CyclicGroup, order: int):
def parent_mapping(e: GroupElement, G: Group = G, order=order) -> GroupElement:
return G.element(e.to('int') * G.order() // order)

return parent_mapping

class ParentMapping:
def __init__(self, G: CyclicGroup, order: int):
self.G = G
self.order = order

def __call__(self, e: GroupElement):
return self.G.element(e.to('int') * self.G.order() // self.order)


def _build_child_map(G: CyclicGroup, sg: CyclicGroup):
assert G.order() % sg.order() == 0
Expand All @@ -595,3 +637,18 @@ def child_mapping(e: GroupElement, G=G, sg: Group = sg) -> GroupElement:
return child_mapping


class ChildMapping:
def __init__(self, G: CyclicGroup, sg: CyclicGroup):
assert G.order() % sg.order() == 0
self.G = G
self.sg = sg

def __call__(self, e: GroupElement):
assert e.group == self.G
i = e.to('int')
ratio = self.G.order() // self.sg.order()
if i % ratio != 0:
return None
else:
return self.sg.element(i // ratio)

174 changes: 166 additions & 8 deletions escnn/group/groups/dihedralgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,24 @@ def _subgroup(self, id: Tuple[int, int]) -> Tuple[

# parent_mapping = lambda e, axis=axis: self.element((e._element, axis * e._element))
# child_mapping = lambda e, axis=axis, sg=sg: None if e._element[1] != e._element[0] * axis else sg.element(e._element[0])
parent_mapping = flip_to_dn(axis, self)
child_mapping = dn_to_flip(axis, sg)
parent_mapping = FlipToDnMapping(axis, self)
child_mapping = DnToFlipMapping(axis, sg)

elif id[0] is None:
# take the elements of the group generated by "r^ratio"
sg = escnn.group.cyclic_group(order)
# parent_mapping = lambda e, ratio=ratio: self.element((0, e._element * ratio))
# child_mapping = lambda e, ratio=ratio, sg=sg: None if (e._element[0] != 0 or e._element[1] % ratio > 0) else sg.element(int(e._element[1] / ratio))
parent_mapping = cn_to_dn(self)
child_mapping = dn_to_cn(sg)
parent_mapping = CnToDnMapping(self)
child_mapping = DnToCnMapping(sg)

else:
# take the elements of the group generated by "r^ratio" and "r^axis f"
sg = escnn.group.dihedral_group(order)
# parent_mapping = lambda e, ratio=ratio, axis=axis: self.element((e._element[0], e._element[1] * ratio + e._element[0] * axis))
# child_mapping = lambda e, ratio=ratio, axis=axis, sg=sg: None if (e._element[1] - e._element[0] * axis) % ratio > 0 else sg.element((e._element[0], int((e._element[1] - e._element[0] * axis) / ratio)))
parent_mapping = dm_to_dn(axis, self)
child_mapping = dn_to_dm(axis, sg)
parent_mapping = DmToDnMapping(axis, self)
child_mapping = DnToDmMapping(axis, sg)

return sg, parent_mapping, child_mapping

Expand Down Expand Up @@ -557,8 +557,8 @@ def irrep(self, j: int, k: int) -> IrreducibleRepresentation:

if id not in self._irreps:

irrep = _build_irrep_dn(j, k)
character = _build_char_dn(j, k)
irrep = IrrepBuilder(j, k)
character = CharacterBuilder(j, k)

if j == 0:

Expand Down Expand Up @@ -687,6 +687,70 @@ def character(element: GroupElement, j=j, k=k) -> float:
return character


class IrrepBuilder:
def __init__(self, j: int, k: int):
self.j = j
self.k = k

def __call__(self, element: GroupElement) -> np.ndarray:
j = self.j
k = self.k

N = element.group.rotation_order

if j == 0:
if k == 0:
return np.eye(1)
elif N % 2 == 0 and k == N // 2:
return np.array([[np.cos(k * element.to('radians')[1])]])
else:
raise ValueError(
f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!")
else:
if k == 0:
# Trivial on Cyclic subgroup Representation
return np.array([[-1 if element.to('int')[0] else 1]])
elif N % 2 == 0 and k == N / 2:
e = element.to('radians')
# 1 dimensional Irreducible representation (only for groups with an even number of rotations)
return np.array([[np.cos(k * e[1]) * (-1 if e[0] else 1)]])
else:
# 2 dimensional Irreducible Representations
e = element.to('radians')
return utils.psichi(e[1], e[0], k=k)


class CharacterBuilder:
def __init__(self, j: int, k: int):
self.j = j
self.k = k

def __call__(self, element: GroupElement):
N = element.group.rotation_order
j = self.j
k = self.k

if j == 0:
if k == 0:
return 1.
elif N % 2 == 0 and k == N // 2:
return np.cos(k * element.to('radians')[1])
else:
raise ValueError(
f"Error! Flip frequency {j} and rotational frequency {k} don't correspond to any irrep of the group {element.group.name}!")
else:
if k == 0:
# Trivial on Cyclic subgroup Representation
return -1 if element.to('int')[0] else 1
elif N % 2 == 0 and k == N / 2:
e = element.to('radians')
# 1 dimensional Irreducible representation (only for groups with an even number of rotations)
return np.cos(k * e[1]) * (-1 if e[0] else 1)
else:
# 2 dimensional Irreducible Representations
e = element.to('radians')
return 0 if e[0] else (2 * np.cos(k * e[1]))

# Cyclic ###############################

def dn_to_cn(cn: escnn.group.CyclicGroup):
Expand All @@ -705,6 +769,22 @@ def _map(e: GroupElement, cn=cn):

return _map

class DnToCnMapping:
def __init__(self, cn: escnn.group.CyclicGroup):
self.cn = cn

def __call__(self, e: GroupElement):
assert isinstance(e.group, DihedralGroup)

flip, rotation = e.to('int')

ratio = e.group.rotation_order // self.cn.order()

if flip == 0 and rotation % ratio == 0:
return self.cn.element(rotation // ratio, 'int')
else:
return None


def cn_to_dn(dn: DihedralGroup):

Expand All @@ -719,6 +799,19 @@ def _map(e: GroupElement, dn=dn):

return _map

class CnToDnMapping:
def __init__(self, dn: DihedralGroup):
self.dn = dn

def __call__(self, e: GroupElement):
assert isinstance(e.group, escnn.group.CyclicGroup)

ratio = self.dn.rotation_order // e.group.order()

return self.dn.element(
(0, e.to('int') * ratio), 'int'
)


# Flip wrt an axis ######################################

Expand All @@ -741,6 +834,25 @@ def _map(e: GroupElement, flip=flip, axis=axis):
return _map


class DnToFlipMapping:
def __init__(self, axis: int, flip: escnn.group.CyclicGroup):
assert isinstance(flip, escnn.group.CyclicGroup) and flip.order() == 2
self.axis = axis
self.flip = flip

def _map(e: GroupElement):
assert isinstance(e.group, DihedralGroup)

f, rot = e.to('int')

if f == 0 and rot == 0:
return self.flip.identity
elif f == 1 and rot == self.axis:
return self.flip.element(1)
else:
return None


def flip_to_dn(axis: int, dn: DihedralGroup):

def _map(e: GroupElement, axis=axis, dn=dn):
Expand All @@ -755,6 +867,21 @@ def _map(e: GroupElement, axis=axis, dn=dn):

return _map

class FlipToDnMapping:
def __init__(self, axis: int, dn: DihedralGroup):
self.axis = axis
self.dn = dn

def __call__(self, e: GroupElement):
assert isinstance(e.group, escnn.group.CyclicGroup) and e.group.order() == 2

f = e.to('int')

if f == 0:
return self.dn.identity
else:
return self.dn.element((1, self.axis))


# Dihedral Group ######################################

Expand All @@ -776,6 +903,24 @@ def _map(e: GroupElement, dm=dm, axis=axis):
return _map


class DnToDmMapping:
def __init__(self, axis: int, dm: escnn.group.DihedralGroup):
assert isinstance(dm, escnn.group.DihedralGroup)
self.axis = axis
self.dm = dm

def __call__(self, e: GroupElement):
assert isinstance(e.group, DihedralGroup)

f, rot = e.to('int')

ratio = e.group.rotation_order // self.dm.rotation_order
if (rot - f*self.axis) % ratio != 0:
return None
else:
return self.dm.element((f, (rot - f*self.axis) // ratio), 'int')


def dm_to_dn(axis: int, dn: DihedralGroup):

def _map(e: GroupElement, axis=axis, dn=dn):
Expand All @@ -789,5 +934,18 @@ def _map(e: GroupElement, axis=axis, dn=dn):

return _map

class DmToDnMapping:
def __init__(self, axis: int, dn: DihedralGroup):
self.axis = axis
self.dn = dn

def __call__(self, e: GroupElement):
assert isinstance(e.group, escnn.group.DihedralGroup)

f, rot = e.to('int')

ratio = self.dn.rotation_order // e.group.rotation_order

return self.dn.element((f, rot * ratio + f * self.axis), 'int')


Loading