Skip to content

Commit 475e836

Browse files
committed
wip new batching method
1 parent 6409e4a commit 475e836

9 files changed

Lines changed: 315 additions & 198 deletions

File tree

docs/examples/ResNet Invariance with TinyImageNet.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@
563563
"provenance": []
564564
},
565565
"kernelspec": {
566-
"display_name": ".venv",
566+
"display_name": "tm",
567567
"language": "python",
568568
"name": "python3"
569569
},

tests/pytorch/test_measure.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def forward(self,x:torch.Tensor):
1313
n = x.shape[0]
1414
result = self.value.expand(n,*self.value.shape)
1515
return result
16+
1617
class IdentityModel(torch.nn.Module):
1718
def __init__(self,) -> None:
1819
super().__init__()
@@ -42,11 +43,11 @@ def __getitem__(self, index):
4243
return self.dataset[index][0]
4344

4445
default_options = tm.pytorch.PyTorchMeasureOptions(batch_size=1024)
45-
large_options = tm.pytorch.PyTorchMeasureOptions(batch_size=2**14,num_workers=128)
46+
large_options = tm.pytorch.PyTorchMeasureOptions(batch_size=2**14,num_workers=12)
4647

4748
def assert_instance(measure,dataset,transformations,activations_model,expected_result,atol=1e-5,options=default_options):
48-
print(options.batch_size)
4949
result = measure.eval(dataset,transformations,activations_model,options)
50+
5051
result = result.numpy()
5152
for name,layer,expected_layer in zip(result.layer_names,result.layers,expected_result):
5253
assert_allclose(layer,expected_layer,err_msg=f"Error in {measure} for activation '{name}'",atol=atol)
@@ -58,16 +59,18 @@ def test_constant_model_invariance():
5859
expected_results = np.zeros(output.shape)
5960
expected_results_normalized = np.ones(output.shape)
6061
model = torch.nn.Sequential(ConstantModel(output))
61-
measures_results = [(tm.pytorch.SampleVarianceInvariance(),[expected_results]),
62+
measures_results = [
63+
# (tm.pytorch.SampleVarianceInvariance(),[expected_results]),
6264
(tm.pytorch.TransformationVarianceInvariance(),[expected_results]),
63-
(tm.pytorch.NormalizedVarianceInvariance(),[expected_results_normalized]),
65+
# (tm.pytorch.NormalizedVarianceInvariance(),[expected_results_normalized]),
6466
]
65-
transformations = tm.pytorch.transformations.IdentityTransformationSet()
66-
67-
dataset = ConstantDataset(2,(100,5))
67+
n = 5
68+
transformations = RepeatedIdentitySet(n)
69+
default_options.batch_size=3
70+
dataset = ConstantDataset(2,(n,5))
6871
activations_model = tm.pytorch.AutoActivationsModule(model)
6972
for measure,expected_result in measures_results:
70-
assert_instance(measure,dataset,transformations,activations_model,expected_result)
73+
assert_instance(measure,dataset,transformations,activations_model,expected_result,options=default_options)
7174

7275
class RepeatedIdentitySet(tm.pytorch.transformations.PyTorchTransformationSet):
7376
def __init__(self,transformations=1):
@@ -79,7 +82,7 @@ def copy(self):
7982
def id(self):
8083
return "Identity"
8184

82-
def test_random_model_invariance():
85+
def atest_random_model_invariance():
8386
output_shape = (2,2)
8487
mean,std=2.0,3
8588
model = torch.nn.Sequential(RandomModel(output_shape,2,3))
@@ -89,15 +92,24 @@ def test_random_model_invariance():
8992
(tm.pytorch.TransformationVarianceInvariance(),[expected_results]),
9093
(tm.pytorch.NormalizedVarianceInvariance(),[expected_results_normalized]),
9194
]
92-
sample_size_order = 2
95+
sample_size_order = 8
9396
n = 10**sample_size_order
9497
atol = 10**(-np.sqrt(sample_size_order//2))
9598
transformations = RepeatedIdentitySet(n)
9699
dataset = ConstantDataset(2,(n,2))
97100
activations_model = tm.pytorch.AutoActivationsModule(model)
101+
large_options.batch_size = n
98102
for measure,expected_result in measures_results:
99103
assert_instance(measure,dataset,transformations,activations_model,expected_result,atol=1e-1,options=large_options)
100104

101105

106+
102107
if __name__ == "__main__":
103-
test_random_model_invariance()
108+
import logging
109+
logging.basicConfig()
110+
111+
#set logger to info level
112+
tm.logger.setLevel(logging.INFO)
113+
114+
test_constant_model_invariance()
115+
#atest_random_model_invariance()

tmeasures/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
"""
44

55
import logging
6+
67
logger = logging.getLogger(__name__)
8+
9+
710
log_level = logging.WARN
811
logger.setLevel(log_level)
912

0 commit comments

Comments
 (0)