from darkroom_env import DarkroomEnv, DarkroomEnvStitch, DarkroomEnvPermuted, DarkroomEnvVec, DarkroomOptPolicy, DarkroomTransformerController, RandCommit
import torch
import numpy as np
import scipy
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib.cm import viridis
from IPython import embed
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def deploy_online(env, controller, save_video=False, **kwargs):


    Heps = kwargs['Heps']
    H = kwargs['H']
    horizon = kwargs['horizon']
    include_partial_hist = kwargs['include_partial_hist']
    grow_context = kwargs['grow_context']
    filename = kwargs['filename']
    i_eval = kwargs['i_eval']
    assert H % horizon == 0

    ctx_rollouts = H // horizon

    rollin_xs = torch.zeros((1, ctx_rollouts, horizon, env.dx)).float().to(device)
    rollin_us = torch.zeros((1, ctx_rollouts, horizon, env.du)).float().to(device)
    rollin_xps = torch.zeros((1, ctx_rollouts, horizon, env.dx)).float().to(device)
    rollin_rs = torch.zeros((1, ctx_rollouts, horizon, 1)).float().to(device)
    
    cum_means = []
    for i in range(ctx_rollouts):
        batch = {
            'rollin_xs': rollin_xs[:,:i,:,:].reshape(1, -1, env.dx),
            'rollin_us': rollin_us[:,:i,:].reshape(1, -1, env.du),
            'rollin_xps': rollin_xps[:,:i,:,:].reshape(1, -1, env.dx),
            'rollin_rs': rollin_rs[:,:i,:,:].reshape(1, -1, 1),
        }
        controller.set_batch(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(controller)
        rollin_xs[:,i,:,:] = torch.tensor(xs_lnr)
        rollin_us[:,i,:,:] = torch.tensor(us_lnr)
        rollin_xps[:,i,:,:] =  torch.tensor(xps_lnr)
        rollin_rs[:,i,:,:] = torch.tensor(rs_lnr[None,:, None])

        cum_means.append(np.sum(rs_lnr))

        # print("\n")
    
    for h_ep in range(ctx_rollouts, Heps):

        # Reshape the batch as a singular length H = ctx_rollouts * horizon sequence.
        batch = {
            'rollin_xs': rollin_xs.reshape(1, -1, env.dx),
            'rollin_us': rollin_us.reshape(1, -1, env.du),
            'rollin_xps': rollin_xps.reshape(1, -1, env.dx),
            'rollin_rs': rollin_rs.reshape(1, -1, 1),
        }
        controller.set_batch(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(
            controller,
            include_partial_hist=include_partial_hist,
            grow_context=grow_context)
        # print("\n")
        mean = np.sum(rs_lnr)
        cum_means.append(mean)

        if save_video:
            # visualize the context
            states = batch['rollin_xs'][0].cpu().numpy().astype(np.float64)
            actions = batch['rollin_us'][0].cpu().numpy().astype(np.float64)
            states0 = states[:, 0]
            states1 = states[:, 1]

            actions = np.argmax(actions, axis=-1)
            directions = {
                0: (-0.1, 0),
                1: (0.1, 0),
                2: (0, 0.1),
                3: (0, -0.1),
                4: (0, 0),
            }

            for j in range(len(states0)):
                color = viridis(j/len(states0))
                action = actions[j]
                plt.quiver(states1[j], states0[j], directions[action][1], directions[action][0], color=color, alpha=0.5, scale=3)

            # visualize the rollout
            states = xs_lnr.astype(np.float64)
            states0 = states[:, 0]
            states1 = states[:, 1]
            for j in range(len(states0)):
                plt.scatter(states1[j], states0[j], c='g', marker='x', s=200)

            plt.scatter(env.goal[1], env.goal[0], marker='o', facecolors='none', edgecolors='b', s=200)
            # plt.scatter(states1[0], states0[0], c='b', marker='o', s=200)
            plt.ylim(-1, 10)
            plt.xlim(-1, 10)
            plt.gca().invert_yaxis()
            plt.savefig(f'videos/{filename}/test_{i_eval}_online_traj{h_ep}.png')
            plt.clf()

        # convert to torch
        xs_lnr = torch.tensor(xs_lnr).float().to(device)
        us_lnr = torch.tensor(us_lnr).float().to(device)
        xps_lnr = torch.tensor(xps_lnr).float().to(device)
        rs_lnr = torch.tensor(rs_lnr[:, None]).float().to(device)

        # Roll in new data by shifting the batch and appending the new data.
        rollin_xs = torch.cat((rollin_xs[:,1:,:,:], xs_lnr[None, None, :, :]), dim=1)
        rollin_us = torch.cat((rollin_us[:,1:,:,:], us_lnr[None, None, :, :]), dim=1)
        rollin_xps = torch.cat((rollin_xps[:,1:,:,:], xps_lnr[None, None, :, :]), dim=1)
        rollin_rs = torch.cat((rollin_rs[:,1:,:,:], rs_lnr[None, None, :, :]), dim=1)

    return np.array(cum_means)



def online(eval_trajs, model, **kwargs):
    Heps = kwargs['Heps']
    H = kwargs['H']
    n_eval = kwargs['n_eval']
    dim = kwargs['dim']
    horizon = kwargs['horizon']
    stitch = kwargs['stitch']
    permuted = kwargs['permuted']
    random_init = kwargs['random_init']
    assert H % horizon == 0


    all_means_lnr = []
    all_means_rnd = []


    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")
        traj = eval_trajs[i_eval]
        kwargs['traj'] = traj
        kwargs['i_eval'] = i_eval
        goal = traj['goal']
        if stitch:
            env = DarkroomEnvStitch(dim, goal, horizon, eval=True)
        elif permuted:
            env = DarkroomEnvPermuted(dim, traj['perm_index'], horizon)
        else:
            env = DarkroomEnv(dim, goal, horizon, random_init=random_init)

        lnr_controller = DarkroomTransformerController(model, sample=True)
        rnd_controller = RandCommit(env)


        cum_means_lnr = deploy_online(env, lnr_controller, save_video=True, **kwargs)
        cum_means_rnd = deploy_online(env, rnd_controller, save_video=False, **kwargs)


        all_means_lnr.append(cum_means_lnr)
        all_means_rnd.append(cum_means_rnd)
    
    all_means_lnr = np.array(all_means_lnr)
    all_means_rnd = np.array(all_means_rnd)

    means_lnr = np.mean(all_means_lnr, axis=0)
    means_rnd = np.mean(all_means_rnd, axis=0)

    sems_lnr = scipy.stats.sem(all_means_lnr, axis=0)
    sems_rnd = scipy.stats.sem(all_means_rnd, axis=0)


    # plot individual curves
    for i in range(n_eval):
        plt.plot(all_means_lnr[i], color='blue', alpha=0.2)
        plt.plot(all_means_rnd[i], color='red', alpha=0.2)


    # plot the results with fill between
    plt.plot(means_lnr, label='LNR')
    plt.fill_between(np.arange(Heps), means_lnr - sems_lnr, means_lnr + sems_lnr, alpha=0.2)
    plt.plot(means_rnd, label='RND')
    plt.fill_between(np.arange(Heps), means_rnd - sems_rnd, means_rnd + sems_rnd, alpha=0.2)
    plt.legend()
    plt.xlabel('t')
    plt.ylabel('Average Reward')
    plt.title(f'Online Evaluation on {n_eval} envs')


def deploy_online_vec(vec_env, controller, save_video=False, **kwargs):
    Heps = kwargs['Heps']
    H = kwargs['H']
    horizon = kwargs['horizon']
    include_partial_hist = kwargs['include_partial_hist']
    grow_context = kwargs['grow_context']
    filename = kwargs['filename']
    n_eval = kwargs['n_eval']
    assert H % horizon == 0

    ctx_rollouts = H // horizon

    def save_video_vec(batch, xs_lnr, h_ep):
        if not save_video:
            return

        directions = {
            0: (-0.1, 0),
            1: (0.1, 0),
            2: (0, 0.1),
            3: (0, -0.1),
            4: (0, 0),
        }

        for i_eval in range(n_eval):
            # visualize the context
            states = batch['rollin_xs'][i_eval].cpu().numpy().astype(np.float64)
            actions = batch['rollin_us'][i_eval].cpu().numpy().astype(np.float64)
            states0 = states[:, 0]
            states1 = states[:, 1]

            actions = np.argmax(actions, axis=-1)

            colors, us, vs = [], [], []
            for j in range(len(states0)):
                colors.append(viridis(j/len(states0)))
                us.append(directions[actions[j]][1])
                vs.append(directions[actions[j]][0])

            plt.quiver(states1, states0, us, vs, color=colors, alpha=0.5, scale=3)

            # visualize the rollout
            states = xs_lnr[i_eval].astype(np.float64)
            states0 = states[:, 0]
            states1 = states[:, 1]
            plt.scatter(states1, states0, c='g', marker='x', s=200)

            plt.scatter(vec_env.envs[i_eval].goal[1], vec_env.envs[i_eval].goal[0], marker='o', facecolors='none', edgecolors='b', s=200)
            plt.ylim(-1, 10)
            plt.xlim(-1, 10)
            plt.gca().invert_yaxis()
            plt.savefig(f'videos/{filename}/test_{i_eval}_online_traj{h_ep}.png')
            plt.clf()

    num_envs = vec_env.num_envs
    dx = vec_env.envs[0].dx
    du = vec_env.envs[0].du
    rollin_xs = torch.zeros((num_envs, ctx_rollouts, horizon, dx)).float().to(device)
    rollin_us = torch.zeros((num_envs, ctx_rollouts, horizon, du)).float().to(device)
    rollin_xps = torch.zeros((num_envs, ctx_rollouts, horizon, dx)).float().to(device)
    rollin_rs = torch.zeros((num_envs, ctx_rollouts, horizon, 1)).float().to(device)

    cum_means = []
    for i in range(ctx_rollouts):
        batch = {
            'rollin_xs': rollin_xs[:,:i,:,:].reshape(num_envs, -1, dx),
            'rollin_us': rollin_us[:,:i,:].reshape(num_envs, -1, du),
            'rollin_xps': rollin_xps[:,:i,:,:].reshape(num_envs, -1, dx),
            'rollin_rs': rollin_rs[:,:i,:,:].reshape(num_envs, -1, 1),
        }
        controller.set_batch(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = vec_env.deploy_eval(controller)
        rollin_xs[:,i,:,:] = torch.tensor(xs_lnr)
        rollin_us[:,i,:,:] = torch.tensor(us_lnr)
        rollin_xps[:,i,:,:] =  torch.tensor(xps_lnr)
        rollin_rs[:,i,:,:] = torch.tensor(rs_lnr[:, :, None])

        cum_means.append(np.sum(rs_lnr, axis=-1))

        save_video_vec(batch, xs_lnr, 0)

    for h_ep in range(ctx_rollouts, Heps):
        # Reshape the batch as a singular length H = ctx_rollouts * horizon sequence.
        batch = {
            'rollin_xs': rollin_xs.reshape(num_envs, -1, dx),
            'rollin_us': rollin_us.reshape(num_envs, -1, du),
            'rollin_xps': rollin_xps.reshape(num_envs, -1, dx),
            'rollin_rs': rollin_rs.reshape(num_envs, -1, 1),
        }
        controller.set_batch(batch)
        xs_lnr, us_lnr, xps_lnr, rs_lnr = vec_env.deploy_eval(
            controller,
            include_partial_hist=include_partial_hist,
            grow_context=grow_context)

        mean = np.sum(rs_lnr, axis=-1)
        cum_means.append(mean)

        save_video_vec(batch, xs_lnr, h_ep)

        # convert to torch
        xs_lnr = torch.tensor(xs_lnr).float().to(device)
        us_lnr = torch.tensor(us_lnr).float().to(device)
        xps_lnr = torch.tensor(xps_lnr).float().to(device)
        rs_lnr = torch.tensor(rs_lnr[:, :, None]).float().to(device)

        # Roll in new data by shifting the batch and appending the new data.
        rollin_xs = torch.cat((rollin_xs[:,1:,:,:], xs_lnr[:, None, :, :]), dim=1)
        rollin_us = torch.cat((rollin_us[:,1:,:,:], us_lnr[:, None, :, :]), dim=1)
        rollin_xps = torch.cat((rollin_xps[:,1:,:,:], xps_lnr[:, None, :, :]), dim=1)
        rollin_rs = torch.cat((rollin_rs[:,1:,:,:], rs_lnr[:, None, :, :]), dim=1)

    return np.stack(cum_means, axis=1)


def online_vec(eval_trajs, model, **kwargs):
    Heps = kwargs['Heps']
    H = kwargs['H']
    n_eval = kwargs['n_eval']
    dim = kwargs['dim']
    horizon = kwargs['horizon']
    stitch = kwargs['stitch']
    permuted = kwargs['permuted']
    random_init = kwargs['random_init']
    assert H % horizon == 0

    all_means_lnr = []
    all_means_rnd = []

    envs = []
    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")
        traj = eval_trajs[i_eval]
        kwargs['traj'] = traj
        kwargs['i_eval'] = i_eval
        goal = traj['goal']
        if stitch:
            env = DarkroomEnvStitch(dim, goal, horizon, eval=True)
        elif permuted:
            env = DarkroomEnvPermuted(dim, traj['perm_index'], horizon)
        else:
            env = DarkroomEnv(dim, goal, horizon, random_init=random_init)

        rnd_controller = RandCommit(env)
        cum_means_rnd = deploy_online(env, rnd_controller, save_video=False, **kwargs)
        all_means_rnd.append(cum_means_rnd)

        envs.append(env)

    lnr_controller = DarkroomTransformerController(model, batch_size=n_eval, sample=True)
    vec_env = DarkroomEnvVec(envs)
    kwargs['n_eval'] = n_eval
    cum_means_lnr = deploy_online_vec(vec_env, lnr_controller, save_video=True, **kwargs)

    all_means_lnr = np.array(cum_means_lnr)
    all_means_rnd = np.array(all_means_rnd)

    means_lnr = np.mean(all_means_lnr, axis=0)
    means_rnd = np.mean(all_means_rnd, axis=0)

    sems_lnr = scipy.stats.sem(all_means_lnr, axis=0)
    sems_rnd = scipy.stats.sem(all_means_rnd, axis=0)

    # plot individual curves
    for i in range(n_eval):
        plt.plot(all_means_lnr[i], color='blue', alpha=0.2)
        plt.plot(all_means_rnd[i], color='red', alpha=0.2)

    # plot the results with fill between
    plt.plot(means_lnr, label='LNR')
    plt.fill_between(np.arange(Heps), means_lnr - sems_lnr, means_lnr + sems_lnr, alpha=0.2)
    plt.plot(means_rnd, label='RND')
    plt.fill_between(np.arange(Heps), means_rnd - sems_rnd, means_rnd + sems_rnd, alpha=0.2)
    plt.legend()
    plt.xlabel('t')
    plt.ylabel('Average Reward')
    plt.title(f'Online Evaluation on {n_eval} envs')
