"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


class WeightedCrossEntropyLoss(torch.nn.Module):
    def __init__(self, power=1.0):
        super(WeightedCrossEntropyLoss, self).__init__()
        self.power = power

    def forward(self, logits, targets):
        # Compute positive and negative counts in the batch
        positive_counts = targets.sum(dim=0)
        negative_counts = targets.size(0) - positive_counts

        # Calculate weights for positive and negative classes
        total_counts = positive_counts + negative_counts
        
        positive_weights = (total_counts / (positive_counts + 1e-12)) ** self.power
        negative_weights = (total_counts / (negative_counts + 1e-12)) ** self.power
    

        # Calculate the weighted cross entropy loss
        loss = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
        weighted_loss = positive_weights * targets * loss + negative_weights * (1 - targets) * loss
        weighted_loss = weighted_loss.mean(dim=0).mean(dim=-1)

        return weighted_loss
    
    
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, input, target):
        """
        input: tensor of shape (batch_size, num_classes)
        target: tensor of shape (batch_size, num_classes)
        """
        assert input.shape == target.shape, "Input and target shapes must be the same."
        p = torch.sigmoid(input)
        ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction="none")
        p_t = p * target + (1 - p) * (1 - target)
        loss = ce_loss * ((1 - p_t) ** self.gamma)
        
        if self.alpha >= 0:
            alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target)
            loss = alpha_t * loss
        return loss.mean()
    
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.1, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # Create a true distribution with smoothed values
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))


def unique(x, dim=None):
    """Unique elements of x and indices of those unique elements
    https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810

    e.g.

    unique(tensor([
        [1, 2, 3],
        [1, 2, 4],
        [1, 2, 3],
        [1, 2, 5]
    ]), dim=0)
    => (tensor([[1, 2, 3],
                [1, 2, 4],
                [1, 2, 5]]),
        tensor([0, 1, 3]))
    """
    unique, inverse = torch.unique(
        x, sorted=True, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(0), dtype=inverse.dtype,
                        device=inverse.device)
    inverse, perm = inverse.flip([0]), perm.flip([0])
    return unique, inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)

class HMLMAC(nn.Module):
    def __init__(self, temperature=0.07,
                 base_temperature=0.07, layer_penalty=None, loss_type='hmc'):
        super(HMLC, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        if not layer_penalty:
            # set default layer penalty to the pow_2 function
            self.layer_penalty = self.pow_2  
        else:
            self.layer_penalty = layer_penalty
        self.sup_con_loss = SupConLoss(temperature)  # initialize SupConLoss object
        self.loss_type = loss_type

    def pow_2(self, value):
        return torch.pow(2, value)  # power function with exponent 2

    def forward(self, features, labels):
        # features: B x D
        # label: B x N x L
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))
        # initialize mask to ones
        mask = torch.ones(labels.shape).to(device)  
        # initialize cumulative loss to 0
        cumulative_loss = torch.tensor(0.0).to(device)  
        # initialize max loss to negative infinity
        max_loss_lower_layer = torch.tensor(float('-inf'))  
        # loop over label hierarchy levels
        for l in range(1,labels.shape[2]):  
            # update mask to exclude labels at current level or above
            mask[:, :, labels.shape[2]-l:] = 0  
            # compute labels for current layer
            layer_labels = labels * mask  
            # compute pairwise mask for current layer
            mask_labels = torch.stack([torch.all(torch.eq(layer_labels[i], layer_labels), dim=1)
                                       for i in range(layer_labels.shape[0])]).type(torch.uint8).to(device)  
            # compute loss for current layer
            layer_loss = self.sup_con_loss(features, mask=mask_labels)  
            if self.loss_type == 'hmc':
                # add weighted layer loss to cumulative loss
                cumulative_loss += self.layer_penalty(torch.tensor(
                  1/(l)).type(torch.float)) * layer_loss  
            elif self.loss_type == 'hce':
                 # take max of lower layer loss and current layer loss
                layer_loss = torch.max(max_loss_lower_layer.to(layer_loss.device), layer_loss) 
                cumulative_loss += layer_loss  # add max loss to cumulative loss
            elif self.loss_type == 'hmce':
                # take max of lower layer loss and current layer loss
                layer_loss = torch.max(max_loss_lower_layer.to(layer_loss.device), layer_loss)  
                # add weighted max loss to cumulative loss
                cumulative_loss += self.layer_penalty(torch.tensor(
                    1/l).type(torch.float)) * layer_loss  
            else:
                raise NotImplementedError('Unknown loss')
            # get unique instances for next lower layer
            _, unique_indices = unique(layer_labels, dim=0) 
            # update max loss for next lower layer
            max_loss_lower_layer = torch.max(
                max_loss_lower_layer.to(layer_loss.device), layer_loss)  
            # update labels for next lower layer
            labels = labels[unique_indices]  
            # update mask for next lower layer
            mask = mask[unique_indices]  
            # update features for next lower layer
            features = features[unique_indices]  
        # return average loss across all levels of the hierarchy
        return cumulative_loss / labels.shape[1]  

class HMLC(nn.Module):
    def __init__(self, temperature=0.07,
                 base_temperature=0.07, layer_penalty=None, loss_type='hmce'):
        super(HMLC, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        if not layer_penalty:
            # set default layer penalty to the pow_2 function
            self.layer_penalty = self.pow_2  
        else:
            self.layer_penalty = layer_penalty
        self.sup_con_loss = SupConLoss(temperature)  # initialize SupConLoss object
        self.loss_type = loss_type

    def pow_2(self, value):
        return torch.pow(2, value)  # power function with exponent 2

    def forward(self, features, labels):
        # device = (torch.device('cuda')
        #           if features.is_cuda
        #           else torch.device('cpu'))
        device = labels.device
        # initialize mask to ones
        mask = torch.ones(labels.shape).to(device)  
        # initialize cumulative loss to 0
        cumulative_loss = torch.tensor(0.0).to(device)  
        # initialize max loss to negative infinity
        max_loss_lower_layer = torch.tensor(float('-inf'))  
        # loop over label hierarchy levels
        for l in range(1,labels.shape[1]):  
            # update mask to exclude labels at current level or above
            mask[:, labels.shape[1]-l:] = 0  
            # compute labels for current layer
            layer_labels = labels * mask  
            # compute pairwise mask for current layer
            mask_labels = torch.stack([torch.all(torch.eq(layer_labels[i], layer_labels), dim=1)
                                       for i in range(layer_labels.shape[0])]).type(torch.uint8).to(device)  
            # compute loss for current layer
            layer_loss = self.sup_con_loss(features, mask=mask_labels)  
            if self.loss_type == 'hmc':
                # add weighted layer loss to cumulative loss
                cumulative_loss += self.layer_penalty(torch.tensor(
                  1/(l)).type(torch.float)) * layer_loss  
            elif self.loss_type == 'hce':
                 # take max of lower layer loss and current layer loss
                layer_loss = torch.max(max_loss_lower_layer.to(layer_loss.device), layer_loss) 
                cumulative_loss += layer_loss  # add max loss to cumulative loss
            elif self.loss_type == 'hmce':
                # take max of lower layer loss and current layer loss
                layer_loss = torch.max(max_loss_lower_layer.to(layer_loss.device), layer_loss)  
                # add weighted max loss to cumulative loss
                cumulative_loss += self.layer_penalty(torch.tensor(
                    1/l).type(torch.float)) * layer_loss  
            else:
                raise NotImplementedError('Unknown loss')
            # get unique instances for next lower layer
            _, unique_indices = unique(layer_labels, dim=0) 
            # update max loss for next lower layer
            max_loss_lower_layer = torch.max(
                max_loss_lower_layer.to(layer_loss.device), layer_loss)  
            # update labels for next lower layer
            labels = labels[unique_indices]  
            # update mask for next lower layer
            mask = mask[unique_indices]  
            # update features for next lower layer
            features = features[unique_indices]  
        # return average loss across all levels of the hierarchy
        return cumulative_loss / labels.shape[1]  


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss
