import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3

import numpy as np
from scipy.stats import entropy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning) 


def inception_score(imgs, device, batch_size=32, resize=False, splits=1):
    """
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    """

    N = len(imgs)
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
    inception_model.eval()
    up = lambda x: F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False).to(device)

    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x, dim=1).data.cpu().numpy()

    preds = np.zeros((N, 1000))
    for i, batch in enumerate(dataloader, 0):
        sample = batch[0].to(device)
        batch_size_i = sample.size(0)

        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(sample)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))
    return split_scores


class EstimateInceptionScore:
    def __init__(self, chain_sampler, device):
        self.chain_sampler = chain_sampler
        self.device = device

    def get_score(self, N, splits_n, simple_gan_flag=True):
        self.chain_sampler.sample_chain(N)

        sampled_chain = torch.utils.data.TensorDataset(torch.stack(self.chain_sampler.chain))

        score_full_mh_chain = inception_score(sampled_chain, self.device, batch_size=32, resize=True, splits=splits_n)
        print('full MH chain counted:')
        print(np.mean(score_full_mh_chain), np.std(score_full_mh_chain))

        del sampled_chain
        torch.cuda.empty_cache()
        
        score_tr_mh_chain = 0.0
        if False:
            print('kek')
            idx_tr = torch.LongTensor(np.where(self.chain_sampler.flag_tr == 1.0)[0])
            M = len(np.where(self.chain_sampler.flag_tr == 1.0)[0])
            print(M)

            sampled_chain_tr = torch.stack(self.chain_sampler.chain).index_select(0, idx_tr)
            sampled_chain_tr = torch.utils.data.TensorDataset(sampled_chain_tr)

            score_tr_mh_chain = inception_score(sampled_chain_tr, self.device, batch_size=32, resize=True, splits=splits_n)
            print('transition chain counted:')
            print(np.mean(score_tr_mh_chain), np.std(score_full_mh_chain))

            del sampled_chain_tr
            torch.cuda.empty_cache()

        if simple_gan_flag:
            fixed_noise = torch.randn(N, 100, 1, 1, device=self.device)
            simple_GAN_chain = torch.utils.data.TensorDataset(self.chain_sampler.G_net(fixed_noise))
            score_gan_chain = inception_score(simple_GAN_chain, self.device, batch_size=32, resize=True, splits=splits_n)

            del fixed_noise
            del simple_GAN_chain
            torch.cuda.empty_cache()

            print('simple GAN chain counted:')
            print(np.mean(score_gan_chain), np.std(score_gan_chain))
            result = (score_full_mh_chain, score_gan_chain)
        else:
            result = (score_full_mh_chain, score_tr_mh_chain)
        return result
