-
Notifications
You must be signed in to change notification settings - Fork 32
Description
Thank you very much for your open-source work. I am confused about the setting of the mode parameter.
In the Mist paper, a fused algorithm is mentioned, which consists of two parts: Semantic Loss and Textual Loss. Semantic loss is to maximize Ldm loss, which means maximizing the MSE loss between predicted noise and real noise. Textual loss is the minimization of the distance between perturbed features and target features.
It seems not to be implemented this way in code. First, the Semantic loss in the code is not calculated using MSE loss but instead torch.sum (model_pred. loat() * target. loat()), and secondly, the sign of the loss seems to be incorrect. The source code is as follows:
Lines 614 to 618 in ef8001a
| if args.mode == 'fused': | |
| loss = -torch.sum(model_pred.float() * target.float()) | |
| latent_attack = LatentAttack() | |
| loss = loss - 1e2 * latent_attack(latents, target_tensor=target_tensor) |
The correct code seems to be:
if args.mode == 'fused':
semantic_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
latent_attack = LatentAttack()
textual_loss = -latent_attack(latents, target_tensor=target_tensor)
loss = semantic_loss - textual_loss
I'm not sure if my understanding is correct, I hope to receive your reply.