import itertools
import os
import argparse
import copy
import pickle

import numpy as np

import os
import argparse
import torch
import numpy as np

import pandas as pd
import numpy as np

import time
import copy
import numpy as np
import torch
import pandas as pd

import os
import csv
import torch
import numpy as np

import os
import time
import copy
import numpy as np
import torch
import pandas as pd
import torch.nn.functional as F

from path_constant import project_root

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def save_checkpoint(Exp, saved_path, label_generators, G_optimizer, label_discriminators, D_optimizers):
    print("=> Saving checkpoint")
    gen_checkpoint = {"epoch": Exp.cur_epochs}
    for id, lb in enumerate(label_generators):
        gen_checkpoint[f"state_dict{id}"] = label_generators[lb].state_dict()

    gen_checkpoint["optimizer"] = G_optimizer.state_dict()

    os.makedirs(saved_path + f"/checkpoints_generators", exist_ok=True)
    gfile = saved_path + f"/checkpoints_generators/epoch{Exp.cur_epochs:03}.pth"
    last_gfile = saved_path + f"/checkpoints_generators/epochLast.pth"
    torch.save(gen_checkpoint, gfile)
    torch.save(gen_checkpoint, last_gfile)

    disc_checkpoint = {"epoch": Exp.cur_epochs}
    disc_checkpoint["dstate_dict"] = label_discriminators.state_dict()
    disc_checkpoint["doptimizer"] = D_optimizers.state_dict()
    os.makedirs(saved_path + f"/checkpoints_discriminator", exist_ok=True)
    dfile = saved_path + f"/checkpoints_discriminator/epoch{Exp.cur_epochs:03}.pth"
    last_dfile = saved_path + f"/checkpoints_discriminator/epochLast.pth"
    torch.save(disc_checkpoint, dfile)
    torch.save(disc_checkpoint, last_dfile)


def calculate_TVD(dist1, dist2, doPrint):
    if len(dist1) != len(dist2):
        return 10000
    tvd = 0
    for perm in dist1:
        tvd += abs(dist1[perm] - dist2[perm])
        r1 = round(dist1[perm], 3)
        r2 = round(dist2[perm], 3)

        r3 = abs(dist1[perm] - dist2[perm])
        if doPrint == True and r3 > 0.01:
            print("perm:", perm, "tvd", r3)
    return tvd * 0.5


def calculate_KL(gen, real, doPrint):
    if len(real) != len(gen):
        raise ValueError('distribution doesnt match size')

    kl = 0
    for perm in real:
        if real[perm] == 0 or gen[perm] == 0:
            continue
        kl += (real[perm]) * np.log(real[perm] / (gen[perm]))

    return kl


def generate_permutations(dim_list):
    sequences = []
    for dim in dim_list:
        sequences.append([i for i in range(dim)])

    lst = []
    for p in itertools.product(*sequences):
        lst.append(p)

    np_ara = np.array(lst)
    return np_ara


def get_joint_distributions_from_samples(params, observed_var, corrensponding_samples):
    dim_list = [params.label_dim[lb] for lb in observed_var]
    observe_perms = generate_permutations(dim_list)

    combinations, count = np.unique(corrensponding_samples, axis=0, return_counts=True)

    upd_dist = {}
    for comb in observe_perms:
        upd_dist[tuple(list(comb))] = 1e-6

    total = corrensponding_samples.shape[0]
    for comb, cnt in zip(combinations, count):
        upd_dist[tuple(list(comb))] = cnt / total

    return upd_dist


def get_multiple_labels_fill(data_input, dims_list):
    labels_fill = []
    for id in range(data_input.shape[1]):
        label_dim = dims_list[id]
        data = data_input[:, id].to(torch.int64)
        current_label = F.one_hot(data.to(device), num_classes=label_dim).float()
        labels_fill.append(current_label)
    real_labels_fill = torch.cat(labels_fill, 1).to(device)  # this one

    return real_labels_fill


def map_fill_to_discrete(Exp, ara, dims_list):
    each_col = []

    start, end = 0, 0
    for dim in dims_list:
        end = start + dim
        indices = torch.argmax(ara[:, start: end], dim=1).view(-1, 1)  # for each variable
        each_col.append(indices)
        start = end
    result = torch.cat(each_col, 1)
    return result


def calc_gradient_penalty(critic, real_data, fake_data, device='cpu', pac=10):
    """Compute the gradient penalty."""
    alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
    alpha = alpha.repeat(1, pac, real_data.size(1))
    alpha = alpha.view(-1, real_data.size(1))

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    disc_interpolates = critic(interpolates)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size(), device=device),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]

    gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
    gradient_penalty = ((gradients_view) ** 2).mean()
    # gradient_penalty = ((gradients_view) ** 2).mean() * lambda_

    return gradient_penalty


# def get_genxray_labels(params, generators, intv_onehot, sample_size=100, NOISE_DIM=64, device="cuda"):
#
#     noise = torch.randn(sample_size, NOISE_DIM).to(device)
#     gen_labels1 = generators[0](params, noise, [])
#     noise = torch.randn(sample_size, NOISE_DIM).to(device)
#     gen_label2 = generators[1](params, noise, torch.cat([gen_labels1, intv_onehot], dim=1))
#     generated_labels_onehot = torch.cat([gen_labels1, intv_onehot, gen_label2], 1)
#
#     return generated_labels_onehot



def get_genxray_labels(params, label_generators,  intervened, mini_batch):

    gen_labels = {}
    for lbid, label in enumerate(params.sampling_network):

        parent_gen_labels = []
        for parent in params.sampling_network[label]:
            parent_gen_labels.append(gen_labels[parent])

        if label in intervened.keys():
            gen_labels[label] = intervened[label]

        else:
            noise = torch.randn(mini_batch, params.NOISE_DIM).to(device)
            gen_labels[label] = label_generators[label](params, noise, parent_gen_labels)

    return_labels= torch.cat(list(gen_labels.values()), dim=1)
    return return_labels


def get_fake_distribution(params, generators, intv_onehot, compare_Var, sample_size):
    ### Generate labels
    # noise = torch.randn(batch_size, params.NOISE_DIM).to(params.device)
    # gen_labels1 = generators[0](params, noise, [])
    # noise = torch.randn(batch_size, params.NOISE_DIM).to(params.device)
    # gen_label2 = generators[1](params, noise, torch.cat([gen_labels1, intv_onehot], dim=1))
    # generated_labels_onehot = torch.cat([gen_labels1, intv_onehot, gen_label2], 1)

    generated_labels_onehot = get_genxray_labels(params, generators, intv_onehot, sample_size)
    dims_list = [params.label_dim[lb] for lb in compare_Var]
    generated_labels_full = map_fill_to_discrete(params, generated_labels_onehot,
                                                 dims_list).detach().cpu().numpy().astype(int)
    fake_dist_dict = get_joint_distributions_from_samples(params, compare_Var, generated_labels_full)

    return fake_dist_dict


def train(params, generators, G_optimizer, critic, D_optimizer,
          label_data):
    current_real_label = torch.cat(list(label_data.values()), dim=1)
    real_labels_onehot = get_multiple_labels_fill(current_real_label, list(label_dim.values()))
    generated_labels_onehot = get_genxray_labels(params, generators, {}, current_real_label.shape[0])

    D_losses = []
    for crit_ in range(5):
        D_real_decision_obs = critic(real_labels_onehot).squeeze()
        D_fake_decision_obs = critic(generated_labels_onehot).squeeze()
        gp_obs = calc_gradient_penalty(critic, real_labels_onehot, generated_labels_onehot, device=params.device)

        D_loss_obs = (
                    -  (torch.mean(D_real_decision_obs) - torch.mean(D_fake_decision_obs)) + params.LAMBDA_GP * gp_obs)
        D_losses.append((D_loss_obs).data)
        critic.zero_grad()
        D_loss_obs.backward(retain_graph=True)
        D_optimizer.step()

    # %%%%%%%%%%%%%%%%%%% generator  training  %%%%%%%%%%%%%%%%%%%
    # Back propagation
    for lb in generators:
        generators[lb].zero_grad()

    D_fake_decision_obs = critic(generated_labels_onehot).squeeze()
    G_loss = -torch.mean(D_fake_decision_obs)

    # Back propagation
    G_loss.backward()

    G_optimizer.step()

    D_loss = torch.mean(torch.FloatTensor(D_losses))  # just mean of losses

    return G_loss.data, D_loss.data


def init_weights(m):  # for generator and discriminator, they are initialized inside the class
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)


def anneal_temperature(params, tot_iters):
    params.Temperature = np.maximum(
        params.Temperature * np.exp(-params.ANNEAL_RATE * tot_iters),
        params.temp_min)
    print(tot_iters, ":Temperature", params.Temperature)


def labelMain(params, label_generators, G_optimizers, discriminators, D_optimizers, data_loader, tvd_diff, kl_diff, intv_prob):
    iteration = 0
    for batch in data_loader:
        g_loss, d_loss = train(params, label_generators, G_optimizers, discriminators,
                               D_optimizers, batch)

        print('Epoch [%d/%d], Step [%d/%d],' % (
            params.cur_epochs + 1, params.num_epochs, iteration + 1, len(data_loader)),
              ' D_loss: %.4f, G_loss: %.4f' % (d_loss.data, g_loss.data))

        # Annealing
        tot_iter = params.cur_epochs * len(data_loader) + iteration
        if tot_iter % 100 == 0:
            anneal_temperature(params, tot_iter)

        iteration += 1

    #
    #

    if (params.cur_epochs + 1) % 1 == 0:

        for node in params.sampling_network:

            print(f'P({node}|{params.sampling_network[node]}')

            intv= {par: label_data[par].squeeze().to(torch.int64) for par in params.sampling_network[node] }
            for lb in intv:
                intv[lb]= torch.nn.functional.one_hot(intv[lb], num_classes=label_dim[lb])[0:params.sample_size]

            compare_var= params.sampling_network[node]+[node]
            fake_dist_dict = get_fake_distribution(params, label_generators, intv, compare_var,
                                                   sample_size=params.sample_size)

            current_real_label=[]
            for lb in compare_var:
                current_real_label.append(label_data[lb])


            current_real_label = torch.cat(current_real_label, dim=1)
            dataset_dist_dict = get_joint_distributions_from_samples(params, compare_var,
                                                                     current_real_label.detach().cpu().numpy().astype(int))
            for comb in fake_dist_dict:
                print(f"{comb}: {fake_dist_dict[comb]} vs {dataset_dist_dict[comb]}")
            obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
            obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
            tvd_diff[node].append(round(obs_tvd, 4))  # todo: fix it
            kl_diff[node].append(round(obs_kl, 4))



        ########### P(V|do(x)) Intervention #########
        symptoms = {'Pneumonia':0, 'Atelectasis':1}   #do(Pneumonia=0, atelectasis=1)
        input_lb = {}
        for lb in symptoms:
            input_lb[lb] = torch.tensor(symptoms[lb])
            input_lb[lb] = F.one_hot(input_lb[lb].to(torch.int64).to(device), num_classes=params.label_dim[lb]).float()
            input_lb[lb] = input_lb[lb].repeat(params.sample_size, 1)

        fake_dist_dict= get_fake_distribution(params, label_generators, input_lb, params.labels, params.sample_size)
        # fake_dist_dict = dict(sorted(fake_dist_dict.items(), key=lambda item: -item[1]))
        print(f'P(V|do{symptoms}: {fake_dist_dict}')
        intv_prob['Atelectasis'].append(fake_dist_dict)


        symptoms = {'Pneumonia':1, 'Pleural Effusion': 1}  # do(Pneumonia=1, effusion=1)
        input_lb = {}
        for lb in symptoms:
            input_lb[lb] = torch.tensor(symptoms[lb])
            input_lb[lb] = F.one_hot(input_lb[lb].to(torch.int64).to(device), num_classes=params.label_dim[lb]).float()
            input_lb[lb] = input_lb[lb].repeat(params.sample_size, 1)
        fake_dist_dict= get_fake_distribution(params, label_generators, input_lb, params.labels, params.sample_size)
        # fake_dist_dict = dict(sorted(fake_dist_dict.items(), key=lambda item: -item[1]))
        print(f'P(V|do{symptoms}: {fake_dist_dict}')
        intv_prob['Pleural Effusion'].append(fake_dist_dict)

        with open(params.SAVED_PATH + "/intv_prob.pickle", "wb") as f:
            pickle.dump(intv_prob, f)


        ############ ends ##########

        ll = -min(10, len(list(tvd_diff.values())[0]))
        for dist in tvd_diff:
            print("###", dist, " loss%:", [round(val, 4) for val in tvd_diff[dist][ll:]])


        for dist in tvd_diff:
            os.makedirs(params.SAVED_PATH + "/tvd", exist_ok=True)
            os.makedirs(params.SAVED_PATH + "/kl", exist_ok=True)
            torch.save(torch.FloatTensor(tvd_diff[dist]), params.SAVED_PATH + "/tvd/" + dist)
            torch.save(torch.FloatTensor(kl_diff[dist]), params.SAVED_PATH + "/kl/" + dist)

    #
    if (params.cur_epochs + 1) % 10 == 0:
        save_checkpoint(params, params.SAVED_PATH, label_generators, G_optimizers, discriminators, D_optimizers)
        print(params.cur_epochs, ":model saved at ", params.SAVED_PATH)
    return


from ControllerModels import ControllerGenerator, ControllerDiscriminator
from cfg.dataloader_pickle import PickleDataset

if __name__ == '__main__':

    # several hyperparameters for model
    parser = argparse.ArgumentParser(description='train for xray model')
    parser.add_argument('--Temperature', type=float, default=1, help='Temperature')
    parser.add_argument('--ANNEAL_RATE', type=float, default=0.000003, help='ANNEAL_RATE')
    parser.add_argument('--temp_min', type=float, default=0.1, help='temp_min')
    parser.add_argument('--num_epochs', type=int, default=200, help='num_epochs')
    parser.add_argument('--cur_epochs', type=int, default=0, help='cur_epochs')
    parser.add_argument('--LAMBDA_GP', type=int, default=1, help='LAMBDA_GP')
    parser.add_argument('--NOISE_DIM', type=int, default=64, help='NOISE_DIM')
    parser.add_argument('--sample_size', type=int, default=10000, help='sample_size')
    parser.add_argument('--num_samples', type=int, default=20000, help='num_samples')
    parser.add_argument('--batch_size', type=int, default=100, help='batch_size')
    parser.add_argument('--SAVED_PATH', type=str, default=f'{project_root}/XrayLLM/trained_models', help='SAVED_PATH')
    parser.add_argument('--label_dim', type=dict, default={}, help='label_dim')
    parser.add_argument('--load_models', type=bool, default=False, help='load_models')
    # args = parser.parse_args()

    params, unknown = parser.parse_known_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    params.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #####
    df = pd.read_csv(f'{project_root}/XrayLLM/mimic-cxr-2.0.0-chexpert.csv')
    df = df[df['No Finding'] != 1]
    df = df + 1
    df = df.fillna(0)

    sampling_network = {
        'Pneumonia': [],
        'Pleural Effusion': ['Pneumonia'],
        'Atelectasis': ['Pneumonia', 'Pleural Effusion'],
        'Lung Opacity': ['Pneumonia', 'Pleural Effusion', 'Atelectasis'],
    }

    params.sampling_network= sampling_network

    label_names = list(sampling_network.keys())
    params.labels = label_names
    df = df[label_names]
    # converting to binary
    df = df.replace(to_replace=1, value=0)
    df = df.replace(to_replace=2, value=1)

    label_dim = {lb: 2 for lb in label_names}
    params.label_dim = label_dim

    ####    #P(D,I,C)
    models = {}
    for node in sampling_network:
        input_dim = []
        for lb in sampling_network[node]:
            input_dim.append(label_dim[lb])

        input_dim= sum(input_dim)+params.NOISE_DIM
        models[node] = ControllerGenerator(input_dim=input_dim, hid_dims=[256, 256],
                                           output_dim_list=[label_dim[node]]).to(device)


        models[node].apply(init_weights)

    nn_pars=[]
    for nd in models:
        nn_pars+= list(models[nd].parameters())
    goptimizer = torch.optim.Adam(nn_pars, lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-6)

    ####
    # last_model = f"{params.SAVED_PATH}/checkpoints_generators/epochLast.pth"
    # if os.path.exists(last_model) and params.load_models == True:
    #     checkpoint = torch.load(last_model)
    #     model1.load_state_dict(checkpoint['state_dict0'])
    #     model2.load_state_dict(checkpoint['state_dict1'])
    #     goptimizer.load_state_dict(checkpoint['optimizer'])
    #     params.cur_epochs = checkpoint["epoch"]

    all_dim = sum(list(label_dim.values()))
    discriminator = ControllerDiscriminator(input_dim=all_dim, hid_dims=[256, 256]).to(device)
    doptimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3,
                                  betas=(0.5, 0.9), weight_decay=1e-6)


    # last_model = f"{params.SAVED_PATH}/checkpoints_discriminator/epochLast.pth"
    # if os.path.exists(last_model) and params.load_models == True:
    #     checkpoint = torch.load(last_model)
    #     discriminator.load_state_dict(checkpoint['dstate_dict'])
    #     doptimizer.load_state_dict(checkpoint['doptimizer'])

    label_data = {}
    for lb in label_names:
        label_data[lb] = torch.tensor(df[lb].values)[0:params.num_samples].view(-1, 1).to(device)

    dataset = PickleDataset(data_dict=label_data)
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=params.batch_size, shuffle=False,
                                             drop_last=True)  # Todo: Keep shuffle true or false

    tvd_diff = {node:[] for node in params.sampling_network}
    kl_diff = {node:[] for node in params.sampling_network}
    intv_prob={'Pleural Effusion':[], 'Atelectasis':[]}
    mech_tvd = 0
    print("Starting training new mechanism")

    for epoch in range(params.cur_epochs, params.num_epochs, 1):
        params.cur_epochs = epoch
        labelMain(params, models, goptimizer, discriminator, doptimizer, dataloader,
                  tvd_diff, kl_diff, intv_prob)
