import argparse
import os
import json
import pdb

import numpy as np
import pandas as pd


def get_expdirs(topdir):
    results_dir = os.path.join(topdir, "results")

    return [os.path.join(results_dir, d) for d in os.listdir(results_dir) if
            os.path.isdir(os.path.join(results_dir, d))]


def get_all_eval_result_directories(expdir, eval_dir_suffix="eval_results"):
    assert any([f.endswith(".yaml") for f in os.listdir(expdir)]), "This is not an expdir."
    eval_results_dirs_list = list()
    for root, dirs, files in os.walk(expdir):
        for d in dirs:
            if d.endswith(eval_dir_suffix):
                eval_results_dirs_list.append(os.path.join(root, d))

    return sorted(eval_results_dirs_list)


def step2metric_logs(expdir):
    # Get all the eval results directories.
    stepnum2metric_logs_dict = dict()
    eval_results_dirs_list = get_all_eval_result_directories(expdir=expdir)
    for eval_results_dir in eval_results_dirs_list:
        # Read the metrics dict.
        metrics_json_path = os.path.join(eval_results_dir, "metric_logs.json")
        with open(metrics_json_path, "r") as f:
            curr_metrics_json = json.load(f)
            curr_global_step = curr_metrics_json['trainer/global_step']
        stepnum2metric_logs_dict[curr_global_step] = curr_metrics_json

    return stepnum2metric_logs_dict


def get_all_metric_logs_in_expdirs(expdirs):
    return_dict = dict()
    for expdir in expdirs:
        expdir_metric_logs = step2metric_logs(expdir=expdir)
        return_dict[expdir] = expdir_metric_logs
    return return_dict


def aggregate_metric_logs(expdir):
    step2metrics_dict = step2metric_logs(expdir)

    # Get all keys.
    all_steps = [k for k, v in step2metrics_dict.items()]
    all_metric_keys = sorted(step2metrics_dict[all_steps[0]].keys())

    # Iterate over the steps and record all keys values in a list.
    metric_key2value_list = {k: list() for k in all_metric_keys}
    for step, metrics_dict in step2metrics_dict.items():
        for metric_key in all_metric_keys:
            metric_key2value_list[metric_key].append(metrics_dict[metric_key])

    # Turn this into pandas dataframe.
    results_df = pd.DataFrame(metric_key2value_list)

    # Save.
    save_dir = os.path.join(expdir, "aggregated_results")
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "aggregated_metric_logs.csv")
    results_df.to_csv(save_path)


def get_aggregated_metrics_under_topdir(topdir):
    expdir2metrics_df = dict()

    expdirs = get_expdirs(topdir=topdir)
    for expdir in expdirs:
        aggregated_metrics_path = os.path.join(expdir, "aggregated_results", "aggregated_metric_logs.csv")
        curr_metrics_df = pd.read_csv(aggregated_metrics_path)
        expdir2metrics_df[expdir] = curr_metrics_df
    return expdir2metrics_df


def get_expdir2global_step2per_example_metrics(topdir):
    expdirs = get_expdirs(topdir=topdir)

    expdir2per_example_jsons = dict()
    for expdir in expdirs:
        # Gather all per-example jsons.
        curr_per_example_jsons = list()
        for root, dirs, files in os.walk(expdir):
            for file in files:
                filepath = os.path.join(root, file)
                if "per_example_metrics" in filepath and filepath.endswith(".json"):
                    curr_per_example_jsons.append(filepath)
        expdir2per_example_jsons[expdir] = curr_per_example_jsons

    expdir2gs2metrics = dict()
    for expdir, json_files in expdir2per_example_jsons.items():
        gs2metrics = dict()
        for json_file in json_files:
            # Extract global step.
            curr_global_step = int(json_file.split("__step_")[1].split("__2022")[0])

            # Get the metrics.
            with open(json_file, "r") as f:
                metrics = json.load(f)
            gs2metrics[curr_global_step] = metrics
        expdir2gs2metrics[expdir] = gs2metrics

    return expdir2gs2metrics


def get_expdir2per_example_metrics(topdir):
    expdirs = get_expdirs(topdir)
    is_per_example_metric_json = lambda x: "per_example_metrics" in filepath and filepath.endswith(".json")

    expdir2per_example_metrics = {k: list() for k in expdirs}
    for expdir in expdirs:
        for root, dirs, files in os.walk(expdir):
            for file in files:
                filepath = os.path.join(root, file)
                if is_per_example_metric_json(filepath):
                    # Get the step.
                    curr_step = int(filepath.split("__step_")[1].split("__2022")[0])

                    # Read the metrics.
                    with open(filepath, "r") as f:
                        metrics = json.load(f)

                    # Compute the average values.
                    keys = metrics[0].keys()
                    get_avg_key = lambda x: f"avg_{x}"
                    avg_metrics = {get_avg_key(k): list() for k in keys}
                    for ex in metrics:
                        for k, v in ex.items():
                            avg_metrics[get_avg_key(k)].append(v)
                    avg_metrics = {k: np.array(v).mean() for k, v in avg_metrics.items()}

                    # Add to expdir.
                    expdir2per_example_metrics[expdir].append(dict(step=curr_step,
                                                                   filepath=filepath,
                                                                   topdir=topdir,
                                                                   metrics=metrics,
                                                                   avg_metrics=avg_metrics))

    return expdir2per_example_metrics


def aggregate_results(topdir):
    # Get all the expdirs.
    expdirs = get_expdirs(topdir=topdir)

    # Gather per-expdir results.
    for expdir in expdirs:
        print(f"Aggregating results within {expdir}.")
        aggregate_metric_logs(expdir)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.mains.tasks.aggregate_results --topdir="runs/dev/tmp"
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--topdir", type=str,
                        help="The topdir that contains the expdirs. ")
    args = parser.parse_args()

    aggregate_results(topdir=args.topdir)
