# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch 
import torch.nn as nn

from torch.nn import functional as F


def ce_loss_logit(logits, targets, reduction='none', mask=None):
    """
    cross entropy loss in pytorch.

    Args:
        logits: logit values, shape=[Batch size, # of classes]
        targets: integer or vector, shape=[Batch size] or [Batch size, # of classes]
        # use_hard_labels: If True, targets have [Batch size] shape with int values. If False, the target is vector (default True)
        reduction: the reduction argument
    """
    if logits.shape == targets.shape:
        # one-hot target
        log_pred = F.log_softmax(logits, dim=-1)
        nll_loss = torch.sum(-targets * log_pred, dim=1)
        if reduction == 'none':
            return nll_loss
        else:
            return nll_loss.mean()
    else:
        if mask is None:
            log_pred = F.log_softmax(logits, dim=-1)
            return F.nll_loss(log_pred, targets, reduction=reduction)
        else:
            log_pred = F.log_softmax(logits, dim=-1)
            targets = F.one_hot(targets, num_classes=len(logits[0]))
            loss_matrix = -log_pred * targets
            mask = mask.clone().detach().repeat(len(logits), 1)
            return (loss_matrix * mask).mean()


def ce_loss_prob(probs, targets, reduction='none'):
    """
    cross entropy loss in pytorch.

    Args:
        logits: probability values, shape=[Batch size, # of classes]
        targets: integer or vector, shape=[Batch size] or [Batch size, # of classes]
        # use_hard_labels: If True, targets have [Batch size] shape with int values. If False, the target is vector (default True)
        reduction: the reduction argument
    """
    if probs.shape == targets.shape:
        # one-hot target

        nll_loss = torch.sum(-targets * probs, dim=1)
        if reduction == 'none':
            return nll_loss
        else:
            return nll_loss.mean()
    else:
        return F.nll_loss(probs, targets, reduction=reduction)


class CELoss(nn.Module):
    """
    Wrapper for ce loss
    """
    def forward(self, logits, targets, reduction='none', mask=None):
        return ce_loss_logit(logits, targets, reduction, mask)


class NCMLoss(nn.Module):
    """
    Wrapper for ce loss
    """
    def forward(self, probs, targets, reduction='none'):
        return ce_loss_prob(probs, targets, reduction)
