"""
Visualizes the dynamics model.
- Rolls out ground truth trajectory using the simulator
- Compares that with trajectory that uses the learned model
- Visualizes the result in meshcat
"""

import os
import sys
import numpy as np
import pickle
import yaml
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
import cv2

import torch
import torch.nn.functional as F

from pwa import PWA
from model import MLP

# load config
config = yaml.safe_load(open("configs/PWA.yaml", 'r'))

cwd = os.getcwd()
train_dir = os.path.join(cwd, 'experiments', config['experiment_name'])
model_dir = os.path.join(train_dir, 'models')
model_name = config['evaluate']['model_name']
model_file = os.path.join(model_dir,model_name)

model = MLP(config)
model = torch.load(model_file)
model = model.eval()
model = model.cuda()
model.gumbel = False

eval_folder = os.path.join(model_dir, '%s_eval' % model_name)
os.system('mkdir -p ' + eval_folder)

### ground truth dynamics
pwa = PWA()

n_unstable = 0
dx = 0.04
dy = 0.04
xrange = np.arange(-1., 1., dx)
yrange = np.arange(-1., 1., dy)

N_ROLL = 1
rec_loss = np.zeros((len(xrange), len(yrange), N_ROLL))
rec_data = np.zeros((len(xrange), len(yrange), 2 * 2))
relu_activation = []


loss_total = 0.


for idx_x, x in enumerate(xrange):
    for idx_y, y in enumerate(yrange):

        print("Evaluate (%d, %d) / (%d, %d)" % (
            idx_x, idx_y, len(xrange), len(yrange)))


        # rollout length
        n_his = config['train']['n_history']
        action_seq = np.zeros((N_ROLL, 2))

        # [B=1, n_his=1, state_dim=2]
        state_init = torch.FloatTensor(np.array([[x, y]])).unsqueeze(0).cuda()
        # [B=1, n_his=1, action_dim=2]
        action_seq_torch = torch.FloatTensor(action_seq).unsqueeze(0).cuda()

        rollout_pred = model.rollout_model(
            state_init=state_init,
            action_seq=action_seq_torch)

        # [B=1, n_roll=1, state_dim=2]
        state_pred_rollout = rollout_pred['state_pred']

        # [B=1, n_roll=1, state_dim=2]
        state_gt_rollout = pwa.simulate(np.array([x, y]), u=np.ones((1, 1)))[1:]
        state_gt_rollout = torch.FloatTensor(state_gt_rollout).unsqueeze(0).cuda()

        relu_activation.append(rollout_pred['activation'])

        MSELoss = torch.nn.MSELoss()
        loss = 0.
        for i in range(N_ROLL):
            pred = state_pred_rollout[0, i]
            gt = state_gt_rollout[0, i]
            rec_loss[idx_x, idx_y, i] = np.sqrt(MSELoss(gt, pred).item())
            loss += MSELoss(gt, pred)
            loss_total += MSELoss(gt, pred).item()

        rec_data[idx_x, idx_y, :2] = state_pred_rollout[0, 0].data.cpu().numpy()
        rec_data[idx_x, idx_y, 2:] = state_gt_rollout[0, 0].data.cpu().numpy()

print('MSE', loss_total / (len(xrange) * len(yrange)))

np.save(os.path.join(eval_folder, 'rec_loss.npy'), rec_loss)
np.save(os.path.join(eval_folder, 'rec_data.npy'), rec_data)
with open(os.path.join(eval_folder, 'relu_activation.p'), 'wb') as fp:
    pickle.dump(relu_activation, fp)

### calculate the ub and lb for each neuron

d = relu_activation
n_relu = len(d[0][0])
n_timestep = len(d[0])
n_episode = len(d)

print('n_episode', n_episode)
print('n_timestep', n_timestep)
print('n_relu', n_relu)

lb, ub = [], []

for idx_relu in range(n_relu):
    rec = []
    for idx_timestep in range(n_timestep):
        for idx_episode in range(n_episode):
            rec.append(d[idx_episode][idx_timestep][idx_relu])
    rec = np.array(rec)
    lb.append(np.min(rec, (0, 1)))
    ub.append(np.max(rec, (0, 1)))

bound = [lb, ub]
with open(os.path.join(eval_folder, 'bound.p'), 'wb') as fp:
    pickle.dump(bound, fp)


X, Y = np.meshgrid(xrange, yrange)
colors = ['c', 'r']

elevation = np.arange(0, 360, 18)

img_folder = 'viz'
os.system('mkdir -p ' + os.path.join(eval_folder, img_folder))

my_red = np.zeros((50, 50, 4))
for i in range(50):
    for j in range(50):
        if i > j and i + j < 50:
            my_red[i, j] = np.array([1., 0., 0., .9])
        if i > j and i + j >= 50:
            my_red[i, j] = np.array([1., 0., 0., .7])
        if i <= j and i + j < 50:
            my_red[i, j] = np.array([1., 0., 0., .5])
        if i <= j and i + j >= 50:
            my_red[i, j] = np.array([1., 0., 0., .3])


my_blue = np.zeros((50, 50, 4))
for i in range(50):
    for j in range(50):
        if i > j and i + j < 50:
            my_blue[i, j] = np.array([0., 0., 1., .9])
        if i > j and i + j >= 50:
            my_blue[i, j] = np.array([0., 0., 1., .7])
        if i <= j and i + j < 50:
            my_blue[i, j] = np.array([0., 0., 1., .5])
        if i <= j and i + j >= 50:
            my_blue[i, j] = np.array([0., 0., 1., .3])

my_cols = [my_red, my_blue]


for k in range(2):

    # visualize the ground truth
    D = rec_data[:, :, 2 + k]

    fig = plt.figure()

    '''
    plot = plt.pcolormesh(X, Y, D, cmap='RdBu', shading='flat')
    plt.show()
    '''

    # ax = fig.gca(projection='3d')
    ax = fig.add_subplot(projection='3d')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])

    # ax.plot_surface(X, Y, D, color=colors[k])
    ax.plot_surface(X, Y, D, facecolors=my_cols[k], linewidth=0, shade=True)

    # plt.xlabel('x1')
    # plt.ylabel('x2')
    # plt.title('x ' + str(k + 1) + '_new')

    for idx_e, e in enumerate(elevation):
        ax.view_init(40, e)
        # plt.show()
        plt.savefig(os.path.join(
            eval_folder, img_folder, 'gt_%d_%d.png' % (k, idx_e)))

    # visualize the prediction

    D = rec_data[:, :, k]

    fig = plt.figure()

    # ax = fig.gca(projection='3d')
    ax = fig.add_subplot(projection='3d')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])

    ax.plot_surface(X, Y, D, facecolors=my_cols[k], linewidth=0, shade=True)

    # plt.xlabel('x1')
    # plt.ylabel('x2')
    # plt.title('x ' + str(k + 1) + '_new')

    for idx_e, e in enumerate(elevation):
        ax.view_init(40, e)
        # plt.show()
        plt.savefig(os.path.join(
            eval_folder, img_folder, 'pred_%d_%d.png' % (k, idx_e)))