Source code for pathflowai.losses

"""
losses.py
=======================
Some additional loss functions that can be called using the pipeline, some of which still to be implemented.
"""

import torch, numpy as np
from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union
from torch import Tensor, einsum
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt as distance
from torch import nn

[docs]def assert_(condition, message='', exception_type=AssertionError): """https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py Like assert, but with arbitrary exception types.""" if not condition: raise exception_type(message)
[docs]class ShapeError(ValueError): # """https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/exceptions.py""" pass
[docs]def flatten_samples(input_): """ https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/utils/torch_utils.py Flattens a tensor or a variable such that the channel axis is first and the sample axis is second. The shapes are transformed as follows: (N, C, H, W) --> (C, N * H * W) (N, C, D, H, W) --> (C, N * D * H * W) (N, C) --> (C, N) The input must be atleast 2d. """ assert_(input_.dim() >= 2, "Tensor or variable must be atleast 2D. Got one of dim {}." .format(input_.dim()), ShapeError) # Get number of channels num_channels = input_.size(1) # Permute the channel axis to first permute_axes = list(range(input_.dim())) permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] # For input shape (say) NCHW, this should have the shape CNHW permuted = input_.permute(*permute_axes).contiguous() # Now flatten out all but the first axis and return flattened = permuted.view(num_channels, -1) return flattened
[docs]class GeneralizedDiceLoss(nn.Module): """ https://raw.githubusercontent.com/inferno-pytorch/inferno/0561e8a95cde6bfc5e10a3609841b7b0ca5b03ca/inferno/extensions/criteria/set_similarity_measures.py Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237 This version works for multiple classes and expects predictions for every class (e.g. softmax output) and one-hot targets for every class. """ def __init__(self, weight=None, channelwise=False, eps=1e-6, add_softmax=False): super(GeneralizedDiceLoss, self).__init__() self.register_buffer('weight', weight) self.channelwise = channelwise self.eps = eps self.add_softmax = add_softmax
[docs] def forward(self, input, target): """ input: torch.FloatTensor or torch.cuda.FloatTensor target: torch.FloatTensor or torch.cuda.FloatTensor Expected shape of the inputs: - if not channelwise: (batch_size, nb_classes, ...) - if channelwise: (batch_size, nb_channels, nb_classes, ...) """ assert input.size() == target.size() if self.add_softmax: input = F.softmax(input, dim=1) if not self.channelwise: # Flatten input and target to have the shape (nb_classes, N), # where N is the number of samples input = flatten_samples(input) target = flatten_samples(target).float() # Find classes weights: sum_targets = target.sum(-1) class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) # Compute generalized Dice loss: numer = ((input * target).sum(-1) * class_weigths).sum() denom = ((input + target).sum(-1) * class_weigths).sum() loss = 1. - 2. * numer / denom.clamp(min=self.eps) else: def flatten_and_preserve_channels(tensor): tensor_dim = tensor.dim() assert tensor_dim >= 3 num_channels = tensor.size(1) num_classes = tensor.size(2) # Permute the channel axis to first permute_axes = list(range(tensor_dim)) permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0] permuted = tensor.permute(*permute_axes).contiguous() flattened = permuted.view(num_channels, num_classes, -1) return flattened # Flatten input and target to have the shape (nb_channels, nb_classes, N) input = flatten_and_preserve_channels(input) target = flatten_and_preserve_channels(target) # Find classes weights: sum_targets = target.sum(-1) class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps) # Compute generalized Dice loss: numer = ((input * target).sum(-1) * class_weigths).sum(-1) denom = ((input + target).sum(-1) * class_weigths).sum(-1) channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps) if self.weight is not None: if channelwise_loss.dim() == 2: channelwise_loss = channelwise_loss.squeeze(1) assert self.weight.size() == channelwise_loss.size(),\ """`weight` should have shape (nb_channels, ), `target` should have shape (batch_size, nb_channels, nb_classes, ...)""" # Apply channel weights: channelwise_loss = self.weight * channelwise_loss loss = channelwise_loss.sum() return loss
[docs]class FocalLoss(nn.Module): # add boundary loss """ # https://raw.githubusercontent.com/Hsuxu/Loss_ToolBox-PyTorch/master/FocalLoss/FocalLoss.py This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' Focal_Loss= -1*alpha*(1-pt)*log(pt) :param num_class: :param alpha: (tensor) 3D or 4D the scalar factor for this criterion :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more focus on hard misclassified example :param smooth: (float,double) smooth value when cross entropy :param balance_index: (int) balance class index, should be specific when alpha is float :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. """ def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True): super(FocalLoss, self).__init__() self.num_class = num_class self.alpha = alpha self.gamma = gamma self.smooth = smooth self.size_average = size_average if self.alpha is None: self.alpha = torch.ones(self.num_class, 1) elif isinstance(self.alpha, (list, np.ndarray)): assert len(self.alpha) == self.num_class self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1) self.alpha = self.alpha / self.alpha.sum() elif isinstance(self.alpha, float): alpha = torch.ones(self.num_class, 1) alpha = alpha * (1 - self.alpha) alpha[balance_index] = self.alpha self.alpha = alpha else: raise TypeError('Not support alpha type') if self.smooth is not None: if self.smooth < 0 or self.smooth > 1.0: raise ValueError('smooth value should be in [0,1]')
[docs] def forward(self, logit, target): # logit = F.softmax(input, dim=1) if logit.dim() > 2: # N,C,d1,d2 -> N,C,m (m=d1*d2*...) logit = logit.view(logit.size(0), logit.size(1), -1) logit = logit.permute(0, 2, 1).contiguous() logit = logit.view(-1, logit.size(-1)) target = target.view(-1, 1) # N = input.size(0) # alpha = torch.ones(N, self.num_class) # alpha = alpha * (1 - self.alpha) # alpha = alpha.scatter_(1, target.long(), self.alpha) epsilon = 1e-10 alpha = self.alpha if alpha.device != input.device: alpha = alpha.to(input.device) idx = target.cpu().long() one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) if one_hot_key.device != logit.device: one_hot_key = one_hot_key.to(logit.device) if self.smooth: one_hot_key = torch.clamp( one_hot_key, self.smooth/(self.num_class-1), 1.0 - self.smooth) pt = (one_hot_key * logit).sum(1) + epsilon logpt = pt.log() gamma = self.gamma alpha = alpha[idx] loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt if self.size_average: loss = loss.mean() else: loss = loss.sum() return loss
[docs]def uniq(a: Tensor) -> Set: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" return set(torch.unique(a.cpu()).numpy())
[docs]def sset(a: Tensor, sub: Iterable) -> bool: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" return uniq(a).issubset(sub)
[docs]def eq(a: Tensor, b) -> bool: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" return torch.eq(a, b).all()
[docs]def simplex(t: Tensor, axis=1) -> bool: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" _sum = t.sum(axis).type(torch.float32) _ones = torch.ones_like(_sum, dtype=torch.float32) return torch.allclose(_sum, _ones)
[docs]def one_hot(t: Tensor, axis=1) -> bool: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" return simplex(t, axis) and sset(t, [0, 1])
[docs]def class2one_hot(seg: Tensor, C: int) -> Tensor: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" if len(seg.shape) == 2: # Only w, h, used by the dataloader seg = seg.unsqueeze(dim=0) assert sset(seg, list(range(C))) b, w, h = seg.shape # type: Tuple[int, int, int] res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) assert res.shape == (b, C, w, h) assert one_hot(res) return res
[docs]def one_hot2dist(seg: np.ndarray) -> np.ndarray: """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/utils.py""" assert one_hot(torch.Tensor(seg), axis=0) C: int = len(seg) res = np.zeros_like(seg) for c in range(C): posmask = seg[c].astype(np.bool) if posmask.any(): negmask = ~posmask res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask return res
[docs]class SurfaceLoss(): """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py""" def __init__(self, **kwargs): # Self.idc is used to filter out some classes of the target mask. Use fancy indexing self.idc: List[int] = kwargs["idc"] print(f"Initialized {self.__class__.__name__} with {kwargs}") def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor: assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, self.idc, ...].type(torch.float32) dc = dist_maps[:, self.idc, ...].type(torch.float32) multipled = einsum("bcwh,bcwh->bcwh", pc, dc) loss = multipled.mean() return loss
[docs]class GeneralizedDice(): """https://raw.githubusercontent.com/LIVIAETS/surface-loss/master/losses.py""" def __init__(self, **kwargs): # Self.idc is used to filter out some classes of the target mask. Use fancy indexing self.idc: List[int] = kwargs["idc"] print(f"Initialized {self.__class__.__name__} with {kwargs}") def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor: assert simplex(probs) and simplex(target) pc = probs[:, self.idc, ...].type(torch.float32) tc = target[:, self.idc, ...].type(torch.float32) w: Tensor = 1 / ((einsum("bcwh->bc", tc).type(torch.float32) + 1e-10) ** 2) intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc) union: Tensor = w * (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc)) divided: Tensor = 1 - 2 * (einsum("bc->b", intersection) + 1e-10) / (einsum("bc->b", union) + 1e-10) loss = divided.mean() return loss