forked from thoppe/alph-the-sacred-river
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
114 lines (77 loc) · 2.59 KB
/
api.py
File metadata and controls
114 lines (77 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import fastapi
import torch
import pandas as pd
import numpy as np
from src import clip
from fastapi import FastAPI
from typing import List
from pydantic import BaseModel
app_name = "alph-the-sacred-river"
__version__ = "0.1.0"
class CLIP:
def __init__(self):
pass
def load(self, f_latents="data/img_latents.npy", f_keys="data/img_keys.csv"):
# Load the model for inference on the CPU
self.model, self.transform = clip.load("ViT-B/32", device="cpu")
self.model.eval()
self.scale = self._to_numpy(self.model.logit_scale.exp())
# Load the pre-computed unsplash latent codes
self.V = np.load(f_latents)
self.V /= np.linalg.norm(self.V, ord=2, axis=-1, keepdims=True)
# Load the mapping of the latent codes to the IDs
self.keys = pd.read_csv(f_keys)["unsplashID"].values
def _to_numpy(self, x):
return x.detach().cpu().numpy()
def encode_text(self, lines):
with torch.no_grad():
tokens = clip.tokenize(lines)
latents = self.model.encode_text(tokens)
latents = self._to_numpy(latents)
latents /= np.linalg.norm(latents, ord=2, axis=-1, keepdims=True)
return latents
def compute(self, lines):
TL = self.encode_text(lines)
logits = (self.scale * self.V).dot(TL.T)
return logits
def __call__(self, lines, top_k=4):
X = self.compute(lines)
df = pd.DataFrame(data=X, columns=lines, index=self.keys)
data = []
for k, sent in enumerate(df.columns):
dx = df.sort_values(sent, ascending=False)[sent]
dx = dx[:top_k]
data.append(
{
"text": sent,
"unsplashIDs": dx.index.values.tolist(),
"scores": dx.values.tolist(),
}
)
return data
def load_sample_data():
with open("docs/collected_poems/Tentación.txt") as FIN:
sents = FIN.read().split("\n")
sents = [" ".join(line.split()) for line in sents if line.strip()]
return sents
app = FastAPI()
clf = CLIP()
clf.load()
class TextListInput(BaseModel):
lines: List[str]
@app.get("/")
def root():
return {
"app_name": app_name,
"version": __version__,
}
@app.get("/infer")
def infer_multi(q: TextListInput):
return clf(q.lines)
if __name__ == "__main__":
sents = load_sample_data()
from fastapi.testclient import TestClient
client = TestClient(app)
print(client.get("/").json())
r = client.get("/infer", json={"lines": sents})
print(r.json())