import argparse
import os
from typing import Dict

import numpy as np
import pandas as pd
from ray import tune

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from IPython.lib.pretty import pprint


BBAC_PRIOR_SCALE_ABLATION_PATH = os.path.expanduser(os.path.join(
    "~/ray_results",
    "dm_control",
    "cartpole",
    "custom_swingup_sparse",
    "2021-01-28T09-39-29-bbac-prior-scale-sweep-1",
))

BBAC_ENSEMBLE_SIZE_ABLATION_PATH = os.path.expanduser(os.path.join(
    "~/ray_results",
    "dm_control",
    "cartpole",
    "custom_swingup_sparse",
    "2021-01-21T08-04-35-bbac-ensemble-size-sweep-1",
))

BBAC_PRIOR_LOSS_WEIGHT_ABLATION_PATH = os.path.expanduser(os.path.join(
    "~/ray_results",
    "dm_control",
    "cartpole",
    "custom_swingup_sparse",
    "2021-02-08T17-45-22-bbac-regularizatin-weight-sweep-1",
))

# print(BBAC_PRIOR_SCALE_ABLATION_PATH,
#       os.path.exists(BBAC_PRIOR_SCALE_ABLATION_PATH))
# print(BBAC_ENSEMBLE_SIZE_ABLATION_PATH,
#       os.path.exists(BBAC_ENSEMBLE_SIZE_ABLATION_PATH))
# print(BBAC_PRIOR_LOSS_WEIGHT_ABLATION_PATH,
#       os.path.exists(BBAC_PRIOR_LOSS_WEIGHT_ABLATION_PATH))
# exit()


def load_analysis(analysis_path) -> Dict[str, tune.Analysis]:
    analysis = tune.Analysis(analysis_path)
    return analysis


def smooth_dataframe(dataframe: pd.DataFrame) -> pd.DataFrame:
    x_key = 'sampler/total-samples'
    y_key = 'training/episode-reward-mean'

    window = 4
    x_values = dataframe[x_key].iloc[window::window]
    # hue_values = dataframe2[hue_key].iloc[window::window]
    y_values = dataframe[y_key].rolling(window).mean()[window::window]

    windowed_dataframe = pd.DataFrame({
        x_key: x_values,
        y_key: y_values,
        # hue_key: hue_values,
    })
    return windowed_dataframe

def load_dataframe(analysis: tune.Analysis,
                   trial_filter: callable,
                   hue_key: str,
                   hue_getter: callable) -> pd.DataFrame:
    visualization_dataframes = []
    trial_keys = analysis.dataframe().logdir
    configs = analysis.get_all_configs()

    for trial_key in trial_keys:
        trial_config = configs[trial_key]
        if trial_filter(trial_config):
            continue
        trial_dataframe = analysis.trial_dataframes[trial_key]

        assert hue_key not in trial_dataframe.columns
        trial_dataframe = smooth_dataframe(trial_dataframe)
        trial_dataframe[hue_key] = hue_getter(trial_config)
        visualization_dataframes += [trial_dataframe]

    visualization_dataframe = pd.concat(visualization_dataframes)
    return visualization_dataframe


def bbac_prior_scale_ablation_trial_filter(config):
    should_include = 4 < config['Q_params']['config']['prior_scale']
    should_filter = not should_include
    return should_filter


def bbac_ensemble_size_ablation_trial_filter(config):
    should_include = config['Q_params']['config']['N'] < 32
    should_filter = not should_include
    return should_filter


def bbac_prior_loss_weight_ablation_trial_filter(config):
    should_include = (
        config['Q_params']['config']['kernel_regularizer']['config']['l']
        <= 3e-4
        and
        config['Q_params']['config']['kernel_regularizer']['config']['l']
        != 3e-8)
    should_filter = not should_include
    return should_filter


def visualize(dataframe: pd.DataFrame, hue_key: str) -> None:
    x_key = 'sampler/total-samples'
    y_key = 'training/episode-reward-mean'

    x_axis_unit = 'millions'
    unit_labels = {'thousands': '1e3', 'millions': '1e6'}
    unit_values = {'thousands': 1e3, 'millions': 1e6}
    label_map = {
        x_key: f'samples [{unit_labels[x_axis_unit]}]',
        y_key: 'return',
    }

    if x_axis_unit == 'thousands':
        dataframe[x_key] /= 1e3
    elif x_axis_unit == 'millions':
        dataframe[x_key] /= 1e6
    else:
        raise NotImplementedError(
            f"TODO: x_axis_unit={x_axis_unit}")

    def postprocess_legend_labels(labels):
        try:
            int_labels = [int(label) for label in labels]
            all_ints = all(
                int_label == label
                for int_label, label in zip(int_labels, labels))
        except ValueError:
            all_ints = False

        if all_ints:
            return [f'{int_label:d}' for int_label in int_labels]
            # return [f'${hue_key}$={int_label:d}' for int_label in int_labels]

        try:
            float_labels = [float(label) for label in labels]
            use_scientific_notation = (
                abs(np.min(float_labels) - np.max(float_labels)) < 0.1)
        except ValueError:
            use_scientific_notation = False

        if use_scientific_notation:
            return [
                f'{float_label:.0e}'
                if float_label != 0
                else '0'
                for float_label in float_labels
                ]
            # return [
            #     f'${hue_key}$={float_label:.0e}' for float_label in float_labels]

        return labels

    hue_order = postprocess_legend_labels(
        sorted(dataframe[hue_key].unique()))
    dataframe[hue_key] = postprocess_legend_labels(dataframe[hue_key])

    figure_scale = 0.4
    figsize = figure_scale * np.array([7.2, 5.3])
    # figsize = figure_scale * plt.figaspect(2/3)
    # figsize = figure_scale * plt.figaspect(3/4)
    figure, axis = plt.subplots(1, 1, figsize=figsize)

    sns.lineplot(
        data=dataframe,
        x=x_key,
        y=y_key,
        hue=hue_key,
        hue_order=hue_order,
        ax=axis)

    handles, labels = axis.get_legend_handles_labels()
    labels = [label.replace('e-0', 'e-') for label in labels]
    # labels = postprocess_legend_labels(labels)
    axis.legend(
        handles=handles,
        labels=labels,
        handlelength=1.0,
        framealpha=0.6,
        loc='upper left',
        title=f'${hue_key}$')
    axis.set_xlabel(label_map[x_key])
    axis.set_ylabel(label_map[y_key])

    axis.set_ylim(-50.0, None)
    axis.set_xlim(0, 3e6 / unit_values[x_axis_unit])

    plt.yticks(rotation='vertical')
    # plt.tick_params(axis='both', direction='in')

    # thousand_formatter = FuncFormatter(lambda value, pos: f'{value / 1e3}e3')
    # axis.xaxis.set_major_formatter(thousand_formatter)

    plt.tight_layout()
    plt.savefig(
        f'/tmp/cartpole-{hue_key}-ablation.pdf',
        bbox_inches='tight',
        pad_inches=0.05)


# def filter_data(data: Dict[str, tune.Analysis]):
#     pass


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'visualization_type',
        choices=(
            'bbac_ensemble_size_ablation',
            'bbac_prior_loss_weight_ablation',
            'bbac_prior_scale_ablation'),
        type=str)
    return parser


def main() -> None:
    parser = get_parser()
    args = parser.parse_args()
    pprint(vars(args))

    visualization_type = args.visualization_type

    if visualization_type == 'bbac_ensemble_size_ablation':
        analysis = load_analysis(BBAC_ENSEMBLE_SIZE_ABLATION_PATH)
        hue_key = 'L'
        hue_getter = lambda config: config['Q_params']['config']['N']
        dataframe = load_dataframe(
            analysis,
            trial_filter=bbac_ensemble_size_ablation_trial_filter,
            hue_key=hue_key,
            hue_getter=hue_getter)
    elif visualization_type == 'bbac_prior_loss_weight_ablation':
        analysis = load_analysis(BBAC_PRIOR_LOSS_WEIGHT_ABLATION_PATH)
        hue_key = '\lambda'
        hue_getter = lambda config: (
            config['Q_params']['config']['kernel_regularizer']['config']['l'])
        dataframe = load_dataframe(
            analysis,
            trial_filter=bbac_prior_loss_weight_ablation_trial_filter,
            hue_key=hue_key,
            hue_getter=hue_getter)
    elif visualization_type == 'bbac_prior_scale_ablation':
        analysis = load_analysis(BBAC_PRIOR_SCALE_ABLATION_PATH)
        hue_key = '\sigma'
        hue_getter = lambda config: (
            config['Q_params']['config']['prior_scale'])
        dataframe = load_dataframe(
            analysis,
            trial_filter=bbac_prior_scale_ablation_trial_filter,
            hue_key=hue_key,
            hue_getter=hue_getter)
    else:
        raise ValueError(visualization_type)

    visualize(dataframe, hue_key=hue_key)

    print(dataframe)


if __name__ == '__main__':
    main()
