# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np 
from tqdm import tqdm
import torch.nn as nn
import math
import torch.optim as optim
from torch.distributions import Normal
from sklearn.metrics import r2_score
from lag_caVAE.lag import Lag_Net
from lag_caVAE.leap import Leap_Net
from lag_caVAE.leap import TF_Block_EXP_Residual_TV_2
from lag_caVAE.leap import TransformerEncoderLayerCategoricalsCatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CatPos
from lag_caVAE.leap import TransformerEncoderLayer_v2_CategoricalsCatPos
from lag_caVAE.leap import EstimatorNetwork
from lag_caVAE.leap import TransformerEncoderLayerCategoricals
from lag_caVAE.leap import TransformerEncoderLayer_v2_Categoricals
from lag_caVAE.leap import EstimatorNetwork_v2
from lag_caVAE.leap import EstimatorNetwork_NoScale
from lag_caVAE.leap import MLP_Mag
from lag_caVAE.leap import P_Neural_TIME_MultipleParameter
from lag_caVAE.leap import TransformerEncoderLayer
from lag_caVAE.leap import TransformerEncoderLayer_v2
from lag_caVAE.leap import  PositionalEncodingCat
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD, Encoder
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset
from torch.utils.tensorboard import SummaryWriter
from scipy.integrate import odeint
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
import warnings
from matplotlib.mathtext import MathTextWarning

# Define the plotting parameters
params = {'legend.fontsize': 32,
         'axes.labelsize': 32,
         'axes.titlesize': 32,
         'xtick.labelsize':32,
         'ytick.labelsize':32}
pylab.rcParams.update(params)
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['mathtext.rm'] = 'serif'
plt.rcParams["font.family"] = "cmr10"
linewidth = 4
line_alpha = .8
dpi_value = 300
# Define the number of batch data
num_data = 200
# Define the number of training
num_iter = 1000
# Define the learning rate
learning_rate = 1e-2

# Creat the dir
try:
    os.mkdir(save_dir)
except:
    print('File Exist')
try:
    os.mkdir(save_dir+'Plot/')
except:
    print('File Exist')

# Initial writer
writer = SummaryWriter(save_dir)

# Function implementing the fix
def fix(ax=None):
    if ax is None:
        ax = plt.gca()
    fig = ax.get_figure()
    # Force the figure to be drawn
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', category=MathTextWarning)
        fig.canvas.draw()
    # Remove '\mathdefault' from all minor tick labels
    labels = [label.get_text().replace('\mathdefault', '')
              for label in ax.get_xminorticklabels()]
    ax.set_xticklabels(labels, minor=True)
    labels = [label.get_text().replace('\mathdefault', '')
              for label in ax.get_yminorticklabels()]
    ax.set_yticklabels(labels, minor=True)

class vehicle_steering():
    def __init__(self):
        super(vehicle_steering, self).__init__()
 
    def dynamics(self, x,t,u_list,T):
        x1, x2 = x 
        # x1 is lateral path deviation
        # x2 is the turning rate

        # Index the input signal
        index = abs(T-t).argmin()

        # Parameter 
        gamma = 0.1

        # Select the current input signal
        u = u_list[index] #steering angle

        dxdt = [\
                x2  + gamma * u, \
                u  \
                ]

        return dxdt

    # Define observations
    def state_2_observation(self,x):

        # From turning rate to the acceleration
        temp_1 =  np.concatenate((np.expand_dims(x[0,1], axis=0),x[0:-1,1]),axis=0) 
        temp_2 =  np.concatenate((x[1:,1],np.expand_dims(x[-1,1], axis=0),),axis=0) 
        o = (temp_2 - temp_1) / 2.0
        o[0] = o[0] * 2
        o[-1] = o[-1] * 2
        
        # Return the state
        return x[:,1]
        # Return the acceleration
        #return o

class vehicle_steering_observer():
    def __init__(self):
        super(vehicle_steering_observer, self).__init__()

    # Define observer
    def dynamics(self, x,t,u_list,o_list,T):

        x1, x2 = x 
        # x1 is lateral path deviation
        # x2 is the turning rate

        # Index the input signal
        index = abs(T-t).argmin()

        # Parameter 
        gamma = 0.1

        # Select the current input signal
        u = u_list[index] #steering angle
        # Select the observation
        o = o_list[index]

        # Define L
        c0 = 1
        w0 = 1
        l_1 = 2*c0*w0
        l_2 = w0 ** 2

        dxdt = [\
                x2 + gamma * u + l_1 * (o-x1), \
                u  + l_2 * (o-x1)\
                ]

        return dxdt

# Define frequency
Hz = 10
# Define the simulation length
sim_length = 100
# Generate the dataset
mu, sigma = 0.0, 0.5
# Define the time invervel
T = np.arange(0,sim_length/Hz,1/Hz)
# Get the input excitation
u_list = np.random.normal(mu, sigma, int(sim_length*Hz))
# Set the intial condition
x0 = [0,0]
# Get the dynamics
sys = vehicle_steering()
dynamics_sys = sys.dynamics
# Get the data
xt = odeint(dynamics_sys, x0, T, args=(u_list, T))
# Get the observation
ot = sys.state_2_observation(xt)
# Set the intial prediction of the predicted state
x_hat0 = np.random.rand(2)
# Get the dynamics
sys_observer = vehicle_steering_observer()
dynamics_sys_observer = sys_observer.dynamics
# Get the prediction of the state of the partial observation
x_hatt = odeint(dynamics_sys_observer, x_hat0, T, args=(u_list, ot, T))

# Get the array
xt = np.array(xt)
x_hatt = np.array(x_hatt)
# Plot
fig1 = plt.figure(constrained_layout=False, figsize=(16,12))
gs = fig1.add_gridspec(7, 10)
# Plot the time series data
ax = fig1.add_subplot(gs[0:3, 0:10])
ax.plot(xt[:,0],      linewidth=linewidth, alpha=line_alpha,linestyle='-',label=r'$x_1$',color='red')
ax.plot(x_hatt[:,0],  linewidth=linewidth, alpha=line_alpha, linestyle='--',label=r'$\hat{x}_1$',color='green')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Time steps')
plt.ylabel(fr'The lateral path deviation')
fix()
plt.legend(loc='lower right')
########
gs = fig1.add_gridspec(6, 10)
# Plot the time series data
ax = fig1.add_subplot(gs[4:6, 0:10])
ax.plot(xt[:,1],      linewidth=linewidth, alpha=line_alpha,linestyle='-',label=r'$x_2$',color='red')
ax.plot(x_hatt[:,1],  linewidth=linewidth, alpha=line_alpha, linestyle='--',label=r'$\hat{x}_2$',color='green')
ax.grid(True, which='both', alpha=.3)
plt.xlabel(fr'Time steps')
plt.ylabel(fr'The turning rate')
fix()
plt.legend(loc='lower right')
plt.savefig(save_dir+'Plot/' + 'vehicle_steering' + '.png',dpi=dpi_value,bbox_inches='tight')
plt.close()

class Model(nn.Module):
    """docstring for ClassName"""
    def __init__(self, input_size=1, encoder_size=20, attention_input_size=40, nhead=1, output_size=2, dropout=0.0, max_len=sim_length, pos_en_scale=1):
        super(Model, self).__init__()
        # Define the torch procedure
        self.pos_encoder = PositionalEncodingCat(int(attention_input_size/2), dropout=dropout, max_len=max_len, scale=pos_en_scale)
        # Define the multihead attention
        self.self_attn   = nn.MultiheadAttention(attention_input_size, nhead, dropout=dropout)
        self.self_attn_1 = nn.MultiheadAttention(attention_input_size, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear_input_1 = nn.Linear(input_size,                  encoder_size)
        self.linear_input_2 = nn.Linear(encoder_size, int(attention_input_size/2))
        # Processing
        self.linear1 = nn.Linear(attention_input_size, attention_input_size)
        self.linear2 = nn.Linear(attention_input_size, output_size)
        #self.linear3 = nn.Linear(d_model*80, d_final)

        # Define the dropout
        self.dropout = nn.Dropout(dropout)
        # Define the regulirazation approach
        self.norm1 = nn.LayerNorm(attention_input_size)
        self.norm2 = nn.LayerNorm(attention_input_size)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None): 
        # Increase the size
        #src = torch.relu(self.linear_input_1(src))
        src = self.linear_input_1(src)
        src = self.linear_input_2(src)
        # Add positional encoding
        src = self.pos_encoder(src)
        # Attention network
        src2,   src2_attn_weight   = self.self_attn(src, src, src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)
        #src_final = self.linear2(self.dropout(torch.relu(self.linear1(src2))))
        src_final = self.linear2(self.linear1(src2))

        '''
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(torch.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        src_final = self.linear3(src)
        '''
        return src_final, src2_attn_weight

def get_parameter(model):
    for name, param in model.self_attn.named_parameters():
        print(name,param)
    sys.exit('Done')

# Define the function for testing the model
def testing(model, X, O, training_index):

    X_pred_list = []
    X_true_list = []
    attention_weight_list = []

    # Run full batch of the data
    for index in range(num_data):
        # Set the training set
        X_batch = torch.tensor(X).float()
        X_batch = X_batch[index]
        y_batch = torch.tensor(O).float()
        y_batch = y_batch[index]
        y_batch = y_batch.unsqueeze(1)
        y_batch = y_batch.unsqueeze(2)
        X_pred, attention_weight = model(y_batch)

        X_pred_list.append(X_pred.squeeze())
        X_true_list.append(X_batch)
        attention_weight_list.append(attention_weight)

    X_pred_list = torch.stack(X_pred_list, axis=0)
    X_pred_list.retain_grad()
    X_true_list = torch.stack(X_true_list, axis=0)

    # Compute the loss
    # Only for the second states
    loss = loss_fn(X_pred_list[:,:,0],X_true_list[:,:,0])
    loss = loss.sum([1]).mean()

    for ii in range(10):
        # Plot
        fig1 = plt.figure(constrained_layout=False, figsize=(16,12))
        gs = fig1.add_gridspec(7, 10)
        # Plot the time series data
        ax = fig1.add_subplot(gs[0:3, 0:10])
        ax.plot(X_true_list[ii,:,0].detach().cpu().numpy(),  linewidth=linewidth, alpha=line_alpha,linestyle='-',label=r'$x(1)$',color='red')
        ax.plot(X_pred_list[ii,:,0].detach().cpu().numpy(),  linewidth=linewidth, alpha=line_alpha, linestyle='--',label=r'$\hat{x}(1)$',color='green')
        ax.grid(True, which='both', alpha=.3)
        plt.xlabel(fr'Time steps')
        plt.ylabel(fr'Lateral path deviation $d$')
        fix()
        plt.legend(loc='lower right')
        ########
        gs = fig1.add_gridspec(6, 10)
        # Plot the time series data
        ax = fig1.add_subplot(gs[4:6, 0:10])
        ax.plot(X_true_list[ii,:,1].detach().cpu().numpy(), linewidth=linewidth, alpha=line_alpha,linestyle='-',label=r'$x(2)$',color='red')
        ax.plot(X_pred_list[ii,:,1].detach().cpu().numpy(), linewidth=linewidth, alpha=line_alpha, linestyle='--',label=r'$\hat{x}(2)$',color='green')
        ax.grid(True, which='both', alpha=.3)
        ax.set_title(r'Final MSE: ' +str(loss.item()))
        plt.xlabel(fr'Time steps')
        plt.ylabel(fr'Turning rate $v$')
        fix()
        plt.legend(loc='lower right')
        plt.savefig(save_dir+'Plot/' + 'State' + '_' + str(ii) + '_test_idx' + str(training_index) +'.png',dpi=dpi_value,bbox_inches='tight')
        plt.close()


    # Get the attention
    ii = 0 
    for attention_weight in attention_weight_list:
        Attn = attention_weight.detach().cpu().numpy()

        fig1 = plt.figure(constrained_layout=False, figsize=(5,5))
        gs = fig1.add_gridspec(1, 1, width_ratios=[1.0])
        ax = fig1.add_subplot(gs[0, 0])
        im = ax.imshow(Attn[0], cmap='inferno', extent=[-.5,.5,.5,-.5],vmin=0.0, vmax=0.2)
        fig1.colorbar(im, orientation='vertical')
        ax.xaxis.tick_top()
        extent = [-.5,.5]
        ax.set_xticks([-.5,.5])
        ax.set_yticks([-.5,.5])
        ax.set_xticklabels([fr'${t:g}$' for t in extent])
        ax.set_yticklabels([fr'${t:g}$' for t in extent])
        xtick = ax.get_xticks()
        ax.set_xticks(xtick)
        ax.set_xticklabels([fr'${t:g}$' for t in xtick])
        ax.set_title('Attention weight')
        plt.tight_layout()
        plt.savefig(save_dir+'Plot/' + 'Attn' + '_' + str(ii) + '_test_idx' + str(training_index) +'.png',dpi=dpi_value)
        plt.close()
        ii += 1
        if ii == 10:
            break

# Generate the dataset
X = []
O = []
for ii_batch in range(num_data):
    # Define frequency
    Hz = 10
    # Define the simulation length
    sim_length = 100
    # Generate the dataset
    mu, sigma = 0.0, 0.5
    # Define the time invervel
    T = np.arange(0,sim_length/Hz,1/Hz)
    # Get the input excitation
    u_list = np.random.normal(mu, sigma, int(sim_length*Hz))
    # Set the intial condition
    #x0 = np.random.rand(2)
    x0 = [0,0]
    # Get the dynamics
    sys = vehicle_steering()
    dynamics_sys = sys.dynamics
    # Get the data
    xt = odeint(dynamics_sys, x0, T, args=(u_list, T))
    # Get the observation
    ot = sys.state_2_observation(xt)
    X.append(xt)
    O.append(ot)

X = np.array(X)
O = np.array(O)

# Test the model
# Generate the testing dataset
X_test = []
O_test = []
for ii_batch in range(num_data):
    # Define frequency
    Hz = 10
    # Define the simulation length
    sim_length = 100
    # Generate the dataset
    mu, sigma = 0.0, 0.5
    # Define the time invervel
    T = np.arange(0,sim_length/Hz,1/Hz)
    # Get the input excitation
    u_list = np.random.normal(mu, sigma, int(sim_length*Hz))
    # Set the intial condition
    #x0 = np.random.rand(2)
    x0 = [0,0]
    # Get the dynamics
    sys = vehicle_steering()
    dynamics_sys = sys.dynamics
    # Get the data
    xt = odeint(dynamics_sys, x0, T, args=(u_list, T))
    # Get the observation
    ot = sys.state_2_observation(xt)
    X_test.append(xt)
    O_test.append(ot)

X_test = np.array(X_test)
O_test = np.array(O_test)

# Initialize the model
Attention_model = Model(max_len=sim_length,pos_en_scale=1.0)
# Define the loss
loss_fn = torch.nn.MSELoss(reduction='none')
# Define the optimizer
optimizer = optim.Adam(Attention_model.parameters(), lr=learning_rate, weight_decay = 0.0)
# Training iteration
for i_iter in tqdm(range(num_iter)):

    X_pred_list = []
    X_true_list = []
    attention_weight_list = []

    # Run full batch of the data
    for index in range(num_data):
        # Set the training set
        X_batch = torch.tensor(X).float()
        X_batch = X_batch[index]
        y_batch = torch.tensor(O).float()
        y_batch = y_batch[index]
        y_batch = y_batch.unsqueeze(1)
        y_batch = y_batch.unsqueeze(2)
        X_pred, attention_weight   = Attention_model(y_batch)

        X_pred_list.append(X_pred.squeeze())
        X_true_list.append(X_batch)
        attention_weight_list.append(attention_weight)

    X_pred_list = torch.stack(X_pred_list, axis=0)
    X_pred_list.retain_grad()
    X_true_list = torch.stack(X_true_list, axis=0)

    # Compute the loss
    # Only for the second states
    loss = loss_fn(X_pred_list[:,:,0],X_true_list[:,:,0])
    loss = loss.sum([1]).mean()

    # Reset the gradient
    optimizer.zero_grad()
    # Compute the gradient
    loss.backward()
    # Update the parameter
    optimizer.step()
    print('i:',i_iter,':',loss.item())

    X_pred_list = X_pred_list.detach().cpu().numpy()
    X_true_list = X_true_list.detach().cpu().numpy()

    # Log the running loss
    writer.add_scalar('Loss',loss.item(),i_iter)

    if i_iter % 100 == 0:
        testing(Attention_model,X_test,O_test, i_iter)
        #get_parameter(Attention_model)

testing(Attention_model,X_test,O_test, '_final')
torch.save(Attention_model.state_dict(),save_dir+'Model.pth')
writer.close()
