Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions astra/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from astra.torch.al.acquisitions.base import RandomAcquisition
from astra.torch.al.acquisitions.furthest import Furthest
2 changes: 2 additions & 0 deletions astra/torch/al/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@

# Acquisition functions
from astra.torch.al.acquisitions.uniform_random import UniformRandomAcquisition
from astra.torch.al.acquisitions.furthest import Furthest
from astra.torch.al.acquisitions.centroid import Centroid
35 changes: 35 additions & 0 deletions astra/torch/al/acquisitions/centroid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from astra.torch.al.acquisitions.base import DiversityAcquisition


class Centroid(DiversityAcquisition):
def acquire_scores(self, labeled_embeddings, unlabeled_embeddings, n):
"""
Parameters
----------
labeled_embeddings: tensor([n_train, embedding_size])
Embedding of the train data
unlabeled_embeddings: tensor([n_pool, embedding_size])
Embedding of the pool data
n: int
Number of data points to be selected
Returns:
----------
idxs: list
List of selected data point indices with respect to unlabeled_embeddings
"""
if labeled_embeddings.shape[0] == 0:
min_dist = torch.full((unlabeled_embeddings.shape[0],), float("inf"))
else:
centroid_embedding = torch.mean(labeled_embeddings, dim=0).unsqueeze(0)
dist_ctr = torch.cdist(unlabeled_embeddings, centroid_embedding, p=2)
min_dist = torch.min(dist_ctr, dim=1)[0]
idxs = []
for i in range(n):
idx = torch.argmax(min_dist)
idxs.append(idx.item())
dist_new_ctr = torch.cdist(
unlabeled_embeddings, unlabeled_embeddings[[idx], :]
)
min_dist = torch.minimum(min_dist, dist_new_ctr[:, 0])
return idxs
34 changes: 34 additions & 0 deletions astra/torch/al/acquisitions/furthest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from astra.torch.al.acquisitions.base import DiversityAcquisition


class Furthest(DiversityAcquisition):
def acquire_scores(self, labeled_embeddings, unlabeled_embeddings, n):
"""
Parameters
----------
labeled_embeddings: tensor([n_train, embedding_size])
Embedding of the train data
unlabeled_embeddings: tensor([n_pool, embedding_size])
Embedding of the pool data
n: int
Number of data points to be selected
Returns:
----------
idxs: list
List of selected data point indices with respect to unlabeled_embeddings
"""
if labeled_embeddings.shape[0] == 0:
min_dist = torch.full((unlabeled_embeddings.shape[0],), float("inf"))
else:
dist_ctr = torch.cdist(unlabeled_embeddings, labeled_embeddings, p=2)
min_dist = torch.min(dist_ctr, dim=1)[0]
idxs = []
for i in range(n):
idx = torch.argmax(min_dist)
idxs.append(idx.item())
dist_new_ctr = torch.cdist(
unlabeled_embeddings, unlabeled_embeddings[[idx], :]
)
min_dist = torch.minimum(min_dist, dist_new_ctr[:, 0])
return idxs
41 changes: 16 additions & 25 deletions astra/torch/al/strategies/diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def query(
n_mc_samples: int = None,
batch_size: int = None,
) -> Dict[str, torch.Tensor]:
"""
"""Diversity query strategy with multiple neural networks

Args:
net: A neural network to extract features.
pool_indices: The indices of the pool set.
Expand All @@ -50,18 +51,17 @@ def query(
Returns:
best_indices: A dictionary of acquisition names and the corresponding best indices.
"""
assert isinstance(pool_indices, torch.Tensor), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"
assert isinstance(
pool_indices, torch.Tensor
), f"pool_indices must be a torch.Tensor, got {type(pool_indices)}"
assert isinstance(
context_indices, torch.Tensor
), f"context_indices must be a torch.Tensor, got {type(context_indices)}"

if batch_size is None:
batch_size = len(pool_indices)

data_loader = DataLoader(self.dataset)

# put model on eval mode
net.eval()
data_loader = DataLoader(self.dataset, batch_size=batch_size)

with torch.no_grad():
# Get all features
Expand All @@ -70,26 +70,17 @@ def query(
features = net(x)
features_list.append(features)
features = torch.cat(features_list, dim=0) # (data_dim, feature_dim)
# Get the features for the pool
pool_features = features[pool_indices] # (pool_dim, feature_dim)

best_indices = {}

if not isinstance(pool_indices, list):
# tolist() works for both numpy and torch tensors. It also work for tensors on GPU.
pool_indices = pool_indices.tolist()
if not isinstance(context_indices, list):
context_indices = context_indices.tolist()
# Get the features for the context
context_features = features[context_indices] # (context_dim, feature_dim)

best_indices = {}
for acq_name, acquisition in self.acquisitions.items():
selected_indices = []
# TODO: We can make this loop faster by computing scores only for updated indices. There can be a method in acquisition to update the scores.
for _ in range(n_query_samples):
scores = acquisition.acquire_scores(features, pool_indices, context_indices)
index = torch.argmax(scores)
selected_index = pool_indices[index]
selected_indices.append(selected_index)
pool_indices = torch.cat([pool_indices[:index], pool_indices[index + 1 :]])
context_indices = torch.cat([context_indices, selected_index])
selected_indices = torch.tensor(selected_indices, device=self.device)
best_indices[acq_name] = selected_indices

selected_indices = acquisition.acquire_scores(
context_features, pool_features, n_query_samples
)
selected_indices = torch.tensor(selected_indices)
best_indices[acq_name] = pool_indices[selected_indices]
return best_indices
1 change: 1 addition & 0 deletions notebooks/al/accuracy_AL_list.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"train_1000_pool_query_100_iter_50_AL_seed_0": [{"train": 1.0, "pool": 0.359743595123291, "test": 0.36239999532699585}, {"train": 1.0, "pool": 0.3673521876335144, "test": 0.36344999074935913}, {"train": 1.0, "pool": 0.38229382038116455, "test": 0.38119998574256897}, {"train": 0.9961538314819336, "pool": 0.3655813932418823, "test": 0.36739999055862427}, {"train": 1.0, "pool": 0.37953367829322815, "test": 0.3763499855995178}, {"train": 1.0, "pool": 0.3845194876194, "test": 0.3882499933242798}, {"train": 1.0, "pool": 0.3876822888851166, "test": 0.392549991607666}, {"train": 0.9999999403953552, "pool": 0.39339426159858704, "test": 0.38964998722076416}, {"train": 0.9816666841506958, "pool": 0.38900521397590637, "test": 0.3919000029563904}, {"train": 0.9999999403953552, "pool": 0.39257216453552246, "test": 0.39739999175071716}, {"train": 1.0, "pool": 0.3982894718647003, "test": 0.3962499797344208}, {"train": 0.9995238184928894, "pool": 0.38957783579826355, "test": 0.3977999985218048}, {"train": 0.9995454549789429, "pool": 0.40761905908584595, "test": 0.4105999767780304}, {"train": 0.994347870349884, "pool": 0.39708220958709717, "test": 0.40004998445510864}, {"train": 1.0, "pool": 0.4099467992782593, "test": 0.41589999198913574}, {"train": 0.9967999458312988, "pool": 0.4129333198070526, "test": 0.41574999690055847}, {"train": 0.9961538314819336, "pool": 0.4021390378475189, "test": 0.412200003862381}, {"train": 0.9937037229537964, "pool": 0.389249324798584, "test": 0.4004499912261963}, {"train": 1.0, "pool": 0.4118548333644867, "test": 0.4205999970436096}, {"train": 0.9993103742599487, "pool": 0.409191370010376, "test": 0.40904998779296875}, {"train": 0.9996666312217712, "pool": 0.41572973132133484, "test": 0.41574999690055847}, {"train": 0.9935484528541565, "pool": 0.41850948333740234, "test": 0.4250499904155731}, {"train": 0.9899999499320984, "pool": 0.4024728238582611, "test": 0.41304999589920044}, {"train": 0.9984848499298096, "pool": 0.43013623356819153, "test": 0.42944997549057007}, {"train": 0.9955881834030151, "pool": 0.4202459156513214, "test": 0.42890000343322754}, {"train": 0.9971428513526917, "pool": 0.42136985063552856, "test": 0.428849995136261}, {"train": 0.9888889193534851, "pool": 0.42178571224212646, "test": 0.4289499819278717}, {"train": 1.0, "pool": 0.44046831130981445, "test": 0.44699999690055847}, {"train": 0.9855262637138367, "pool": 0.4252209961414337, "test": 0.42969998717308044}, {"train": 0.9971794486045837, "pool": 0.42786702513694763, "test": 0.4344500005245209}, {"train": 0.9882500171661377, "pool": 0.41972222924232483, "test": 0.42674997448921204}, {"train": 0.9992682933807373, "pool": 0.433426171541214, "test": 0.43764999508857727}, {"train": 0.9933333396911621, "pool": 0.43706706166267395, "test": 0.4499000012874603}, {"train": 0.9986046552658081, "pool": 0.4481792747974396, "test": 0.45559999346733093}, {"train": 0.9931818246841431, "pool": 0.437303364276886, "test": 0.4438999891281128}, {"train": 1.0, "pool": 0.45529577136039734, "test": 0.45934998989105225}, {"train": 0.973695695400238, "pool": 0.4378530979156494, "test": 0.4435499906539917}, {"train": 0.9987233877182007, "pool": 0.44028326869010925, "test": 0.4486999809741974}, {"train": 0.9927083849906921, "pool": 0.4510795474052429, "test": 0.4575499892234802}, {"train": 0.9991836547851562, "pool": 0.4533333480358124, "test": 0.46414998173713684}, {"train": 0.9979999661445618, "pool": 0.45160001516342163, "test": 0.45739999413490295}, {"train": 0.998235285282135, "pool": 0.4595129191875458, "test": 0.47384998202323914}, {"train": 0.989807665348053, "pool": 0.45428162813186646, "test": 0.4681999981403351}, {"train": 0.994717001914978, "pool": 0.4530547559261322, "test": 0.4634999930858612}, {"train": 0.995555579662323, "pool": 0.46179190278053284, "test": 0.4684999883174896}, {"train": 0.9985454082489014, "pool": 0.4749855101108551, "test": 0.4799499809741974}, {"train": 0.9901785850524902, "pool": 0.45107558369636536, "test": 0.4607999920845032}, {"train": 0.9889473915100098, "pool": 0.46157434582710266, "test": 0.4737499952316284}, {"train": 0.9979310631752014, "pool": 0.4727485179901123, "test": 0.48429998755455017}, {"train": 0.9971186518669128, "pool": 0.48741933703422546, "test": 0.4948499798774719}, {"train": 0.9871666431427002, "pool": 0.4876176714897156, "test": 0.4944499731063843}], "train_1000_pool_query_100_iter_50_AL_seed_1": [{"train": 1.0, "pool": 0.3570256531238556, "test": 0.35679998993873596}, {"train": 0.9981818199157715, "pool": 0.35516709089279175, "test": 0.3454499840736389}, {"train": 0.9991666674613953, "pool": 0.3752061724662781, "test": 0.37974998354911804}, {"train": 0.9915384650230408, "pool": 0.3546253442764282, "test": 0.3507999777793884}, {"train": 1.0, "pool": 0.37826424837112427, "test": 0.37884998321533203}, {"train": 1.0, "pool": 0.38332468271255493, "test": 0.37814998626708984}, {"train": 0.9918749928474426, "pool": 0.3650781214237213, "test": 0.3571999967098236}, {"train": 0.9994117021560669, "pool": 0.3855091333389282, "test": 0.38609999418258667}, {"train": 0.9711111187934875, "pool": 0.37112563848495483, "test": 0.3728500008583069}, {"train": 0.9915788769721985, "pool": 0.3621784746646881, "test": 0.36774998903274536}, {"train": 0.955500066280365, "pool": 0.3879210352897644, "test": 0.3815999925136566}, {"train": 0.9904761910438538, "pool": 0.37335091829299927, "test": 0.37070000171661377}, {"train": 0.9981818199157715, "pool": 0.40317460894584656, "test": 0.4027499854564667}, {"train": 0.9947826266288757, "pool": 0.3878779709339142, "test": 0.3884499967098236}, {"train": 0.9987500309944153, "pool": 0.4042021334171295, "test": 0.41110000014305115}, {"train": 0.9991999864578247, "pool": 0.4037066698074341, "test": 0.399649977684021}, {"train": 1.0, "pool": 0.42195186018943787, "test": 0.41714999079704285}, {"train": 0.9870370626449585, "pool": 0.4088471829891205, "test": 0.4018999934196472}, {"train": 0.9989285469055176, "pool": 0.39505377411842346, "test": 0.3941499888896942}, {"train": 1.0, "pool": 0.40404313802719116, "test": 0.40264999866485596}, {"train": 1.0, "pool": 0.4173513352870941, "test": 0.4183499813079834}, {"train": 1.0, "pool": 0.42173442244529724, "test": 0.41964998841285706}, {"train": 0.99281245470047, "pool": 0.38383153080940247, "test": 0.3866499960422516}, {"train": 1.0, "pool": 0.4314441382884979, "test": 0.4320499897003174}, {"train": 0.9997058510780334, "pool": 0.42314207553863525, "test": 0.42170000076293945}, {"train": 0.9997142553329468, "pool": 0.43871232867240906, "test": 0.44224998354911804}, {"train": 1.0, "pool": 0.4419780373573303, "test": 0.4456000030040741}, {"train": 1.0, "pool": 0.43920108675956726, "test": 0.4417499899864197}, {"train": 0.9923684000968933, "pool": 0.42563536763191223, "test": 0.4275999963283539}, {"train": 0.9984614849090576, "pool": 0.4365096986293793, "test": 0.43469998240470886}, {"train": 0.9807500243186951, "pool": 0.421277791261673, "test": 0.42229998111724854}, {"train": 0.988048791885376, "pool": 0.4245125353336334, "test": 0.431799978017807}, {"train": 0.9952380657196045, "pool": 0.4323742985725403, "test": 0.4402499794960022}, {"train": 0.9995349049568176, "pool": 0.43577033281326294, "test": 0.44154998660087585}, {"train": 0.9900000095367432, "pool": 0.4253370761871338, "test": 0.430899977684021}, {"train": 1.0, "pool": 0.4522816836833954, "test": 0.4582499861717224}, {"train": 0.9930434823036194, "pool": 0.4439830482006073, "test": 0.4453499913215637}, {"train": 0.9978722929954529, "pool": 0.45682719349861145, "test": 0.4556500017642975}, {"train": 0.9904167056083679, "pool": 0.4419602155685425, "test": 0.4412999749183655}, {"train": 0.9861224293708801, "pool": 0.4492877423763275, "test": 0.45274999737739563}, {"train": 0.9833999872207642, "pool": 0.4604857265949249, "test": 0.4638499915599823}, {"train": 0.996274471282959, "pool": 0.446074515581131, "test": 0.4492499828338623}, {"train": 0.9986538290977478, "pool": 0.4528448283672333, "test": 0.4557499885559082}, {"train": 0.9941509366035461, "pool": 0.44780978560447693, "test": 0.4483499825000763}, {"train": 0.9968518614768982, "pool": 0.45375722646713257, "test": 0.45524999499320984}, {"train": 0.9912726879119873, "pool": 0.4563768208026886, "test": 0.45879998803138733}, {"train": 0.9969642758369446, "pool": 0.46017444133758545, "test": 0.461249977350235}, {"train": 0.9977193474769592, "pool": 0.4575510025024414, "test": 0.46334999799728394}, {"train": 0.9898276329040527, "pool": 0.4673684239387512, "test": 0.47540000081062317}, {"train": 0.9998304843902588, "pool": 0.4774486720561981, "test": 0.48144999146461487}, {"train": 0.9906666278839111, "pool": 0.47747060656547546, "test": 0.48124998807907104}], "train_1000_pool_query_100_iter_50_AL_seed_2": [{"train": 1.0, "pool": 0.3610256314277649, "test": 0.3632499873638153}, {"train": 1.0, "pool": 0.36318764090538025, "test": 0.36134999990463257}, {"train": 0.9908333420753479, "pool": 0.3523195683956146, "test": 0.3562999963760376}, {"train": 0.9992307424545288, "pool": 0.35702842473983765, "test": 0.361299991607666}, {"train": 1.0, "pool": 0.3882383406162262, "test": 0.38670000433921814}, {"train": 1.0, "pool": 0.3858701288700104, "test": 0.3887999951839447}, {"train": 1.0, "pool": 0.3786458373069763, "test": 0.382099986076355}, {"train": 0.9994117021560669, "pool": 0.3812010586261749, "test": 0.3909499943256378}, {"train": 0.9294444918632507, "pool": 0.35374343395233154, "test": 0.35944998264312744}, {"train": 0.9999999403953552, "pool": 0.39314958453178406, "test": 0.3980499804019928}, {"train": 1.0, "pool": 0.3911578953266144, "test": 0.39729997515678406}, {"train": 0.9904761910438538, "pool": 0.37799471616744995, "test": 0.3786499798297882}, {"train": 0.9972727298736572, "pool": 0.3903968334197998, "test": 0.39569997787475586}, {"train": 1.0, "pool": 0.4017506539821625, "test": 0.413100004196167}, {"train": 0.9979166984558105, "pool": 0.3963829576969147, "test": 0.40779998898506165}, {"train": 0.9995999932289124, "pool": 0.401226669549942, "test": 0.40709999203681946}, {"train": 0.998846173286438, "pool": 0.4005882143974304, "test": 0.4090999960899353}, {"train": 0.9985185265541077, "pool": 0.4027882218360901, "test": 0.4125500023365021}, {"train": 0.9996428489685059, "pool": 0.4031451642513275, "test": 0.41189998388290405}, {"train": 1.0, "pool": 0.4231536388397217, "test": 0.4311999976634979}, {"train": 0.9889999628067017, "pool": 0.4052162170410156, "test": 0.4063499867916107}, {"train": 0.9677419662475586, "pool": 0.3903523087501526, "test": 0.400549978017807}, {"train": 1.0, "pool": 0.42682066559791565, "test": 0.4373999834060669}, {"train": 0.9796969890594482, "pool": 0.4097275137901306, "test": 0.41484999656677246}, {"train": 0.99294114112854, "pool": 0.415245920419693, "test": 0.4190499782562256}, {"train": 0.9851428270339966, "pool": 0.41558903455734253, "test": 0.4227999746799469}, {"train": 0.9975000023841858, "pool": 0.42230769991874695, "test": 0.43629997968673706}, {"train": 0.9921621680259705, "pool": 0.427465558052063, "test": 0.43514999747276306}, {"train": 0.9978947043418884, "pool": 0.42154696583747864, "test": 0.43324998021125793}, {"train": 0.9999999403953552, "pool": 0.44105264544487, "test": 0.4493999779224396}, {"train": 0.9965000748634338, "pool": 0.42738890647888184, "test": 0.4347499907016754}, {"train": 0.9629268050193787, "pool": 0.42345401644706726, "test": 0.44235000014305115}, {"train": 0.9950000047683716, "pool": 0.4265642464160919, "test": 0.4380999803543091}, {"train": 0.9965116381645203, "pool": 0.4352381229400635, "test": 0.446399986743927}, {"train": 0.9988636374473572, "pool": 0.44651684165000916, "test": 0.4585999846458435}, {"train": 0.9948889017105103, "pool": 0.4376901388168335, "test": 0.44589999318122864}, {"train": 0.980434775352478, "pool": 0.43206214904785156, "test": 0.4382999837398529}, {"train": 0.9902127385139465, "pool": 0.43067988753318787, "test": 0.44259998202323914}, {"train": 0.9945833683013916, "pool": 0.43667614459991455, "test": 0.45034998655319214}, {"train": 0.9881632924079895, "pool": 0.4238176643848419, "test": 0.43209999799728394}, {"train": 0.9959999918937683, "pool": 0.44374287128448486, "test": 0.45409998297691345}, {"train": 0.9945098161697388, "pool": 0.4448997378349304, "test": 0.4566499888896942}, {"train": 0.9830769300460815, "pool": 0.42916667461395264, "test": 0.4426499903202057}, {"train": 0.9833962321281433, "pool": 0.4487896263599396, "test": 0.4575999975204468}, {"train": 0.9927777647972107, "pool": 0.45554912090301514, "test": 0.4675999879837036}, {"train": 0.986727237701416, "pool": 0.44171014428138733, "test": 0.46164998412132263}, {"train": 0.9887499809265137, "pool": 0.4571221172809601, "test": 0.46379998326301575}, {"train": 0.9954386353492737, "pool": 0.4502623677253723, "test": 0.46239998936653137}, {"train": 0.9960345029830933, "pool": 0.4535380005836487, "test": 0.46619999408721924}, {"train": 1.0, "pool": 0.47096773982048035, "test": 0.4897499978542328}, {"train": 0.9889999628067017, "pool": 0.47044119238853455, "test": 0.48864999413490295}]}
1 change: 1 addition & 0 deletions notebooks/al/accuracy_AL_list_diversity.json

Large diffs are not rendered by default.

Loading