Skip to content

Commit 6409e4a

Browse files
committed
added test. reimplementing feed_threads to use larger batch sizes
1 parent 4ff3386 commit 6409e4a

10 files changed

Lines changed: 2435 additions & 98 deletions

File tree

docs/examples/Variance to rotations of a CNN trained on MNIST with PyTorch.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"\n",
3737
"torch.manual_seed(0)\n",
3838
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
39-
"results_path = Path(\"~/tm_example_pytorch/\").expanduser()\n",
39+
"results_path = Path(\"~/.tmeasures/\").expanduser()\n",
4040
"results_path.mkdir(parents=True, exist_ok=True)"
4141
]
4242
},
@@ -175,7 +175,8 @@
175175
"train_augmentation = [transforms.RandomRotation(degree_range)]\n",
176176
"train_transform = transforms.Compose(train_augmentation + base_preprocessing)\n",
177177
"measure_transform = transforms.Compose(base_preprocessing)\n",
178-
"path = results_path / 'mnist'\n",
178+
"path = Path('~/.datasets/mnist').expanduser()\n",
179+
"path.mkdir(exist_ok=True,parents=True)\n",
179180
"\n",
180181
"train_dataset = datasets.MNIST(path, train=True, download=True,\n",
181182
" transform=train_transform)\n",
@@ -353,7 +354,7 @@
353354
"\n",
354355
"Last step before computing the measure: we need to define a PyTorchMeasureOptions object to configure where and the measure will be computed. The `batch_size` and `num_workers` keywords are analogous to the ones used in [PyTorch's DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). \n",
355356
"\n",
356-
"The `data_device`, `model_device` and `measure_device` indicate, respectively, where the transformations and data preprocessing is performed, where the activations of the model are computed, and finally where the actual measure is computed. In simple cases, these devices could all be the same.\n",
357+
"The `data_device`, `model_device` and `measure_device` indicate, respectively, where the transformations and data preprocessing are performed, where the activations of the model are computed, and finally where the actual measure is computed. In most cases, using the same device in all cases will have the most performance; however, in some cases it is necessary or desirable to perform data preprocessing in `cpu` and model and measure computations in a `gpu` or other accelerator.\n",
357358
"\n",
358359
"Finally, we can `eval` the measure with the dataset, transformation, model and options, obtaining a `PyTorchMeasureResult`, which can be handily converted to a `numpy` version for easy visualization.\n"
359360
]

docs/examples/basic_example_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, x):
5656

5757
torch.manual_seed(0)
5858
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59-
results_path = Path("~/tm_example_pytorch/").expanduser()
59+
results_path = Path("~/.tm_example_pytorch/").expanduser()
6060
results_path.mkdir(parents=True, exist_ok=True)
6161

6262
# DATASET
@@ -147,7 +147,7 @@ def __getitem__(self, index):
147147
for measure,model in measures:
148148
exp_id = f"rot{degree_range}_{measure}"
149149
result_filepath = results_path / f'{exp_id}_result.pickle'
150-
if os.path.exists(result_filepath) and False:
150+
if os.path.exists(result_filepath):
151151
print(f"Measure {measure} already evaluated, loading...")
152152
# Load result (optional, in case you don't want to run the above or your session died)
153153
with open(result_filepath, 'rb') as f:

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ dependencies = [
1515
"scipy",
1616
"scikit-image>=0.25.0",
1717
"scikit-learn",
18-
"data-science-types",
1918
"statsmodels>=0.14.4",
2019
"tqdm>=4.67.1",
2120
]
2221

2322

2423
[dependency-groups]
2524
dev = [
25+
"torch>=2",
26+
"torchvision",
2627
"pandas-stubs>=2.2.3.250308",
2728
"scipy-stubs>=1.15.2.1",
2829
"microsoft-python-type-stubs",
@@ -32,6 +33,8 @@ dev = [
3233
"types-tqdm>=4.67.0.20250401",
3334
"pre-commit>=4.2.0",
3435
"poethepoet>=0.35.0",
36+
"data-science-types",
37+
"poutyne",
3538
]
3639
docs = [
3740
"sphinx>=8.2.3",

tests/pytorch/test_measure.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
import pytest
3+
import tmeasures as tm
4+
import numpy as np
5+
from numpy.testing import assert_allclose
6+
7+
8+
class ConstantModel(torch.nn.Module):
9+
def __init__(self,value=torch.Tensor(0)) -> None:
10+
super().__init__()
11+
self.value=value
12+
def forward(self,x:torch.Tensor):
13+
n = x.shape[0]
14+
result = self.value.expand(n,*self.value.shape)
15+
return result
16+
class IdentityModel(torch.nn.Module):
17+
def __init__(self,) -> None:
18+
super().__init__()
19+
def forward(self,x:torch.Tensor):
20+
return x
21+
22+
class RandomModel(torch.nn.Module):
23+
def __init__(self,shape:tuple,mean=0.0,std=1.0) -> None:
24+
super().__init__()
25+
self.mean=mean
26+
self.std=std
27+
self.shape=shape
28+
def forward(self,x:torch.Tensor):
29+
n = x.shape[0]
30+
shape = (n,*self.shape)
31+
return torch.normal(mean=self.mean,std=self.std,size=shape)
32+
33+
34+
35+
class ConstantDataset(torch.utils.data.Dataset):
36+
def __init__(self,value=0,shape=(10,10)):
37+
super().__init__()
38+
self.dataset = torch.utils.data.TensorDataset(torch.ones(shape)*value)
39+
def __len__(self):
40+
return len(self.dataset)
41+
def __getitem__(self, index):
42+
return self.dataset[index][0]
43+
44+
default_options = tm.pytorch.PyTorchMeasureOptions(batch_size=1024)
45+
large_options = tm.pytorch.PyTorchMeasureOptions(batch_size=2**14,num_workers=128)
46+
47+
def assert_instance(measure,dataset,transformations,activations_model,expected_result,atol=1e-5,options=default_options):
48+
print(options.batch_size)
49+
result = measure.eval(dataset,transformations,activations_model,options)
50+
result = result.numpy()
51+
for name,layer,expected_layer in zip(result.layer_names,result.layers,expected_result):
52+
assert_allclose(layer,expected_layer,err_msg=f"Error in {measure} for activation '{name}'",atol=atol)
53+
54+
55+
def test_constant_model_invariance():
56+
output_shape = (2,2)
57+
output = torch.rand(output_shape)
58+
expected_results = np.zeros(output.shape)
59+
expected_results_normalized = np.ones(output.shape)
60+
model = torch.nn.Sequential(ConstantModel(output))
61+
measures_results = [(tm.pytorch.SampleVarianceInvariance(),[expected_results]),
62+
(tm.pytorch.TransformationVarianceInvariance(),[expected_results]),
63+
(tm.pytorch.NormalizedVarianceInvariance(),[expected_results_normalized]),
64+
]
65+
transformations = tm.pytorch.transformations.IdentityTransformationSet()
66+
67+
dataset = ConstantDataset(2,(100,5))
68+
activations_model = tm.pytorch.AutoActivationsModule(model)
69+
for measure,expected_result in measures_results:
70+
assert_instance(measure,dataset,transformations,activations_model,expected_result)
71+
72+
class RepeatedIdentitySet(tm.pytorch.transformations.PyTorchTransformationSet):
73+
def __init__(self,transformations=1):
74+
super().__init__([tm.pytorch.transformations.IdentityTransformation()]*transformations)
75+
def valid_input(self):
76+
return True
77+
def copy(self):
78+
return self
79+
def id(self):
80+
return "Identity"
81+
82+
def test_random_model_invariance():
83+
output_shape = (2,2)
84+
mean,std=2.0,3
85+
model = torch.nn.Sequential(RandomModel(output_shape,2,3))
86+
expected_results = np.ones(output_shape)*std
87+
expected_results_normalized = np.ones(output_shape)
88+
measures_results = [(tm.pytorch.SampleVarianceInvariance(),[expected_results]),
89+
(tm.pytorch.TransformationVarianceInvariance(),[expected_results]),
90+
(tm.pytorch.NormalizedVarianceInvariance(),[expected_results_normalized]),
91+
]
92+
sample_size_order = 2
93+
n = 10**sample_size_order
94+
atol = 10**(-np.sqrt(sample_size_order//2))
95+
transformations = RepeatedIdentitySet(n)
96+
dataset = ConstantDataset(2,(n,2))
97+
activations_model = tm.pytorch.AutoActivationsModule(model)
98+
for measure,expected_result in measures_results:
99+
assert_instance(measure,dataset,transformations,activations_model,expected_result,atol=1e-1,options=large_options)
100+
101+
102+
if __name__ == "__main__":
103+
test_random_model_invariance()

tmeasures/measure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from .utils import get_all
77

8+
## todo change np.ndarray to something more general
89
ActivationsByLayer = List[np.ndarray]
910

1011
# TODO change `layer` for `activation` in variable/methods to unify vocabulary

tmeasures/pytorch/activations_iterator.py

Lines changed: 134 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from collections.abc import Generator
2+
import typing
3+
4+
from tmeasures.pytorch.transformations import PyTorchTransformation
15
from .dataset2d import STDataset, Dataset2D
26
import torch
37
from torch.utils.data import DataLoader
@@ -25,12 +29,12 @@
2529
class ActivationsTransformer(abc.ABC):
2630

2731
@abc.abstractmethod
28-
def transform(self, activations: torch.Tensor, x: torch.Tensor, transformations: List[Transformation]) -> torch.Tensor:
32+
def transform(self, activations: torch.Tensor, x: torch.Tensor, transformations: List[PyTorchTransformation]) -> torch.Tensor:
2933
pass
3034

3135

3236
class IdentityActivationsTransformer(ActivationsTransformer):
33-
def transform(self, activations: torch.Tensor, x: torch.Tensor, transformations: List[Transformation]) -> torch.Tensor:
37+
def transform(self, activations: torch.Tensor, x: torch.Tensor, transformations: List[PyTorchTransformation]) -> torch.Tensor:
3438
return activations
3539

3640
from tmeasures import logger
@@ -44,6 +48,7 @@ def __init__(self,layers:list[str],rows:int,n_batch:int,stop=False) -> None:
4448
self.layers = layers
4549
self.qs = {l: IterableQueue(rows,maxsize=1,name=f"q({l})") for l in layers}
4650
self.row_qs = {l: IterableQueue(n_batch,maxsize=1,name=f"q({l}_row)") for l in layers}
51+
4752
@property
4853
def queues(self):
4954
return list(self.qs.values())+list(self.row_qs.values())
@@ -115,55 +120,136 @@ def check_finished(self,worker_futures,server_future,tm:ThreadsManager):
115120
if not e is None:
116121
logger.info(f"Worker exception, about to re raise from main thread\n{e}\n thread id {threading.get_ident()}\n")
117122
raise e
118-
123+
124+
def move_activations_to_measure_device(self,activations:list[torch.Tensor]):
125+
for i, layer_activations in enumerate(activations):
126+
if self.o.model_device != self.o.measure_device:
127+
layer_activations=layer_activations.to(self.o.measure_device,non_blocking=True)
128+
129+
def transform_activations(self,activations:list[torch.Tensor],x_transformed,transformations)->list[torch.Tensor]:
130+
for i, layer_activations in enumerate(activations):
131+
activations[i] = self.activations_transformer.transform(layer_activations, x_transformed,transformations)
132+
133+
@torch.no_grad
134+
def feed_threads2(self,tm:ThreadsManager):
135+
layers = self.model.activation_names()
136+
rows, cols = self.dataset.len0, self.dataset.len1
137+
138+
# print(f"act it starting,num workers {self.o.num_workers}:")
139+
dataloader = DataLoader(self.dataset, batch_size=self.o.batch_size, shuffle=False, num_workers=self.o.num_workers,pin_memory=True)
140+
i=0
141+
142+
for row in range(rows):
143+
144+
for k, q in tm.qs.items():
145+
logger.info(f"AI: putting row {row} dataloader for layer {k}")
146+
q.put(tm.row_qs[k])
147+
148+
# print(f"AI: finished putting row {row} dataloaders for all layers")
149+
# for k,q in qs.items():
150+
# print(f"AI: {k}→ {q.queue.qsize()} items")
151+
if tm.stop:
152+
logger.info("Server thread stopping, exception detected")
153+
return
154+
col = 0
155+
# print("col",col)
156+
for batch_i,x_transformed in tqdm.tqdm(enumerate(dataloader), disable=not self.o.verbose, leave=False):
157+
sample_i_start = batch_i*self.o.batch_size
158+
i_samples = [self.dataset.d1tod2(i) for i in range(sample_i_start,sample_i_start+self.o.batch_size)]
159+
i_rows, i_cols = typing.cast(tuple[list[int],list[int]], zip(*i_samples))
160+
# print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
161+
x_transformed = x_transformed.to(self.o.model_device,non_blocking=True)
162+
# print("AI: getting activations..")
163+
activations = self.model.forward_activations(x_transformed)
164+
# print("AI: got activations")
165+
transformations = self.dataset.get_transformations(i_rows,i_cols)
166+
col_to = col + x_transformed.shape[0]
167+
# Move acti
168+
self.move_activations_to_measure_device(activations)
169+
activations = self.transform_activations(activations,x_transformed,transformations)
170+
if tm.stop:
171+
logger.info("Server thread stopping, exception detected")
172+
return
173+
174+
# print(f"AI: act it, shape {layer_activations.shape}")
175+
# print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
176+
for row, row_activations in self.split_row_activations(activations,i_rows):
177+
for i,layer_activations in enumerate(row_activations):
178+
tm.row_qs[layers[i]].put(layer_activations)
179+
180+
# print("AI: finished row")
181+
# print("AI: finished all rows")
182+
183+
def split_row_activations(self,activations:list[torch.Tensor],i_rows:list[int])->Generator[tuple[int,list[torch.Tensor]]]:
184+
all_rows = list(range(min(i_rows),max(i_rows)+1))
185+
start = 0
186+
last = all_rows[-1]
187+
for current_row in all_rows:
188+
if current_row == last:
189+
end = len(i_rows)+1
190+
else:
191+
end = i_rows.index(current_row+1)
192+
193+
activations_row = [a[start:end,] for a in activations]
194+
start=end+1
195+
yield current_row,activations_row
196+
197+
198+
@torch.no_grad
119199
def feed_threads(self,tm:ThreadsManager):
120-
layers = self.model.activation_names()
121-
rows, cols = self.dataset.len0, self.dataset.len1
122-
123-
with torch.no_grad():
124-
# print(f"act it starting,num workers {self.o.num_workers}:")
125-
for row in tqdm.trange(rows, disable=not self.o.verbose, leave=False):
126-
row_dataset = self.dataset.row_dataset(row)
127-
row_dataloader = DataLoader(row_dataset, batch_size=self.o.batch_size, shuffle=False, num_workers=0,pin_memory=True)
128-
129-
for k, q in tm.qs.items():
130-
logger.info(f"AI: putting row {row} dataloader for layer {k}")
131-
q.put(tm.row_qs[k])
200+
layers = self.model.activation_names()
201+
rows, cols = self.dataset.len0, self.dataset.len1
132202

133-
# print(f"AI: finished putting row {row} dataloaders for all layers")
134-
# for k,q in qs.items():
135-
# print(f"AI: {k}→ {q.queue.qsize()} items")
136-
if tm.stop:
137-
logger.info("Server thread stopping, exception detected")
138-
return
139-
col = 0
140-
# print("col",col)
203+
# print(f"act it starting,num workers {self.o.num_workers}:")
204+
for row in tqdm.trange(rows, disable=not self.o.verbose, leave=False):
205+
row_dataset = self.dataset.row_dataset(row)
206+
row_dataloader = DataLoader(row_dataset, batch_size=self.o.batch_size, shuffle=False, num_workers=0,pin_memory=True)
207+
208+
for k, q in tm.qs.items():
209+
logger.info(f"AI: putting row {row} dataloader for layer {k}")
210+
q.put(tm.row_qs[k])
211+
212+
# print(f"AI: finished putting row {row} dataloaders for all layers")
213+
# for k,q in qs.items():
214+
# print(f"AI: {k}→ {q.queue.qsize()} items")
215+
if tm.stop:
216+
logger.info("Server thread stopping, exception detected")
217+
return
218+
col = 0
219+
# print("col",col)
220+
221+
for batch_i,x_transformed in enumerate(row_dataloader):
222+
# print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
223+
x_transformed = x_transformed.to(self.o.model_device,non_blocking=True)
224+
# print("AI: getting activations..")
225+
activations = self.model.forward_activations(x_transformed)
226+
# print("AI: got activations")
227+
228+
n_batch = x_transformed.shape[0]
229+
col_to = col + n_batch
230+
i_rows = [row]*n_batch
231+
i_cols = list(range(col,col_to))
232+
233+
transformations = self.dataset.get_transformations(i_rows,i_cols)
234+
235+
for i, layer_activations in enumerate(activations):
236+
if self.o.model_device != self.o.measure_device:
237+
layer_activations=layer_activations.to(self.o.measure_device,non_blocking=True)
238+
239+
141240

142-
for batch_i,x_transformed in enumerate(row_dataloader):
143-
# print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
144-
x_transformed = x_transformed.to(self.o.model_device,non_blocking=True)
145-
# print("AI: getting activations..")
146-
activations = self.model.forward_activations(x_transformed)
147-
# print("AI: got activations")
148-
col_to = col + x_transformed.shape[0]
149-
for i, layer_activations in enumerate(activations):
150-
if self.o.model_device != self.o.measure_device:
151-
layer_activations=layer_activations.to(self.o.measure_device,non_blocking=True)
152-
153-
154-
transformations = self.dataset.get_transformations(row, col, col_to)
155-
layer_activations = self.activations_transformer.transform(layer_activations, x_transformed,transformations)
156-
# print(f"AI: act it, shape {layer_activations.shape}")
157-
# print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
158-
tm.row_qs[layers[i]].put(layer_activations)
159-
# print(f"put {layer_activations.shape} into {layers[i]} {row_qs[layers[i]]}")
160-
# Check if there's been an exception
161-
if tm.stop:
162-
logger.info("Server thread stopping, exception detected")
163-
return
164-
col = col_to
165-
# print("AI: finished row")
166-
# print("AI: finished all rows")
241+
layer_activations = self.activations_transformer.transform(layer_activations, x_transformed,transformations)
242+
# print(f"AI: act it, shape {layer_activations.shape}")
243+
# print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
244+
tm.row_qs[layers[i]].put(layer_activations)
245+
# print(f"put {layer_activations.shape} into {layers[i]} {row_qs[layers[i]]}")
246+
# Check if there's been an exception
247+
if tm.stop:
248+
logger.info("Server thread stopping, exception detected")
249+
return
250+
col = col_to
251+
# print("AI: finished row")
252+
# print("AI: finished all rows")
167253

168254

169255
def evaluate(self, m: PyTorchLayerMeasure):

0 commit comments

Comments
 (0)