-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
28 lines (18 loc) · 953 Bytes
/
utils.py
File metadata and controls
28 lines (18 loc) · 953 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
class Trigger(nn.Module):
def __init__(self, size: int = 32, transparency: float = 1.) -> None:
super().__init__()
self.size = size
self.mask = nn.Parameter(torch.rand(size, size,device=torch.device('cuda')),requires_grad=True)
self.transparency = transparency
self.trigger = nn.Parameter(torch.rand(3, size, size,device=torch.device('cuda')) * 4 - 2,requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.transparency * self.mask * self.trigger + (1 - self.mask * self.transparency) * x
class UAP(nn.Module):
def __init__(self, size: int = 32) -> None:
super().__init__()
self.size = size
self.perturbation = nn.Parameter(torch.zeros(3, size, size,device=torch.device('cuda')),requires_grad=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.perturbation