import os
import urllib.request
import tarfile
import torch
from torch import nn as nn
import math
import copy
import torch.optim as optim
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms

class MiniDataLoader:
    def __init__(self, *tensor_dataset, batch_size):
        num_data=tensor_dataset[0].shape[0]
        for tensors in tensor_dataset:
            assert num_data == tensors.shape[0], "{} vs {}".format(tensor_dataset[0].shape,tensors.shape)
        self.num_data=num_data
        self.dataset=tensor_dataset
        self.batch_size=batch_size

    def __iter__(self):
        if self.batch_size<0:
            return iter([self.dataset])
        curr_perm=torch.randperm(self.num_data)
        curr_data=[data[curr_perm] for data in self.dataset]
        iterator=[]
        for batch in range(math.ceil(self.num_data/self.batch_size)):
            iterator.append([data[batch*self.batch_size:(batch+1)*self.batch_size] for data in curr_data])
        return iter(iterator)

    def truncate_data(self, length):
        rdn_state = torch.get_rng_state()
        torch.random.manual_seed(0)
        data_idx = torch.randperm(self.dataset[0].shape[0])[:length]
        torch.set_rng_state(rdn_state)
        self.dataset=tuple(dataset[data_idx] for dataset in self.dataset)

def get_concentric_data(num_data, dim=2):
    data=torch.zeros(num_data,dim)
    labels=torch.zeros(num_data,1)
    for i in range(num_data):
        radius=i % 2
        angle=i/num_data  #+ 0.1*torch.rand(1)
        data[i, 0]=(radius*5+3)*math.cos(math.pi*2*angle) #+torch.randn(1)*0.5
        data[i, 1]=(radius*5+3)*math.sin(math.pi*2*angle) #+torch.randn(1)*0.5
        labels[i,0]=radius*2-1
    return data.double(), labels.double()

def get_gaussian_data(num_data, dim=2):
    data=torch.zeros(num_data,dim)
    labels=torch.zeros(num_data,1)
    for i in range(num_data):
        offset=torch.rand(1).round().item()*2-1
        data[i, 0]+=offset*5+torch.randn(1).item()
        data[i, 1]+=1+torch.randn(1).item()
        labels[i,0]=offset
    return data.double(), labels.double()

def get_unit_circle_data(num_data, frac, dim=2, offset=0):
    data=torch.zeros(num_data,dim)
    angles=[]
    labels=torch.zeros(num_data, 1)
    for i in range(num_data):
        data[i, 0]=math.cos(math.pi*2*i/num_data*frac+offset)
        data[i, 1]=math.sin(math.pi*2*i/num_data*frac+offset)
        angles.append(math.pi*2*i/num_data*frac+offset)
        labels[i] = 1-(i%2)*2
    return data.double(), labels.double()
def get_vertical_data(num_data, min, max, dim=2):
    data=torch.zeros(num_data,dim)
    xs=[]
    for i, x in enumerate(np.arange(min, max, (max-min)/num_data)):
        data[i, 1]=x
        xs.append(x)
    return data.double(), xs

def get_orthogonal_data(num_data, min, max, dim=2):
    assert dim >= 3
    data=torch.zeros(num_data, dim)
    xs=[]
    for i, x in enumerate(np.arange(min, max, (max-min)/num_data)):
        data[i, 2]=x
        xs.append(x)
    return data.double(), xs

def get_fmnist(batch_size, data_dir='./data', num_train_data=1000, binary=False, dtype=torch.double):
    transform = transforms.Compose([
#                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))])
            
    train_data = dset.FashionMNIST(data_dir, train=True,
                                           transform=transform,
                                           download=True)
    test_data = dset.FashionMNIST(data_dir, train=False,
                                          transform=transform,
                                          download=True)

    mnist_test_data = dset.MNIST(data_dir, train=False, transform=transform, download=True)
    emnist_test_data = dset.EMNIST(data_dir, train=False, transform=transform, download=True, split='letters')
    kmnist_test_data = dset.KMNIST(data_dir, train=False, transform=transform, download=True)

    train_label = train_data.targets
    test_label = test_data.targets

    if binary:
        train_label = (train_label % 2).unsqueeze(-1)*2-1
        test_label = (test_label % 2).unsqueeze(-1)*2-1

    train_dloader = MiniDataLoader(transform(train_data.data.float().unsqueeze(1) / 255.).type(dtype),
                                   train_label.type(dtype),
                                   batch_size=batch_size)
    test_dloader = MiniDataLoader(transform(test_data.data.float().unsqueeze(1) / 255.).type(dtype), test_label.type(dtype),
                                  batch_size=batch_size)
    ood_dloaders = [
        MiniDataLoader(transform(mnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), mnist_test_data.targets,
                       batch_size=batch_size),
        MiniDataLoader(transform(emnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), emnist_test_data.targets,
                       batch_size=batch_size),
        MiniDataLoader(transform(kmnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), kmnist_test_data.targets,
                       batch_size=batch_size)
    ]

    return train_dloader, test_dloader, ood_dloaders


def get_mnist(batch_size, data_dir='./data', num_train_data=1000, binary=False, dtype=torch.double):
    train_transform = transforms.Compose([transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.1307,), std=(0.3081,))])

    train_data = dset.MNIST(data_dir, train=True, download=True)

    rdn_state = torch.get_rng_state()
    torch.random.manual_seed(0)
    train_data_idx = torch.randperm(len(train_data))[:num_train_data]

    permutation = torch.randperm(28 * 28)
    torch.set_rng_state(rdn_state)

    test_data = dset.MNIST(data_dir, train=False, transform=transform, download=True)
    fmnist_test_data = dset.FashionMNIST(data_dir, train=False, transform=transform, download=True)
    emnist_test_data = dset.EMNIST(data_dir, train=False, transform=transform, download=True, split='letters')
    kmnist_test_data = dset.KMNIST(data_dir, train=False, transform=transform, download=True)
    permuted_mnist_data = test_data.data.view(test_data.data.shape[0], -1)[:, permutation].view(*test_data.data.shape)

    train_label = train_data.targets[train_data_idx]
    test_label = test_data.targets

    if binary:
        train_label = (train_label % 2).unsqueeze(-1)*2-1
        test_label = (test_label % 2).unsqueeze(-1)*2-1
    else:
        train_label = torch.nn.functional.one_hot(train_label.long(), 10)
        test_label = torch.nn.functional.one_hot(test_label.long(), 10)

    train_dloader = MiniDataLoader(train_transform(train_data.data[train_data_idx].float().unsqueeze(1) / 255.).type(dtype),
                                   train_label.type(dtype),
                                   batch_size=batch_size)
    test_dloader = MiniDataLoader(train_transform(test_data.data.float().unsqueeze(1) / 255.).type(dtype), test_label.type(dtype),
                                  batch_size=batch_size)
    ood_dloaders = [
        MiniDataLoader(train_transform(fmnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), fmnist_test_data.targets,
                       batch_size=batch_size),
        MiniDataLoader(train_transform(emnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), emnist_test_data.targets,
                       batch_size=batch_size),
        MiniDataLoader(train_transform(kmnist_test_data.data.float().unsqueeze(1) / 255.).type(dtype), kmnist_test_data.targets,
                       batch_size=batch_size),
        MiniDataLoader(train_transform(permuted_mnist_data.float().unsqueeze(1) / 255.).type(dtype), test_data.targets,
                       batch_size=batch_size)]

    return train_dloader, test_dloader, ood_dloaders

def load_cifar_ood_data(dataset_name, transform, num_test_data=1000,data_dir='./data'):

    if dataset_name == 'cifar10':
        cifar_ood_test_data = dset.CIFAR100(data_dir, train=False,
                                        transform=transform,
                                        download=True)
    elif dataset_name == 'cifar100':
        cifar_ood_test_data = dset.CIFAR10(data_dir, train=False,
                                        transform=transform,
                                        download=True)
    else:
        raise Exception('Unknown dataset')

    svhn_ood_test_data = dset.SVHN(data_dir, split='test',
                                    transform=transform,
                                    download=True)
    # download if necessary
    if not (os.path.exists(data_dir + '/LSUN')):
        urllib.request.urlretrieve('https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz?dl=1',
                                   data_dir + '/LSUN_resize.tar.gz')
        tar = tarfile.open(data_dir + '/LSUN_resize.tar.gz', "r:gz")
        tar.extractall(path=data_dir + '/LSUN')
        tar.close()

    lsun_ood_test_data = dset.ImageFolder(data_dir + '/LSUN',
                                          transform=transform)

    # download if necessary
    if not (os.path.exists(data_dir + '/TIN')):
        urllib.request.urlretrieve('https://www.dropbox.com/s/kp3my3412u5k9rl/Imagenet_resize.tar.gz?dl=1',
                                   data_dir + '/Imagenet_resize.tar.gz')
        tar = tarfile.open(data_dir + '/Imagenet_resize.tar.gz', "r:gz")
        tar.extractall(path=data_dir + '/TIN')
        tar.close()

    tin_ood_test_data = dset.ImageFolder(data_dir + '/TIN',
                                         transform=transform)

    if not (os.path.exists(data_dir + '/iSUN')):
        urllib.request.urlretrieve('https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz?dl=1',
                                   data_dir + '/iSUN.tar.gz')
        tar = tarfile.open(data_dir + '/iSUN.tar.gz', "r:gz")
        tar.extractall(path=data_dir + '/iSUN')
        tar.close()

    isun_ood_test_data = dset.ImageFolder(data_dir + '/iSUN',
                                          transform=transform)

    def sample_random(num_sample, data):
        rdn_state = torch.get_rng_state()
        torch.random.manual_seed(1)
        dloader=torch.utils.data.DataLoader(data, batch_size=num_sample, shuffle=True)
        data_image = iter(dloader).next()[0]
        torch.set_rng_state(rdn_state)
        return data_image

    return [sample_random(num_test_data, data) for data in [cifar_ood_test_data, svhn_ood_test_data, lsun_ood_test_data, tin_ood_test_data, isun_ood_test_data]]


def get_cifar(dataset_name, batch_size, data_dir='./data', num_train_data=1000, num_test_data=1000, binary=False, dtype=torch.double):
    if dataset_name == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.201]
        dataset=dset.CIFAR10
    elif dataset_name == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
        dataset=dset.CIFAR100
    else:
        raise Exception('Unknown dataset')

    def normalize(images):
        print("Normalizing mean by", images.mean([0,2,3]))
        print("Normalizing std by", images.std([2,3]).mean(0))

        images -= images.mean([0,2,3]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        images /= images.std([2,3]).mean(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        return images

    transform = transforms.Compose([transforms.ToTensor()]) # ,
#                                    transforms.Normalize(mean=mean, std=std)])

    train_data = dataset(data_dir, transform=transform, train=True, download=True)

    rdn_state = torch.get_rng_state()
    torch.random.manual_seed(0)
    train_data_idx = torch.randperm(len(train_data))[:num_train_data]
    torch.set_rng_state(rdn_state)
    train_image = torch.stack([transform(d/255.) for d in train_data.data[train_data_idx]])
    train_label = torch.Tensor(train_data.targets).int()[train_data_idx]

    test_data = dataset(data_dir, transform=transform, train=False, download=True)
    rdn_state = torch.get_rng_state()
    torch.random.manual_seed(0)
    test_data_idx = torch.randperm(len(test_data))[:num_test_data]
    torch.set_rng_state(rdn_state)
    test_image = torch.stack([transform(d/255.) for d in test_data.data[test_data_idx]])
    test_label = torch.Tensor(test_data.targets).int()[test_data_idx]

    if binary:
        train_label = (train_label % 2).unsqueeze(-1)*2-1
        test_label = (test_label % 2).unsqueeze(-1)*2-1
    else:
        train_label = torch.nn.functional.one_hot(train_label.long(), 10)
        test_label = torch.nn.functional.one_hot(test_label.long(), 10)

    ood_data_list = load_cifar_ood_data(dataset_name, transform, num_test_data=num_test_data, data_dir=data_dir)

    train_dloader = MiniDataLoader(normalize(train_image).type(dtype), train_label.type(dtype), batch_size=batch_size)
    test_dloader = MiniDataLoader(normalize(test_image).type(dtype), test_label.type(dtype), batch_size=batch_size)
    ood_dloaders = [MiniDataLoader(normalize(image_data).type(dtype), torch.zeros(image_data.data.shape[0]),batch_size=batch_size)
                    for image_data in ood_data_list]

    return train_dloader, test_dloader, ood_dloaders


def get_datasets(dataset, num_train_data, num_test_data, dim=2, binary=True):
    rdn_state=torch.get_rng_state()
    torch.random.manual_seed(0)
    if dataset == "concentric":
        train_data, labels=get_concentric_data(num_train_data, dim=dim)
        test_data, test_labels=get_concentric_data(num_test_data, dim=dim)
        train_dloader = MiniDataLoader(train_data, labels, batch_size=-1)
        ind_dloader = MiniDataLoader(test_data, labels, batch_size=-1)
        ood_dloaders=[]
        in_dim = [dim]
        out_dim=1

    if dataset == "full_circle":
        train_data, labels=get_unit_circle_data(num_train_data, frac=1, dim=dim)
        # Test is irrelevant
        test_data, test_labels=get_unit_circle_data(num_test_data, frac=0.5, offset=math.pi/num_train_data/2, dim=dim)
        train_dloader = MiniDataLoader(train_data, labels, batch_size=-1)
        ind_dloader = MiniDataLoader(test_data, test_labels, batch_size=-1)
        ood_dloaders=[]
        in_dim = [dim]
        out_dim=1

        # data=dict()
        # data['cirlce'] = get_unit_circle_data(100, 1)
        # data['vertical'] = get_vertical_data(100, -3., 3.)
        # data['orthogonal'] = get_orthogonal_data(100, -3., 3.)
    if dataset == "half_circle":
        train_data, labels=get_unit_circle_data(num_train_data, frac=0.5, dim=dim)
        test_data, test_labels=get_unit_circle_data(num_test_data, frac=0.5, offset=math.pi/num_train_data/2, dim=dim)
        train_dloader = MiniDataLoader(train_data, labels, batch_size=-1)
        ind_dloader = MiniDataLoader(test_data, test_labels, batch_size=-1)
        ood_dloaders=[]
        in_dim = [dim]
        out_dim=1
        # data=dict()
        # data['cirlce'] = get_unit_circle_data(100, 1)
        # data['vertical'] = get_vertical_data(100, -3., 3.)
        # data['orthogonal'] = get_orthogonal_data(100, -3., 3.)
    elif dataset == "gaussian":
        train_data, labels=get_gaussian_data(num_train_data, dim=dim)
        test_data, test_labels=get_gaussian_data(num_test_data, dim=dim)
        train_dloader = MiniDataLoader(train_data, labels, batch_size=-1)
        ind_dloader = MiniDataLoader(test_data, test_labels, batch_size=-1)
        ood_dloaders=[]
        in_dim = [dim]
        out_dim=1

    elif dataset == "mnist":
        train_dloader, ind_dloader, ood_dloaders = get_mnist(batch_size=-1, data_dir='./data',
                                                            num_train_data=num_train_data, binary=binary)
        ind_dloader.truncate_data(num_test_data)
        for ood_dloader in ood_dloaders:
            ood_dloader.truncate_data(num_test_data)
        in_dim = [1,28,28]
        out_dim=1 if binary else 10

    elif dataset == "fmnist":
        train_dloader, ind_dloader, ood_dloaders = get_fmnist(batch_size=-1, data_dir='./data',
                                                            num_train_data=num_train_data, binary=True)
        train_dloader.truncate_data(num_train_data)
        ind_dloader.truncate_data(num_test_data)
        for ood_dloader in ood_dloaders:
            ood_dloader.truncate_data(num_test_data)
        in_dim = [1,28,28]
        out_dim=1

    elif dataset == "cifar10" or dataset == "cifar100":
        train_dloader, ind_dloader, ood_dloaders = get_cifar(dataset, batch_size=-1, data_dir='./data',
                                                            num_train_data=num_train_data, num_test_data=num_test_data, binary=binary)
#        ind_dloader.truncate_data(num_test_data)
#        for ood_dloader in ood_dloaders:
#            ood_dloader.truncate_data(num_test_data)
        in_dim = [3,32,32]
        out_dim=1 if binary else 10
    
    print("Data shapes")
    print(train_dloader.dataset[0].shape)
    print(ind_dloader.dataset[0].shape)
    for dloader in ood_dloaders:
        print(dloader.dataset[0].shape)

    torch.set_rng_state(rdn_state)
    train_dloader.dataname = dataset
    return train_dloader, ind_dloader, ood_dloaders, in_dim, out_dim
