import torch
import torch.nn
import torch.nn.functional as F
from IPython import embed
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.cm import viridis
import imageio
from functools import partial
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset
from net import TransformerVision
import pandas as pd
from evals import eval_miniworld
from miniworld_env import (
    MiniworldOptPolicy,
    MiniworldRandCommit,
    MiniworldRandPolicy,
    MiniworldTransformerController,
    MiniworldEnvVec,
)
import gymnasium as gym

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding")
    parser.add_argument("--head", type=int, required=False, default=1, help="Embedding")
    parser.add_argument("--layer", type=int, required=False, default=3, help="Embedding")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Dimension")
    parser.add_argument("--epoch", type=int, required=False, default=-1, help="Epoch to evaluate")
    parser.add_argument("--opt", type=int, required=False, default=0, help="Optimizer type")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--trans", type=int, required=False, default=0, help="Transformer type")
    parser.add_argument("--hor", type=int, required=False, default=-1, help="Episode horizon (for mdp)")
    parser.add_argument("--envname", type=str, required=False, default='mini', help="Environment name")
    parser.add_argument('--dataset_prefix', type=str, required=False, default='', help="Dataset prefix")
    parser.add_argument('--model_prefix', type=str, required=False, default='models', help="Model prefix")
    parser.add_argument(
        "--eval_with_expert_trajs",
        type=lambda x: (str(x).lower() == 'true'),
        required=False,
        default=False,
        help="Whether to evaluate with expert context")
    parser.add_argument(
        "--eval_in_train_tasks",
        type=lambda x: (str(x).lower() == 'true'),
        required=False,
        default=False,
        help="Whether to evaluate in train tasks")

    parser.add_argument("--sample", type=int, required=False, default=0, help="Read in the data?")
    parser.add_argument("--small", action='store_true', help="Use small images")
    parser.add_argument("--shift", action='store_true', help="Whether to shift the images")
    parser.add_argument('--full', default=False, action='store_true')
    parser.add_argument('--shuffle', default=False, action='store_true')
    parser.add_argument('--test', default=False, action='store_true')
    parser.add_argument("--include_partial_hist", default=False, action='store_true')
    parser.add_argument("--grow_context", default=False, action='store_true')
    parser.add_argument("--random_init", default=False, action='store_true')

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    lr = args['lr']
    epoch = args['epoch']
    shuffle = args['shuffle']
    full = args['full']
    opt = args['opt']
    dropout = args['dropout']
    trans = args['trans']
    horizon = args['hor']
    use_test = args['test']
    envname = args['envname']
    dataset_prefix = args['dataset_prefix']
    model_prefix = args['model_prefix']

    sample = args['sample']
    small = args['small']
    shift = args['shift']
    include_partial_hist = args['include_partial_hist']
    grow_context = args['grow_context']
    random_init = args['random_init']
    eval_with_expert_trajs = args['eval_with_expert_trajs']
    eval_in_train_tasks = args['eval_in_train_tasks']

    use_net = False
    save_video = True

    if horizon < 0:
        horizon = H

    filename = f'{envname}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_sample{sample}_shift{shift}'

    config = {
        'H': H,
        'du': 4,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'dropout': dropout,
        'small': small,
        'shift': False,     # Do not shift images at test time
    }

    model = TransformerVision(config).to(device)

    if epoch < 0:       model_path = f'{model_prefix}/{filename}.pt'
    else:               model_path = f'{model_prefix}/{filename}_epoch{epoch}.pt'
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)
    model.eval()
    model.config['full'] = False
    
    n_eval = 100
    H_eval = 10
    
    traj_str = 'expert_trajs' if eval_with_expert_trajs else 'trajs'
    traj_str += '_train' if eval_in_train_tasks else '_test'
    eval_filepath = f'datasets/{traj_str}_{envname}_envs{n_eval}_H{horizon}_small{small}.pkl'
    save_filename = f'{filename}_hor{horizon}_test{use_test}_iph{include_partial_hist}_gc{grow_context}_train{eval_in_train_tasks}_expert{eval_with_expert_trajs}_ri{random_init}.pkl'

    file = open(eval_filepath, 'rb')
    eval_trajs = pickle.load(file)
    file.close()

    n_eval = min(n_eval, len(eval_trajs))

    evals_filename = f"evals_epoch{epoch}"
    if not os.path.exists(f'figs/{evals_filename}'):
        os.makedirs(f'figs/{evals_filename}', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/pess'):
        os.makedirs(f'figs/{evals_filename}/pess', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/bar'):
        os.makedirs(f'figs/{evals_filename}/bar', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/lines'):
        os.makedirs(f'figs/{evals_filename}/lines', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/online'):
        os.makedirs(f'figs/{evals_filename}/online', exist_ok=True)

    if save_video and not os.path.exists(f'videos/{save_filename}/{evals_filename}'):
        os.makedirs(f'videos/{save_filename}/{evals_filename}', exist_ok=True)

    os.makedirs(f'data_results/{evals_filename}/online/', exist_ok=True)

    config = { 
        'Heps': 40,
        'horizon': horizon,
        'H': H,
        'n_eval': min(20, n_eval),
        'include_partial_hist': include_partial_hist,
        'grow_context': grow_context,
        'filename': f'videos/{save_filename}/{evals_filename}' + '/{controller}_env{env_id}_ep{ep}_online.gif',
        'envname': envname,
    }
    online_results = eval_miniworld.online_vec(eval_trajs, model, **config)
    plt.savefig(f'figs/{evals_filename}/online/{save_filename}.png')
    plt.clf()

    with open(f'data_results/{evals_filename}/online/{save_filename}', 'wb') as f:
        pickle.dump(online_results, f)

    all_xs = []
    all_rs_lnr = []
    all_rs_lnr_greedy = []

    envs = []
    trajs = []

    if envname.startswith('mini_two_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMulti'
    elif envname.startswith('mini_three_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiThreeBoxes'
    elif envname.startswith('mini_four_boxes'):
        env_name = 'MiniWorld-OneRoomS6FastMultiFourBoxes'
    elif envname.startswith('mini_blue'):
        env_name = 'MiniWorld-OneRoomS6FastMultiBlue'
    else:
        raise ValueError("Invalid envname")

    # OFFLINE EVALUATION SINGLE
    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")

        env = gym.make(f'{env_name}FixedInit-v0')
        env.set_task(env_id=8000 + i_eval)
        envs.append(env)

        traj = eval_trajs[i_eval]
        trajs.append(traj)

    print("Running darkroom offline evaluations in parallel")
    vec_env = MiniworldEnvVec(envs)
    lnr_filename_template = partial(
        'videos/{save_filename}/{evals_filename}/{controller}_env{env_id}_offline.gif'.format,
        controller='lnr',
        save_filename=save_filename,
        evals_filename=evals_filename)
    lnr = MiniworldTransformerController(
        model,
        batch_size=n_eval,
        sample=True,
        save_video=save_video,
        filename_template=lnr_filename_template)
    lnr_greedy_filename_template = partial(
        'videos/{save_filename}/{evals_filename}/{controller}_env{env_id}_offline.gif'.format,
        controller='lnr_greedy',
        save_filename=save_filename,
        evals_filename=evals_filename)
    lnr_greedy = MiniworldTransformerController(
        model,
        batch_size=n_eval,
        sample=False,
        save_video=save_video,
        filename_template=lnr_greedy_filename_template)
    opt_filename_template = partial(
        'videos/{save_filename}/{evals_filename}/{controller}_env{env_id}_offline.gif'.format,
        controller='opt',
        save_filename=save_filename,
        evals_filename=evals_filename)
    opt = MiniworldOptPolicy(
        vec_env, batch_size=n_eval, save_video=True, filename_template=opt_filename_template)
    rand = MiniworldRandPolicy(vec_env, batch_size=n_eval)

    batch = eval_miniworld.process_trajs_into_batch(trajs, lnr)
    for i in range(n_eval):
        filename = f'videos/{save_filename}/{evals_filename}/context_env{i}_offline.gif'
        filepath = trajs[i]['rollin_obs']
        images = np.load(filepath)
        imageio.mimsave(filename, images)
    lnr.set_batch(batch)
    lnr_greedy.set_batch(batch)
    opt.set_batch(batch)
    rand.set_batch(batch)

    (
        xs_lnr,
        poses_lnr,
        angle_lnr,
        us_lnr,
        xps_lnr,
        next_poses_lnr,
        next_angles_lnr,
        rs_lnr,
    ) = vec_env.deploy_eval(
        lnr, include_partial_hist=include_partial_hist, grow_context=grow_context)
    (
        xs_lnr_greedy,
        poses_lnr_greedy,
        angles_lnr_greedy,
        us_lnr_greedy,
        xps_lnr_greedy,
        next_poses_lnr_greedy,
        next_angle_lnr_greedy,
        rs_lnr_greedy,
    ) = vec_env.deploy_eval(
        lnr_greedy, include_partial_hist=include_partial_hist, grow_context=grow_context)
    (
        xs_opt,
        poses_opt,
        angles_opt,
        us_opt,
        xps_opt,
        next_poses_opt,
        next_angles_opt,
        rs_opt,
    ) = vec_env.deploy_eval(opt)
    (
        xs_rand,
        poses_rand,
        angles_rand,
        us_rand,
        xps_rand,
        next_poses_rand,
        next_angles_rand,
        rs_rand,
    ) = vec_env.deploy_eval(rand)

    print("Computing random-commit performance from opt and rand controllers")
    rs_rand_commit = rs_rand.copy()
    for i_eval in range(n_eval):
        found_goal_ts = np.where(rs_rand[i_eval] > 0)[0]
        if len(found_goal_ts) > 0:
            found_goal_t = found_goal_ts[0]
            rs_rand_commit[i_eval, found_goal_t+1:] = rs_opt[i_eval, found_goal_t+1:]

    all_rs_lnr = np.sum(rs_lnr, axis=-1)
    all_rs_lnr_greedy = np.sum(rs_lnr_greedy, axis=-1)
    all_rs_opt = np.sum(rs_opt, axis=-1)
    all_rs_rand_commit = np.sum(rs_rand_commit, axis=-1)
    all_rs_rand = np.sum(rs_rand, axis=-1)

    baselines = {
        'opt': np.array(all_rs_opt),
        'lnr': np.array(all_rs_lnr),
        'lnr_greedy': np.array(all_rs_lnr_greedy),
        'rand_commit': np.array(all_rs_rand_commit),
        'rand': np.array(all_rs_rand),
    }

    baselines_means = {
        k: np.mean(v) for k, v in baselines.items()
    }

    colors = plt.cm.viridis(np.linspace(0, 1, len(baselines_means)))
    plt.bar(baselines_means.keys(), baselines_means.values(), color=colors)
    plt.title(f'Mean Reward on {n_eval} Trajectories')
    plt.savefig(f'figs/{evals_filename}/bar/{save_filename}_bar.png')
    plt.clf()

    os.makedirs(f'data_results/{evals_filename}/bar', exist_ok=True)
    with open(f'data_results/{evals_filename}/bar/{save_filename}_bar.pkl', 'wb') as f:
        pickle.dump(baselines, f)
