import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import cvxpy as cvx
import numpy as np
import matplotlib.pyplot as plt
import scipy.fftpack as fftpack

# the structure of the VAE model
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # Defining the encoder architecture
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # Defining the decoder architecture
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        # Defining the encoder forward pass
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # Returning (mu, log_var)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # Returning z sample
        
    def decoder(self, z):
        # Defining the decoder forward pass
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        # Defining the global forward pass
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# find the index sets of each class
def find_index_classes(test_dataset):
    index_dict = {}
    for i in range(len(test_dataset)):
        label = test_dataset[i][1]
        if 'class'+str(label) not in index_dict:
            index_dict['class'+str(label)] = [i]
        else:
            index_dict['class'+str(label)].append(i)
    return index_dict

# this function is to select the index sets of the test images to be used
def select_testImageIndexes(index_dict, num_images=30, num_classes = 10):
    index_set = []
    if num_images % num_classes != 0:
        print('N must be divisable by ' + str(num_classes))
    else:
        for j in range(num_images//num_classes):
            for i in range(num_classes):
                curr_list = index_dict['class'+str(i)]
                index_set.append(curr_list[j])
    return index_set

# display all the test images of MNIST
def showImages(images, num_classes = 10):
    num_images = len(images)
    nrows = num_images // num_classes
    fig, ax = plt.subplots(nrows, num_classes, figsize=(15, 5))
    if nrows == 1:
        row_num = 0
        for j in range(num_classes):
            ax[j].imshow(images[j + row_num * num_classes].view(28, 28).cpu().numpy(), cmap='gray')
            ax[j].set_xticks([])
            ax[j].set_yticks([])
    else:
        for row_num in range(nrows):
            for j in range(num_classes):
                ax[row_num, j].imshow(images[j+row_num*num_classes].view(28,28).cpu().numpy(), cmap ='gray')
                ax[row_num,j].set_xticks([])
                ax[row_num,j].set_yticks([])
    fig.suptitle('All the test images')
    plt.show()
    
# display all the test images of CelebA
def showImages_celeba(images, num_images_per_row = 10): 
    num_images = len(images)
    nrows = num_images // num_images_per_row
    fig, ax = plt.subplots(nrows, num_images_per_row, figsize=(15, 5))
    if nrows == 1:
        row_num = 0
        for j in range(num_images_per_row):
            currImg = images[j + row_num * num_images_per_row]
            currImg = currImg.cpu().detach().numpy()
            currImg = currImg.transpose((1, 2, 0))
            ax[j].imshow(currImg, interpolation='bilinear')
            ax[j].set_xticks([])
            ax[j].set_yticks([])
    else:
        for row_num in range(nrows):
            for j in range(num_images_per_row):
                currImg = images[j + row_num * num_images_per_row]
                currImg = currImg.cpu().detach().numpy()
                currImg = currImg.transpose((1, 2, 0)) # 3 * 32 * 32 to 32 * 32 * 3
                #print(currImg.shape)
                ax[row_num, j].imshow(currImg, interpolation='bilinear')
                ax[row_num, j].set_yticks([])
    fig.suptitle('All the test images')
    plt.show()
    
# uniform quantization
def unif_quant(x, Delta):
    return Delta * (0.5 + torch.floor(x/Delta))

# generate an m * n standard Gaussian matrix
def genA(m, n):
    normal = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
    A = normal.sample((m,n)).squeeze()
    return A

def cos_sim(tensor1, tensor2):
    t1 = torch.squeeze(tensor1)
    t2 = torch.squeeze(tensor2)
    if torch.norm(tensor1) > 0. and torch.norm(tensor2) > 0.:
        return torch.dot(t1, t2)/(torch.norm(t1) * torch.norm(t2))
    else:
        return 0.

def rel_dist(tensor1, tensor2):
    t1 = torch.squeeze(tensor1)
    t2 = torch.squeeze(tensor2)
    return torch.norm(t1 - t2) / torch.norm(t2)

# images contains all the test images; res contains all the reconstructed vectors; n is the ambient dimension
def dist_vec(res, images, n, dist_measure = 'cos_sim'):
    num_testImages = len(images)
    num_reconstructed = len(res)
    result = []
    if num_reconstructed % num_testImages != 0:
        print('The number of reconstructed images must be divisable by the number of test images!')
        return []
    else:
        res = torch.FloatTensor(res)
        num_div = num_reconstructed // num_testImages
        for i in range(num_div):
            for j in range(num_testImages):
                if dist_measure == 'cos_sim':
                    result.append(cos_sim(res[j+i*num_testImages], images[j].view(n)))
                elif dist_measure == 'rel_dist':
                    result.append(rel_dist(res[j+i*num_testImages], images[j].view(n)))                        
                else:
                    print('The measure of distance must be cos_sim or rel_dist!')
                    return []
    return np.array(result)

# calculate T for various nonlinear models
def model_T(model):
    if model == '1bit':
        T = np.sqrt(2/np.pi)
    elif model == 'relu':
        T = 0.5
    elif model == 'unifQuant':
        T = 1
    elif model == '1bitDither':
        T = 0.  # it is dependent on R and will be specified later    
    else:
        print('Model not found!')
        return
    return T

def shrinkage(x, alpha):
    return torch.max(abs(x)-alpha, torch.zeros(x.shape).cuda()) * torch.sign(x)
