import torch
from torch.utils.data import DataLoader
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import random 
from torch import Tensor
from torch.optim import Optimizer

from typing_extensions import Literal, TypedDict
from tqdm import tqdm

import gzip
import os
from urllib.request import urlretrieve
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def mnist(path=None):
    r"""Return (train_images, train_labels, test_images, test_labels).

    Args:
        path (str): Directory containing MNIST. Default is
            /home/USER/data/mnist or C:\Users\USER\data\mnist.
            Create if nonexistant. Download any missing files.

    Returns:
        Tuple of (train_images, train_labels, test_images, test_labels), each
            a matrix. Rows are examples. Columns of images are pixel values.
            Columns of labels are a onehot encoding of the correct class.
    """
    url = 'http://yann.lecun.com/exdb/mnist/'
    files = ['train-images-idx3-ubyte.gz',
             'train-labels-idx1-ubyte.gz',
             't10k-images-idx3-ubyte.gz',
             't10k-labels-idx1-ubyte.gz']

    if path is None:
        # Set path to /home/USER/data/mnist or C:\Users\USER\data\mnist
        path = os.path.join(os.path.expanduser('~'), 'data', 'mnist')

    # Create path if it doesn't exist
    os.makedirs(path, exist_ok=True)

    # Download any missing files
    for file in files:
        if file not in os.listdir(path):
            urlretrieve(url + file, os.path.join(path, file))
            print("Downloaded %s to %s" % (file, path))

    def _images(path):
        """Return images loaded locally."""
        with gzip.open(path) as f:
            # First 16 bytes are magic_number, n_imgs, n_rows, n_cols
            pixels = np.frombuffer(f.read(), 'B', offset=16)
        return pixels.reshape(-1, 784).astype('float32') / 255

    def _labels(path):
        """Return labels loaded locally."""
        with gzip.open(path) as f:
            # First 8 bytes are magic_number, n_labels
            integer_labels = np.frombuffer(f.read(), 'B', offset=8)

        def _onehot(integer_labels):
            """Return matrix whose rows are onehot encodings of integers."""
            n_rows = len(integer_labels)
            n_cols = integer_labels.max() + 1
            onehot = np.zeros((n_rows, n_cols), dtype='uint8')
            onehot[np.arange(n_rows), integer_labels] = 1
            return onehot

        return _onehot(integer_labels)

    train_images = _images(os.path.join(path, files[0]))
    train_labels = _labels(os.path.join(path, files[1]))
    test_images = _images(os.path.join(path, files[2]))
    test_labels = _labels(os.path.join(path, files[3]))
    #Reshape images and concatenate all examples
    images= np.concatenate((train_images.reshape((60000,1,28,28)), test_images.reshape((10000,1,28,28))),axis=0)
    labels= np.concatenate( (train_labels,test_labels),axis=0)

    return Tensor(train_images).reshape((-1,1,28,28)),Tensor(test_images).reshape((-1,1,28,28))

training_data, test_data = mnist()

def p_expmap0(x, epsilon=1e-6): #map point to origin
  return torch.tanh(torch.norm(x, p=2, dim=1, keepdim=True))*(x/torch.norm(x, p=2, dim=1, keepdim=True))
def d_expmap0(x,epsilon=1e-6): #map point to origin
  return torch.tan(torch.norm(x, p=2, dim=1, keepdim=True))*(x/torch.norm(x, p=2, dim=1, keepdim=True))

class Encoder(nn.Module):
    def __init__(self, hidden_dim=20, latent_dim=2):  
        super(Encoder, self).__init__()


        self.encoder = nn.Sequential(
            nn.Conv2d(1,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim,latent_dim,3),
            nn.BatchNorm2d(latent_dim),
            nn.SiLU(),
            nn.Flatten(),
            nn.Linear(latent_dim*6*6,latent_dim*6*6),
            nn.SiLU(),
        )


    def forward(self, x):

        z = self.encoder(x)

        return z 

class Decoder(nn.Module):
    
    def __init__(self, hidden_dim=20, latent_dim=2):
        super().__init__()

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim*6*6,latent_dim*6*6),
            nn.SiLU(),
            nn.Unflatten(dim=1, unflattened_size=(latent_dim, 6, 6)),
            torch.nn.ConvTranspose2d(latent_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,hidden_dim,3),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            torch.nn.ConvTranspose2d(hidden_dim,1,3),
            nn.Sigmoid()
        )
        
    def forward(self, z):

        x = self.decoder(z)

        return x

class Autoencoder(nn.Module):

    
          
    def __init__(self, latent_space='e', hidden_dim = 20, base_dimensionality=2):
        super(Autoencoder, self).__init__()

        def extract_list_of_model_spaces(productString):

          """ Count number of model spaces based on product string, this information
          is used in multiple computations later."""

          num_e = 0
          num_h = 0
          num_s = 0
          num_p = 0
          num_d = 0

          # Count how many times each model space appears
          for letter in productString:
            if letter.lower() == 'e':
              num_e += 1
            elif letter.lower() == 'h':
              num_h += 1
            elif letter.lower() == 's':
              num_s += 1
            elif letter.lower() == 'p':
              num_p += 1
            elif letter.lower() == 'd':
              num_d += 1
            else:
              print('Invalid Sequence')
          return num_e, num_h, num_s, num_p, num_d
    
        def infer_hidden_dimensions(num_e, num_h, num_s, num_p, num_d, base_dimensionality):

          """ Function to infer dimensionality of GNN, Linear, BatchNorm1d layers based on the product
          manifold length (how many model spaces are used)"""

          total_number_of_products = num_e + num_h + num_s + num_p + num_d

          return base_dimensionality*total_number_of_products
        
        self.hidden_dim = hidden_dim
        self.base_dimensionality = base_dimensionality
        self.num_e, self.num_h, self.num_s, self.num_p, self.num_d = extract_list_of_model_spaces(latent_space)
        self.latent_dim = infer_hidden_dimensions(self.num_e, self.num_h, self.num_s, self.num_p, self.num_d,self.base_dimensionality)
        self.encoder = Encoder(hidden_dim=self.hidden_dim,latent_dim=self.latent_dim)
        self.decoder = Decoder(hidden_dim=self.hidden_dim,latent_dim=self.latent_dim)
        self.latent_chunk = self.latent_dim*6*6


    def forward(self, x):
        x = x.to(device)
        z = self.encoder(x)

        idx_start = 0
        idx_end = 0
        first = True

        for counter in range(self.num_e):

            idx_end += self.latent_chunk

            if first:
              z_proj = z[...,idx_start:idx_end]
              first = False
            else:
              z_proj = torch.cat((z_proj,z[...,idx_start:idx_end]),dim=-1)

            idx_start += self.latent_chunk

        for counter in range(self.num_p):

            idx_end += self.latent_chunk

            if first:
              z_proj = p_expmap0(z[...,idx_start:idx_end])
              first = False
            else:
              z_proj = torch.cat((z_proj,p_expmap0(z[...,idx_start:idx_end])),dim=-1)

            idx_start += self.latent_chunk

        for counter in range(self.num_d):

            idx_end += self.latent_chunk

            if first:
              z_proj = d_expmap0(z[...,idx_start:idx_end])
              first = False
            else:
              z_proj = torch.cat((z_proj,d_expmap0(z[...,idx_start:idx_end])),dim=-1)

            idx_start += self.latent_chunk

        return self.decoder(z_proj)

def train_step_d(model,
               data_train,
               data_test,
               optimizer,
               loss_fn,
               epoch
               ):

  # Set model to training mode
  model.train()
  # Initialize optimizer
  optimizer.zero_grad()

  # Make Prediction
  reconstruction = model(data_train)

  # Compute reconstruction loss
  loss = loss_fn(reconstruction, data_train)
  
  # Backpropagate through loss, this updates the GNN model parameters
  loss.backward()
  optimizer.step()

  if epoch % 100 == 0:
    model.eval()
    test_reconstruction = model(data_test)
    test_loss = loss_fn(test_reconstruction, data_test)
    # plot_reconstruction(model,data_test)
    print('Train loss: ',loss.item(),'Test loss: ',test_loss.item())

def plot_reconstruction(model,data_test):
  model.eval()
  cols, rows = 3, 3
  figure = plt.figure(figsize=(8, 8))
  for i in range(1, cols * rows + 1):
      img = model(data_test)
      figure.add_subplot(rows, cols, i)
      plt.axis("off")
      plt.imshow(img[i].cpu().detach().numpy().squeeze(), cmap="gray")
  plt.show()

#Trainig loop

def train_latent_space(latent_space,
                       MAX_EPOCHS = 5000,
                       LEARNING_RATE = 1e-4,
                       WEIGHT_DECAY = 1e-4,
                       BATCH_SIZE = 500):
  loss_fn = nn.MSELoss()
  model = Autoencoder(latent_space=latent_space)
  model.to(device)

  optimizer = torch.optim.Adam(model.parameters(),
                              lr = LEARNING_RATE,
                              weight_decay = WEIGHT_DECAY)


  for epoch in range(1,MAX_EPOCHS):

    training_idx = random.sample(range(0, training_data.shape[0]), BATCH_SIZE)
    data_train = training_data[training_idx].to(device)
    test_idx = random.sample(range(0, test_data.shape[0]), BATCH_SIZE)
    data_test = test_data[test_idx].to(device)

    train_step_d(model,
                data_train,
                data_test,
                optimizer,
                loss_fn,
                epoch
                )
    
  model.eval()

  start_idx = 0
  end_idx = 100
  final_test_loss = 0
  for steps in range(100):

    final_test_reconstruction = model(test_data[start_idx:end_idx].to(device))
    final_test_loss += loss_fn(final_test_reconstruction,test_data[start_idx:end_idx].to(device)).item()
    start_idx+=100
    end_idx+=100
    del final_test_reconstruction

  print(str(final_test_loss))
  results.append(str(final_test_loss))

  torch.cuda.empty_cache()
  del final_test_loss

import random

def calculateCombinations(number_of_modelspaces):
  #Level h = 1 in the formula we derived correspond to there been at least two model spaces
  h = number_of_modelspaces-1 
  return int((1) + (1 + h) + (1 + h + h*(h+1)/2))

def generate_all_combinations(numModelSpaces=3):
  
  combinations = {}
  combinations_set = set()
  choseModelSpace = ['e','p','d']
  numCombinations = calculateCombinations(numModelSpaces)

  counter = 0
  while len(combinations) < numCombinations:
    proposed_product_string =''
    for iteration in range(numModelSpaces):
      if iteration == numModelSpaces - 1 or numModelSpaces == 1:
        proposed_product_string += random.choice(choseModelSpace)
      else:
        proposed_product_string += random.choice(choseModelSpace)
    check = 1
    proposed_product_string = ''.join(sorted(proposed_product_string))
    if proposed_product_string not in combinations_set:
      combinations_set.add(proposed_product_string)
      counter += 1
      combinations[proposed_product_string] = []

  return combinations

number_of_modelspaces = 7
results = []
for spaces in range(1,number_of_modelspaces+1): 
  # Create dictionary with all product manifold combinations for a given dimensionality
  combinations = generate_all_combinations(numModelSpaces=spaces)

  for latent_space in combinations:
    results.append(latent_space)
    print(latent_space)
    train_latent_space(latent_space,MAX_EPOCHS = 2500)
  
with open("file.txt", "w") as output:
    output.write(str(results))
