""" 1. Generate experiment data. """
import os
import shutil
import numpy as np
import itertools
import pickle as pk
from hashlib import md5
import sys
from CDExperimentSuite_DEV.src import Scalers
from CDExperimentSuite_DEV.src import utils

np.set_printoptions(suppress=False)  # scientific notation is ok
np.set_printoptions(threshold=sys.maxsize)  # don't truncate arrays


class DataGenerator:
    def __init__(self):
        pass

    def create_data(self, opt, parameters, W_true):
        """Create sample for dataset"""
        utils.set_random_seed(parameters.random_seed)
        l, u = parameters.noise_sigma_lims
        sigma = 1 / np.sqrt(12) * (u - l) ** 2
        match parameters.noise_sigma_dist:
            case "uniform":
                noise_scales = np.random.uniform(l, u, size=parameters.n_nodes)
            case "exp":
                noise_scales = np.random.exponential(
                    scale=sigma, size=parameters.n_nodes
                )
            case "gauss":
                noise_scales = np.random.normal(
                    loc=0, scale=sigma, size=parameters.n_nodes
                )
            case _:
                raise NotImplementedError()
        B_true = np.where(W_true != 0, 1, 0)
        data, _ = utils.simulate_linear_sem(
            W_true,
            parameters.n_obs,
            sem_type=parameters.noise_dist,
            noise_scale=noise_scales,
            harmonize=False,
        )

        vars = np.var(data, axis=0)

        if opt.scaler.name() == "Harmonizer":
            data, W_true = utils.simulate_linear_sem(
                W_true,
                parameters.n_obs,
                sem_type=parameters.noise_dist,
                noise_scale=noise_scales,
                harmonize=True,
            )

        data = opt.scaler.transform(data)

        # vsb
        scaled_vars = np.var(data, axis=0)
        varsortability = (
            opt.vsb_function(data, W_true) if opt.vsb_function else np.array([-1])
        )
        # rsb
        R2 = utils.r2coef(data.T)
        R2sortability = (
            opt.R2sb_function(data, W_true) if opt.R2sb_function else np.array([-1])
        )
        # cev-stb
        CEVsortability = (
            opt.CEVsb_function(data, W_true) if opt.CEVsb_function else np.array([-1])
        )

        return utils.Dataset(
            parameters=parameters,
            W_true=W_true,
            B_true=B_true,
            data=data,
            hash=md5(data).hexdigest()[0:10],
            scaler=opt.scaler.name(),
            scaling_factors=opt.scaler.scaling_factors,
            sigma=noise_scales,
            vars=vars,
            scaled_vars=scaled_vars,
            varsortability=varsortability,
            R2=R2,
            R2sortability=R2sortability,
            CEVsortability=CEVsortability,
        )

    def create_dataset(self, opt, parameters):
        """
        Create a single Dataset according to specifications
        Args:
            parameters(Dataset): Namedtuple of Dataset parameters
        """
        utils.set_random_seed(parameters.random_seed)
        B_true = utils.simulate_dag(parameters)
        w = parameters.edge_weight_range
        if isinstance(w, int):
            w_ranges = ((w, w), (w, w))
        else:
            w_ranges = (
                tuple([-i for i in reversed(list(parameters.edge_weight_range))]),
                parameters.edge_weight_range,
            )
        W_true = utils.simulate_parameter(B_true, w_ranges=w_ranges)
        return self.create_data(opt, parameters, W_true)

    def get_parameters(self, opt):
        """Generate Dataset parameterss as permutations of inputs
        Args:
            graphs(list): ["ER-2", "SF-4"]
            noise_distributions(list): Noise distribution with variance (c.f. utils)
            edge_weights(list): Intervals
            n_nodes(list): Number of nodes in graph
            n_obs(list): Number of observations in sample
        """
        combinations = list(
            itertools.product(
                opt.graphs,
                opt.edge_types,
                opt.edges,
                opt.noise_distributions,
                opt.edge_weights,
                opt.n_nodes,
                opt.n_obs,
                list(range(opt.n_repetitions)),
            )
        )
        return [
            utils.DatasetParameters(
                graph_type=gt,
                edge_type=et,
                x=ed,
                noise_dist=nd.noise_dist,
                noise_sigma_dist=nd.noise_sigma_dist,
                noise_sigma_lims=nd.noise_sigma_lims,
                edge_weight_range=ew,
                n_nodes=nn,
                n_obs=no,
                random_seed=rs,
            )
            for (gt, et, ed, nd, ew, nn, no, rs) in combinations
        ]

    def generate(self, opt):
        """Generate Data
        Args:
            All params are lists
        Returns:
            All Datasets
        """
        dataset_parameters = self.get_parameters(opt)
        return [self.create_dataset(opt, i) for i in dataset_parameters]

    def generate_and_save(self, opt):
        """Generate and save Dataset data"""

        # create output dir
        self.exp_dir = os.path.join(
            opt.base_dir, os.path.basename(opt.base_dir) + opt.exp_name
        )
        utils.create_folder(self.exp_dir)
        utils.snapshot(opt)
        # create output folder
        self.output_folder = os.path.join(self.exp_dir, "_data")
        utils.create_folder(self.output_folder, overwrite=True)

        # create Datasets
        all_datasets = self.generate(opt)

        # create parent directories
        unique_dirs = set(
            [
                os.path.join(self.output_folder, utils.dataset_dirname(dataset))
                for dataset in all_datasets
            ]
        )
        for d in unique_dirs:
            if os.path.isdir(d):
                shutil.rmtree(d)
            os.mkdir(d)

        # save datasets
        checksums = [("name,checksum")]
        prev_dir = ""
        for exp_idx, dataset in enumerate(all_datasets):
            dirname = utils.dataset_dirname(dataset)
            fname = utils.dataset_parameters(dataset)
            dir = os.path.join(self.output_folder, dirname)

            hashsum = md5(dataset.data).digest()
            checksums.append(f"{fname},{hashsum}")

            fdir = os.path.join(dir, fname) + ".pk"
            with open(fdir, "wb") as f:
                pk.dump(dataset, f)

            if dirname != prev_dir:
                print(exp_idx, "COMPLETED", dirname)
                prev_dir = dirname

        with open(os.path.join(self.output_folder, "checksums.csv"), "w") as f:
            f.write("\n".join(checksums))
