From 6e20f3bc9556ba944f491319e213df13b8225039 Mon Sep 17 00:00:00 2001 From: Nico Kemnitz Date: Fri, 13 Feb 2026 10:18:31 +0100 Subject: [PATCH] feat: optional Focal Loss parameter --- deepem/loss/loss.py | 14 ++++++++++++-- deepem/train/option.py | 3 ++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/deepem/loss/loss.py b/deepem/loss/loss.py index 8c630ff..0cf2e09 100644 --- a/deepem/loss/loss.py +++ b/deepem/loss/loss.py @@ -8,9 +8,10 @@ class BCELoss(nn.Module): """ Binary cross entropy loss with logits. + With focal_gamma set, model will downweight well-classified examples. """ def __init__(self, size_average=True, margin0=0, margin1=0, inverse=True, - class_balancer=None, **kwargs): + class_balancer=None, focal_gamma=None, **kwargs): super().__init__() self.bce = F.binary_cross_entropy_with_logits self.size_average = size_average @@ -18,6 +19,7 @@ def __init__(self, size_average=True, margin0=0, margin1=0, inverse=True, self.margin1 = float(np.clip(margin1, 0, 1)) self.inverse = inverse self.balancer = class_balancer + self.focal_gamma = focal_gamma def forward(self, input, target, mask): # Number of valid voxels @@ -46,7 +48,15 @@ def forward(self, input, target, mask): m_ext = torch.le(activ, m0) * torch.eq(target, 0) mask *= 1 - (m_int + m_ext).type(mask.dtype) - loss = self.bce(input, tgt, weight=mask, reduction='sum') + if self.focal_gamma is not None: + # Focal loss + p = torch.sigmoid(input) + pt = p * tgt + (1.0 - p) * (1.0 - tgt) + focal_weight = (1.0 - pt) ** self.focal_gamma + bce = self.bce(input, tgt, reduction='none') + loss = (focal_weight * bce * mask).sum() + else: + loss = self.bce(input, tgt, weight=mask, reduction='sum') if self.size_average: loss = loss / nmsk.item() diff --git a/deepem/train/option.py b/deepem/train/option.py index ed081bb..aa77f08 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -100,6 +100,7 @@ def initialize(self): self.parser.add_argument('--class_balancing', action='store_true') self.parser.add_argument('--class_weight0', type=parse_class_weight, default=None) self.parser.add_argument('--class_weight1', type=parse_class_weight, default=None) + self.parser.add_argument('--focal_gamma', type=float, default=None) self.parser.add_argument('--default_aux', action='store_true') # Mean-based loss @@ -256,7 +257,7 @@ def parse(self): args = vars(opt) # Loss - loss_keys = ['size_average','margin0','margin1','inverse'] + loss_keys = ['size_average','margin0','margin1','inverse','focal_gamma'] opt.loss_params = {k: args[k] for k in loss_keys} # Metirc learning