forked from tfwu/ordered-topk-attack-release
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsample_subset_of_pred.py
More file actions
127 lines (106 loc) · 3.79 KB
/
sample_subset_of_pred.py
File metadata and controls
127 lines (106 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import os
import json
import pickle
import random
from collections import defaultdict
def load_results(path):
if path.endswith(".pkl"):
with open(path, "rb") as f:
return pickle.load(f)
elif path.endswith(".json"):
with open(path, "r") as f:
return json.load(f)
else:
raise ValueError(f"Unsupported file format: {path}")
def get_filenames(results, use_correct_pred: bool = True):
return (
{entry["filename"]: entry for entry in results if entry["gt"] == entry["pred"]}
if use_correct_pred
else {
entry["filename"]: entry
for entry in results
if entry["gt"] != entry["pred"]
}
)
def intersect_filenames(models, use_correct_pred: bool = True):
filenames_per_model = [get_filenames(m, use_correct_pred) for m in models]
common_filenames = set.intersection(*(set(d.keys()) for d in filenames_per_model))
return common_filenames, filenames_per_model[0] # use any one as reference
def sample_topk_per_class(common_filenames, reference_dict, k, seed):
# Group filenames by gt class
class_to_fnames = defaultdict(list)
for fname in common_filenames:
gt = reference_dict[fname]["gt"]
class_to_fnames[gt].append(fname)
# Sample up to k per class
sampled = []
rng = random.Random(seed)
for cls, fnames in class_to_fnames.items():
if len(fnames) <= k:
sampled.extend(fnames)
else:
sampled.extend(rng.sample(fnames, k))
return sampled
def main():
parser = argparse.ArgumentParser(
description="Find common correctly classified samples across models."
)
parser.add_argument(
"--models",
nargs="+",
required=True,
help="model names (e.g., resnet50)",
)
parser.add_argument(
"--results-folder",
type=str,
required=True,
help="output folder where results are saved",
)
parser.add_argument(
"--results-tag", default="3-224-224-bicubic-1.0", type=str, help="results tag"
)
parser.add_argument(
"--k", type=int, default=1, help="Number of samples per class (default: 1)"
)
parser.add_argument(
"--use-incorrect-pred",
action="store_true",
help="Use images with incorrect pred, so we can change top-1 attack target to the gt",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
)
args = parser.parse_args()
use_correct_pred = not args.use_incorrect_pred
save_filename = "+".join(m for m in args.models)
save_filename = os.path.join(
args.results_folder,
f"{args.seed}-{args.k}-{use_correct_pred}-{save_filename}-{args.results_tag}.json",
)
if os.path.exists(save_filename):
print(f"[INFO] |-- This job is done already, check {save_filename}")
return
result_files = [
os.path.join(args.results_folder, f"{model}-{args.results_tag}.pkl")
for model in args.models
]
for f in result_files:
assert os.path.exists(f), f"Not found {f}"
model_results = [load_results(path) for path in result_files]
common_filenames, reference_dict = intersect_filenames(
model_results, use_correct_pred
)
print(
f"[INFO] Found {len(common_filenames)} common predicted images across all models."
)
sampled_filenames = sample_topk_per_class(
common_filenames, reference_dict, args.k, args.seed
)
print(f"[INFO] Sampled {len(sampled_filenames)} images ({args.k} per class max).")
with open(save_filename, "w") as f:
json.dump(sorted(sampled_filenames), f, indent=2)
print(f"[INFO] Saved sampled filenames to: {save_filename}")
if __name__ == "__main__":
main()