# MIT License

# Copyright (c) [2023] [Anima-Lab]

# This code is adapted from https://github.com/NVlabs/edm/blob/main/training/loss.py. 
# The original code is licensed under a Creative Commons 
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt. 

"""Loss functions used in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import torch
import torch.nn.functional as F

from utils import *
from train_utils.helper import unwrap_model
from torchmetrics.functional.clustering import mutual_info_score


# Improved loss function proposed in the paper "Elucidating the Design Space
# of Diffusion-Based Generative Models" (EDM).


class EDMLoss_mi:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def __call__(self, net,
                 images, 
                 labels=None, 
                 mask_ratio=0, 
                 mae_loss_coef=0, 
                 feat=None, augment_pipe=None,ema=None):
        # sample x_t
        rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
        n = torch.randn_like(y) * sigma

        model_out,new_out = net(y + n,sigma, labels,clean_x=y, mask_ratio=mask_ratio, mask_dict=None, feat=feat,return_feature=True)
        D_yn = model_out['x']

        '''EMA'''
        ema_out=ema.decode(y + n,sigma,new_out,new_out['clean_x'],mask_ratio)
        ema_out2=ema.decode(y + n,sigma,new_out,None,mask_ratio)

        assert D_yn.shape == y.shape
        loss = weight*((D_yn - y) ** 2)  # (N, C, H, W)
        if mask_ratio > 0:
            assert net.training and 'mask' in model_out
            loss = F.avg_pool2d(loss.mean(dim=1), net.module.model.patch_size).flatten(1)  # (N, L)
            mask=model_out['mask']
            unmask = 1 - model_out['mask']
            # unmask=torch.ones_like(unmask,device=unmask.device)
            loss = (loss * unmask).sum(dim=1) / unmask.sum(dim=1)  # (N) MAE loss
            pure_loss=loss.clone()
            # mae_loss_cur=torch.zeros_like(loss,device=loss.device)
            if mae_loss_coef > 0:
                mae_loss_cur= mae_loss_coef*mae_loss(net.module, y+n, D_yn, 1-unmask,norm_pix_loss=True)
            
            KL_loss1=0.05*mae_loss(net.module,ema_out,D_yn,unmask,norm_pix_loss=True)
            KL_loss2=0.1*mae_loss(net.module,ema_out2,D_yn,unmask,norm_pix_loss=True)
            
            loss=loss+mae_loss_cur+KL_loss2+KL_loss1
        else:
            loss = mean_flat(loss)  # (N)

        raw_net = unwrap_model(net)
        if mask_ratio == 0.0 and raw_net.model.mask_token is not None:
            loss += 0 * torch.sum(raw_net.model.mask_token)
        assert loss.ndim == 1
        return loss,mae_loss_cur,KL_loss1,KL_loss2

# ----------------------------------------------------------------------------


Losses = {
    'edm_mi':EDMLoss_mi
}


# ----------------------------------------------------------------------------

def patchify(imgs, patch_size=2, num_channels=4):
    """
    imgs: (N, 3, H, W)
    x: (N, L, patch_size**2 *3)
    """
    p, c = patch_size, num_channels
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c))
    return x


def mae_loss(net, target, pred, mask, norm_pix_loss=True):
    target = patchify(target, net.model.patch_size, net.model.out_channels)
    pred = patchify(pred, net.model.patch_size, net.model.out_channels)
    if norm_pix_loss:
        mean = target.mean(dim=-1, keepdim=True)
        var = target.var(dim=-1, keepdim=True)
        target = (target - mean) / (var + 1.e-6)**.5

    loss = (pred - target) ** 2
    loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

    loss = (loss * mask).sum(dim=1) / mask.sum(dim=1)  # mean loss on removed patches, (N)
    assert loss.ndim == 1
    return loss
