Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions escnn/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@

from .field_type import FieldType
from .field_type import FieldType, FourierFieldType
from .geometric_tensor import GeometricTensor, tensor_directsum
from .grid_tensor import GridTensor

from .modules import *
from .modules import __all__ as modules_list

__all__ = [
"FieldType",
"FourierFieldType",
"GeometricTensor",
"GridTensor",
"tensor_directsum",
# Modules
] + modules_list + [
# init
*modules_list,
"init",
]
76 changes: 75 additions & 1 deletion escnn/nn/field_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import List, Dict, Tuple, Union
from typing import List, Dict, Tuple, Union, Optional

from collections import defaultdict
from itertools import groupby
Expand All @@ -17,6 +17,11 @@

__all__ = ["FieldType"]

# TODO:
# I think the band-limit frequency should be the argument to this class, not
# the irreps. I can get the irreps from the frequency, because I have the
# group. And it's nice to be able to compare against the frequency, e.g. for
# determining how many grid points to use.

class FieldType:

Expand Down Expand Up @@ -67,6 +72,7 @@ def __init__(self,


"""
assert isinstance(gspace, GSpace)
assert len(representations) > 0

assert isinstance(representations, tuple) or isinstance(representations, list)
Expand Down Expand Up @@ -605,3 +611,71 @@ def testing_elements(self):

def __call__(self, tensor: torch.Tensor, coords: torch.Tensor = None) -> 'escnn.nn.GeometricTensor':
return escnn.nn.GeometricTensor(tensor, self, coords)

class FourierFieldType(FieldType):
"""
A field type that is compatible with Fourier transforms.
"""

def __init__(
self,
gspace: GSpace,
channels: int,
bl_irreps: List,
*,
subgroup_id: Optional[Tuple] = None,
unpack=False
):
r"""
A ``FieldType`` that is compatible with the Fourier transform modules.

More specifically, this is a field type that is guaranteed to use only
spectral regular representations. Feature vectors transformed by such
representations can be interpreted as the coefficients of a
band-limited set of Fourier basis vectors.

Args:
gspace (GSpace): the gspace describing the symmetries of the data. The Fourier transform is
performed over the group ```gspace.fibergroup```
channels (int): the number of band-limited spectral regular representations that comprise each fiber.
irreps (list): list of irreps' ids to construct the band-limited representation
subgroup_id (tuple): ...
unpack (bool): Whether to treat the representation as a single entity (True) or as an set of irreps (False). This affect nonlinearities like `GatedNonLinearity1`.

Attributes:
~.gspace (GSpace)
~.representations (tuple)
~.size (int): dimensionality of the feature space described by the :class:`~escnn.nn.FieldType`.
It corresponds to the sum of the dimensionalities of the individual feature fields or
group representations (:attr:`escnn.group.Representation.size`).

Example:

>>> gspace = rot3DonR3()
>>> so3 = gspace.fibergroup
>>> in_type = FourierFieldType(gspace, 10, so3.bl_irreps(2))
"""
self.channels = channels
self.bl_irreps = bl_irreps
self.subgroup_id = subgroup_id
self.rho = make_fourier_representation(
gspace.fibergroup,
bl_irreps,
subgroup_id,
)

if unpack:
rho = [gspace.fibergroup.irrep(*n) for n in self.rho.irreps]
else:
rho = [self.rho]

super().__init__(gspace, rho * channels)


def make_fourier_representation(group, bl_irreps, subgroup_id=None):
if subgroup_id is None:
return group.spectral_regular_representation(*bl_irreps, name=None)
else:
return group.spectral_quotient_representation(subgroup_id, *bl_irreps, name=None)


6 changes: 6 additions & 0 deletions escnn/nn/grid_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class GridTensor:

def __init__(self, tensor, grid, coords):
self.tensor = tensor
self.grid = grid
self.coords = coords
6 changes: 5 additions & 1 deletion escnn/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from .linear import Linear

from .fourier import FourierTransform, InverseFourierTransform

from .nonlinearities import GatedNonLinearity1
from .nonlinearities import GatedNonLinearity2
from .nonlinearities import GatedNonLinearityUniform
Expand Down Expand Up @@ -90,7 +92,7 @@
"MergeModule",
"MultipleModule",
"Linear",
] + _point_conv_modules + [
*_point_conv_modules,
"R3Conv",
"R2Conv",
"R2ConvTransposed",
Expand Down Expand Up @@ -150,4 +152,6 @@
"IdentityModule",
"MaskModule",
"HarmonicPolynomialR3",
"FourierTransform",
"InverseFourierTransform",
]
159 changes: 159 additions & 0 deletions escnn/nn/modules/fourier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import numpy as np
import torch

from escnn.nn import FourierFieldType, GeometricTensor, GridTensor
from escnn.group import Group, GroupElement
from torch.nn import Module

from typing import Sequence, Optional

__all__ = ['FourierTransform', 'InverseFourierTransform']

# Docs:
# - Low level class, meant for building other modules
# - If use directly, possible to break equivariance

class InverseFourierTransform(Module):

def __init__(
self,
in_type: FourierFieldType,
out_grid: Sequence[GroupElement],
*,
normalize: bool = True,
):
super().__init__()

assert isinstance(in_type, FourierFieldType)

self.in_type = in_type
self.out_grid = list(out_grid)

A = _build_ift(in_type, out_grid, normalize)
A = torch.tensor(A, dtype=torch.get_default_dtype())

self.register_buffer('A', A)

def forward(self, input: GeometricTensor) -> GridTensor:
assert input.type == self.in_type

x_hat = input.tensor.view(
input.shape[0],
self.in_type.channels,
self.in_type.rho.size,
*input.shape[2:],
)

x = torch.einsum('bcf...,gf->bcg...', x_hat, self.A)

return GridTensor(x, self.out_grid, input.coords)

class FourierTransform(Module):

def __init__(
self,
in_grid,
out_type,
*,
extra_irreps: Optional[list] = None,
normalize: bool = True,
):
super().__init__()

assert isinstance(out_type, FourierFieldType)

self.in_grid = in_grid
self.out_type = out_type

if extra_irreps is None:
ift_type = out_type
else:
extra_irreps = [
x
for x in extra_irreps
if x not in out_type.bl_irreps
]
ift_type = FourierFieldType(
out_type.gspace,
out_type.channels,
out_type.bl_irreps + extra_irreps,
subgroup_id=out_type.subgroup_id,
)

A = _build_ift(ift_type, in_grid, normalize)

eps = 1e-8
n = ift_type.rho.size
Ainv = np.linalg.inv(A.T @ A + eps * np.eye(n)) @ A.T

if extra_irreps is not None:
Ainv = Ainv[:out_type.rho.size, :]

Ainv = torch.tensor(Ainv, dtype=torch.get_default_dtype())

self.register_buffer('Ainv', Ainv)

def forward(self, input: GridTensor) -> GeometricTensor:
assert input.grid == self.in_grid

y = input.tensor

y_hat = torch.einsum('bcg...,fg->bcf...', y, self.Ainv)

y_hat = y_hat.reshape(y.shape[0], self.out_type.size, *y.shape[3:])

return GeometricTensor(y_hat, self.out_type, input.coords)


def _build_ift(in_type: FourierFieldType, out_grid, normalize: bool):
"""
Create a matrix that will apply an inverse Fourier transform to a feature
vector of the given *in_type*.
"""
assert isinstance(in_type, FourierFieldType)

G = in_type.fibergroup

if in_type.subgroup_id is None:
kernel = _build_regular_kernel(G, in_type.bl_irreps)
else:
kernel = _build_quotient_kernel(G, in_type.subgroup_id, in_type.bl_irreps)

assert kernel.shape[0] == in_type.rho.size

if normalize:
kernel = kernel / np.linalg.norm(kernel)

kernel = kernel.reshape(-1, 1)

return np.concatenate(
[
in_type.rho(g) @ kernel
for g in out_grid
], axis=1
).T

def _build_regular_kernel(G: Group, irreps: list[tuple]):
kernel = []

for irr in irreps:
irr = G.irrep(*irr)

c = int(irr.size//irr.sum_of_squares_constituents)
k = irr(G.identity)[:, :c] * np.sqrt(irr.size)
kernel.append(k.T.reshape(-1))

kernel = np.concatenate(kernel)
return kernel

def _build_quotient_kernel(G: Group, subgroup_id: tuple, irreps: list[tuple]):
kernel = []

X: HomSpace = G.homspace(subgroup_id)

for irr in irreps:
k = X._dirac_kernel_ft(irr, X.H.trivial_representation.id)
kernel.append(k.T.reshape(-1))

kernel = np.concatenate(kernel)
return kernel
Loading