import importlib
import torch
import numpy as np
import seaborn as sns

import matplotlib.pyplot as plt


# mapping from the filename of the blackbox function at func/filename.py
# to its classname implemented in the file func/filename.py
# This is used to load the function given its name
module2class = {
    "branin": "Branin",
    "goldstein": "Goldstein",
    "gp_sample": "GPSample",
    "lake_zurich": "LakeZurich",
    "phosphorus": "Phosphorus",
    "intel_lab_data_humidity": "IntelLabHumidity",
    "intel_lab_data_temperature": "IntelLabTemperature",
}


def normalize_data(xs):
    min_xs = np.min(xs, axis=0, keepdims=True)
    max_xs = np.max(xs, axis=0, keepdims=True)
    xs = (xs - min_xs) / (max_xs - min_xs)
    return xs


def get_func_name_and_params(func_string):
    """get_func_name_and_params.

    Parameters
    ----------
    func_string : string
        Examples,
        branin
        branin___xsize__100
        branin___xsize__100___noise_std__0.01
        (the param name must the same as that defined in blackbox function's __init__
        e.g., func/branin.py : __init__)

    Returns
    -------
    string
        name of the blackbox function
    params
        dictionary of float values of all params

    """
    info = func_string.split("___")

    func_name = info[0]

    params = {}
    for param_info in info[1:]:
        key, val = param_info.split("__")

        # the correct type will be casted in Blackbox function class's __init__
        params[key] = float(val)

    return func_name, params


def get_function_class_name(func_name):
    """get_function_class_name.

    Parameters
    ----------
    func_name : string
        name of the function in func/*.py
        e.g., branin

    Returns
    -------
    string
        the class name defined in func/{func_name}.py

    """
    if func_name not in module2class:
        raise Exception(f"Function module {func_name} not found in module2class")
    return module2class[func_name]


def get_function_instance(func_name, params):
    """get_function_instance.

    Parameters
    ----------
    func_name : string
        name of the function in func/*.py
        e.g., branin
    params : dictionary
        parameters to initialize function
        defined in __init__ in func/*.py
        e.g., {'xsize': 100, 'noise_std': 0.01}

    Returns
    -------
    function instance
        function instance defined in func/{func_name}.py

    """
    module = f"func.{func_name}"
    classname = get_function_class_name(func_name)
    return getattr(importlib.import_module(module), classname)(**params)


def get_complement_set(n, idx_set):
    """get_complement_set.
    return the difference between set [n] and idx_set

    Parameters
    ----------
    n : int
        n
    idx_set : list or set of integers < n
        list or set of indices

    Returns
    -------
    a list of integers < n
    """
    return list(set(list(range(n))).difference(set(idx_set)))


def get_iset_idxs(vals, iset_type="maximizer"):
    """get_estimators.

    Parameters
    ----------
    vals : tensor of shape (n,)
        an array value where the interesting set is extracted

    iset_type : string
        the type of the interesting set to be extracted

    Returns
    int tensor of shape (n,)
        indices of the interesting elements in vals

    -------
    """
    vals = vals.reshape(
        -1,
    )

    if iset_type == "maximizer":
        max_idx = torch.argmax(vals)
        max_idxs = (vals == vals[max_idx]).nonzero()
        return (
            max_idxs.reshape(
                -1,
            )
            .numpy()
            .tolist()
        )

    elif iset_type == "minimizer":
        min_idx = torch.argmin(vals)
        min_idxs = (vals == vals[min_idx]).nonzero()
        return (
            min_idxs.reshape(
                -1,
            )
            .numpy()
            .tolist()
        )

    elif iset_type.endswith("_top_percentile"):
        percentile = float(iset_type.split("_")[0]) / 100.0
        sep_idx = int(np.floor(len(vals) * percentile))

        vals_np = vals.detach().numpy()
        sort_idxs = np.argsort(vals_np)[::-1]  # decreasing
        return (
            sort_idxs[: sep_idx + 1]
            .reshape(
                -1,
            )
            .tolist()
        )

    elif iset_type.endswith("_bottom_percentile"):
        percentile = float(iset_type.split("_")[0]) / 100.0
        sep_idx = int(np.floor(len(vals) * percentile))

        vals_np = vals.detach().numpy()
        sort_idxs = np.argsort(vals_np)  # increasing
        return (
            sort_idxs[: sep_idx + 1]
            .reshape(
                -1,
            )
            .tolist()
        )

    else:
        raise Exception(f"Unknown iset_type: {iset_type}")


def get_regret_between_2_sets(gt_set, lt_set, mode="sum"):
    """get_reget_between_2_sets.
    every element in gt_set is supposed to be larger than any element in lt_set
    The regret is due to if the orderings are not correct

    Parameters
    ----------
    set1 : list of len k1
        a list of floats
    set2 : list of len k2
        a list of floats

    Problem
    -------


    Returns
    -------
    a float
      sum_{i in [k]} abs(set1[i] - set2[pi_*[i]])
    """

    if mode == "sum":
        regret = 0

        for gti in range(len(gt_set)):
            for lti in range(len(lt_set)):
                regret += max(0, lt_set[lti] - gt_set[gti])

    elif mode == "sum":
        regret = 0.0

        for gti in range(len(gt_set)):
            for lti in range(len(lt_set)):
                regret = max(regret, max(0, lt_set[lti] - gt_set[gti]))

    else:
        raise Exception(f"Unknown mode: {mode}")

    return regret


def get_pair_in_cl_val_mode(lower_f, upper_f, iset_idxs):
    xsize = len(lower_f)
    remaining_set_idxs = list(set(range(xsize)))

    remaining_upper_f = upper_f[remaining_set_idxs].reshape(-1, 1)
    remaining_lower_f = lower_f[remaining_set_idxs].reshape(-1, 1)
    iset_lower_f = lower_f[iset_idxs].reshape(1, -1)
    iset_upper_f = upper_f[iset_idxs].reshape(1, -1)

    remaining_upper_iset_upper_diff = remaining_upper_f - iset_lower_f
    iset_upper_iset_lower_diff = iset_upper_f - remaining_lower_f

    relative_regret = torch.minimum(
        remaining_upper_iset_upper_diff, iset_upper_iset_lower_diff
    )
    # (len(remaining), len(iset))

    imax_idx_in_iset = torch.argmax(torch.amax(relative_regret, dim=0))
    rmax_idx_in_remaining = torch.argmax(relative_regret[:, imax_idx_in_iset])

    imax_idx = iset_idxs[imax_idx_in_iset]
    rmax_idx = remaining_set_idxs[rmax_idx_in_remaining]

    return (
        imax_idx,
        rmax_idx,
    )


def get_pair_in_boundary_mode(lower_f, upper_f, iset_idxs):
    xsize = len(lower_f)
    remaining_set_idxs = list(set(range(xsize)).difference(set(iset_idxs)))

    remaining_upper_f = upper_f[remaining_set_idxs].reshape(-1, 1)
    remaining_lower_f = lower_f[remaining_set_idxs].reshape(-1, 1)
    iset_lower_f = lower_f[iset_idxs].reshape(1, -1)
    iset_upper_f = upper_f[iset_idxs].reshape(1, -1)

    remaining_upper_iset_upper_diff = remaining_upper_f - iset_lower_f
    iset_upper_iset_lower_diff = iset_upper_f - remaining_lower_f

    relative_regret = torch.minimum(
        remaining_upper_iset_upper_diff, iset_upper_iset_lower_diff
    )
    # (len(remaining), len(iset))

    imax_idx_in_iset = torch.argmax(torch.amax(relative_regret, dim=0))
    rmax_idx_in_remaining = torch.argmax(relative_regret[:, imax_idx_in_iset])

    imax_idx = iset_idxs[imax_idx_in_iset]
    rmax_idx = remaining_set_idxs[rmax_idx_in_remaining]

    return (
        imax_idx,
        rmax_idx,
    )


def get_pair_in_cl_ord_mode(lower_f, upper_f, iset_idxs):
    xsize = len(lower_f)
    remaining_set_idxs = list(set(range(xsize)))

    remaining_upper_f = upper_f[remaining_set_idxs].reshape(-1, 1)
    remaining_lower_f = lower_f[remaining_set_idxs].reshape(-1, 1)
    iset_lower_f = lower_f[iset_idxs].reshape(1, -1)
    iset_upper_f = upper_f[iset_idxs].reshape(1, -1)

    remaining_upper_iset_upper_diff = remaining_upper_f - iset_lower_f
    iset_upper_iset_lower_diff = iset_upper_f - remaining_lower_f

    relative_regret = torch.minimum(
        remaining_upper_iset_upper_diff, iset_upper_iset_lower_diff
    )
    # (len(remaining), len(iset))

    relative_regret[iset_idxs, list(range(len(iset_idxs)))] = float("-inf")

    imax_idx_in_iset = torch.argmax(torch.amax(relative_regret, dim=0))
    rmax_idx_in_remaining = torch.argmax(relative_regret[:, imax_idx_in_iset])

    imax_idx = iset_idxs[imax_idx_in_iset]
    rmax_idx = remaining_set_idxs[rmax_idx_in_remaining]

    return (
        imax_idx,
        rmax_idx,
    )


def plot_1d(
    x_domain,
    gt_func_evals,
    f_means,
    upper_f,
    lower_f,
    pred_iset_idxs,
    gt_iset_idxs,
    imax_idx,
    rmax_idx,
    query_idx,
    obs_at_query,
    obs_X,
    obs_y,
    iteration,
    is_show=False,
    save_filename=None,
    is_show_imax_rmax=True,
    show_legend=False,
):
    x_domain = x_domain.detach().numpy().squeeze()
    f_means = f_means.detach().numpy().squeeze()
    upper_f = upper_f.detach().numpy().squeeze()
    lower_f = lower_f.detach().numpy().squeeze()

    fig, ax = plt.subplots(figsize=(4.5, 3))
    ax.plot(x_domain, gt_func_evals, "--", c="#282828", zorder=49, label="Blackbox function")
    ax.plot(x_domain, f_means, c="#8F3F71", zorder=50, label="GP mean")
    ax.plot(x_domain, upper_f, ":", c="#8F3F71", zorder=50, label="Upper bound")
    ax.plot(x_domain, lower_f, ":", c="#8F3F71", zorder=50, label="Lower bound")

    intersect = list(set(pred_iset_idxs).intersection(gt_iset_idxs))
    pred_excl_iset = list(set(pred_iset_idxs).difference(intersect))
    iset_excl_pred = list(set(gt_iset_idxs).difference(intersect))

    ax.scatter(
        x_domain[intersect], f_means[intersect], c="#076778", label="Correct Predicted Top-k",
        edgecolors="black",
        zorder=99
    )
    ax.scatter(
        x_domain[pred_excl_iset],
        f_means[pred_excl_iset],
        c="#CC241D",
        edgecolors="black",
        label="Incorrect Predicted Top-k",
        zorder=99
    )
    ax.scatter(
        x_domain[iset_excl_pred],
        gt_func_evals[iset_excl_pred],
        c="#282828",
        edgecolors="black",
        label="Missing Top-k",
        zorder=98
    )

    # ylim = ax.get_ylim()
    ylim = (-2.2, 1.8)

    if imax_idx is not None:
        ax.plot(
            [x_domain[imax_idx], x_domain[imax_idx]],
            [ylim[0], ylim[1]],
            c="#fe8019",
            label="Comparison Pair",
            zorder=999,
        )

    if rmax_idx is not None:
        ax.plot(
            [x_domain[rmax_idx], x_domain[rmax_idx]],
            [ylim[0], ylim[1]],
            c="#fe8019",
            zorder=999,
        )

    # ax.plot(
    #     [x_domain[query_idx], x_domain[query_idx]],
    #     [ylim[0], ylim[1]],
    #     c="red",
    #     label="Input Query",
    #     zorder=99
    # )

    ax.scatter(
        [x_domain[query_idx]],
        [obs_at_query],
        marker='o',
        s=28,
        c="#fe8019",
        edgecolors="#fe8019",
        label="Current Observation",
        zorder=999,
        alpha=1.0
    )

    ax.scatter(
        obs_X,
        obs_y,
        marker="o",
        s=30,
        c="white",
        edgecolors="black",
        label="Past Observation",
        zorder=95,
        alpha=0.5,
    )

    _, _, rects = ax.hist(
        obs_X,
        range=(x_domain.min(), x_domain.max()),
        bins=100,
        color="#282828",
        lw=0, # linewidth
        bottom=ylim[0],
        zorder=1000,
    )

    # iterate through rectangles, change the height of each
    maxheight = 0
    for r in rects:
        maxheight = max(maxheight, r.get_height())
    for r in rects:
        r.set_height(r.get_height() / maxheight * (ylim[1] - ylim[0]) / 10)
    ax.set_ylim(ylim[0], ylim[1])
    if show_legend:
        ax.legend()
    ax.set_title(f"Iteration {iteration}")

    if is_show:
        plt.show()
    if save_filename is not None:
        fig.savefig(save_filename, dpi=300)
