import os
import copy
import wandb
import json
from spaghettini import quick_register, Configurable
from abc import ABC, abstractmethod
from collections.abc import Iterable

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.utils.misc import stdlog, is_scalar
from src.dl.fixed_point_solvers.fixed_point_iterator import fixed_point_iterator
from src.analysis.logging.common_logging_utils import LoggingFunction

COSINE_SIMILARITY_THRESHOLD = 0.97


@quick_register
class PathIndependenceQuantifying(LoggingFunction):
    def __call__(self, metric_logs, pl_system, **kwargs):
        if "test" in kwargs["prepend_key"]:
            pass
        elif "validation" not in kwargs["prepend_key"] or kwargs["batch_idx"] != 0:
            return metric_logs
        elif "dataloader_idx" in kwargs:
            if kwargs["dataloader_idx"] != 0:
                return metric_logs
        print(f"Running logger {self.__class__.__name__}")
        with torch.no_grad():
            # Set the number of times an input will be reinitialized with other fixed points.
            repeats = self.initial_kwargs["repeats"]
            if "z0_init_method" in kwargs:
                z0_init = kwargs["z0_init_method"]
            else:
                z0_init = copy.deepcopy(pl_system.model.z0_init_method)

            # Save the state of modified attributes, so that we can revert the modification later.
            original_forward_solver = copy.deepcopy(pl_system.model.forward_solver)
            original_backward_solver = copy.deepcopy(pl_system.model.backward_solver)
            original_pretraining_mode = copy.deepcopy(pl_system.model_kwargs_getter)
            original_z0_init_method = copy.deepcopy(pl_system.model.z0_init_method)
            mode = pl_system.training

            # Put the model in eval mode.
            pl_system.eval()

            # Get the loaders.
            loaders = [pl_system.valid_loader] if not isinstance(pl_system.valid_loader, Iterable) \
                else pl_system.valid_loader

            # Prepare the model to run fixed point iterations.
            pls = pl_system
            pls.model_kwargs_getter.initial_kwargs['num_pretraining_steps'] = 0

            if "solver" in self.initial_kwargs:
                assert not self.initial_kwargs["use_training_forward"], "solver must not be specified when use_training_forward is True. "
                assert self.initial_kwargs["num_forward_iter"] is None, "If solver is provided, don't specify custom forward depth. "

            if not self.initial_kwargs["use_training_forward"]:
                if "solver" not in self.initial_kwargs:
                    pls.model.forward_solver = Configurable(fixed_point_iterator, num_iters=-1)
                    pls.model.forward_solver.initial_kwargs["num_iters"] = self.initial_kwargs["num_forward_iter"]
                    num_forward_iter = self.initial_kwargs["num_forward_iter"]
                else:
                    pls.model.forward_solver = self.initial_kwargs["solver"]
                    if "threshold" in pls.model.forward_solver.initial_kwargs:
                        num_forward_iter = pls.model.forward_solver.initial_kwargs["threshold"]
                    else:
                        assert "num_iters" in pls.model.forward_solver.initial_kwargs
                        num_forward_iter = pls.model.forward_solver.initial_kwargs["num_iters"]
            else:
                if "threshold" in pls.model.forward_solver.initial_kwargs:
                    num_forward_iter = pls.model.forward_solver.initial_kwargs["threshold"]
                else:
                    assert "num_iters" in pls.model.forward_solver.initial_kwargs
                    num_forward_iter = pls.model.forward_solver.initial_kwargs["num_iters"]

            solver_type = pls.model.forward_solver.f.__name__

            # Run the path-independence analysis on all different loaders.
            lengths = list()
            loader_idxs = list()
            path_indep_metrics = list()
            for loader_idx, curr_loader in enumerate(loaders):
                loader_idxs.append(loader_idx)

                # Take a batch from each validation loader.
                num_iters = (self.initial_kwargs["batch_size"] // curr_loader.batch_size) + 1
                xs, ys = list(), list()
                for i, batch in enumerate(curr_loader):
                    if i == num_iters:
                        break
                    curr_xs, curr_ys = batch
                    xs.append(curr_xs)
                    ys.append(curr_ys)
                curr_xs = torch.cat(xs, dim=0)[:self.initial_kwargs["batch_size"]]
                curr_ys = torch.cat(ys, dim=0)[:self.initial_kwargs["batch_size"]]
                curr_xs = pl_system.reconcile_input_and_model_types(tensor=curr_xs)
                curr_ys = pl_system.reconcile_input_and_model_types(tensor=curr_ys)
                curr_length = curr_xs.shape[-1]

                # Record the length of the samples.
                if curr_length > self.initial_kwargs["maximum_length"]:
                    continue
                lengths.append(curr_length)

                # ____ Step 1: Run forward pass using the default initialization. ____
                print(f"Get default fixed points. length: {curr_length}")
                pls.model.z0_init_method = z0_init
                default_outs, default_model_dict = pls(curr_xs)
                default_fixed_points = default_model_dict["result"]

                # Check whether a fixed point has been found. If not, declare that the network is not path independent.
                rel_criterion = default_model_dict["rel_diff"] < float(self.initial_kwargs["rel_diff_threshold"])
                abs_criterion = default_model_dict["diff_l2"] < float(self.initial_kwargs["diff_l2_threshold"])
                convergence_criterion = torch.logical_and(input=rel_criterion, other=abs_criterion)

                # ____ Step 2: Run with fixed points obtained from other samples. ____
                print(f"Get new length: {curr_length}")
                bs = curr_xs.shape[0]
                pls.model.z0_init_method = "external"
                expanded_xs = torch.repeat_interleave(input=curr_xs, repeats=repeats, dim=0)
                expanded_ys = torch.repeat_interleave(input=curr_ys, repeats=repeats, dim=0)
                expanded_fps = default_fixed_points.repeat([repeats] + (len(default_fixed_points.shape) - 1)*[1])
                expanded_outs, expanded_model_dict = pls(expanded_xs, external_zs0=expanded_fps)
                interleaved_fixed_points = expanded_model_dict["result"].view([bs] + [repeats] + list(default_fixed_points.shape[1:]))
                interleaved_l2_diffs_avg = expanded_model_dict["diff_l2"].view([bs] + [repeats]).mean(dim=-1)
                interleaved_rel_diffs_avg = expanded_model_dict["rel_diff"].view([bs] + [repeats]).mean(dim=-1)

                # Make every neuron have zero expectation (a kind of batch normalization)
                normalizer = (default_fixed_points.shape[0] + (interleaved_fixed_points.shape[0] * interleaved_fixed_points.shape[1]))
                neuron_means = (interleaved_fixed_points.sum(dim=[0, 1]) + default_fixed_points.sum(dim=0)) / normalizer
                default_fixed_points -= neuron_means
                interleaved_fixed_points -= neuron_means

                # ____ Step 3: Compute the cosine similarity between default_fixed_points and expanded_fixed_points. ____
                # Flatten.
                default_fixed_points = default_fixed_points.view(bs, -1)
                interleaved_fixed_points = interleaved_fixed_points.view(bs, repeats, -1)

                # Normalize across the last axis.
                normalize_across_last_axis = lambda x: x / torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True) + 1e-16)
                normalized_default_fixed_points = normalize_across_last_axis(default_fixed_points)
                normalized_interleaved_fixed_points = normalize_across_last_axis(interleaved_fixed_points)

                # Compute dot products and compare.
                # normalized_default_fixed_points = normalized_default_fixed_points[None, None, ...]
                # normalized_interleaved_fixed_points = normalized_interleaved_fixed_points[:, :, None, ...]
                # dot_prods = torch.sum(normalized_default_fixed_points * normalized_interleaved_fixed_points, dim=-1)
                # dot_prods = torch.transpose(input=dot_prods, dim0=0, dim1=2)
                normalized_default_fixed_points = normalized_default_fixed_points[:, None, ...]
                dot_prods = torch.sum(normalized_default_fixed_points * normalized_interleaved_fixed_points, dim=-1)

                # Check, per example, if all of the new fixed points are aligned with the original fixed point).
                path_independence_per_example_and_fp = (dot_prods < COSINE_SIMILARITY_THRESHOLD).float()
                path_independence_per_example = (torch.sum(path_independence_per_example_and_fp, dim=1) == 0)

                # Fuse the path independence analysis with the convergence check.
                path_independence_per_example = torch.logical_and(convergence_criterion, path_independence_per_example).float()

                average_path_independence = path_independence_per_example.mean()
                path_indep_metrics.append(average_path_independence)

                # ____ Step 4: Log relevant metrics. ____
                # Compute pairwise cosine similarities.
                num_visualize = 8
                vecs = expanded_model_dict["result"].view(bs*repeats, -1)[:num_visualize*repeats]
                normalized_vecs = normalize_across_last_axis(vecs)
                test_dot_ps = normalized_vecs @ normalized_vecs.T
                # fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), squeeze=False)
                plt.imshow(test_dot_ps.cpu().numpy())
                plt.title(f"Cross Fixed Point Similarities \n Zero Init: {z0_init} \nGlobal Step: {kwargs['global_step']}")
                plt.colorbar()

                fig = plt.gcf()
                fig.canvas.draw()
                data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
                data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
                wandb.log({f"{kwargs['prepend_key']}len_{curr_length}__zero_init_{z0_init}__solver_{solver_type}__tmp_cross_fp_cosine_similarities": wandb.Image(data),
                           "global_step": kwargs["global_step"]})

                if "save_locally" in kwargs:
                    if kwargs["save_locally"]:
                        save_dir = os.path.join(kwargs["save_dir"],
                                                f"pairwise_cosine_similarities_with_tiled_fixed_points")
                        os.makedirs(save_dir, exist_ok=True)
                        plot_name = f"pairwise_cos_sim__length_{curr_length}__num_forward_steps_{num_forward_iter}__zero_init_{z0_init}__solver_{solver_type}.png"
                        plot_save_path = os.path.join(save_dir, plot_name)
                        plt.savefig(plot_save_path)
                plt.close('all')

                # # ____ Step 4: Compute accuracies and save per-example metrics. ____
                accs = ((default_outs.argmax(dim=1) == curr_ys).float().sum(dim=-1) == default_outs.shape[-1])
                avg_cosine_sims = dot_prods.mean(dim=1)
                acc_and_cosine_sim_list = list()

                per_bit_cross_entropies = torch.nn.functional.cross_entropy(input=default_outs,  target=curr_ys.long(), reduction='none')
                per_example_cross_entropies = per_bit_cross_entropies.mean(dim=-1)

                for i in range(accs.shape[0]):
                    acc_and_cosine_sim_list.append(dict(acc=accs[i].item(),
                                                        shifted_cosine_sim=avg_cosine_sims[i].item(),
                                                        global_step=kwargs["global_step"],
                                                        cross_entropy=per_example_cross_entropies[i].item(),
                                                        diff_l2=default_model_dict['diff_l2'][i].item(),
                                                        fixed_point_init_diff_l2_avg=interleaved_l2_diffs_avg[i].item(),
                                                        fixed_point_init_rel_diff_avg=interleaved_rel_diffs_avg[i].item(),
                                                        rel_diff=default_model_dict['rel_diff'][i].item()))

                # Save.
                per_example_save_dir = os.path.join(kwargs["save_dir"], "per_example_metrics")
                os.makedirs(per_example_save_dir, exist_ok=True)
                per_example_save_path = os.path.join(per_example_save_dir, f"per_example__length_{curr_length}__num_forward_steps_{num_forward_iter}__zero_init_{z0_init}__solver_{solver_type}.json")
                with open(per_example_save_path, "w") as f:
                    json.dump(acc_and_cosine_sim_list, f, indent=4)

            # Put the computed values in metric logs.
            path_indep_metrics = [v.item() for v in path_indep_metrics]
            for l in lengths:
                for pi_metric in path_indep_metrics:
                    key = f"cosine_based_pi_metric/length_{l}__num_forward_iters_{num_forward_iter}"
                    metric_logs[key] = pi_metric

            fig, axs = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), squeeze=False)
            axs[0, 0].plot(lengths, path_indep_metrics, '-o')
            axs[0, 0].set_xlabel("Length")
            axs[0, 0].set_ylabel("Path independence metric")
            axs[0, 0].set_title(f"Path Independence Metric Plot \n Epoch: {kwargs['epoch']} \n Global Step: {kwargs['global_step']}")
            plt.legend()
            plt.tight_layout()

            if "save_locally" in kwargs:
                if kwargs["save_locally"]:
                    plot_name = f"cosine_based_path_indep_metric__num_forward_steps_{num_forward_iter}__zero_init_{z0_init}__solver_{solver_type}.png"
                    plot_save_path = os.path.join(kwargs["save_dir"], plot_name)
                    plt.savefig(plot_save_path)
            plt.close('all')

            # Log the plot.
            fig.canvas.draw()
            data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            wandb.log({f"{kwargs['prepend_key']}path_independence_plot": wandb.Image(data),
                       "global_step": kwargs["global_step"]})
            plt.close('all')

            # Revert the changes made to the pl_system.
            pl_system.model.forward_solver = original_forward_solver
            pl_system.model.backward_solver = original_backward_solver
            pl_system.model_kwargs_getter = original_pretraining_mode
            pl_system.model.z0_init_method = original_z0_init_method
            pl_system.training = mode

            return metric_logs
