import sys
import torch
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F

sys.path.append(".")
from image_uncertainty.cifar.cifar_evaluate import (
    load_model, get_eval_args, described_plot, cifar_test,
    misclassification_detection
)
from image_uncertainty.cifar.cifar_datasets import get_training_dataloader
from image_uncertainty.cifar import settings
from nuq import NuqClassifier
from experiments.imagenet_discrete import dump_ues


def get_embeddings(model, loader):
    labels = []
    embeddings = []
    for i, (images, batch_labels) in enumerate(tqdm(loader)):
        with torch.no_grad():
            if args.gpu:
                images = images.cuda()
            # embeddings.append(model(images).cpu().numpy())
            model(images)
            embeddings.append(model.feature.cpu().numpy())
        labels.extend(batch_labels.tolist())
        # if i == 5:
        #     break

    return np.concatenate(embeddings), np.array(labels)


def main(args):
    train_loader, val_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=args.b,
        shuffle=True,
        ood_name=args.ood_name,
        seed=args.data_seed
    )

    test_loader = cifar_test(args.b, False, args.ood_name)
    ood_loader = cifar_test(args.b, True, args.ood_name)

    model = load_model(args.net, args.weights, args.gpu)
    model.eval()


    base_dir = './'
    if not args.cached:
        x_train, y_train = get_embeddings(model, train_loader)
        x_val, y_val = get_embeddings(model, val_loader)
        x_test, y_test = get_embeddings(model, test_loader)
        x_ood, y_ood = get_embeddings(model, ood_loader)

        with open('t_x_train.npy', 'wb') as f:
            np.save(f, x_train)
        with open('t_y_train.npy', 'wb') as f:
            np.save(f, y_train)
        with open('t_x_val.npy', 'wb') as f:
            np.save(f, x_val)
        with open('t_y_val.npy', 'wb') as f:
            np.save(f, y_val)
        with open('t_x_test.npy', 'wb') as f:
            np.save(f, x_test)
        with open('t_y_test.npy', 'wb') as f:
            np.save(f, y_test)
        with open('t_x_ood.npy', 'wb') as f:
            np.save(f, x_ood)

        ood_val_loader = cifar_test(args.b, True, 'svhn')
        x_ood_val, y_ood_val = get_embeddings(model, ood_val_loader)
        with open('t_x_ood_val.npy', 'wb') as f:
            np.save(f, x_ood_val)
    else:
        with open(f'{base_dir}t_x_train.npy', 'rb') as f:
            x_train = np.load(f)
        with open(f'{base_dir}t_y_train.npy', 'rb') as f:
            y_train = np.load(f)
        with open(f'{base_dir}t_x_val.npy', 'rb') as f:
            x_val = np.load(f)
        with open(f'{base_dir}t_y_val.npy', 'rb') as f:
            y_val = np.load(f)
        with open(f'{base_dir}t_x_test.npy', 'rb') as f:
            x_test = np.load(f)
        with open(f'{base_dir}t_y_test.npy', 'rb') as f:
            y_test = np.load(f)
        with open(f'{base_dir}t_x_ood.npy', 'rb') as f:
            x_ood = np.load(f)
        with open(f'{base_dir}t_x_ood_val.npy', 'rb') as f:
            x_ood_val = np.load(f)

        print(x_ood_val.shape)


    def calc_gradient_penalty(x, y_pred):
        gradients = torch.autograd.grad(
            outputs=y_pred,
            inputs=x,
            grad_outputs=torch.ones_like(y_pred),
            create_graph=True,
        )[0]

        gradients = gradients.flatten(start_dim=1)

        # L2 norm
        grad_norm = gradients.norm(2, dim=1)

        # Two sided penalty
        gradient_penalty = ((grad_norm - 1) ** 2).mean()

        # One sided penalty - down
        #     gradient_penalty = F.relu(grad_norm - 1).mean()

        return gradient_penalty

    batch_size = 500
    batch_size_ = 5000
    l_gradient_penalty = 0.02
    sigma = 15
    emb_size = 128

    class Head(nn.Module):
        def __init__(self, features, num_embeddings, sigma):
            super().__init__()

            self.gamma = 0.99
            self.sigma = sigma

            embedding_size = emb_size

            self.W = nn.Parameter(torch.normal(torch.zeros(embedding_size, num_embeddings, features), 1))

            self.register_buffer('N', torch.ones(num_embeddings) * 20)
            self.register_buffer('m', torch.normal(torch.zeros(embedding_size, num_embeddings), 1))

            self.m = self.m * self.N.unsqueeze(0)

        def embed(self, x):
            # i is batch, m is embedding_size, n is num_embeddings (classes)
            x = torch.einsum('ij,mnj->imn', x, self.W)
            return x

        def bilinear(self, z):
            embeddings = self.m / self.N.unsqueeze(0)
            diff = z - embeddings.unsqueeze(0)
            y_pred = (- diff ** 2).mean(1).div(2 * self.sigma ** 2).exp()
            return y_pred

        def forward(self, x):
            z = self.embed(x)
            y_pred = self.bilinear(z)

            return z, y_pred

        def update_embeddings(self, x, y):
            z = self.embed(x)

            # normalizing value per class, assumes y is one_hot encoded
            self.N = torch.max(self.gamma * self.N + (1 - self.gamma) * y.sum(0), torch.ones_like(self.N))

            # compute sum of embeddings on class by class basis
            features_sum = torch.einsum('ijk,ik->jk', z, y)

            self.m = self.gamma * self.m + (1 - self.gamma) * features_sum

    def benchmark(dl_test, model2, epoch=0, loss=0):
        x, y = next(iter(dl_test))
        x = x.cuda()
        y = y.cuda()
        x.requires_grad_(True)
        z, y_pred = model2(x)
        accuracy = (torch.sum(torch.argmax(y, dim=-1) == torch.argmax(y_pred, dim=-1)) / len(y)).item()
        bce = F.binary_cross_entropy(y_pred, y).item()
        gp = l_gradient_penalty * calc_gradient_penalty(x, y_pred)

        print(f"{epoch}: {accuracy:.3f}, {bce:.3f}, {gp:.3f}, {loss:.3f}")
        return accuracy

    ds_train = torch.utils.data.TensorDataset(torch.from_numpy(x_train).float(),
                                              F.one_hot(torch.from_numpy(y_train)).float())
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    ds_val = torch.utils.data.TensorDataset(torch.from_numpy(x_val).float(),
                                             F.one_hot(torch.from_numpy(y_val)).float())
    dl_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size_, shuffle=False)


    ds_test = torch.utils.data.TensorDataset(torch.from_numpy(x_test).float(),
                                             F.one_hot(torch.from_numpy(y_test)).float())
    dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size_, shuffle=False)

    ds_ood = torch.utils.data.TensorDataset(torch.from_numpy(x_ood).float(),
                                            F.one_hot(torch.zeros(len(x_ood)).long()).float())
    dl_ood = torch.utils.data.DataLoader(ds_ood, batch_size=batch_size_, shuffle=False)


    def train(sigma):
        head = Head(2048, 100, sigma).cuda()
        head_optimizer = torch.optim.SGD(head.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        for e in range(6):
            head.train()
            for i, (x, y) in enumerate(tqdm(dl_train)):
                x = x.cuda()
                y = y.cuda()
                head_optimizer.zero_grad()
                x.requires_grad_(True)

                z, y_pred = head(x)

                loss1 = F.binary_cross_entropy(y_pred, y)
                loss2 = l_gradient_penalty * calc_gradient_penalty(x, y_pred)
                loss = loss1 + loss2

                loss.backward()
                head_optimizer.step()

                with torch.no_grad():
                    head.update_embeddings(x, y)

            accuracy = benchmark(dl_val, head, e, loss.item())
        return head, accuracy

    accs = []
    sigmas = [0.1, 0.5, 1, 3, 5, 8, 12, 20, 50]
    for sigma in sigmas:
        _, acc = train(sigma)
        accs.append(acc)
    best_sigma = sigmas[np.argmax(accs)]
    print(accs, sigmas, best_sigma)
    head, _ = train(best_sigma)

    def get_ues(dl, model):
        ues = []
        for i, (x, _) in enumerate(tqdm(dl)):
            x = x.cuda()
            output = model(x)[1]
            ues.extend((-1 * output.max(1)[0].cpu().detach()).tolist())
        return np.array(ues)

    ues_test = get_ues(dl_test, head)

    for ood_name in ['svhn', 'lsun', 'smooth']:
        print()
        print('OOD', ood_name)

        ood_loader = cifar_test(args.b, True, ood_name)
        x_ood, y_ood = get_embeddings(model, ood_loader)
        ds_ood = torch.utils.data.TensorDataset(torch.from_numpy(x_ood).float(),
                                                F.one_hot(torch.zeros(len(x_ood)).long()).float())
        dl_ood = torch.utils.data.DataLoader(ds_ood, batch_size=batch_size_, shuffle=False)
        ues_ood = get_ues(dl_ood, head)

        described_plot(
            ues_test, ues_ood, args.ood_name, 'spectral', title_extras=f'DUQ SN'
        )
    import ipdb; ipdb.set_trace()


if __name__ == '__main__':
    args = get_eval_args()
    print(args.__dict__)
    print(args.weights)
    main(args)
