from collections import defaultdict
import os
import argparse
import pprint
import numpy as np


def parse_log(filename):
    loss_names = [
        'loss',
        'cmd_type_loss',
        'move_loss',
        'attack_loss',
        'gather_loss',
        'build_unit_loss',
        'build_building_loss'
    ]

    train_loss = defaultdict(list)
    valid_loss = defaultdict(list)
    lines = open(filename, 'r').readlines()
    loss_dict = None

    for l in lines:
        l = l.strip()
        if l.startswith('eval'):
            loss_dict = valid_loss
        elif l.startswith('train'):
            loss_dict = train_loss

        for name in loss_names:
            if l.startswith(name):
                val = float(l.split(' ')[-1])
                loss_dict[name].append(val)
            elif ':'+name in l:
                # print(l)
                # print(l.split(':')[3][:-5])
                val = float(l.split(':')[3][:-5])
                if loss_dict is None:
                    loss_dict = train_loss
                loss_dict[name].append(val)
                # print(loss_dict)

        if 'Time spent' in l:
            # print(l)
            if loss_dict is train_loss:
                # print('change to valid')
                loss_dict = valid_loss
            elif loss_dict is valid_loss:
                # print('change to train')
                loss_dict = train_loss

    # print(valid_loss['loss'])
    if len(valid_loss) == 0:
        return 100, 0

    min_loss = min(valid_loss['loss'])
    min_epoch = np.argmin(valid_loss['loss']) + 1
    return min_loss, min_epoch, len(valid_loss['loss'])


def average_across_seed(logs):
    new_logs = defaultdict(list)
    for k, v in logs.items():
        s = k.rsplit('_', 1)
        if len(s) == 2:
            name, seed = s
        elif len(s) == 1:
            name = 'default'
            seed = s[0]
        if not seed.startswith('s'):
            print('no multiple seeds, omit averaging')

        new_logs[name].append(v[0])

    for k in new_logs:
        vals = new_logs[k]
        new_logs[k] = [np.mean(vals), np.std(vals)]

    l = list(new_logs.items())
    l = sorted(l, key=lambda x: -x[1][0])
    for k, (v, std) in l:
        print('%s: %.4f, %.4f' % (k, v, std))

    # pprint.pprint(l, width=150)
    return new_logs


def parse_from_root(root):
    logs = {}
    root = os.path.abspath(root)
    for exp in os.listdir(root):
        exp_folder = os.path.join(root, exp)
        if os.path.isdir(exp_folder):
            log_file = os.path.join(exp_folder, 'train.log')
            if os.path.exists(log_file):
                logs[exp] = parse_log(log_file)

    l = list(logs.items())
    l = sorted(l, key=lambda x: x[0])
    # l = sorted(l, key=lambda x: -x[1][0])
    pprint.pprint(l, width=150)

    average_across_seed(logs)

    min_key = None
    min_val = 1000
    for key, val in logs.items():
        if val[0] < min_val:
            min_val = val[0]
            min_key = key

    print('best model: %s, perf: %f' % (min_key, min_val))


def parse_generator_log(filename):
    if not os.path.exists(filename):
        return 100, 100, -1, -1
    lines = open(filename, 'r').readlines()
    min_nll = 100
    min_eval_nll = 100
    min_epoch = -1
    epoch = 0
    for i, l in enumerate(lines):
        l = l.strip()
        if l.startswith('val epoch'):
            epoch += 1
            nll = float(lines[i+1].strip().split()[-1])
            if nll < min_nll:
                min_nll = nll
                min_epoch = epoch
        if l.startswith('eval epoch') and epoch == min_epoch:
            nll = float(lines[i+1].strip().split()[-1])
            # if nll < min_nll:
            min_eval_nll = nll

    return min_nll, min_eval_nll, min_epoch, epoch


def average_generator_across_seed(logs):
    new_logs = defaultdict(list)
    eval_logs = defaultdict(list)
    for k, v in logs.items():
        s = k.rsplit('_', 1)
        if len(s) == 2:
            name, seed = s
        elif len(s) == 1:
            name = 'default'
            seed = s[0]
        if not seed.startswith('s'):
            print('no multiple seeds, omit averaging')

        new_logs[name].append(v[0])
        eval_logs[name].append(v[1])

    for k in new_logs:
        vals = new_logs[k]
        new_logs[k] = [np.mean(vals), np.std(vals)]
        eval_logs[k] = [np.mean(eval_logs[k]), np.std(eval_logs[k])]

    l = list(new_logs.items())
    l = sorted(l, key=lambda x: -x[1][0])
    for k, (v, std) in l:
        print('%s: %.4f, %.4f, eval: %.4f, %.4f'
              % (k, v, std, eval_logs[k][0], eval_logs[k][1]))

    # pprint.pprint(l, width=150)
    return new_logs


def parse_generator(root):
    logs = {}
    root = os.path.abspath(root)
    for exp in os.listdir(root):
        exp_folder = os.path.join(root, exp)
        if os.path.isdir(exp_folder):
            log_file = os.path.join(exp_folder, 'train.log')
            logs[exp] = parse_generator_log(log_file)

    l = list(logs.items())
    l = sorted(l, key=lambda x: x[0])
    # l = sorted(l, key=lambda x: -x[1][0])
    pprint.pprint(l, width=150)
    print('=====avg, stderr======')
    average_generator_across_seed(logs)

    min_key = None
    min_val = 1000
    for key, (val, _, _, _) in logs.items():
        if val < min_val:
            min_val = val
            min_key = key

    print('best model: %s, perf: %f' % (min_key, min_val))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', type=str)
    parser.add_argument('--type', type=str)
    args = parser.parse_args()

    if args.type == 'executor':
        parse_from_root(args.root)
    elif args.type == 'coach':
        parse_generator(args.root)
