1+ from collections .abc import Generator
2+ import typing
3+
4+ from tmeasures .pytorch .transformations import PyTorchTransformation
15from .dataset2d import STDataset , Dataset2D
26import torch
37from torch .utils .data import DataLoader
2529class ActivationsTransformer (abc .ABC ):
2630
2731 @abc .abstractmethod
28- def transform (self , activations : torch .Tensor , x : torch .Tensor , transformations : List [Transformation ]) -> torch .Tensor :
32+ def transform (self , activations : torch .Tensor , x : torch .Tensor , transformations : List [PyTorchTransformation ]) -> torch .Tensor :
2933 pass
3034
3135
3236class IdentityActivationsTransformer (ActivationsTransformer ):
33- def transform (self , activations : torch .Tensor , x : torch .Tensor , transformations : List [Transformation ]) -> torch .Tensor :
37+ def transform (self , activations : torch .Tensor , x : torch .Tensor , transformations : List [PyTorchTransformation ]) -> torch .Tensor :
3438 return activations
3539
3640from tmeasures import logger
@@ -44,6 +48,7 @@ def __init__(self,layers:list[str],rows:int,n_batch:int,stop=False) -> None:
4448 self .layers = layers
4549 self .qs = {l : IterableQueue (rows ,maxsize = 1 ,name = f"q({ l } )" ) for l in layers }
4650 self .row_qs = {l : IterableQueue (n_batch ,maxsize = 1 ,name = f"q({ l } _row)" ) for l in layers }
51+
4752 @property
4853 def queues (self ):
4954 return list (self .qs .values ())+ list (self .row_qs .values ())
@@ -115,55 +120,136 @@ def check_finished(self,worker_futures,server_future,tm:ThreadsManager):
115120 if not e is None :
116121 logger .info (f"Worker exception, about to re raise from main thread\n { e } \n thread id { threading .get_ident ()} \n " )
117122 raise e
118-
123+
124+ def move_activations_to_measure_device (self ,activations :list [torch .Tensor ]):
125+ for i , layer_activations in enumerate (activations ):
126+ if self .o .model_device != self .o .measure_device :
127+ layer_activations = layer_activations .to (self .o .measure_device ,non_blocking = True )
128+
129+ def transform_activations (self ,activations :list [torch .Tensor ],x_transformed ,transformations )-> list [torch .Tensor ]:
130+ for i , layer_activations in enumerate (activations ):
131+ activations [i ] = self .activations_transformer .transform (layer_activations , x_transformed ,transformations )
132+
133+ @torch .no_grad
134+ def feed_threads2 (self ,tm :ThreadsManager ):
135+ layers = self .model .activation_names ()
136+ rows , cols = self .dataset .len0 , self .dataset .len1
137+
138+ # print(f"act it starting,num workers {self.o.num_workers}:")
139+ dataloader = DataLoader (self .dataset , batch_size = self .o .batch_size , shuffle = False , num_workers = self .o .num_workers ,pin_memory = True )
140+ i = 0
141+
142+ for row in range (rows ):
143+
144+ for k , q in tm .qs .items ():
145+ logger .info (f"AI: putting row { row } dataloader for layer { k } " )
146+ q .put (tm .row_qs [k ])
147+
148+ # print(f"AI: finished putting row {row} dataloaders for all layers")
149+ # for k,q in qs.items():
150+ # print(f"AI: {k}→ {q.queue.qsize()} items")
151+ if tm .stop :
152+ logger .info ("Server thread stopping, exception detected" )
153+ return
154+ col = 0
155+ # print("col",col)
156+ for batch_i ,x_transformed in tqdm .tqdm (enumerate (dataloader ), disable = not self .o .verbose , leave = False ):
157+ sample_i_start = batch_i * self .o .batch_size
158+ i_samples = [self .dataset .d1tod2 (i ) for i in range (sample_i_start ,sample_i_start + self .o .batch_size )]
159+ i_rows , i_cols = typing .cast (tuple [list [int ],list [int ]], zip (* i_samples ))
160+ # print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
161+ x_transformed = x_transformed .to (self .o .model_device ,non_blocking = True )
162+ # print("AI: getting activations..")
163+ activations = self .model .forward_activations (x_transformed )
164+ # print("AI: got activations")
165+ transformations = self .dataset .get_transformations (i_rows ,i_cols )
166+ col_to = col + x_transformed .shape [0 ]
167+ # Move acti
168+ self .move_activations_to_measure_device (activations )
169+ activations = self .transform_activations (activations ,x_transformed ,transformations )
170+ if tm .stop :
171+ logger .info ("Server thread stopping, exception detected" )
172+ return
173+
174+ # print(f"AI: act it, shape {layer_activations.shape}")
175+ # print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
176+ for row , row_activations in self .split_row_activations (activations ,i_rows ):
177+ for i ,layer_activations in enumerate (row_activations ):
178+ tm .row_qs [layers [i ]].put (layer_activations )
179+
180+ # print("AI: finished row")
181+ # print("AI: finished all rows")
182+
183+ def split_row_activations (self ,activations :list [torch .Tensor ],i_rows :list [int ])-> Generator [tuple [int ,list [torch .Tensor ]]]:
184+ all_rows = list (range (min (i_rows ),max (i_rows )+ 1 ))
185+ start = 0
186+ last = all_rows [- 1 ]
187+ for current_row in all_rows :
188+ if current_row == last :
189+ end = len (i_rows )+ 1
190+ else :
191+ end = i_rows .index (current_row + 1 )
192+
193+ activations_row = [a [start :end ,] for a in activations ]
194+ start = end + 1
195+ yield current_row ,activations_row
196+
197+
198+ @torch .no_grad
119199 def feed_threads (self ,tm :ThreadsManager ):
120- layers = self .model .activation_names ()
121- rows , cols = self .dataset .len0 , self .dataset .len1
122-
123- with torch .no_grad ():
124- # print(f"act it starting,num workers {self.o.num_workers}:")
125- for row in tqdm .trange (rows , disable = not self .o .verbose , leave = False ):
126- row_dataset = self .dataset .row_dataset (row )
127- row_dataloader = DataLoader (row_dataset , batch_size = self .o .batch_size , shuffle = False , num_workers = 0 ,pin_memory = True )
128-
129- for k , q in tm .qs .items ():
130- logger .info (f"AI: putting row { row } dataloader for layer { k } " )
131- q .put (tm .row_qs [k ])
200+ layers = self .model .activation_names ()
201+ rows , cols = self .dataset .len0 , self .dataset .len1
132202
133- # print(f"AI: finished putting row {row} dataloaders for all layers")
134- # for k,q in qs.items():
135- # print(f"AI: {k}→ {q.queue.qsize()} items")
136- if tm .stop :
137- logger .info ("Server thread stopping, exception detected" )
138- return
139- col = 0
140- # print("col",col)
203+ # print(f"act it starting,num workers {self.o.num_workers}:")
204+ for row in tqdm .trange (rows , disable = not self .o .verbose , leave = False ):
205+ row_dataset = self .dataset .row_dataset (row )
206+ row_dataloader = DataLoader (row_dataset , batch_size = self .o .batch_size , shuffle = False , num_workers = 0 ,pin_memory = True )
207+
208+ for k , q in tm .qs .items ():
209+ logger .info (f"AI: putting row { row } dataloader for layer { k } " )
210+ q .put (tm .row_qs [k ])
211+
212+ # print(f"AI: finished putting row {row} dataloaders for all layers")
213+ # for k,q in qs.items():
214+ # print(f"AI: {k}→ {q.queue.qsize()} items")
215+ if tm .stop :
216+ logger .info ("Server thread stopping, exception detected" )
217+ return
218+ col = 0
219+ # print("col",col)
220+
221+ for batch_i ,x_transformed in enumerate (row_dataloader ):
222+ # print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
223+ x_transformed = x_transformed .to (self .o .model_device ,non_blocking = True )
224+ # print("AI: getting activations..")
225+ activations = self .model .forward_activations (x_transformed )
226+ # print("AI: got activations")
227+
228+ n_batch = x_transformed .shape [0 ]
229+ col_to = col + n_batch
230+ i_rows = [row ]* n_batch
231+ i_cols = list (range (col ,col_to ))
232+
233+ transformations = self .dataset .get_transformations (i_rows ,i_cols )
234+
235+ for i , layer_activations in enumerate (activations ):
236+ if self .o .model_device != self .o .measure_device :
237+ layer_activations = layer_activations .to (self .o .measure_device ,non_blocking = True )
238+
239+
141240
142- for batch_i ,x_transformed in enumerate (row_dataloader ):
143- # print(f"AI: {batch_i}: moving to device {self.o.model_device}... ")
144- x_transformed = x_transformed .to (self .o .model_device ,non_blocking = True )
145- # print("AI: getting activations..")
146- activations = self .model .forward_activations (x_transformed )
147- # print("AI: got activations")
148- col_to = col + x_transformed .shape [0 ]
149- for i , layer_activations in enumerate (activations ):
150- if self .o .model_device != self .o .measure_device :
151- layer_activations = layer_activations .to (self .o .measure_device ,non_blocking = True )
152-
153-
154- transformations = self .dataset .get_transformations (row , col , col_to )
155- layer_activations = self .activations_transformer .transform (layer_activations , x_transformed ,transformations )
156- # print(f"AI: act it, shape {layer_activations.shape}")
157- # print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
158- tm .row_qs [layers [i ]].put (layer_activations )
159- # print(f"put {layer_activations.shape} into {layers[i]} {row_qs[layers[i]]}")
160- # Check if there's been an exception
161- if tm .stop :
162- logger .info ("Server thread stopping, exception detected" )
163- return
164- col = col_to
165- # print("AI: finished row")
166- # print("AI: finished all rows")
241+ layer_activations = self .activations_transformer .transform (layer_activations , x_transformed ,transformations )
242+ # print(f"AI: act it, shape {layer_activations.shape}")
243+ # print(f"AI: putting col {col} batch for layer {i} ({layers[i]})")
244+ tm .row_qs [layers [i ]].put (layer_activations )
245+ # print(f"put {layer_activations.shape} into {layers[i]} {row_qs[layers[i]]}")
246+ # Check if there's been an exception
247+ if tm .stop :
248+ logger .info ("Server thread stopping, exception detected" )
249+ return
250+ col = col_to
251+ # print("AI: finished row")
252+ # print("AI: finished all rows")
167253
168254
169255 def evaluate (self , m : PyTorchLayerMeasure ):
0 commit comments