Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions deepem/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
class BCELoss(nn.Module):
"""
Binary cross entropy loss with logits.
With focal_gamma set, model will downweight well-classified examples.
"""
def __init__(self, size_average=True, margin0=0, margin1=0, inverse=True,
class_balancer=None, **kwargs):
class_balancer=None, focal_gamma=None, **kwargs):
super().__init__()
self.bce = F.binary_cross_entropy_with_logits
self.size_average = size_average
self.margin0 = float(np.clip(margin0, 0, 1))
self.margin1 = float(np.clip(margin1, 0, 1))
self.inverse = inverse
self.balancer = class_balancer
self.focal_gamma = focal_gamma

def forward(self, input, target, mask):
# Number of valid voxels
Expand Down Expand Up @@ -46,7 +48,15 @@ def forward(self, input, target, mask):
m_ext = torch.le(activ, m0) * torch.eq(target, 0)
mask *= 1 - (m_int + m_ext).type(mask.dtype)

loss = self.bce(input, tgt, weight=mask, reduction='sum')
if self.focal_gamma is not None:
# Focal loss
p = torch.sigmoid(input)
pt = p * tgt + (1.0 - p) * (1.0 - tgt)
focal_weight = (1.0 - pt) ** self.focal_gamma
bce = self.bce(input, tgt, reduction='none')
loss = (focal_weight * bce * mask).sum()
else:
loss = self.bce(input, tgt, weight=mask, reduction='sum')

if self.size_average:
loss = loss / nmsk.item()
Expand Down
3 changes: 2 additions & 1 deletion deepem/train/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def initialize(self):
self.parser.add_argument('--class_balancing', action='store_true')
self.parser.add_argument('--class_weight0', type=parse_class_weight, default=None)
self.parser.add_argument('--class_weight1', type=parse_class_weight, default=None)
self.parser.add_argument('--focal_gamma', type=float, default=None)
self.parser.add_argument('--default_aux', action='store_true')

# Mean-based loss
Expand Down Expand Up @@ -256,7 +257,7 @@ def parse(self):
args = vars(opt)

# Loss
loss_keys = ['size_average','margin0','margin1','inverse']
loss_keys = ['size_average','margin0','margin1','inverse','focal_gamma']
opt.loss_params = {k: args[k] for k in loss_keys}

# Metirc learning
Expand Down