#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : get_best_hps.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

"""
Usage: python3 scripts/get_best_hps.py scripts/bash/${CONFIG_NAME}.py -du ${DUMP_DIR}
"""
import os, sys

sys.path.append(os.getcwd())

import argparse
import collections
import numpy as np
import os.path as osp
import pickle

from sweeping import load_source
from sweeping import get_hps_choices, get_hps_dict, get_selected_names, get_num_choices
from sweeping import get_name_by_hps_dict, hps_kv_to_str
from sweeping import get_stats, get_stats_str, get_summary, is_better, print_list
from sweeping.utils import f2str

parser = argparse.ArgumentParser()

parser.add_argument("config", type=str, help="The config file for sweeping params")
parser.add_argument(
    "--dump-dir",
    "-du",
    type=str,
    default="dumps",
    help="The dump dir for all sweeping exps",
)
parser.add_argument(
    "--summary-fname",
    "-sf",
    type=str,
    default="progress.csv",  # "summary.json",
    help="The name of the summary file",
)
parser.add_argument(
    "--allow-multi-runs",
    "-al",
    action="store_true",
    help="allow multi runs in one summary.json file",
)
parser.add_argument(
    "--smaller-better",
    "-sb",
    action="store_true",
    help="smaller is better, useful for regression tasks",
)
parser.add_argument(
    "--acc",
    "-acc",
    action="store_true",
    help="use acc key (backward compatible)",
)
parser.add_argument(
    "--control-variable",
    "-cv",
    action="store_true",
    help="Control variable for sweeping",
)

args = parser.parse_args()


def get_results(ind, hps_dict, print_stat_dict=True):
    name = get_name_by_hps_dict(hps_dict, prefix="sweep")
    dump_dir = osp.join(args.dump_dir, name)
    summary_file = osp.join(dump_dir, args.summary_fname)
    runs, summary = get_summary(
        summary_file,
        smaller_better=args.smaller_better,
        key="acc" if args.acc else "res",
        complete_runs=True,
    )
    best_epoch_id, best_train_accs, best_val_accs, best_test_accs, avg_time = summary
    stat_dict = get_stats(best_test_accs)
    if print_stat_dict:
        mean = f2str(stat_dict["mean"] * 100, precision=2)
        std = f2str(stat_dict["std"] * 100, precision=2)
        print(
            f"{ind}: {mean} | {std}, {get_stats_str(stat_dict)}, name={name}, "
            f"Avg Epoch time={avg_time:.4f} s"
        )

        print_list("Epoch", best_epoch_id)
        print_list("Train", best_train_accs)
        print_list("Val", best_val_accs)
        print_list("Test", best_test_accs)
    return best_test_accs, stat_dict


def main():
    if args.config.endswith(".py"):
        config = load_source(args.config).CONFIG
        is_sampled = False
    else:
        # The sampled ones
        assert args.config.endswith(".pkl")
        is_sampled = True
        with open(args.config, "rb") as f:
            read_dict = pickle.load(f)
            config = read_dict["config"]
            sampled_inds = read_dict["sampled_inds"]

    cv = args.control_variable
    hps_choices = list(get_hps_choices(config, control_variable=cv))
    total = get_num_choices(hps_choices, control_variable=cv)
    all_inds = list(range(total))
    if is_sampled:
        all_inds = sampled_inds
    assert len(all_inds) > 0

    best_mean_stat = None
    best_hps = None
    # multi-hps in one combo count as one hp
    total_num_hps = sum(list(map(lambda x: x.length, hps_choices)))
    res_dict = [collections.defaultdict(list) for k in range(total_num_hps)]

    for i, ind in enumerate(all_inds):
        hps_dict = get_hps_dict(ind, hps_choices, control_variable=cv)
        try:
            _, stat_dict = get_results(i, hps_dict)
            avg = stat_dict["mean"]
            if best_hps is None or is_better(
                avg, best_mean_stat["mean"], smaller_better=args.smaller_better
            ):
                best_mean_stat = stat_dict
                best_hps = hps_dict
            if not is_sampled:
                selected_names = get_selected_names(
                    ind, hps_choices, control_variable=cv
                )
                for res, kv_pairs in zip(res_dict, selected_names):
                    res[kv_pairs[1]].append(stat_dict if cv else avg)
        except Exception as e:
            pass

    print(f"best hps: {best_hps}")
    print(f"best mean stat: {get_stats_str(best_mean_stat)}")
    if not is_sampled:
        names = [kv[0] for kv in selected_names]
        assert len(names) == total_num_hps
        for name, res in zip(names, res_dict):
            if len(res) > 1:
                print(name)
                for k, v in res.items():
                    if cv:  # Control variable exps only take 1 run for each choice
                        stats = v[0]
                    else:
                        stats = get_stats(v)
                    print(f"selected: {k}, stat: {get_stats_str(stats)}")


if __name__ == "__main__":
    main()
