# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import * 

class CrossEntropyLabelSmooth(nn.Module):
	"""Cross entropy loss with label smoothing regularizer.

	Reference:
	Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
	Equation: y = (1 - epsilon) * y + epsilon / K.

	Args:
		num_classes (int): number of classes.
		epsilon (float): weight.
	"""

	def __init__(self, num_classes, epsilon=0.1):
		super(CrossEntropyLabelSmooth, self).__init__()
		self.num_classes = num_classes
		self.epsilon = epsilon
		self.logsoftmax = nn.LogSoftmax(dim=1).cuda()

	def forward(self, inputs, targets):
		"""
		Args:
			inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
			targets: ground truth labels with shape (num_classes)
		"""
		log_probs = self.logsoftmax(inputs)
		targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
		targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
		loss = (- targets * log_probs).mean(0).sum()
		return loss

class SoftEntropy(nn.Module):
	def __init__(self):
		super(SoftEntropy, self).__init__()
		self.logsoftmax = nn.LogSoftmax(dim=1).cuda()

	def forward(self, inputs, targets):
		log_probs = self.logsoftmax(inputs)
		loss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum()
		return loss

class NegEntropy(nn.Module):
	def __init__(self):
		super(NegEntropy, self).__init__()
		self.logsoftmax = nn.LogSoftmax(dim=-1).cuda()

	def forward(self, inputs):
		log_probs = self.logsoftmax(inputs)
		loss = (F.softmax(inputs, dim=-1) * log_probs).mean(0).sum()
		return loss

class FocalLoss(nn.Module):
	def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
		super(FocalLoss, self).__init__()
		self.gamma = gamma
		self.eps = eps
		self.ce = nn.CrossEntropyLoss(weight=weight, reduction=reduction)

	def forward(self, inputs, targets):
		log_probs = self.ce(inputs, targets)
		probs = torch.exp(-log_probs)
		loss = (1-probs)**self.gamma*log_probs
		return loss

class SoftFocalLoss(nn.Module):
	def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
		super(SoftFocalLoss, self).__init__()
		self.gamma = gamma
		self.eps = eps
		self.ce = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
		self.logsoftmax = nn.LogSoftmax(dim=1).cuda()

	def forward(self, inputs, targets):
		# print("+++data+++")
		# print(inputs)
		log_probs = self.logsoftmax(inputs)
		# print("+++++++")
		# for log_prob in log_probs:
		# 	print(log_prob)
		loss = (- F.softmax(targets, dim=1).detach() * log_probs * ((1-F.softmax(inputs, dim=1))**self.gamma)).mean(0).sum()
		# print("loss")
		# print(loss)
		return loss

# *Generalized Cross Entropy Loss
class GCELoss(nn.Module):
    def __init__(self, q=0.7, ignore_index=-100):
        super(GCELoss, self).__init__()
        self.q = q
        self.ignore_index = ignore_index
    def forward(self, logits, targets, weights):
        valid_idx = targets != self.ignore_index
        logits = logits[valid_idx]
        targets = targets[valid_idx]
        weights = weights[valid_idx]
        # vanilla cross entropy when q = 0
        if self.q == 0:
            if logits.size(-1) == 1:
                ce_loss = nn.BCEWithLogitsLoss(reduction='none')
                loss = ce_loss(logits.view(-1), targets.float())
            else:
                ce_loss = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='none')
                loss = ce_loss(logits, targets)
        else:
            if logits.size(-1) == 1:
                pred = torch.sigmoid(logits)
                pred = torch.cat((1-pred, pred), dim=-1)
            else:
                pred = F.softmax(logits, dim=-1)
            pred = torch.gather(pred, dim=-1, index=torch.unsqueeze(targets, -1))
            loss = (1-pred**self.q) / self.q
        loss = (loss.view(-1)*weights).sum() / weights.sum()
        return loss

# *Debiased Self-training loss
def shift_log(x, offset=1e-6):
    """
    First shift, then calculate log for numerical stability.
    """
    return torch.log(torch.clamp(x + offset, max=1.))

class WorstCaseEstimationLoss(nn.Module):
	def __init__(self, eta_prime):
		super(WorstCaseEstimationLoss, self).__init__()
		self.eta_prime = eta_prime

	def forward(self, y_l, y_l_adv, y_u, y_u_adv):
		_, prediction_l = y_l.max(dim=1)
		loss_l = self.eta_prime * F.cross_entropy(y_l_adv, prediction_l)

		_, prediction_u = y_u.max(dim=1)
		loss_u = F.nll_loss(shift_log(1. - F.softmax(y_u_adv, dim=1)), prediction_u)
		return loss_l + loss_u

# *CONTAiNER Contrastive loss
def nt_xent(loss, num, denom, temperature = 1):
    loss = torch.exp(loss/temperature)
    cnts = torch.sum(num, dim = 1)
    loss_num = torch.sum(loss * num, dim = 1)
    loss_denom = torch.sum(loss * denom, dim = 1)
    # sanity check
    nonzero_indexes = torch.where(cnts > 0)
    loss_num, loss_denom, cnts = loss_num[nonzero_indexes], loss_denom[nonzero_indexes], cnts[nonzero_indexes]

    loss_final = -torch.log2(loss_num) + torch.log2(loss_denom) + torch.log2(cnts)
    return loss_final

def loss_kl(mu_i, sigma_i, mu_j, sigma_j, embed_dimension):
    sigma_ratio = sigma_j / sigma_i
    trace_fac = torch.sum(sigma_ratio, 1)
    log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=1)
    mu_diff_sq = torch.sum((mu_i - mu_j) ** 2 / sigma_i, axis=1)
    ij_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
    sigma_ratio = sigma_i / sigma_j
    trace_fac = torch.sum(sigma_ratio, 1)
    log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=1)
    mu_diff_sq = torch.sum((mu_j - mu_i) ** 2 / sigma_j, axis=1)
    ji_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
    kl_d = 0.5 * (ij_kl + ji_kl)
    return kl_d

def euclidean_distance(a, b, normalize=False):
    if normalize:
        a = F.normalize(a)
        b = F.normalize(b)
    logits = ((a - b) ** 2).sum(dim=1)
    return logits

def CONTAiNERLoss(embedding_dim, output_embedding_mu, output_embedding_sigma, labels_straightened,
                              consider_mutual_O=False, loss_type="KL"):
    if not consider_mutual_O:
        filter_indices = torch.where(labels_straightened > 0)[0]
        filtered_embedding_mu = output_embedding_mu[filter_indices]
        filtered_embedding_sigma = output_embedding_sigma[filter_indices]
        filtered_labels = labels_straightened[filter_indices]
    else:
        filtered_embedding_mu = output_embedding_mu
        filtered_embedding_sigma = output_embedding_sigma
        filtered_labels = labels_straightened
    filtered_instances_nos = len(filtered_labels)
    # repeat interleave
    filtered_embedding_mu = torch.repeat_interleave(filtered_embedding_mu, len(output_embedding_mu), dim=0)
    filtered_embedding_sigma = torch.repeat_interleave(filtered_embedding_sigma, len(output_embedding_sigma),dim=0)
    filtered_labels = torch.repeat_interleave(filtered_labels, len(output_embedding_mu), dim=0)

    # only repeat
    repeated_output_embeddings_mu = output_embedding_mu.repeat(filtered_instances_nos, 1)
    repeated_output_embeddings_sigma = output_embedding_sigma.repeat(filtered_instances_nos, 1)
    repeated_labels = labels_straightened.repeat(filtered_instances_nos)

    # avoid losses with own self
    loss_mask = torch.all(filtered_embedding_mu != repeated_output_embeddings_mu, dim=-1).int()
    loss_weights = (filtered_labels == repeated_labels).int()
    loss_weights = loss_weights * loss_mask

    if loss_type == "euclidean":
        loss = -euclidean_distance(filtered_embedding_mu, repeated_output_embeddings_mu, normalize=True)
    elif loss_type == "KL":  # KL_divergence
        loss = -loss_kl(filtered_embedding_mu, filtered_embedding_sigma,
                            repeated_output_embeddings_mu, repeated_output_embeddings_sigma,
                            embed_dimension=embedding_dim)
    else:
        raise Exception("unknown loss")

    loss = loss.view(filtered_instances_nos, -1)
    loss_mask = loss_mask.view(filtered_instances_nos, -1)
    loss_weights = loss_weights.view(filtered_instances_nos, -1)

    loss_final = nt_xent(loss, loss_weights, loss_mask, temperature = 1)
    return torch.mean(loss_final)

class elr_loss(nn.Module):
    def __init__(self, num_example, num_classes=10, beta=0.3):
        super(elr_loss, self).__init__()
        self.num_classes = num_classes
        self.target = torch.zeros(num_example, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_example, self.num_classes)
        self.beta = beta

    def forward(self, index, output):
        y_pred = F.softmax(output,dim=1)
        y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
        y_pred_ = y_pred.data.detach()
        self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
        elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
        return elr_reg