# detect fact to get bbox
self.det_net = init_detection_model('retinaface_resnet50', half=False, device=device, model_rootpath='models/pulid/facexlib').requires_grad_(False).to(device, dtype=weight_dtype)
# get face emb
self.arcface_model = get_arcface_model(model_name='r50', pretrained_path='models/pulid/glint_cosface_res50.pth').to(device, dtype=weight_dtype)
def score_grad(self, image: torch.Tensor, image_gt: torch.Tensor, **kwargs):