-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenroller.py
More file actions
89 lines (72 loc) · 2.65 KB
/
enroller.py
File metadata and controls
89 lines (72 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from face_recognition.api import np
from client import ClientParameters
from face_match import FaceEmbedding, norm_face_embedding
from openfhe import openfhe
@dataclass
class EnrollerDatabase:
"""Enroller database shared with Server."""
database: list[list[openfhe.Ciphertext]]
labels: list[str]
class Enroller:
def __init__(
self,
database_dir: Path,
) -> None:
self.database, self.labels = self._init_database(database_dir)
def _init_database(
self,
database_dir: Path,
) -> tuple[list[FaceEmbedding], list[str]]:
database = []
labels = []
for image_path in database_dir.iterdir():
face_embedding = norm_face_embedding(image_path)
label = image_path.stem
database.append(face_embedding)
labels.append(label)
return database, labels
def enroll(
self,
params: ClientParameters,
) -> EnrollerDatabase:
database = []
database_aligned = self.database.copy()
if len(database_aligned) % params.embedding_dim != 0:
padding_size = params.embedding_dim - len(database_aligned) % params.embedding_dim
database_aligned.extend(np.zeros(params.embedding_dim) for _ in range(padding_size))
for batch in itertools.batched(database_aligned, params.batch_size):
encoded_batch = self._encode_batch(
params.cc,
params.pk,
params.embedding_dim,
params.batch_size,
batch,
)
database.append(encoded_batch)
return EnrollerDatabase(
database=database,
labels=self.labels,
)
def _encode_batch(
self,
cc: openfhe.CryptoContext,
pk: openfhe.PublicKey,
embedding_dim: int,
batch_size: int,
batch: Sequence[FaceEmbedding],
) -> list[openfhe.Ciphertext]:
batch_diagonalized = [[0 for _ in range(batch_size)] for _ in range(embedding_dim)]
for n, partition in enumerate(itertools.batched(batch, embedding_dim)):
for i in range(embedding_dim):
for j in range(embedding_dim):
batch_diagonalized[i][j + n * embedding_dim] = partition[j][(i + j) % embedding_dim]
batch_encoded = []
for i in range(embedding_dim):
pt_row = cc.MakeCKKSPackedPlaintext(batch_diagonalized[i])
ct_row = cc.Encrypt(pk, pt_row)
batch_encoded.append(ct_row)
return batch_encoded