-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_test.py
More file actions
121 lines (109 loc) · 4.28 KB
/
model_test.py
File metadata and controls
121 lines (109 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*-coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from utils import file_processing, image_processing, debug
import face_recognition
import predict
resize_width = 160
resize_height = 160
def face_face_recognition_batch(
model_path, dataset_path, filename, filePath_list, label_list, threshold
):
# 加载数据库的数据
dataset_emb, names_list = predict.load_dataset(dataset_path, filename)
print("loadind dataset...\n names_list:{}".format(names_list))
# 初始化mtcnn人脸检测
face_detect = face_recognition.Facedetection()
# 初始化facenet
face_net = face_recognition.facenetEmbedding(model_path)
right_num = 0
wrong_num = 0
detection_num = 0
test_num = len(filePath_list)
for image_path, label_name in zip(filePath_list, label_list):
print("image_path:{}".format(image_path))
# 读取图片
image = image_processing.read_image_gbk(image_path)
# 人脸检测
T0 = debug.TIME()
bboxes, landmarks = face_detect.detect_face(image)
bboxes, landmarks = face_detect.get_square_bboxes(
bboxes, landmarks, fixed="height"
)
T1 = debug.TIME()
if bboxes == [] or landmarks == []:
print("-----no face")
continue
if len(bboxes) >= 2 or len(landmarks) >= 2:
print("-----image have {} faces".format(len(bboxes)))
continue
T2 = debug.TIME()
# 获得人脸框区域
face_images = image_processing.get_bboxes_image(
image, bboxes, resize_height, resize_width
)
face_images = image_processing.get_prewhiten_images(
face_images, normalization=True
)
# face_images = image_processing.get_prewhiten_images(face_images,normalization=True)
pred_emb = face_net.get_embedding(face_images)
T3 = debug.TIME()
pred_name, pred_score = predict.compare_embadding(
pred_emb, dataset_emb, names_list, threshold=threshold
)
# 在图像上绘制人脸边框和识别的结果
# show_info = [n + ':' + str(s)[:5] for n, s in zip(pred_name, pred_score)]
# image_processing.show_image_text("face_recognition", image, bboxes, show_info)
index = 0
pred_name = pred_name[index]
pred_score = pred_score[index]
if pred_name == label_name:
right_num += 1
else:
wrong_num += 1
detection_num += 1
print(
"--detect face time:{}ms,recognition:{}ms,label_name:{},pred_name:{},score:{:3.4f},status:{}".format(
debug.RUN_TIME(T1 - T0),
debug.RUN_TIME(T3 - T2),
label_name,
pred_name,
pred_score,
(label_name == pred_name),
)
)
# 准确率
accuracy = right_num / detection_num
# 漏检率
misdetection = (test_num - detection_num) / test_num
print(
"-------------right_num/detection_num:{}/{},accuracy rate:{}".format(
right_num, detection_num, accuracy
)
)
print(
"-------------misdetection/all_num:{}/{},misdetection rate:{}".format(
(test_num - detection_num), test_num, misdetection
)
)
def face_recognition_(model_path, dataset_path, test_dataset, filename, threshold):
# 获得测试图片的路径和label
filePath_list, label_list = file_processing.gen_files_labels(
test_dataset, postfix=None
)
label_list = [name.split("_")[0] for name in label_list]
print("filePath_list:{},label_list{}".format(len(filePath_list), len(label_list)))
face_face_recognition_batch(
model_path, dataset_path, filename, filePath_list, label_list, threshold
)
if __name__ == "__main__":
model_path = "models/20200808-102900"
dataset_path = "dataset/emb/faceEmbedding.npy"
filename = "dataset/emb/name.txt"
# image_path='dataset/test_images/1.jpg'
# test_dataset='E:/Face/dataset/bzl/test2/test_dataset'
# test_dataset='E:/Face/dataset/bzl/test3/test_dataset'
test_dataset = "../database/lfw-align-128/Grady_Irvin_Jr"
threshold = 1.14
face_recognition_(model_path, dataset_path, test_dataset, filename, threshold)