import argparse
import json
import logging
import os
from time import time

import importlib
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from pathlib import Path

import wandb
from pythae.data.preprocessors import DataProcessor
from pythae.models import AutoModel
from pythae.pipelines import GenerationPipeline
from pythae.trainers import BaseTrainerConfig
from TTUR.fid import compute_fid_wo_paths


logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

PATH = os.path.dirname(os.path.abspath(__file__))

ap = argparse.ArgumentParser()

ap.add_argument(
    "--models_path",
    help="The path to a model to generate from",
    required=True,
)
ap.add_argument(
    "--use_wandb",
    help="whether to log the metrics in wandb",
    action="store_true",
)
ap.add_argument(
    "--wandb_project",
    help="wandb project name",
    default="reconstruction_metrics",
)
ap.add_argument(
    "--wandb_entity",
    help="wandb entity name",
    default="benchmark_team",
)

args = ap.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"

def main(args):

    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    model_signature = os.listdir(args.models_path)[0]

    model_path = os.path.join(args.models_path, model_signature, "final_model")


    # reload the model
    trained_model = AutoModel.load_from_folder(model_path).to(device).eval()
    logger.info(f"Successfully reloaded {trained_model.model_name.upper()} model !\n")

    print(trained_model.training)

    train_data = None
    eval_data = None

    if trained_model.model_config.input_dim == (1, 28, 28):
        dataset = 'mnist'

    elif trained_model.model_config.input_dim == (3, 32, 32):
        dataset = 'cifar10'

    elif trained_model.model_config.input_dim == (3, 64, 64):
        dataset = 'celeba'

    try:
        logger.info(f"\nLoading {dataset} data...\n")
        eval_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "eval_data.npz"))["data"]
            / 255.0
        )
        test_data = (
            np.load(os.path.join(PATH, f"data/{dataset}", "test_data.npz"))["data"]
            / 255.0
        )
        
    except Exception as e:
        raise FileNotFoundError(
            f"Unable to load the data from 'data/{dataset}' folder. Please check that both a "
            "'train_data.npz' and 'eval_data.npz' are present in the folder.\n Data must be "
            " under the key 'data', in the range [0-255] and shaped with channel in first "
            "position\n"
            f"Exception raised: {type(e)} with message: " + str(e)
        ) from e

    logger.info("Successfully loaded data !\n")
    logger.info("------------------------------------------------------------")
    logger.info("Dataset \t \t Shape \t \t \t Range")
    logger.info(
        f"{dataset.upper()} eval data: \t {eval_data.shape} \t [{eval_data.min()}-{eval_data.max()}] "
    )
    logger.info(
        f"{dataset.upper()} test data: \t {test_data.shape} \t [{test_data.min()}-{test_data.max()}]"
    )
    logger.info("------------------------------------------------------------\n")

    dataset_type = (
        "DoubleBatchDataset"
        if trained_model.model_name == "FactorVAE"
        else "BaseDataset"
    )

    data_processor = DataProcessor()
    eval_data = data_processor.process_data(eval_data).to(device)
    eval_dataset = data_processor.to_dataset(eval_data, dataset_type=dataset_type)
    eval_loader = DataLoader(dataset=eval_dataset, batch_size=100, shuffle=False)

    eval_recon = []

    try:
        with torch.no_grad():
            for _, inputs in enumerate(eval_loader):
                encoder_output = trained_model(inputs)
                recon_ = encoder_output.recon_x
                eval_recon.append(recon_)

    except RuntimeError:
        for _, inputs in enumerate(eval_loader):
            encoder_output = trained_model(inputs)
            recon_ = encoder_output.recon_x.detach()
            eval_recon.append(recon_)

    eval_recon = torch.cat(eval_recon)

    
    test_data = data_processor.process_data(test_data).to(device)
    test_dataset = data_processor.to_dataset(test_data, dataset_type=dataset_type)
    test_loader = DataLoader(
        dataset=test_dataset, batch_size=100, shuffle=False
    )

    test_recon = []

    try:
        with torch.no_grad():
            for _, inputs in enumerate(test_loader):
                encoder_output = trained_model(inputs)
                recon_ = encoder_output.recon_x
                test_recon.append(recon_)

    except RuntimeError:
        for _, inputs in enumerate(test_loader):
            encoder_output = trained_model(inputs)
            recon_ = encoder_output.recon_x.detach()
            test_recon.append(recon_)

    test_recon = torch.cat(test_recon)

    print(eval_recon.shape, test_recon.shape)
    #print(torch.linalg.norm(eval_data.reshape(eval_data.shape[0], -1) - test_data.reshape(test_data.shape[0], -1)).mean(dim=0))
    #print(torch.linalg.norm(eval_data.reshape(eval_data.shape[0], -1) - eval_data.reshape(eval_data.shape[0], -1)).mean(dim=0))

    eval_mse = torch.nn.functional.mse_loss(eval_recon.reshape(eval_data.shape[0], -1), eval_data.reshape(eval_data.shape[0], -1)).mean(dim=0)
    test_mse = torch.nn.functional.mse_loss(test_recon.reshape(test_data.shape[0], -1), test_data.reshape(test_data.shape[0], -1)).mean(dim=0)


    eval_recon = 255.0 * torch.movedim(eval_recon, 1, 3).cpu().detach().numpy()
    test_recon = 255.0 * torch.movedim(test_recon, 1, 3).cpu().detach().numpy()

    eval_data = 255.0 * torch.movedim(eval_data, 1, 3).cpu().detach().numpy()
    test_data = 255.0 * torch.movedim(test_data, 1, 3).cpu().detach().numpy()

    print(eval_recon.max(), eval_data.max(), test_recon.max(), test_data.max())

    if eval_recon.shape[-1] == 1:
        eval_recon = np.repeat(eval_recon, repeats=3, axis=-1)

    if test_recon.shape[-1] == 1:
        test_recon = np.repeat(test_recon, repeats=3, axis=-1)

    if eval_data.shape[-1] == 1:
        eval_data = np.repeat(eval_data, repeats=3, axis=-1)

    if test_data.shape[-1] == 1:
        test_data = np.repeat(test_data, repeats=3, axis=-1)


    #eval_fid = compute_fid_wo_paths(
    #    gen_data=eval_recon,
    #    ref_data=eval_data,
    #    inception_path='.'
    #    )

    test_fid = compute_fid_wo_paths(
        gen_data=test_recon,
        ref_data=test_data,
        inception_path='.'
        )
    print("----------without save----------")
    print(f"mse vs eval : {eval_mse}")
    #print(f"fid vs eval : {eval_fid}")
    print("----------without save----------")
    print(f"mse vs test : {test_mse}")
    print(f"fid vs test : {test_fid}")
    print("----------without save----------")

    if args.use_wandb:
        
        if importlib.util.find_spec("wandb") is not None:
            
            wandb.init(project=args.wandb_project, entity=args.wandb_entity)
            wandb.config.update(
                {
                    "model_path": model_path,
                    "model_config": trained_model.model_config.to_dict()
                }
            )

        else:
            raise ModuleNotFoundError(
                "`wandb` package must be installed. Run `pip install wandb`"
            )

        # logging some final images

        n_im_to_log = min(40, test_data.shape[0])

        imgs_to_log = []
        line_img = []


        column_names = [str(i) for i in range(min(10, test_data.shape[0]))]
        
        log_recon = True

        for i in range(n_im_to_log):
           # plt.imread(os.path.join(output_dir, imgs_names[i]))
            #imgs_to_log.append(img)

            if i % 10 == 0:
                log_recon = not log_recon

            if log_recon :
                img = test_recon[i-10]
                line_img.append(wandb.Image(img))
            else:
                img = test_data[i]
                line_img.append(wandb.Image(img))

            if len(line_img) == 10:
                imgs_to_log.append(line_img)
                line_img = []

        sampling_table = wandb.Table(data=imgs_to_log, columns=column_names)
        wandb.log(
            {
                "test_reconstructions": sampling_table,
                "eval/mse": eval_mse,
                "test/mse": test_mse,
                #"eval/fid": eval_fid,
                "test/fid": test_fid,
                })

if __name__ == "__main__":

    main(args)
