"""
Script that runs all the quantitative tests and populates a table with them
"""
import contextlib
import logging
import os
import sys
from collections import defaultdict
from pathlib import Path

import dill
import numpy as np
import torch
from tqdm import tqdm

from run import train


@contextlib.contextmanager
def nostdout():
    save_stdout = sys.stdout
    f = open(os.devnull, 'w')
    sys.stdout = f
    yield
    sys.stdout = save_stdout


def main():
    logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
    os.environ["WANDB_SILENT"] = "True"

    dir_configs = Path('configs/experiments/quantitative_results')
    path_save = '/path/to/save/results.pth'
    num_seeds = 10

    if os.path.isfile(path_save):
        with open(path_save, 'rb') as f:
            dict_results = dill.load(f)
    else:
        dict_results = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))

    num_experiments = 0
    for root_dir, cur_dir, files in os.walk(dir_configs):
        if 'qualitative_results' not in root_dir:
            num_experiments += len(files)
    total_elements = num_experiments * num_seeds

    with tqdm(total=total_elements) as pbar:
        for i in range(num_seeds):
            for dataset in dir_configs.iterdir():
                if dataset.is_dir() and not 'qualitative_results' in str(dataset):
                    for duration in dataset.iterdir():
                        if duration.is_dir():
                            for file_name in duration.iterdir():
                                if file_name.is_file():

                                    task = 'standard'
                                    name_metric = 'future_prediction'
                                    if 'uniform' in file_name.stem:
                                        task = 'uniform'
                                    elif 'interpolation' in file_name.stem:
                                        task = 'interpolation'
                                        if not ('vrnn' in file_name.stem or 'trajectron' in file_name.stem):
                                            name_metric = 'interpolation_prediction'
                                    elif 'past' in file_name.stem:
                                        task = 'past'
                                        if not ('vrnn' in file_name.stem or 'trajectron' in file_name.stem):
                                            name_metric = 'past_prediction'

                                    exp = 'trajrep'
                                    if 'noreencode' in file_name.stem:
                                        exp = 'trajrep_noreencode'
                                    elif 'vrnn' in file_name.stem:
                                        exp = 'vrnn'
                                    elif 'uniftrain' in file_name.stem:
                                        exp = 'trajectron_uniftrain'
                                    elif 'trajectron' in file_name.stem:
                                        exp = 'trajectron'

                                    if len(dict_results[duration.stem][dataset.stem][exp][task]) < num_seeds:
                                        sys.argv = ['',
                                                    f'+experiments=quantitative_results/{dataset.name}/{duration.name}/'
                                                    f'{file_name.stem}', '+model.save_to_tmp=True', f'seed={i}',
                                                    '++wandb.save=False', '++verbose=False', '++wandb.offline=True']
                                        with nostdout():
                                            train()
                                        val = torch.load('/tmp/tmp_results.pth')

                                        # dict_results[duration.stem][dataset.stem][exp][task]. \
                                        #     append(val[name_metric]['all_steps'].item())

                                    pbar.update(1)

                # Save more frequently
                with open(path_save, 'wb') as f:
                    dill.dump(dict_results, f)

    durations = ['long', 'short']
    # experiments = ['vrnn', 'trajectron_uniftrain', 'trajectron', 'trajrep_noreencode', 'trajrep']
    experiments = ['vrnn', 'trajectron_uniftrain', 'trajectron', 'trajrep_noreencode', 'trajrep']
    datasets = ['finegym', 'diving48', 'fisv']
    # tasks = ['uniform', 'standard', 'past', 'interpolation']
    tasks = ['standard', 'past', 'interpolation']

    experiments_show = {
        'vrnn': '\\textbf{VRNN} \cite{vrnn}',
        'trajectron_uniftrain': '\\textbf{Trajectron++ uni.} \cite{salzmann2020trajectron++} ',
        'trajectron': '\\textbf{Trajectron++} \cite{salzmann2020trajectron++} ',
        'trajrep_noreencode': '\\textbf{TrajRep - re-enc.} ',
        'trajrep': '\\textbf{TrajRep (ours)}',
    }

    # Do not show std
    for duration in durations:
        print(f'\n\nTable {duration}')
        for experiment in experiments:
            print('')
            print(experiments_show[experiment], end='\t')
            for dataset in datasets:
                print('&', end='\t')
                for task in tasks:
                    val = np.array(dict_results[duration][dataset][experiment][task]) * 100
                    mean = val.mean()
                    if np.isnan(mean):
                        print(f'& -', end='')
                    else:
                        extra_mean = '\\ ' if mean < 10 else ''
                        print(f'& \s{{ {extra_mean}{mean:.02f} }}', end='')
            print('\\\\', end='')

    # Show std
    for duration in durations:
        for dataset in datasets:
            print(f'\n\nTable {duration} {dataset}')
            for experiment in experiments:
                print('')
                print(experiments_show[experiment], end='\t')
                print('&', end='\t')
                for task in tasks:
                    val = np.array(dict_results[duration][dataset][experiment][task]) * 100
                    mean = val.mean()
                    std = val.std()
                    if np.isnan(mean):
                        print(f'& -', end='')
                    else:
                        extra_mean = '\\ ' if mean < 10 else ''
                        print(f'& {extra_mean}{mean:.02f} ({std:0.02f})', end='')
                print('\\\\', end='')


if __name__ == '__main__':
    main()
