import math
import numpy as np
import os
import torch
import torch.nn as nn
from tqdm import tqdm
from ..functions import discrete_klv2d, hist2d
from ..utils import save_scatterplot
from ..utils.train import DummyScheduler, RunningStatistics
import wandb
from matplotlib import pyplot as plt

class Trainer:
    def __init__(
            self,
            model,
            optimizer,
            diffusion,
            epochs,
            trainloader,
            scheduler=None,
            shape=None,
            grad_norm=0,
            device=torch.device("cpu"),
            eval_intv=1,
            chkpt_intv=10,
            gen=0, args=None, 
    ):
        self.model = model
        self.optimizer = optimizer
        self.diffusion = diffusion
        self.epochs = epochs
        self.start_epoch = 0
        self.trainloader = trainloader
        if shape is None:
            shape = next(iter(trainloader)).shape[1:]
            if not shape:
                shape = (1, )
        self.shape = tuple(shape)
        print(self.shape)
        self.scheduler = DummyScheduler() if scheduler is None else scheduler
        self.grad_norm = grad_norm
        self.device = device
        self.eval_intv = eval_intv
        self.chkpt_intv = chkpt_intv
        self.args = args
        self.gen = gen

        self.stats = RunningStatistics(loss=None)

    def loss(self, x):
        B = x.shape[0]
        T = self.diffusion.timesteps
        t = torch.randint(T, size=(B, ), dtype=torch.int64, device=self.device)
        if len(x.shape)==1:
            x = x.unsqueeze(1)
        loss = self.diffusion.train_losses(self.model, x_0=x, t=t)
        assert loss.shape == (B, )
        return loss

    def step(self, x):
        # X has shape (1000, 2)
        B = x.shape[0]
        loss = self.loss(x).mean()
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        # gradient clipping by global norm
        if self.grad_norm:
            nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm)
        self.optimizer.step()
        self.stats.update(B, loss=loss.item() * B)
        if self.args.log_results:
            wandb.log({"loss": loss.item()})

    def train(self, evaluator=None, chkpt_path=None, image_dir=None, **plot_kwargs):

        def sample_fn(n):
            shape = (n,) + self.shape
            sample = self.diffusion.p_sample(
                denoise_fn=self.model, shape=shape, device=self.device, noise=None)
            return sample.cpu().numpy()

        for e in range(self.start_epoch, self.epochs):
            self.stats.reset()
            self.model.train()
            print(f"{e + 1}/{self.epochs} epochs")
            # with tqdm(self.trainloader, desc=f"{e + 1}/{self.epochs} epochs") as t:
            if 1:
                for i, x in enumerate(self.trainloader):
                    self.step(x.to(self.device))
                    # t.set_postfix(self.current_stats)
                    if i == len(self.trainloader) - 1:
                        eval_results = dict()
                        if (e + 1) % self.eval_intv == 0:
                            self.model.eval()
                            if evaluator is not None:
                                eval_results = evaluator.eval(sample_fn)
                            if self.args.log_results:
                                # wandb.log({"kld": eval_results["kld"]})
                                wandb.log(self.current_stats)

                        x_gen = eval_results.pop("x_gen", None)
                        if x_gen is not None and image_dir:
                            # from matplotlib import pyplot as plt
                            # plt.figure(figsize=(6, 6))
                            # plt.hist(x_gen, bins=50, alpha=0.7, edgecolor='black')
                            # plt.tight_layout()
                            # plt.savefig(os.path.join(image_dir, f"{e + 1}.jpg"))
                            # plt.close()
                            try:
                                save_scatterplot(
                                    os.path.join(image_dir, f"{e + 1}.jpg"), x_gen, **plot_kwargs)
                                if self.args.log_results:
                                    wandb.log({f"Samples_{self.gen}": wandb.Image(os.path.join(image_dir, f"{e + 1}.jpg"), caption=f"Epoch {e + 1}")})
                            except:
                                if self.args.dataset=="gaussian2_1d" or self.args.dataset=="gaussian3_1d"  or self.args.dataset == "gaussian1_1d" or self.args.dataset == "gaussian4_1d":
                                    plt.yscale("log")
                                    plt.hist(x_gen, bins=100, alpha=0.7, edgecolor='black')
                                    plt.tight_layout()
                                    plt.savefig(os.path.join(image_dir, f"{e + 1}.jpg"))
                                    plt.close()
                                    if self.args.log_results:
                                        wandb.log({f"Samples_{self.gen}": wandb.Image(os.path.join(image_dir, f"{e + 1}.jpg"), caption=f"Epoch {e + 1}")})

                            # if self.args.log_results and (self.args.dataset=="gaussian_nd_zeros" or self.args.dataset=="gaussian_nd_more_modes" or self.args.dataset == "gaussian_nd_odd_even"):
                            if self.args.log_results and "count" in eval_results.keys():#(self.args.dataset=="gaussian_nd_zeros" or self.args.dataset=="gaussian_nd_more_modes" or self.args.dataset == "gaussian_nd_odd_even"):
                                wandb.log({"count": eval_results["count"]})
                                if self.args.dataset == "gaussian_nd_more_modes":
                                    wandb.log({"count_5": eval_results["count_5"]})
                        results = dict()
                        results.update(self.current_stats)
                        results.update(eval_results)
                        # t.set_postfix(results)
            # adjust learning rate every epoch before checkpoint
            self.scheduler.step()
            if not (e + 1) % self.chkpt_intv and chkpt_path:
                self.save_checkpoint(chkpt_path, epoch=e + 1, **results)
        if evaluator is not None:
            eval_results = evaluator.eval(sample_fn)
        x_gen = eval_results.pop("x_gen", None)
        return x_gen

    @property
    def current_stats(self):
        return self.stats.extract()

    @property
    def trainees(self):
        return ["model", "optimizer"] + [
            "scheduler", ] if self.scheduler is not None else []

    def load_checkpoint(self, chkpt_path):
        chkpt = torch.load(chkpt_path, map_location=self.device)
        self.model.load_state_dict(chkpt["model"])
        self.optimizer.load_state_dict(chkpt["optimizer"])
        if self.scheduler is not None:
            self.scheduler.load_state_dict(chkpt["scheduler"])
        self.start_epoch = chkpt["epoch"]

    def save_checkpoint(self, chkpt_path, **extra_info):
        chkpt = []
        for k, v in self.named_state_dicts():
            chkpt.append((k, v))
        for k, v in extra_info.items():
            chkpt.append((k, v))
        torch.save(dict(chkpt), chkpt_path)

    def named_state_dicts(self):
        for k in self.trainees:
            yield k, getattr(self, k).state_dict()


class Evaluator:
    def __init__(
            self,
            true_data,
            eval_batch_size=500,
            max_eval_count=30000,
            value_range=(-3, 3),
            eps=1e-9
    ):
        self.eval_batch_size = eval_batch_size
        self.max_eval_count = max_eval_count
        self.bins = math.floor(math.sqrt(self.max_eval_count // 10))
        self.value_range = value_range
        self.eps = eps
        self.true_hist = self.get_histogram(true_data)
        self.true_hist.setflags(write=False)  # noqa; make true_hist read-only

    def get_histogram(self, data):
        hist = 0
        for i in range(0, len(data), self.eval_batch_size):
            hist += hist2d(
                data[i:(i + self.eval_batch_size)], bins=self.bins, value_range=self.value_range)
        hist /= np.sum(hist) + self.eps  # avoid zero-division
        return hist

    def eval(self, sample_fn):
        x_gen = []
        gen_hist = 0
        # for _ in range(0, self.max_eval_count + self.eval_batch_size, self.eval_batch_size):
        for _ in range(0, self.max_eval_count, self.eval_batch_size):
            x_gen.append(sample_fn(self.eval_batch_size))
            gen_hist += hist2d(
                x_gen[-1], bins=self.bins, value_range=self.value_range)
        gen_hist /= np.sum(gen_hist) + self.eps
        return {
            "kld": discrete_klv2d(gen_hist, self.true_hist),
            "x_gen": np.concatenate(x_gen, axis=0)
        }

class Evaluator1D:
    def __init__(
        self,
        true_data,
        eval_batch_size=500,
        max_eval_count=30000,
        value_range=(-3, 3),
        eps=1e-9
    ):
        self.eval_batch_size = eval_batch_size
        self.max_eval_count = max_eval_count
        self.bins = math.floor(math.sqrt(self.max_eval_count // 10))
        self.value_range = value_range
        print("Value range: ", value_range)
        self.eps = eps
        # self.true_hist = self.get_histogram(true_data)
        # self.true_hist.setflags(write=False)  # Make true_hist read-only

    def get_histogram(self, data):
        hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist /= np.sum(hist).astype(np.float32) + self.eps  # Avoid zero-division
        hist = hist.astype(np.float32) / (np.sum(hist) + self.eps)

        return hist

    def eval(self, sample_fn):
        x_gen = []
        gen_hist = np.zeros(self.bins)
        for j in range(0, self.max_eval_count, self.eval_batch_size):
            print(j)
            samples = sample_fn(self.eval_batch_size)
            x_gen.extend(samples)
        #     hist, _ = np.histogram(samples, bins=self.bins, density=True)
        #     gen_hist += hist
        # gen_hist /= np.sum(gen_hist) + self.eps
        return {
            # "kld": self.kl_divergence(gen_hist, self.true_hist),
            # "hist": gen_hist,
            "x_gen": np.array(x_gen)
        }

    @staticmethod
    def kl_divergence(p, q):
        return np.sum(np.where(p != 0, p * np.log(p / (q + 1e-9)), 0))

class Evaluator2D_Comp:
    def __init__(
        self,
        true_data,
        eval_batch_size=500,
        max_eval_count=30000,
        value_range=(-3, 3),
        eps=1e-9
    ):
        self.eval_batch_size = eval_batch_size
        self.max_eval_count = max_eval_count
        self.bins = math.floor(math.sqrt(self.max_eval_count // 10))
        self.value_range = value_range
        print("Value range: ", value_range)
        self.eps = eps
        self.true_hist = self.get_histogram(true_data)
        self.true_hist.setflags(write=False)  # Make true_hist read-only

    def get_histogram(self, data):
        hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist /= np.sum(hist).astype(np.float32) + self.eps  # Avoid zero-division
        hist = hist.astype(np.float32) / (np.sum(hist) + self.eps)

        return hist

    def eval(self, sample_fn):
        x_gen = []
        gen_hist = np.zeros(self.bins)
        for _ in range(0, self.max_eval_count, self.eval_batch_size):
            samples = sample_fn(self.eval_batch_size)
            x_gen.extend(samples)
            hist, _ = np.histogram(samples, bins=self.bins, density=True)
            gen_hist += hist
        gen_hist /= np.sum(gen_hist) + self.eps
        count = 0
        for i in range(len(x_gen)):
            if np.abs(x_gen[i][0] - 10) < 5 and np.abs(x_gen[i][1] - 10) < 5:
                # print(x_gen[i])
                count+=1
                print(x_gen[i])
            elif np.abs(x_gen[i][0] - 20) < 5 and np.abs(x_gen[i][1] - 20) < 5:
                count += 1
                print(x_gen[i])
        print("Count: ", count)
        return {
            # "kld": self.kl_divergence(gen_hist, self.true_hist),
            # "hist": gen_hist,
            "count" : count,
            "x_gen": np.array(x_gen)
        }

    @staticmethod
    def kl_divergence(p, q):
        return np.sum(np.where(p != 0, p * np.log(p / (q + 1e-9)), 0))


class Evaluator_ND_Zeros:
    def __init__(
        self,
        true_data,
        eval_batch_size=500,
        max_eval_count=30000,
        value_range=(-3, 3),
        eps=1e-9, mu=None
    ):
        self.eval_batch_size = eval_batch_size
        self.max_eval_count = max_eval_count
        self.value_range = value_range
        self.mu = mu
        print(self.mu)
        # print("Value range: ", value_range)
        self.eps = eps
        # self.true_hist = self.get_histogram(true_data)
        # self.true_hist.setflags(write=False)  # Make true_hist read-only

    def get_histogram(self, data):
        hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist, _ = np.histogram(data, bins=self.bins, density=True)
        # hist /= np.sum(hist).astype(np.float32) + self.eps  # Avoid zero-division
        hist = hist.astype(np.float32) / (np.sum(hist) + self.eps)

        return hist

    def eval(self, sample_fn):
        x_gen = []
        threshold = 0.1
        all_non_zero_counts = []
        # epsilon = 0.1
        all_non_zero_counts_diff = []
        # gen_hist = np.zeros(self.bins)
        for _ in range(0, self.max_eval_count, self.eval_batch_size):
            samples = sample_fn(self.eval_batch_size)
            # Check for nan
            if np.isnan(samples).any():
                print("Nan found")
            # Count the number of nans
            # print("Number of nans: ", np.isnan(samples).sum())
            x_gen.extend(samples)

            # samples[samples < threshold] = 0.0
            # Count number of non-zero entries in samples
            # print(samples[:10])
            copy_samples = np.copy(samples)
            copy_samples[np.abs(copy_samples) < threshold] = 0.0
            non_zero_counts = np.count_nonzero(copy_samples, axis=1)
            # zero_counts = 128 - non_zero_counts
            all_non_zero_counts.extend(non_zero_counts)
            # For each of the non-zero entries, count the number of non-zero entries
            # in the corresponding row
            # print("Non zero count each: ", non_zero_counts)
        count = 0
        # print(len(all_non_zero_counts), len(all_non_zero_counts_diff))
        for j in range(len(all_non_zero_counts)):
            if all_non_zero_counts[j] > 1:
                # print(samples[j])
                # if count < 10:
                #     print(x_gen[j], all_non_zero_counts[j])
                count+=1
        count_5 = 0
        copy_samples = np.copy(x_gen)
        copy_samples[np.abs(copy_samples) < threshold] = 0.0
        non_zero_counts = np.count_nonzero(copy_samples, axis=1)
        # # print(non_zero_counts.shape)
        multi_dim_active = np.where(non_zero_counts > 1)[0]


        for idx, i in enumerate(multi_dim_active):
            # data[i] is sample where 2 dim are active.
            # Find the dimensions where they are active -- non-zero

            non_zero_dims = np.where(copy_samples[i] != 0)[0]
            # Find the value closest to mean
            dist_to_mean = np.abs(copy_samples[i][non_zero_dims] - self.mu[non_zero_dims])
            # Count the number of samples greater than 4 with dist_to_mean
            # print(np.where(dist_to_mean > 5), dist_to_mean)
            if len(np.where(dist_to_mean > 5)[0]) > 1:
                count_5 += 1
            # count_5 += 1 if len(np.where(dist_to_mean > 5)[0]) > 1 else 0
            # break
        # count  = 0
        # for j in range(len(all_non_zero_counts_diff)):
        #     if all_non_zero_counts_diff[j] > 1:
        #         # print(samples[j])
        #         if count < 10:
        #             print(x_gen[j], all_non_zero_counts_diff[j])
        #         count+=1

        print("Count: ", count)
        return {
            # "kld": self.kl_divergence(gen_hist, self.true_hist),
            # "hist": gen_hist,
            "count" : count,
            "count_5": count_5,
            "x_gen": np.array(x_gen)
        }

    @staticmethod
    def kl_divergence(p, q):
        return np.sum(np.where(p != 0, p * np.log(p / (q + 1e-9)), 0))
