Skip to content

Commit bb1bb0f

Browse files
committed
removed old activations iterator. Added basic tests for same equivariance
1 parent c14da2d commit bb1bb0f

12 files changed

Lines changed: 77 additions & 364 deletions

tests/pytorch/test_measure.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tmeasures as tm
44
import numpy as np
55
from numpy.testing import assert_allclose
6-
from .utils import ConstantModel,ConstantDataset,RandomModel,RepeatedIdentitySet,MeasureFixture
6+
from .utils import ConstantModel,ConstantDataset, IdentityModel,RandomModel,RepeatedIdentitySet,MeasureFixture
77
from torch import nn
88

99
default_options = tm.pytorch.PyTorchMeasureOptions(batch_size=32)
@@ -14,13 +14,12 @@ def constant_model_invariance(n:int,bs:int):
1414
o = tm.pytorch.PyTorchMeasureOptions(batch_size=bs)
1515
transformations = RepeatedIdentitySet(n)
1616
dataset = ConstantDataset(n,torch.Tensor((1,)))
17-
transformations = RepeatedIdentitySet(n)
1817
output_shape = (2,2)
1918
output = torch.rand(output_shape)
2019
result = np.zeros(output.shape)
2120
result_nv = np.ones(output.shape)
2221
model = torch.nn.Sequential(ConstantModel(output))
23-
sv,tv,nv = measures
22+
sv,tv,nv = invariance_measures
2423
return [
2524
MeasureFixture(model, sv,[result],dataset,transformations,options=o),
2625
MeasureFixture(model, tv,[result],dataset,transformations,options=o),
@@ -34,11 +33,16 @@ def constant_model_invariance_options():
3433
def test_constant_model_invariance(f:MeasureFixture):
3534
f.assert_fixture(atol=1e-5)
3635

37-
measures = [tm.pytorch.SampleVarianceInvariance(),
36+
invariance_measures = [tm.pytorch.SampleVarianceInvariance(),
3837
tm.pytorch.TransformationVarianceInvariance(),
3938
tm.pytorch.NormalizedVarianceInvariance()
4039
]
4140

41+
same_equivariance_measures = [tm.pytorch.SampleVarianceSameEquivariance(),
42+
tm.pytorch.TransformationVarianceSameEquivariance(),
43+
tm.pytorch.NormalizedVarianceSameEquivariance()
44+
]
45+
4246
def random_model_invariance_options():
4347
sample_size_order = 2
4448
n = 10**sample_size_order
@@ -52,7 +56,7 @@ def random_model_invariance_options():
5256
model = torch.nn.Sequential(RandomModel(output_shape,2,3))
5357
result = np.ones(output_shape)*std
5458
result_nv = np.ones(output_shape)
55-
sv,tv,nv = measures
59+
sv,tv,nv = invariance_measures
5660
return [
5761
MeasureFixture(model, sv,[result],dataset,transformations,options=o),
5862
MeasureFixture(model, tv,[result],dataset,transformations,options=o),
@@ -63,6 +67,26 @@ def random_model_invariance_options():
6367
def test_random_model_invariance(f:MeasureFixture):
6468
f.assert_fixture(atol=1e-1)
6569

70+
71+
72+
def constant_model_same_equivariance_options(n:int=16,bs:int=4):
73+
o = tm.pytorch.PyTorchMeasureOptions(batch_size=bs)
74+
transformations = RepeatedIdentitySet(n)
75+
value = torch.Tensor(((2,3),(4,5)))
76+
dataset = ConstantDataset(n,value)
77+
result = np.zeros(value.shape)
78+
result_nv = np.ones(value.shape)
79+
model = torch.nn.Sequential(IdentityModel())
80+
sev,tev,nev = same_equivariance_measures
81+
return [
82+
MeasureFixture(model, sev,[result],dataset,transformations,options=o),
83+
MeasureFixture(model, tev,[result],dataset,transformations,options=o),
84+
MeasureFixture(model, nev,[result_nv],dataset,transformations,options=o),
85+
]
86+
@pytest.mark.parametrize("f",constant_model_same_equivariance_options())
87+
def test_constant_model_same_equivariance(f:MeasureFixture):
88+
f.assert_fixture(atol=1e-5)
89+
6690
if __name__ == "__main__":
6791
import logging
6892
# logging.basicConfig()

tests/pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __getitem__(self, index):
6060
return self.value
6161

6262

63-
class RepeatedIdentitySet(tm.pytorch.transformations.PyTorchTransformationSet):
63+
class RepeatedIdentitySet(tm.pytorch.transformations.PyTorchInvertibleTransformationSet):
6464
def __init__(self,transformations=1):
6565
super().__init__([tm.pytorch.transformations.IdentityTransformation()]*transformations)
6666
def valid_input(self):

tmeasures/pytorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
"""
2323

2424
from .model import ActivationsModule,AutoActivationsModule,get_activations,ManualActivationsModule
25-
from .base import PyTorchMeasure,PyTorchMeasureOptions,ActivationsByLayer,PyTorchMeasureResult,STMatrixIterator,PyTorchLayerMeasure
25+
from .base import PyTorchMeasure,PyTorchMeasureOptions,ActivationsByLayer,PyTorchMeasureResult,STMatrixIterator,PyTorchActivationMeasure
2626
from .transformations import PyTorchTransformationSet,PyTorchTransformation
2727

28-
from .activations_iterator_base import PytorchActivationsIterator,InvertedPytorchActivationsIterator,BothPytorchActivationsIterator,NormalPytorchActivationsIterator
28+
#from .activations_iterator_base import PytorchActivationsIterator,InvertedPytorchActivationsIterator,BothPytorchActivationsIterator,NormalPytorchActivationsIterator
2929

3030
from .by_layer import PyTorchMeasureByLayer
3131

tmeasures/pytorch/activations_iterator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# from .activations_transformer import ActivationsTransformer
1212
from .. import InvertibleTransformation, Transformation
1313
from . import ActivationsModule
14-
from .base import PyTorchLayerMeasure, PyTorchMeasure, PyTorchMeasureOptions
14+
from .base import PyTorchActivationMeasure, PyTorchMeasure, PyTorchMeasureOptions
1515
from .computation_model import ThreadsComputationModel
1616
from .dataset2d import Dataset2D, STDataset
1717
from .transformations import PyTorchTransformation
@@ -124,12 +124,12 @@ def split_activations_by_row(self,activations:ActivationValues,i_rows:list[int])
124124
start=end
125125
yield current_row,activations_row
126126

127-
def evaluate(self, m: PyTorchLayerMeasure):
128-
layers = self.model.activation_names()
127+
def evaluate(self, m: PyTorchActivationMeasure):
128+
activation_names = self.model.activation_names()
129129
rows, cols = self.dataset.len0, self.dataset.len1
130130
logger.debug(f"Main thread {threading.get_ident()}")
131-
measure_functions = {l:m.eval for l in layers}
131+
measure_functions = {l:m.eval for l in activation_names}
132132
model_evaluating_function = self.feed_measures
133-
max_workers = len(layers)+1
133+
max_workers = len(activation_names)+1
134134
tm = ThreadsComputationModel(model_evaluating_function,measure_functions,max_workers,rows,cols,self.o.batch_size)
135135
return tm.execute()

tmeasures/pytorch/activations_iterator_base.py

Lines changed: 0 additions & 250 deletions
This file was deleted.

0 commit comments

Comments
 (0)