import pprint
import os
import argparse
import json
from collections import defaultdict

import rlpytorch.behavior_clone.global_consts as gc
from rlpytorch.behavior_clone.gen_dataset import get_all_files
import pprint


def process_json(json_file, frame_skip):
    js = json.load(open(json_file, 'r'))

    tick2cmds = defaultdict(list)
    tick2ins = defaultdict(list)
    last_tick_bucket = 0
    for entry in js:
        if entry is None:
            continue

        tick = int(entry['tick'][4:]) # + 1
        tick_bucket = tick // frame_skip
        last_tick_bucket = tick_bucket

        tick2ins[tick_bucket] = entry['instruction']

        if 'targets' not in entry:
            continue

        cmds = entry['targets']
        for cmd in cmds:
            cmd_type = gc.CmdTypes(cmd['cmd_type']).name.lower()
            if 'build' in cmd_type:
                cmd_type = 'build'

            tick2cmds[tick_bucket].append(cmd_type)

    for key, val in tick2cmds.items():
        tick2cmds[key] = sorted(val)

    return tick2cmds, tick2ins, last_tick_bucket


def process_replay(replay_file, player_id, frame_skip):
    lines = open(replay_file, 'r').readlines()
    tick2cmds = defaultdict(list)
    tick2ins = defaultdict(list)
    end_game = False
    last_tick_bucket = 0
    current_instruction = None
    next_instruction = None
    for i, l in enumerate(lines):
        attrs = l.split()
        if i == len(lines) - 1:
            # print(attrs)
            if 'WON' in attrs:
                end_game = True
                # print('End game!!!!!')
            # assert(i == len(lines) - 1)

        if len(attrs) < 4:
            continue

        cmd_type = attrs[0]

        # if 'CmdIssueInstruction' in cmd_type and attrs[-1] == '0':
        #     continue
        tick = int(attrs[1]) # + 1
        tick_bucket = tick // frame_skip
        if tick_bucket > last_tick_bucket:
            if next_instruction is not None:
                current_instruction = next_instruction
                next_instruction = None

        last_tick_bucket = tick_bucket

        # print(cmd_type)
        if 'CmdIssueInstruction' in cmd_type and attrs[-1] == '1':
            instruction = l.split('"')[-2]
            if current_instruction is None or tick <= 3:
                current_instruction = instruction
            else:
                next_instruction = instruction

        tick2ins[tick_bucket] = current_instruction

        unit_id = int(attrs[3])
        cmd_player_id = unit_id // 16000000

        if cmd_player_id != player_id:
            continue

        if 'Gather' in cmd_type:
            tick2cmds[tick_bucket].append('gather')
        if 'Attack' in cmd_type:
            tick2cmds[tick_bucket].append('attack')
        if 'Build' in cmd_type:
            tick2cmds[tick_bucket].append('build')
        if 'Move' in cmd_type:
            tick2cmds[tick_bucket].append('move')

    for key, val in tick2cmds.items():
        tick2cmds[key] = sorted(val)

    return tick2cmds, tick2ins, end_game, last_tick_bucket


def process_folder(folder, player, frame_skip):
    replay_files = get_all_files(folder, '.rep')
    total_num_frames = 0
    match_frames = 0

    num_cmd2num_game = defaultdict(int)
    num_cmd2finished_game = defaultdict(int)

    length_match = 0
    instruction_match = 0  # per game
    total_game = 0
    end_game = 0

    for replay_file in replay_files:
        # print('replay:')
        json_file = '%s.p%d.json' % (replay_file, player)
        if not os.path.exists(json_file):
            continue

        total_game += 1

        replay, replay_ins, end, replay_last_tick = process_replay(
            replay_file, player, frame_skip)
        json, json_ins, json_last_tick = process_json(
            json_file, frame_skip)

        # pprint.pprint(replay_ins)
        # pprint.pprint(json_ins)
        # print()
        if replay_ins == json_ins:
            instruction_match += 1
        else:
            print('instruction_mismatch, replay: %s' % replay_file)

        # print('replay length', replay_last_tick)
        # print('json length', json_last_tick)
        if replay_last_tick != json_last_tick:
            print('length mismatch: replay: %d, json %d, file: %s' %
                  (replay_last_tick, json_last_tick, json_file))
            # print('length mismatch:', json_file)
        else:
            length_match += 1

        end_game += end

        for key in replay.keys():
            total_num_frames += 1
            if replay[key] == json[key]:
                match_frames += (replay[key] == json[key])
            else:
                pass
        #         print('mismatch!:', key)
        #         print(replay[key])
        #         print(json[key])
        # print ('==========')
        total_cmd = sum([len(cmds) for cmds in replay.values()])
        # if total_cmd < 10:
        #     print('Warning: total_cmd:', total_cmd)
        #     print('replay:', replay_file)

        num_cmd2num_game[total_cmd] += 1
        num_cmd2finished_game[total_cmd] += end
        # break

    print('total frames:', total_num_frames)
    print('match frame:', match_frames)
    print('match rate:', match_frames / total_num_frames)

    print('instruction match rate:', instruction_match / total_game)
    print('length match rate:', length_match / total_game)
    print('total game: %d, end game: %d, percent: %f'
          % (total_game, end_game, end_game / total_game))

    # for num_cmd in sorted(num_cmd2num_game.keys()):
    #     num_game = num_cmd2num_game[num_cmd]
    #     finished_game = num_cmd2finished_game[num_cmd]
    #     print(num_cmd, ': ', num_game, ',', finished_game, ',', finished_game / num_game)

    # return num_cmd2num_game


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='sanity check for generated states')
    parser.add_argument('--replay', type=str)


    parser.add_argument('--folder', type=str)
    parser.add_argument('--player', type=int, default=0)
    parser.add_argument('--frame_skip', type=int, default=50)


    args = parser.parse_args()
    if args.folder:
        num_cmd2num_game = process_folder(args.folder, args.player, args.frame_skip)
    # pprint.pprint(num_cmd2num_game)


    if args.replay:
        js = args.replay + '.p0.json'
        replay, replay_ins, end_game, replay_len = process_replay(
            args.replay, args.player, args.frame_skip)
        js, json_ins, json_len = process_json(js, args.frame_skip)

        pprint.pprint(replay_ins)
        pprint.pprint(json_ins)
        print(json_ins == replay_ins)
        for key in sorted(replay_ins.keys()):
            rins = replay_ins[key]
            jins = json_ins[key]
            if rins != jins:
                print(key)
                print(rins)
                print(jins)
        print('----------')
        for key in sorted(json_ins.keys()):
            rins = replay_ins[key]
            jins = json_ins[key]
            if rins != jins:
                print(key)
                print('in replay:', rins)
                print('in json:', jins)

    # print(replay_len)
    # print(json_len)
