from cProfile import label
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def uniform_mix_C(mixing_ratio, num_classes):
    '''
    returns a linear interpolation of a uniform matrix and an identity matrix
    '''
    return mixing_ratio * np.full((num_classes, num_classes), 1 / num_classes) + \
        (1 - mixing_ratio) * np.eye(num_classes)

class Hill(nn.Module):
    r""" Hill as described in the paper "Robust Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        Loss = y \times (1-p_{m})^\gamma\log(p_{m}) + (1-y) \times -(\lambda-p){p}^2 

    where : math:`\lambda-p` is the weighting term to down-weight the loss for possibly false negatives,
          : math:`m` is a margin parameter, 
          : math:`\gamma` is a commonly used value same as Focal loss.

    .. note::
        Sigmoid will be done in loss. 

    Args:
        lambda (float): Specifies the down-weight term. Default: 1.5. (We did not change the value of lambda in our experiment.)
        margin (float): Margin value. Default: 1 . (Margin value is recommended in [0.5,1.0], and different margins have little effect on the result.)
        gamma (float): Commonly used value same as Focal loss. Default: 2

    """

    def __init__(self, lamb: float = 1.5, margin: float = 1.0, gamma: float = 2.0,  reduction: str = 'sum') -> None:
        super(Hill, self).__init__()
        self.lamb = lamb
        self.margin = margin
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`

        Returns:
            torch.Tensor: loss
        """

        # Calculating predicted probability
        logits_margin = logits - self.margin
        pred_pos = torch.sigmoid(logits_margin)
        pred_neg = torch.sigmoid(logits)

        # Focal margin for postive loss
        pt = (1 - pred_pos) * targets + (1 - targets)
        focal_weight = pt ** self.gamma

        # Hill loss calculation
        los_pos = targets * torch.log(pred_pos)
        los_neg = (1-targets) * -(self.lamb - pred_neg) * pred_neg ** 2

        loss = -(los_pos + los_neg)
        loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class SPLNC(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 1,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPLNC, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # SPLC missing label correction
        lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
        ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
        # if epoch >= self.change_epoch:
        #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
        #     targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
        
        pred = torch.sigmoid(logits)

        #loss positive
        los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

        loss = -(los_pos + los_neg)
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SPLNCF(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.60,
                taup: float = 0.60,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPLNCF, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            pt = (1 - pred) * targets + pred * (1 - targets)
            focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)



            #loss positive
            pt = (1 - pred)**self.gamma
            pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * pt * F.logsigmoid(logits) + (1 - lp) * pn * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * pn * F.logsigmoid(-logits) + (1 - ln) * pt * F.logsigmoid(logits))
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class SPLNCFW(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.60,
                taup: float = 0.60,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                beta: float = 0.6,
                reduction: str = 'sum') -> None:
        super(SPLNCFW, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.beta = beta
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            pt = (1 - pred) * targets + pred * (1 - targets)
            focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            
            pred = torch.sigmoid(logits)

            #loss positive
            pt = (1 - pred)**self.gamma
            pn = (pred)**self.gamma
            w_pos = self.beta * pt * F.logsigmoid(logits) + (1 - self.beta) * pn * F.logsigmoid(-logits)
            w_neg = (1 - self.beta) * pt * F.logsigmoid(logits) + self.beta * pn * F.logsigmoid(-logits)
            
            true_pos = pt * F.logsigmoid(logits)
            true_neg = pn * F.logsigmoid(-logits)
            lp = torch.where(torch.sigmoid(logits) > self.tau, true_pos.float(), w_pos.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, true_neg.float(), w_neg.float())
            

            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * lp
            los_neg = (1 - targets) * ln
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SPLNCFC(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 4,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPLNCFC, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction
        self.list_fp = []
        self.list_fn = []

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor, target_c: torch.LongTensor, 
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        # list_fp = []
        # list_fn = []

        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            pt = (1 - pred) * targets + pred * (1 - targets)
            focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)
            # target_c = target_c.max(dim=1)[0]
            target_c = target_c.cpu().numpy()

            # pred_s = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # pred_n = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            
            pred_s = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            pred_n = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())

            pred_s = pred_s.cpu().numpy()
            pred_n = pred_n.cpu().numpy()

            # print(pred_s)
            # print(target_c)
            fp = np.sum(np.logical_not(target_c) * pred_s, axis=1).astype('float32')
            fn = np.sum(target_c * (pred_n), axis=1).astype('float32')

            self.list_fp.append(np.sum(fp))
            self.list_fn.append(np.sum(fn))

            


            #loss positive
            pt = (1 - pred)**self.gamma
            pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * pt * F.logsigmoid(logits) + (1 - lp) * pn * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * pn * F.logsigmoid(-logits) + (1 - ln) * pt * F.logsigmoid(logits))
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight
        if len(self.list_fn) + len(self.list_fp) > 0:
            print("FP List: ", np.sum(self.list_fp))
            print("FN List: ", np.sum(self.list_fn))
            print("sum FP and FN: ", np.sum(self.list_fp)+np.sum(self.list_fn))
            print("Len Pred: ", len(pred)*80)
            self.list_fn = []
            self.list_fp = []
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class SPLNCB(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPLNCB, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            # pt = (1 - pred) * targets + pred * (1 - targets)
            # focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            # loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)



            #loss positive
            # pt = (1 - pred)**self.gamma
            # pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * F.logsigmoid(logits) + (1 - lp) * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * F.logsigmoid(-logits) + (1 - ln) * F.logsigmoid(logits))
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SPLNCBK(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                kL: float = 2.9,
                lamda: float = 2.0,
                reduction: str = 'sum') -> None:
        super(SPLNCBK, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.kL = kL
        self.lamda = lamda
        self.reduction = reduction


    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            # pt = (1 - pred) * targets + pred * (1 - targets)
            # focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            # loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)
            est_kL = torch.where(pred > self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            est_kL = np.array(est_kL)
            num_kL = np.sum(np.sum(est_kL, axis=0))
            reg = ((num_kL-self.kL)/(len(pred)))**2



            #loss positive
            # pt = (1 - pred)**self.gamma
            # pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * F.logsigmoid(logits) + (1 - lp) * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * F.logsigmoid(-logits) + (1 - ln) * F.logsigmoid(logits))
            loss = -(los_pos + los_neg) + self.lamda * reg
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SPCBCM(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                num_class: int = 38,
                reduction: str = 'sum') -> None:
        super(SPCBCM, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction
        self.num_class = num_class

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        C = uniform_mix_C(0.5, self.num_class)
        # # SPLC missing label correction
        p_fp = np.zeros((1,self.num_class))
        p_fn = np.zeros((1,self.num_class))

        for i in range(self.num_class):
            for j in range(self.num_class):
                if j != i:
                    p_fp[0,i] += C[j,i]
                    p_fn[0,i] += C[i,j]
        p_tp = 1 - p_fp
        p_tn = 1 - p_fn

        t_p_tp = torch.from_numpy(p_tp).cuda()
        t_p_tn = torch.from_numpy(p_tn).cuda()
        t_p_fp = torch.from_numpy(p_fp).cuda()
        t_p_fn = torch.from_numpy(p_fn).cuda()
        # print("FP: ", t_p_fp)
        # print("FN: ", t_p_fn)
        pred = torch.sigmoid(logits)
        los_pos = targets * (t_p_tp * F.logsigmoid(logits) + t_p_fp * F.logsigmoid(-logits))
        los_neg = (1 - targets) * (t_p_tn * F.logsigmoid(-logits) + t_p_fn * F.logsigmoid(logits))
        loss = -(los_pos + los_neg)


        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class SPLNCT(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPLNCT, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction
        self.a = 4.2
        self.b = 2.5
        self.c = 4.2

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            # pt = (1 - pred) * targets + pred * (1 - targets)
            # focal_weight = pt**self.gamma

            los_pos = targets * (self.a * torch.tanh(-F.logsigmoid(logits) - self.b) + self.c)
            los_neg = (1 - targets) * (self.a * torch.tanh(-F.logsigmoid(-logits) - self.b) + self.c)

            loss = (los_pos + los_neg)
            # loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)



            #loss positive
            # pt = (1 - pred)**self.gamma
            # pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * (self.a * torch.tanh(-F.logsigmoid(logits) - self.b) + self.c) + (1 - lp) * (self.a * torch.tanh(-F.logsigmoid(-logits) - self.b) + self.c))
            los_neg = (1 - targets) * (ln * (self.a * torch.tanh(-F.logsigmoid(-logits) - self.b) + self.c) + (1 - ln) * (self.a * torch.tanh(-F.logsigmoid(logits) - self.b) + self.c))
            loss = (los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

class SPLNCFASL(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """
    def __init__(self, gamma_neg=4, gamma_pos=2, clip=0.05, eps=1e-8, tau=0.6, taup=0.6, change_epoch=2, margin=1.0, reduction='sum', disable_torch_grad_focal_loss=True):
        super(SPLNCFASL, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.margin = margin
        self.eps = eps
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)

            pt = (1 - pred) * targets  
            ptn = pred * (1 - targets)
            pos_weight = pt**self.gamma_pos 
            neg_weight = ptn**self.gamma_neg

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -los_pos * pos_weight - los_neg * neg_weight
            # loss *= focal_weight
            
        
        else:
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)



            #loss positive
            pt = (1 - pred)**self.gamma_pos
            pn = (pred)**self.gamma_neg
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * pt * F.logsigmoid(logits) + (1 - lp) * pn * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * pn * F.logsigmoid(-logits) + (1 - ln) * pt * F.logsigmoid(logits))
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
    
    
    # def forward(self, x, y, epoch):
    #     """
    #     call function as forward

    #     Args:
    #         logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
    #         targets : Multi-label binarized vector with shape of :math:`(N, C)`
    #         epoch : The epoch of current training.

    #     Returns:
    #         torch.Tensor: loss
    #     """
    #     # Subtract margin for positive logits
    #     if epoch < self.change_epoch:
    #         x_sigmoid = torch.sigmoid(x)
    #         xs_pos = x_sigmoid
    #         xs_neg = 1 - x_sigmoid

    #         # Asymmetric Clipping
    #         if self.clip is not None and self.clip > 0:
    #             xs_neg = (xs_neg + self.clip).clamp(max=1)
            
            
    #         los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
    #         los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
    #         loss = los_pos + los_neg

    #         # Asymmetric Focusing
    #         if self.gamma_neg > 0 or self.gamma_pos > 0:
    #             if self.disable_torch_grad_focal_loss:
    #                 torch.set_grad_enabled(False)
    #             pt0 = xs_pos * y
    #             pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
    #             pt = pt0 + pt1
    #             one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
    #             one_sided_w = torch.pow(1 - pt, one_sided_gamma)
    #             if self.disable_torch_grad_focal_loss:
    #                 torch.set_grad_enabled(True)
    #             loss *= one_sided_w

    #         return -loss.sum()
    #     else:
    #         x_sigmoid = torch.sigmoid(x)
    #         xs_pos = x_sigmoid
    #         xs_neg = 1 - x_sigmoid

    #         # Asymmetric Clipping
    #         if self.clip is not None and self.clip > 0:
    #             xs_neg = (xs_neg + self.clip).clamp(max=1)            

    #         lp = torch.where(torch.sigmoid(x) > self.tau, torch.tensor(1.).cuda(), y.float())
    #         ln = torch.where(torch.sigmoid(x) <= self.taup, torch.tensor(1.).cuda(), y.float())
            
    #         # # # SPLC missing label correction
    #         # los_pos = targets * (lp * pt * F.logsigmoid(logits) + (1 - lp) * pn * F.logsigmoid(-logits))
    #         # los_neg = (1 - targets) * (ln * pn * F.logsigmoid(-logits) + (1 - ln) * pt * F.logsigmoid(logits))
    #         los_pos = y * (lp * torch.log(xs_pos.clamp(min=self.eps)) + (1 - lp) * torch.log(xs_neg.clamp(min=self.eps)))
    #         los_neg = (1 - y) * (ln * torch.log(xs_neg.clamp(min=self.eps)) + (1 - ln) * torch.log(xs_pos.clamp(min=self.eps)))
        
    #         loss = los_pos + los_neg
    #         # Asymmetric Focusing
    #         if self.gamma_neg > 0 or self.gamma_pos > 0:
    #             if self.disable_torch_grad_focal_loss:
    #                 torch.set_grad_enabled(False)
    #             pt0 = xs_pos * y
    #             pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
    #             pt = pt0 + pt1
    #             one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
    #             one_sided_w = torch.pow(1 - pt, one_sided_gamma)
    #             if self.disable_torch_grad_focal_loss:
    #                 torch.set_grad_enabled(True)
    #             loss *= one_sided_w

    #         return -loss.sum()

        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight



class SPLC(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                 tau: float = 0.6,
                 change_epoch: int = 1,
                 margin: float = 1.0,
                 gamma: float = 2.0,
                 reduction: str = 'sum') -> None:
        super(SPLC, self).__init__()
        self.tau = tau
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # SPLC missing label correction
        if epoch >= self.change_epoch:
            #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
        
        pred = torch.sigmoid(logits)

        # Focal margin for postive loss
        pt = (1 - pred) * targets + pred * (1 - targets)
        focal_weight = pt**self.gamma

        los_pos = targets * F.logsigmoid(logits)
        los_neg = (1 - targets) * F.logsigmoid(-logits)

        loss = -(los_pos + los_neg)
        loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
        super(AsymmetricLoss, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -loss.sum()


class AsymmetricLossOptimized(nn.Module):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricLossOptimized, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        # Calculating Probabilities
        self.xs_pos = torch.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)

        # Basic CE calculation
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.sum()


class ASLSingleLabel(nn.Module):
    '''
    This loss is intended for single-label classification problems
    '''
    def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'):
        super(ASLSingleLabel, self).__init__()

        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.reduction = reduction

    def forward(self, inputs, target):
        '''
        "input" dimensions: - (batch_size,number_classes)
        "target" dimensions: - (batch_size)
        '''
        num_classes = inputs.size()[-1]
        log_preds = self.logsoftmax(inputs)
        self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

        # ASL weights
        targets = self.targets_classes
        anti_targets = 1 - targets
        xs_pos = torch.exp(log_preds)
        xs_neg = 1 - xs_pos
        xs_pos = xs_pos * targets
        xs_neg = xs_neg * anti_targets
        asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                                 self.gamma_pos * targets + self.gamma_neg * anti_targets)
        log_preds = log_preds * asymmetric_w

        if self.eps > 0:  # label smoothing
            self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss.sum(dim=-1)
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss

class GLC(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
        super(GLC, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps
        self.C = uniform_mix_C(0.5, 38)

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        C_hat = np.transpose(self.C)
        C_hat = torch.from_numpy(self.C).cuda()

        # Calculating Probabilities
       
        x_sigmoid = torch.matmul(torch.sigmoid(x).float(), C_hat.float())
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        # print(y)
        # labels = torch.where(y == 1.0)
        # label_idx = labels.nonzero(as_tuple=False)
        # print(labels)
        # exit()
        
        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        
        # for i in range(len(y)):
        #     labels = (y[i] == 1.0).nonzero(as_tuple=False)
        #     labels = labels.cpu().numpy()
        #     labels = labels.reshape((1,-1))[0]
        #     # print(self.C[labels])
            
            
        #     # print(C_hat)
        #     for lbl in labels:
        #         xs_pos[lbl] = xs_pos[lbl] * C_hat[lbl] 
        #     print(xs_pos)

        # exit()
        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -loss.sum()

class SPBAG(nn.Module):
    r""" SPLC loss as described in the paper "Simple Loss Design for Multi-Label Learning with Missing Labels "

    .. math::
        &L_{SPLC}^+ = loss^+(p)
        &L_{SPLC}^- = \mathbb{I}(p\leq \tau)loss^-(p) + (1-\mathbb{I}(p\leq \tau))loss^+(p)

    where :math:'\tau' is a threshold to identify missing label 
          :math:`$\mathbb{I}(\cdot)\in\{0,1\}$` is the indicator function, 
          :math: $loss^+(\cdot), loss^-(\cdot)$ refer to loss functions for positives and negatives, respectively.

    .. note::
        SPLC can be combinded with various multi-label loss functions. 
        SPLC performs best combined with Focal margin loss in our paper. Code of SPLC with Focal margin loss is released here.
        Since the first epoch can recall few missing labels with high precision, SPLC can be used ater the first epoch.
        Sigmoid will be done in loss. 

    Args:
        tau (float): threshold value. Default: 0.6
        change_epoch (int): which epoch to combine SPLC. Default: 1
        margin (float): Margin value. Default: 1
        gamma (float): Hard mining value. Default: 2
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Default: ``'sum'``

        """

    def __init__(self,
                tau: float = 0.6,
                taup: float = 0.6,
                change_epoch: int = 2,
                margin: float = 1.0,
                gamma: float = 2.0,
                eps: float = 1e-8,
                reduction: str = 'sum') -> None:
        super(SPBAG, self).__init__()
        self.tau = tau
        self.taup = taup
        self.change_epoch = change_epoch
        self.margin = margin
        self.gamma = gamma
        self.eps = eps
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, logits_aug: torch.Tensor, targets: torch.LongTensor,
                epoch) -> torch.Tensor:
        """
        call function as forward

        Args:
            logits : The predicted logits before sigmoid with shape of :math:`(N, C)`
            targets : Multi-label binarized vector with shape of :math:`(N, C)`
            epoch : The epoch of current training.

        Returns:
            torch.Tensor: loss
        """
        # Subtract margin for positive logits
        # logits = torch.where(targets == 1, logits-self.margin, logits)
        logits = logits.float()
        # # SPLC missing label correction
        
        if epoch < self.change_epoch:
            pred = torch.sigmoid(logits)
            # pt = (1 - pred) * targets + pred * (1 - targets)
            # focal_weight = pt**self.gamma

            los_pos = targets * F.logsigmoid(logits)
            los_neg = (1 - targets) * F.logsigmoid(-logits)

            loss = -(los_pos + los_neg)
            # loss *= focal_weight
            
        
        else:
            thresh_aug = 0.5 * (torch.sigmoid(logits) + torch.sigmoid(logits_aug))
        
            # lp = torch.where(torch.sigmoid(logits) >= self.tau, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            # ln = torch.where(torch.sigmoid(logits) <= self.taup, torch.tensor(1.).cuda(), torch.tensor(0.).cuda())
            lp = torch.where(torch.sigmoid(logits) > thresh_aug, torch.tensor(1.).cuda(), targets.float())
            ln = torch.where(torch.sigmoid(logits) <= thresh_aug, torch.tensor(1.).cuda(), targets.float())
            
            # if epoch >= self.change_epoch:
            #     #print(torch.sigmoid(logits), self.tau, torch.tensor(1.).cuda(), targets)
            #     # targets = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     lp = torch.where(torch.sigmoid(logits) > self.tau, torch.tensor(1.).cuda(), targets.float())
            #     ln = torch.where(torch.sigmoid(logits) <= self.tau, torch.tensor(1.).cuda(), targets.float())

            pred = torch.sigmoid(logits)



            #loss positive
            # pt = (1 - pred)**self.gamma
            # pn = (pred)**self.gamma
            # los_pos = targets * (lp * pt * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * pn * torch.log((1 - pred).clamp(min=self.eps)))
            # los_neg = (1 - targets) * (ln * pn * torch.log((1 - pred).clamp(min=self.eps)) + (1 - ln) * pt * torch.log(pred.clamp(min=self.eps)))

            los_pos = targets * (lp * F.logsigmoid(logits) + (1 - lp) * F.logsigmoid(-logits))
            los_neg = (1 - targets) * (ln * F.logsigmoid(-logits) + (1 - ln) * F.logsigmoid(logits))
            loss = -(los_pos + los_neg)
        # los_pos = targets * (lp * torch.log(pred.clamp(min=self.eps)) + (1 - lp) * torch.log((1 - pred).clamp(min=self.eps)))
        # los_neg = (1 - targets) * ((ln * torch.log((1 - pred).clamp(min=self.eps))) + (1 - ln) * torch.log(pred.clamp(min=self.eps)))
        
        # Focal margin for postive loss
        # pt = (1 - pred) * targets + pred * (1 - targets)
        # focal_weight = pt**self.gamma

        # los_pos = targets * F.logsigmoid(logits)
        # los_neg = (1 - targets) * F.logsigmoid(-logits)

            
        # loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
