import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.loss import LabelSmoothingCrossEntropy


def cross_entropy_loss(z, zt, ytrue, label_smoothing=0):
    zz = torch.cat((z, zt))
    yy = torch.cat((ytrue, ytrue))
    if label_smoothing > 0:
        ce = LabelSmoothingCrossEntropy(label_smoothing)(zz, yy)
    else:
        ce = nn.CrossEntropyLoss()(zz, yy)
    return ce

def cross_entropy(z, zt):
    Pz = F.softmax(z, dim=1)
    Pzt = F.softmax(zt, dim=1)
    return -(Pz * torch.log(Pzt)).mean()

def agmax_loss(z, zt, y_true, d1_weight=1.0):
    Pz = F.softmax(z, dim=1)
    Pzt = F.softmax(zt, dim=1)
    zzt = z * zt
    Pzzt = F.softmax(zzt, dim=1)
    d1_loss = nn.L1Loss()
    yy = torch.cat((Pz, Pzt))
    zz = torch.cat((Pzzt, Pzzt))
    d1 = d1_weight * d1_loss(zz, yy)
    entropy = entropy_loss(Pz, Pzt, Pzzt)
    return entropy, d1

def clamp_to_eps(Pz, Pzt, Pzzt):
    eps = np.finfo(float).eps
    Pz[(Pz < eps).data] = eps
    Pzt[(Pzt < eps).data] = eps
    Pzzt[(Pzzt < eps).data] = eps

    return Pz, Pzt, Pzzt

def batch_probability(Pz, Pzt, Pzzt):
    Pz = Pz.sum(dim=0)
    Pzt = Pzt.sum(dim=0)
    Pzzt = Pzzt.sum(dim=0)

    Pz = Pz/Pz.sum()
    Pzt = Pzt / Pzt.sum()
    Pzzt = Pzzt / Pzzt.sum()
    return clamp_to_eps(Pz, Pzt, Pzzt)


def entropy_loss(Pz, Pzt, Pzzt):
    Pz, Pzt, Pzzt = batch_probability(Pz, Pzt, Pzzt)
    entropy = (Pz * torch.log(Pz)).sum()
    entropy += (Pzt * torch.log(Pzt)).sum()
    entropy += (Pzzt * torch.log(Pzzt)).sum()
    entropy /= 3
    return entropy