-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcompcon_utils.py
More file actions
198 lines (155 loc) · 8.07 KB
/
compcon_utils.py
File metadata and controls
198 lines (155 loc) · 8.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import random
import re
import time
from typing import Dict, List
import matplotlib.pyplot as plt
import numpy as np
import torch
import weave
import wandb
def timeit(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
execution_time = end_time - start_time
wandb.summary[f"{func.__name__}_execution_time"] = execution_time
return result
return wrapper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@timeit
def get_cosine_similarity_batch(
batch_embeddings: np.ndarray,
single_embedding: np.ndarray,
) -> np.ndarray:
"""Cosine similarity of each row in batch vs one vector.
Expects: batch shape (N, D), single shape (D,). Accepts numpy or torch. Returns (N,)."""
be = (
batch_embeddings.detach().cpu().numpy()
if torch.is_tensor(batch_embeddings)
else np.asarray(batch_embeddings)
)
se = (
single_embedding.detach().cpu().numpy()
if torch.is_tensor(single_embedding)
else np.asarray(single_embedding)
)
be = be.reshape(-1, be.shape[-1])
se = se.reshape(-1)
denom = np.linalg.norm(be, axis=1) * np.linalg.norm(se)
return ((be @ se) / denom).astype(np.float32)
def get_cosine_similarity_batch2(embeddings1, embeddings2):
"""
Calculate the cosine similarity between one embedding and a batch of embeddings.
"""
# Ensure embeddings are on the same device as the model
embeddings1 = torch.Tensor(embeddings1).to(device)
embeddings2 = torch.Tensor(embeddings2).to(device)
# Normalize embeddings
embeddings1_norm = torch.nn.functional.normalize(embeddings1, p=2, dim=0)
embeddings2_norm = torch.nn.functional.normalize(embeddings2, p=2, dim=1)
# Compute cosine similarity
cos_sims = torch.mm(embeddings1_norm.unsqueeze(0), embeddings2_norm.T).squeeze(0).cpu().numpy()
return cos_sims
@timeit
def plot_prompts(prompts: List[Dict], attribute: str, num_prompts: int, caption: str):
"""
Plots the prompts with respect to the attribute, including cosine similarity information.
"""
num_prompts = min(num_prompts, len(prompts))
prompts_to_plot = random.sample(prompts, num_prompts)
# print(prompts_to_plot)
fig, axs = plt.subplots(num_prompts, len(prompts[0]["models"]), figsize=(20, 20))
if num_prompts == 1:
axs = [axs] # Make axs a list for consistent indexing
for i, prompt in enumerate(prompts_to_plot):
for j, (model, path) in enumerate(zip(prompt["models"], prompt["paths"])):
img = plt.imread(path)
axs[i][j].imshow(img)
axs[i][j].set_title(f"Model: {model}\nCos Sim: {prompt['cos_sims'][j]:.3f}")
axs[i][j].axis("off")
axs[i][0].set_ylabel(f"Prompt: {prompt['prompt']}...")
plt.tight_layout()
wandb.log({f"{caption} Prompts - {attribute}": wandb.Image(fig)})
plt.close(fig)
# def parse_string(input_string):
# # Split the input string into thought process, description, and prompts
# parts = input_string.split("Description:")
# thought_process = parts[0].replace("Thought Process:", "").strip()
# remaining = parts[1] if len(parts) > 1 else ""
# description_and_prompts = remaining.split("New Prompts:")
# description = description_and_prompts[0].strip()
# prompts_text = description_and_prompts[1] if len(description_and_prompts) > 1 else ""
# prompts = re.findall(r"\d+\.\s*(.*?)(?=\n\d+\.|\Z)", prompts_text, re.DOTALL)
# # if any of the prompts contain a new line and there is not a new prompt number, remove the new line and any following text
# prompts = [re.sub(r"\n.*", "", prompt) for prompt in prompts]
# prompts = [prompt.strip() for prompt in prompts]
# return thought_process, description, prompts
def parse_string(input_string):
# Split the input string into thought process, description, and the rest
parts = input_string.split("Description:")
thought_process = parts[0].replace("Thought Process:", "").strip()
remaining = parts[1] if len(parts) > 1 else ""
description_and_rest = remaining.split("Key Concepts:")
description = description_and_rest[0].strip()
key_concepts_and_prompts = description_and_rest[1] if len(description_and_rest) > 1 else ""
key_concepts_and_prompts_parts = key_concepts_and_prompts.split("New Prompts:")
key_concepts = key_concepts_and_prompts_parts[0].strip().strip("[]")
prompts_text = (
key_concepts_and_prompts_parts[1] if len(key_concepts_and_prompts_parts) > 1 else ""
)
prompts = re.findall(r"\d+\.\s*(.*?)(?=\n\d+\.|\Z)", prompts_text, re.DOTALL)
# Clean up each prompt
prompts = [re.sub(r"\n.*", "", prompt) for prompt in prompts]
prompts = [prompt.strip() for prompt in prompts]
return thought_process, description, key_concepts, prompts
def parse_string_benchmark(input_string):
# Split the input string into thought process, description, and the rest
parts = input_string.split("Description:")
thought_process = parts[0].replace("Thought Process:", "").strip()
remaining = parts[1] if len(parts) > 1 else ""
description_and_rest = remaining.split("Key Concepts:")
description = description_and_rest[0].strip()
key_concepts_and_prompts = description_and_rest[1] if len(description_and_rest) > 1 else ""
key_concepts = key_concepts_and_prompts.strip().strip("[]")
return thought_process, description, key_concepts
# from serve.utils_general import save_data_diff_image
# import hashlib
# import json
# def update_attribute_description(
# prompts: List[Dict],
# attribute: str,
# separable_prompts: List[Dict],
# max_num_prompts=5,
# iteration=0,
# model="gpt-4o",
# ):
# """
# Given seperable and inseperable prompts, update the description of the attribute. by asking a VLM to update its description of differences given images of seperable prompts
# """
# filenames = [prompt["paths"][0] for prompt in separable_prompts]
# save_name = hashlib.sha256(json.dumps(filenames).encode()).hexdigest()
# image_path = f"cache/iteration_images/{save_name}.png"
# batch_prompts = random.sample(separable_prompts, min(max_num_prompts, len(separable_prompts)))
# images_top_row = [{"path": prompt["paths"][0]} for prompt in batch_prompts]
# images_bottom_row = [{"path": prompt["paths"][1]} for prompt in batch_prompts]
# save_data_diff_image(dataset1=images_top_row, dataset2=images_bottom_row, save_path=image_path)
# update_attribute_prompt = """You are given a set of images generated by two different models. The images in the top row were generated by Model A, and the images in the bottom row were generated by Model B. Images in the same column were generated using the same prompt.
# The images in the top row contain an artifact related to "{attribute}," while the images in the bottom row do not contain this artifact. Based on these images, please update your description of the artifact "{attribute}" if you can provide a better description of the differences between the models. Would a person consider {attribute} to be the best description of the attributes that appear in the top row and not the bottom row? Are there any edits you would make to the description of {attribute} to make it more clear, accurate, or precise?
# Keep the updated description under 5 words. Output your response in the following format: <artifact_description>NEW_DESCRIPTION</artifact_description>"""
# update_attribute_response = get_vlm_output(
# image_path, update_attribute_prompt.format(attribute=attribute), model
# )
# if "error" in update_attribute_response.lower():
# print("Error in updating attribute description")
# return attribute
# # parse response
# new_description = (
# update_attribute_response.lower()
# .replace("<artifact_description>", "")
# .replace("</artifact_description>", "")
# .strip()
# )
# print(update_attribute_response)
# wandb.log({"attribute_iteration": wandb.Image(image_path, caption=new_description)})
# return new_description