import numpy as np
from probtransformer.utils.configuration import Config


def load_experiment(expt_dict, eval_keys, train_keys):
    cfg = Config(config_file=expt_dict['dir'] / 'config.yml')

    eval_summary = np.load(expt_dict['dir'] / 'eval_summary.npy', allow_pickle=True).tolist()
    try:
        file = open(expt_dict['dir'] / "log_file.txt")
        test_dict = {l.split(' ')[1].split(':')[0]: float(l.split(' ')[2].replace(',', '.')) for l in file if
                     "_test" in l}
        print("load test dict")
    except:
        test_dict = {}
        print("load test dict failed")

    print("eval_summary keys:", eval_summary.keys())

    eval_steps = np.cumsum(eval_summary["step_count"][:, 0])

    for key in eval_keys:
        # if "_test" in key:
        #     try:
        #         expt_dict[key] = {"step": [0,1], "value": [test_dict[key], test_dict[key]]}
        #     except:
        #         print("issue with loading :", key, expt_dict['dir'])
        # else:
        # try:
        if key not in eval_summary.keys():
            continue

        value = eval_summary[key][:, 0]
        if cfg.train.eval_freq > 1 and key not in ['kappa', 'learning_rate']:
            value = value.repeat(cfg.train.eval_freq)[:eval_steps.shape[0]]

        expt_dict[key] = {"step": eval_steps, "value": value}
        # except:
        #     print("issue with loading :", key, expt_dict['dir'])

    try:
        train_summary = np.load(expt_dict['dir'] / "train_summary.npy", allow_pickle=True).tolist()
        print("train_summary keys:", train_summary.keys())

        train_steps = train_summary["step"][:, 0][1::10]
        for key in train_keys:
            if "_test" in key:
                try:
                    expt_dict[key] = {"step": [0, 1], "value": [test_dict[key], test_dict[key]]}
                except:
                    print("issue with loading :", key, expt_dict['dir'])
            else:
                try:
                    if train_summary[key].shape[1] > 0 and key not in ['kappa', 'learning_rate']:
                        for i in range(train_summary[key].shape[1]):
                            expt_dict[key] = {"step": train_steps, "value": train_summary[key][:, i][1::10],
                                              "raw": train_summary[key][:, i]}
                    else:
                        expt_dict[key] = {"step": train_steps, "value": train_summary[key][:, 0][1::10],
                                          "raw": train_summary[key][:, 0]}
                except:
                    print("issue with loading :", key, expt_dict['dir'])
    except:
        print("issue with loading :", expt_dict['dir'])

    return expt_dict
