-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·69 lines (47 loc) · 1.94 KB
/
main.py
File metadata and controls
executable file
·69 lines (47 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env -S uv run
from typing import Tuple
from hnsw.head_nodes.DistributedHNSWNormalRemote import DistributedHNSWNormalRemote
from hnsw.head_nodes.DistributedHNSWRandomRemote import DistributedHNSWRandomRemote
from hnsw.utils import read_fbin, write_fbin
from hnsw.profiler.Profiler import Profiler
import sys
import numpy as np
def calculate_accuracy(xq: np.ndarray, xr: np.ndarray) -> Tuple[float, float]:
query_count = xq.shape[0]
xr = xr[:, :10, :] # (n, 10, 96)
xq = xq[:, None, :] # (n, 1, 96)
top10 = np.all(np.isclose(xr, xq), axis=2) # (1000, 10)
top1 = top10[:, 0]
results_1_yes = float(top1.sum())
results_1 = results_1_yes / query_count
results_10_yes = float((np.any(top10, axis=1) == True).sum())
results_10 = results_10_yes / query_count
return results_1, results_10
def main():
np.random.seed(42) # For reproducibility
vec_path = sys.argv[1]
x = read_fbin(vec_path)
x = x[:100_000, :]
x_path = write_fbin(x)
n_vecs, dim = x.shape
node_list = ["localhost:50051", "localhost:50052", "localhost:50053"]
hnsw_index = DistributedHNSWRandomRemote(dim, node_list)
print(f"Clearing any previous indices, success: {hnsw_index.clear()}")
print(f"Adding {n_vecs:,} vecs")
hnsw_index.add(x_path, "float32")
print("Random vecs")
xq = x[np.random.choice(x.shape[0], 100, replace=False)]
distances, indices = hnsw_index.search(xq)
print(f"Distances shape: {distances.shape}, Indices shape: {indices.shape}")
num_nodes = distances.shape[0]
xr = []
for i in range(0, num_nodes):
indices_i = indices[i, :] # indices for this shard
xr.append(hnsw_index.get(indices_i, i))
xr = np.stack(xr)
for i in range(0, 3):
print(f"Accuracy for node {i}: {calculate_accuracy(xq, xr[i])}")
# should be around 0.3 due to random sampling
Profiler.dump_all_reports()
if __name__ == "__main__":
main()