-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
239 lines (189 loc) · 7.03 KB
/
utils.py
File metadata and controls
239 lines (189 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Shared utilities for adversarial attack and text generation models."""
import base64
import hashlib
import json
import os
import random
from typing import Any, Dict, List, Union
import yaml
from omegaconf import OmegaConf
import wandb
from config_schema import MainConfig
def load_api_keys() -> Dict[str, List[str]]:
"""Load API keys from the api_keys file.
Returns:
Dict[str, List[str]]: Dictionary containing API keys for different models
Raises:
FileNotFoundError: If no api_keys file is found
"""
for ext in ["yaml", "yml", "json"]:
file_path = f"api_keys.{ext}"
if os.path.exists(file_path):
with open(file_path, "r") as f:
if ext in ["yaml", "yml"]:
return yaml.safe_load(f)
else:
return json.load(f)
raise FileNotFoundError(
"API keys file not found. Please create api_keys.yaml, api_keys.yml, or api_keys.json "
"in the root directory with your API keys."
)
def get_api_keys(model_name: str) -> List[str]:
"""Get all API keys for specified model.
Args:
model_name: Name of the model to get API keys for
Returns:
List[str]: List of all API keys for the specified model
Raises:
KeyError: If API keys for model are not found
"""
api_keys = load_api_keys()
if model_name not in api_keys:
raise KeyError(
f"API key for {model_name} not found in api_keys file. "
f"Available models: {list(api_keys.keys())}"
)
keys_list = api_keys[model_name]
if not keys_list:
raise ValueError(f"No API keys available for {model_name}")
return keys_list
def get_api_key_count(model_name: str) -> int:
"""Get the number of available API keys for a model.
Args:
model_name: Name of the model
Returns:
int: Number of available API keys
Raises:
KeyError: If API keys for model are not found
"""
return len(get_api_keys(model_name))
def hash_training_config(cfg: MainConfig) -> str:
"""Create a deterministic hash of training-relevant config parameters.
Args:
cfg: Configuration object containing model settings
Returns:
str: MD5 hash of the config parameters
"""
# sepcial case, use only in the rebuttal
if hasattr(cfg, "ata_noise"):
noise_val = cfg.ata_noise
return f'ata_noise_{noise_val}'
# Convert backbone list to plain Python list
if isinstance(cfg.model.backbone, (list, tuple)):
backbone = list(cfg.model.backbone)
else:
backbone = OmegaConf.to_container(cfg.model.backbone)
# Create config dict with converted values
train_config = {
"data": {
"batch_size": int(cfg.data.batch_size),
"num_samples": int(cfg.data.num_samples),
"cle_data_path": str(cfg.data.cle_data_path),
"tgt_data_path": str(cfg.data.tgt_data_path),
},
"optim": {
"alpha": float(cfg.optim.alpha),
"epsilon": int(cfg.optim.epsilon),
"steps": int(cfg.optim.steps),
"optimizer": str(cfg.optim.optimizer),
"momentum": float(cfg.optim.momentum),
"momentum_decay": float(cfg.optim.momentum_decay),
"align": str(cfg.optim.align),
"tm_idx": list(cfg.optim.tm_idx),
"beta": float(cfg.optim.beta),
"use_retrieval": bool(cfg.optim.use_retrieval),
},
"model": {
"input_res": int(cfg.model.input_res),
"crop_scale_1": tuple(float(x) for x in cfg.model.crop_scale_1),
"crop_scale_2": tuple(float(x) for x in cfg.model.crop_scale_2),
"crop_scale_3": tuple(float(x) for x in cfg.model.crop_scale_3),
"changer_iters": list(cfg.model.changer_iters),
"ensemble": bool(cfg.model.ensemble),
"backbone": backbone,
"target_num": int(cfg.model.target_num),
"target_crop": bool(cfg.model.target_crop),
},
"attack": cfg.attack,
}
# Conditionally add multi_pass_num if attack is multi-pass
if cfg.attack == "pgd_multi_pass": # Add other multi-pass attack types if needed
train_config["optim"]["multi_pass_num"] = int(cfg.optim.multi_pass_num)
# Convert to JSON string with sorted keys
json_str = json.dumps(train_config, sort_keys=True)
return hashlib.md5(json_str.encode()).hexdigest()
def setup_wandb(cfg: MainConfig, tags=None, name: str = None) -> None:
"""Initialize Weights & Biases logging.
Args:
cfg: Configuration object containing wandb settings
tags: Optional list of tags for the run.
name: Optional name for the run.
"""
config_dict = OmegaConf.to_container(cfg, resolve=True)
wandb.init(
project=cfg.wandb.project,
entity=cfg.wandb.entity,
config=config_dict,
tags=tags,
name=name,
reinit=True,
)
def encode_image(image_path: str) -> str:
"""Encode image file to base64 string.
Args:
image_path: Path to image file
Returns:
str: Base64 encoded image string
"""
with open(image_path, "rb") as image_file:
return base64.standard_b64encode(image_file.read()).decode("utf-8")
def ensure_dir(path: str) -> None:
"""Ensure directory exists, create if it doesn't.
Args:
path: Directory path to ensure exists
"""
os.makedirs(path, exist_ok=True)
def get_output_paths(cfg: MainConfig, config_hash: str) -> Dict[str, str]:
"""Get dictionary of output paths based on config.
Args:
cfg: Configuration object
config_hash: Hash of training config
Returns:
Dict[str, str]: Dictionary containing output paths
"""
return {
"output_dir": os.path.join(cfg.data.output, "img", config_hash),
"desc_output_dir": os.path.join(cfg.data.output, "description", config_hash),
}
# Create batches from a list
def create_batches(items, batch_size):
"""Split a list into batches of specified size.
Args:
items: List of items to split
batch_size: Size of each batch
Returns:
List of batches
"""
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
def log_metrics(pbar, metrics, img_index, epoch=None, log_wandb=True):
"""
Log metrics to progress bar and wandb.
Args:
pbar: tqdm progress bar to update
metrics: Dictionary of metrics to log
img_index: Index of the image (for wandb logging)
epoch: Optional epoch number for logging
log_wandb: Whether to log to wandb
"""
# Format metrics for progress bar
pbar_metrics = {
k: f"{v:.5f}" if "sim" in k else f"{v:.3f}" for k, v in metrics.items()
}
pbar.set_postfix(pbar_metrics)
# Prepare wandb metrics with image index
wandb_metrics = {f"img{img_index}_{k}": v for k, v in metrics.items()}
if epoch is not None:
wandb_metrics["epoch"] = epoch
# Log to wandb
if log_wandb:
wandb.log(wandb_metrics)