# Standard library imports
from argparse import ArgumentParser
import os, sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

# Third party imports
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np 
from tqdm import tqdm
from torch.distributions import Normal
from sklearn.metrics import r2_score

# local application imports
from lag_caVAE.lag import Lag_Net
from lag_caVAE.leap import Leap_Net_TB
from lag_caVAE.leap import TF_Block_EXP_Residual_TV_2
from lag_caVAE.leap import TransformerEncoderLayerCategoricalsCatPos
from lag_caVAE.leap import EstimatorNetworkTB
from lag_caVAE.leap import TransformerEncoderLayerCategoricals
from lag_caVAE.leap import TransformerEncoderLayer_v2_Categoricals
from lag_caVAE.leap import TransformerEncoderLayer_v2_CatPos_TB
from lag_caVAE.leap import EstimatorNetwork_v2
from lag_caVAE.leap import EstimatorNetwork_NoScale
from lag_caVAE.leap import EstimatorNetworkTB_v2
from lag_caVAE.leap import MLP_Mag
from lag_caVAE.leap import P_Neural_TIME_MultipleParameterTB
from lag_caVAE.leap import TransformerEncoderLayer
from lag_caVAE.leap import TransformerEncoderLayer_v2_2C
from lag_caVAE.nn_models import MLP_Encoder, MLP, MLP_Decoder, PSD, Encoder, MLP_Encoder_Sigmoid
from hyperspherical_vae.distributions import VonMisesFisher
from hyperspherical_vae.distributions import HypersphericalUniform
from utils import arrange_data, from_pickle, my_collate, ImageDataset, HomoImageDataset

# Set the prediction length
T_pred = 50
# Set the first n index to compute the loss values
Loss_first_index = 50
# Set the first n FFT index (from lowest to highest)
cutoff_freq_input = 10
# Set the dataset that we want to use
dataset_type = 3
# Learning rate
lr = 1e-3
# Define the size of the batch samples
num_batch = 32 # (gap = 20 -> 50 batches)
# Gradient clip
gradient_clip = 1
# Define the non-linearity for the q net and the recons net
MLP_Encoder_nonLinear = 'tanh' # tanh; elu; softplus; relu
# Define the model we use: 0->FFT input to ODE Net; 2->Time Series data to ODE Net
MODEL_TYPE = 0
# Set the plotting enable
Plot_enable = 0
# Set the attention
enable_attn = 1
# Set model_variant
model_variant = 4 # 2->pos is MLP; and vel is attention; 4->pos and vel combined together
# Set model_variant for attention
model_variant_attn = 4 # 1->attention v2 with extra layer norm and skip connection; 4->based on 1,cat position
# Set if using physics or not
enable_physics = 0
# Set the weight of the reconstruction loss function
weight_recons = 1
# Set the reconstrction
weight_recons_255 = 1e1
# Set the loss function that aligns the states between the ODE solver and the encoder
Time_loss_weight = 1e0
# 
Freq_loss_weight = 0.0
# Set the weight for promoting the sparsity
enable_weight_sparse = 1e0
# Set the FFT cosine theta
FFT_loss_weight = 0
# Set frame velocity
velocity_loss_enable = 0.0
# Define the gap_interval
gap_interval = 200
# Define source mask length 
att_len = 7 #value->size of mask: 1->3; 2->5; 7->15; 12->25; 50 -> 101
# Attention model parameters
d_model_attn = 500 #300 # the final size after doing CNN and flattening
pos_size = 250 # size of the positional encoding
nhead_attn = 10 #10
d_middle_attn = 100 # no use here
d_final_attn = 6 # no use here
dropout_attn = 0.0
pos_en_scale_attn = 0.5
attn_nonlinearity = 'relu'
attn_nonlinearity_1 = 'tanh'
cnn_pooling = 'max' # 'max' or 'avg'
# for estimator network
# Define nonlinear for the esitmator network
nonlinearity = 0 # 0 for tanh; 1 for softplut (donot work well); 2 for elu
# Definin decoder network
dec_size = 100
# Define the weight for sparse loss function
weight_sparse = 0
NN_parameter = {'d_model': d_model_attn,\
                'pos_size':pos_size,\
                'nhead': nhead_attn,\
                'd_middle': d_middle_attn,\
                'd_final': d_final_attn,\
                'dropout': dropout_attn,\
                'pos_en_scale': pos_en_scale_attn,\
                'nonlinearity': attn_nonlinearity,\
                'nonlinearity_1': attn_nonlinearity_1,\
                'pooling': cnn_pooling} # default pos_en_scale is 1.0; for d_final, using Gaussain: d_final=5
gpu_number = 0

# Set the samll value to prevent NaN
small_value = 1e-5
# Set the simulation Hz
Hz = 10
# Define NN input layer
NN_input_layer = [100,100]
# Define the time that the training loss will alternate
training_loss_interval = 15
mean_over_everything = 0
TEST_True_INIT_STATE = 0
# Set the name of the file
if enable_physics:
    save_dir = '_Da'    + str(dataset_type) \
             + '_STL'   + str(Time_loss_weight) \
             + '_SFL'   + str(FFT_loss_weight) \
             + '_VFL'   + str(velocity_loss_enable) \
             + '_SL'    + str(enable_weight_sparse) \
             + '_S'     + str(mean_over_everything) \
             + '_SPL'   + str(weight_sparse)\
             + '_FL'    + str(Freq_loss_weight) \
             + '_Model' + str(model_variant) \
             + '-'      + str(model_variant_attn) \
             + '-'      + str(MODEL_TYPE)\
             + '_Attn' + str(enable_attn) \
             + '_'     + str(d_model_attn) \
             + '_'     + str(pos_size) \
             + '-'     + str(nhead_attn) \
             + '-'     + str(d_middle_attn) \
             + '-'     + str(d_final_attn) \
             + '-'     + str(dropout_attn) \
             + '-'     + str(pos_en_scale_attn) \
             + '-'     + str(attn_nonlinearity) \
             + '-'     + str(attn_nonlinearity_1) \
             + '-'     + str(cnn_pooling) \
             + '-'     + str(att_len) \
             + '_Est'  + str(nonlinearity) \
             + '-'     + str(NN_input_layer[0]) \
             + '-'     + str(NN_input_layer[1]) \
             + '_Dec'  + str(dec_size) \
             + '-'     + str(MLP_Encoder_nonLinear)\
             + '_Lrate'    + str(lr) \
             + '_batch'    + str(num_batch) \
             + '_gradC'    + str(gradient_clip) \
             + '_gapInt'   + str(gap_interval) \
             + '_INITS' + str(TEST_True_INIT_STATE) \
             + '_Rw' + str(weight_recons_255) + '_h1C_len' + str(T_pred)
else:
    save_dir = '_Da'    + str(dataset_type) \
             + '_STL'   + str(Time_loss_weight) \
             + '_SFL'   + str(FFT_loss_weight) \
             + '_VFL'   + str(velocity_loss_enable) \
             + '_S'     + str(mean_over_everything) \
             + '_SPL'   + str(weight_sparse)\
             + '_Model' + str(model_variant) \
             + '-'      + str(model_variant_attn) \
             + '-'      + str(MODEL_TYPE)\
             + '_Attn' + str(enable_attn) \
             + '_'     + str(d_model_attn) \
             + '-'     + str(nhead_attn) \
             + '-'     + str(d_middle_attn) \
             + '-'     + str(d_final_attn) \
             + '-'     + str(dropout_attn) \
             + '-'     + str(pos_en_scale_attn) \
             + '-'     + str(attn_nonlinearity) \
             + '-'     + str(attn_nonlinearity_1) \
             + '-'     + str(cnn_pooling) \
             + '-'     + str(att_len) \
             + '_Est'  + str(nonlinearity) \
             + '-'     + str(NN_input_layer[0]) \
             + '-'     + str(NN_input_layer[1]) \
             + '_Dec'  + str(dec_size) \
             + '-'     + str(MLP_Encoder_nonLinear)\
             + '_Lrate'    + str(lr) \
             + '_batch'    + str(num_batch) \
             + '_gradC'    + str(gradient_clip) \
             + '_gapInt'   + str(gap_interval) \
             + '_INITS' + str(TEST_True_INIT_STATE) \
             + '_Rw' + str(weight_recons_255) \
             + '_h1C_len' + str(T_pred) \
             + '_NoPhysics' 

seed_everything(42)

class Model(pl.LightningModule):

    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()

        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')
        self.loss_fn_mean = torch.nn.MSELoss()
        self.size_image = 64*64
        self.input_dim  = 8 #(m1, qx1, qy1, px1, py1, m2, qx2, qy2, px22, py2, u)
        self.plu_output = 1.0
        self.mul_output = 0.5
        # For plotting purpose
        self.count = 0
        self.cutoff_index_input = cutoff_freq_input

        # Define encoder and decoder
        if enable_attn:
            if model_variant == 2:
                self.recog_q_net_m1 = MLP_Encoder(self.size_image, 500, 3*2, nonlinearity=MLP_Encoder_nonLinear)

                if model_variant_attn == 1:
                    #3C -> three chaneel for r g b images
                    self.recog_q_net_velocity = TransformerEncoderLayer_v2_CatPos_TB(\
                                               pos_size=NN_parameter['pos_size'],\
                                               d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3*2, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
            elif model_variant == 4:
                if model_variant_attn == 4:
                    # A single NN that process all the information
                    self.recog_q_net_state = TransformerEncoderLayer_v2_CatPos_TB(\
                                               pos_size=NN_parameter['pos_size'],\
                                               d_model=NN_parameter['d_model'],\
                                               nhead=NN_parameter['nhead'], \
                                               d_middle=NN_parameter['d_middle'], \
                                               d_final=3*2, \
                                               dropout=NN_parameter['dropout'], \
                                               max_len=self.T_pred+1,
                                               pos_en_scale=NN_parameter['pos_en_scale'], \
                                               activation=NN_parameter['nonlinearity'],\
                                               activation_1=NN_parameter['nonlinearity_1'],\
                                               pooling=NN_parameter['pooling'])
        else:
            # only predict the position, the velocity is computed by subtraction
            self.recog_q_net_m1 = MLP_Encoder(self.size_image, 300, 3*1, nonlinearity=MLP_Encoder_nonLinear)
        
        self.obs_net_m1 = MLP_Encoder(1, dec_size, self.size_image, nonlinearity=MLP_Encoder_nonLinear)
        self.obs_net_m2 = MLP_Encoder(1, dec_size, self.size_image, nonlinearity=MLP_Encoder_nonLinear)  
        
        if MODEL_TYPE == 0:
            # Define ODE Net for the input of FFT features  
            self.MLP_Spec_m1 = EstimatorNetworkTB(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=self.mul_output)#.to(self.device)
            self.MLP_Spec_m2 = EstimatorNetworkTB(input_size=[self.cutoff_index_input,self.cutoff_index_input],input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=self.mul_output)#.to(self.device)
            self.ode = Leap_Net_TB(MLP_Spec_m1=self.MLP_Spec_m1, \
                                MLP_Spec_m2=self.MLP_Spec_m2, \
                                mul_output=self.mul_output,\
                                plu_output=self.plu_output, \
                                cutoff_index=self.cutoff_index_input,\
                                device=self.device,\
                                input_dim=self.input_dim)
        elif MODEL_TYPE == 2:
            # Define ODE Net for the input of time features   
            self.MLP_Spec_m1 = MLP_Mag(input_size=8*T_pred,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=self.mul_output)
            self.MLP_Spec_m2 = MLP_Mag(input_size=8*T_pred,input_layer=NN_input_layer,nonlinearity=nonlinearity,mag=self.mul_output)
            self.ode = P_Neural_TIME_MultipleParameterTB(MLP_Spec_m1=self.MLP_Spec_m1, \
                                                       MLP_Spec_m2=self.MLP_Spec_m2,\
                                                       mul_output=self.mul_output, \
                                                       plu_output=self.plu_output, \
                                                       cutoff_index=self.cutoff_index_input, \
                                                       device=self.device, \
                                                       input_dim=self.input_dim)

        self.train_dataset = None
        self.non_ctrl_ind = 1


        # Generate the src mask 
        self.SRC_MAS_V = []
        for i in np.arange(0,T_pred-cutoff_freq_input ,gap_interval):
            self.SRC_MAS_V.append(self.src_mask(T_pred+1-i))

        self.training_loss_flag = 0

        self.h_para = None 
        self.f_para = None 
        self.g_para = None 

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)
        else:
            # This is our default setting since all the u is zero
            train_dataset = ImageDataset(self.data_path, self.hparams.T_pred)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)

            return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        q_dot0 = (q1_m_n - q0_m_n) / delta_t
        return q_dot0

    def angle_vel_est_euler(self, q0, delta_t):

        T = q0.shape[0]
        theta_dot = (q0[1:T]-q0[0:T-1]) / delta_t

        return theta_dot.unsqueeze(1)

    def get_system_parameter(self, x):

        # Set the index
        i0 = torch.LongTensor([0])
        i1 = torch.LongTensor([1])

        # Get the state
        x1 = x[i0]
        x2 = x[i1]

        # The state is [cos,sin,theta_dot,u]

        # For cos\theta
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        cos_theta = state[0::self.input_dim,:]
        # Use FTT
        cos_theta_fft = torch.fft(cos_theta,1,normalized=False).to(self.device)
        # Get mag
        cos_theta_mag = cos_theta_fft[:,0] ** 2 + cos_theta_fft[:,1] ** 2
        cos_theta_mag = cos_theta_mag[0:self.cutoff_index_input]**0.5

        # For sin\theta
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        sin_theta = state[1::self.input_dim,:]
        # Use FTT
        sin_theta_fft = torch.fft(sin_theta,1,normalized=False).to(self.device)
        # Get mag
        sin_theta_mag = sin_theta_fft[:,0] ** 2 + sin_theta_fft[:,1] ** 2
        sin_theta_mag = sin_theta_mag[0:self.cutoff_index_input]**0.5

        # For theta_dot
        # Add i part
        state = torch.unsqueeze(x[2:], 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        theta_dot = state[2::self.input_dim,:] / 7.0 # do normalization
        # Use FTT
        theta_dot_fft = torch.fft(theta_dot,1,normalized=False).to(self.device)
        # Get mag
        theta_dot_mag = theta_dot_fft[:,0] ** 2 + theta_dot_fft[:,1] ** 2
        theta_dot_mag = theta_dot_mag[0:self.cutoff_index_input]**0.5

        # Only Freq
        mu = self.MLP_Spec_mu(torch.log(cos_theta_mag),\
                              torch.log(sin_theta_mag),\
                              torch.log(theta_dot_mag)) + self.plu_output
        L = self.MLP_Spec_L(torch.log(cos_theta_mag),\
                            torch.log(sin_theta_mag),\
                            torch.log(theta_dot_mag)) + self.plu_output

        return mu, L, cos_theta_mag, theta_dot_mag

    def get_system_parameter_time(self, x):

        # Set the index
        i0 = torch.LongTensor([0])
        i1 = torch.LongTensor([1])

        # Get the state
        x1 = x[i0]
        x2 = x[i1]

        x_input = x[2:]

        # Remove u information in x_input
        cos_theta = x_input[0::self.input_dim]
        sin_theta = x_input[1::self.input_dim]
        theta_dot = x_input[2::self.input_dim]
        x_input = torch.cat((cos_theta,sin_theta,theta_dot),axis=0)

        # Get the prediction
        mu = self.MLP_Spec_mu(x_input) + self.plu_output
        L = self.MLP_Spec_L(x_input) + self.plu_output

        return mu, L, None, None

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos ; theta[:, 0, 1] += -sin ; theta[:, 0, 2] += - x * cos + y * sin
        theta[:, 1, 0] += sin ; theta[:, 1, 1] += cos ;  theta[:, 1, 2] += - x * sin - y * cos
        return theta

    def encode(self, batch_image, d, T):

        # batch_image is in the shape of (101, 3, 64, 64)
        X_attn = batch_image.reshape(T, 3, d*d)
        # Here, attn_output is of size [101, 1, 3]
        # For position
        attn_output_m1 = self.recog_q_net_m1(X_attn[:,1])
        # Here, attn_output is of size [101, 3]
        attn_output_m1 = attn_output_m1.squeeze()

        m1_mean_pos, m1_std_pos = attn_output_m1.split([2, 1], dim=1)
        
        # Location
        m1_mean_pos_normal = m1_mean_pos / (m1_mean_pos.norm(dim=-1, keepdim=True) + small_value)
        m1_std_pos = F.softplus(m1_std_pos) + 1

        return m1_mean_pos, m1_std_pos, m1_mean_pos_normal

    def src_mask(self, dim):
        #https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390/2
        mask = torch.zeros(dim,dim).float() + float('-inf')
        # Define attend half range
        for i in range(dim):
            min_ = max(0, i-att_len)
            max_ = min(dim, i+att_len+1)
            for j in range(min_,max_):
                mask[i,j] = 0.0
        return mask


    def get_FourierFeat(self,x):
        state = torch.unsqueeze(x, 1).to(self.device)
        state = torch.cat((state,torch.zeros((state.shape[0],1)).to(self.device)),axis=1)
        ss_list = []
        for ff in [0,1,2,3]:
            ss = state[ff::4,:]
            # Use FTT
            ss_fft = torch.fft(ss,1,normalized=False).to(self.device)
            # Get mag
            ss_fft_mag = ss_fft[:,0] ** 2 + ss_fft[:,1] ** 2
            ss_fft_mag = (ss_fft_mag[0:self.cutoff_index_input]+1e-5)**0.5
            ss_fft_mag = torch.log(ss_fft_mag)
            ss_list.append(ss_fft_mag)

        return ss_list

    def print_parameter(self):
        

        #print('-----h: recog_q_net_state-----')
        for name, h_param in self.recog_q_net_state.named_parameters():
            break
                
        #print('-----f: MLP_Spec_m1-----')
        for name, f_param in self.MLP_Spec_m1.named_parameters():
            break

        #print('-----g: obs_net_m1-----')
        #for name, g_param in self.obs_net_m1.named_parameters():
        #    break

        if self.h_para != None:
            if (self.h_para != h_param).any():
                print('h has been updated')
            else:
                print('h is fixed')
        if self.f_para != None:
            if (self.f_para != f_param).any():
                print('f has been updated')
            else:
                print('f is fixed')
        #if self.g_para != None:
        #    if (self.g_para != g_param).any():
        #        print('g has been updated')
        #    else:
        #        print('g is fixed')
                
        self.h_para = h_param.clone()
        self.f_para = f_param.clone()
        #self.g_para = g_param.clone()

    def get_position(self, x, y, scale, offset, ratio, bs=None):
        # https://en.wikipedia.org/wiki/Transformation_matrix#/media/File:2D_affine_transformation_matrix.svg
        bs = self.bs if bs is None else bs
        M = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        M[:, 0, 0] +=  -1 ; M[:, 0, 1] += 0 ;  M[:, 0, 2] += ((x * scale) + offset) / ratio
        M[:, 1, 0] +=   0 ; M[:, 1, 1] += 1 ;  M[:, 1, 2] += ((y * scale) + offset) / ratio
        # Remark: The ratio here is to transform between two plane. The data generation one is [-2.2,2.2]. 
        # Now, we are plotting in the range of [-1,1] 
        return M       

    def get_position_noreverse(self, x, y, scale, offset, ratio, bs=None):
        # https://en.wikipedia.org/wiki/Transformation_matrix#/media/File:2D_affine_transformation_matrix.svg
        bs = self.bs if bs is None else bs
        M = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        M[:, 0, 0] +=   1 ; M[:, 0, 1] += 0 ;  M[:, 0, 2] += ((x * scale) + offset) / ratio
        M[:, 1, 0] +=   0 ; M[:, 1, 1] += 1 ;  M[:, 1, 2] += ((y * scale) + offset) / ratio
        # Remark: The ratio here is to transform between two plane. The data generation one is [-2.2,2.2]. 
        # Now, we are plotting in the range of [-1,1] 
        return M   

    def encode_self_attention(self, batch_image, src_mask_v, d, T):

        if model_variant == 2 :
            # batch_image is in the shape of (101, 3, 64, 64)
            batch_image = batch_image
            X_attn = batch_image.reshape(T, 3, d*d)
            # Here, attn_output is of size [101, 1, 3]
            # For position
            attn_output_weight = []
            attn_output_m1 = self.recog_q_net_m1(X_attn[:,1])
            # For velocity
            # Add the source mask to make sure it attends to the right position
            attn_output_velocity, attn_output_weight_velocity = self.recog_q_net_velocity(X_attn[:,1,:].unsqueeze(1),src_mask=src_mask_v.to(self.device))
            # Here, attn_output is of size [101, 3]
            attn_output_m1 = attn_output_m1.squeeze()
            # Here, attn_output_velocity is of size [101, 3]
            attn_output_velocity = attn_output_velocity.squeeze()

            m1_mean_qx, m1_std_qx, m1_mean_qy, m1_std_qy  = attn_output_m1.split([2, 1, 2, 1], dim=1)
            m1_mean_px, m1_std_px, m1_mean_py, m1_std_py = attn_output_velocity.split([2,1,2,1], dim=1)
            
            # Location
            m1_mean_qx_normal = m1_mean_qx / (m1_mean_qx.norm(dim=-1, keepdim=True) + small_value)
            m1_std_qx = F.softplus(m1_std_qx) + 1
            m1_mean_qy_normal = m1_mean_qy / (m1_mean_qy.norm(dim=-1, keepdim=True) + small_value)
            m1_std_qy = F.softplus(m1_std_qy) + 1
            # Velocity
            m1_mean_px_normal = m1_mean_px / (m1_mean_px.norm(dim=-1, keepdim=True) + small_value)
            m1_std_px = F.softplus(m1_std_px) + 1
            m1_mean_py_normal = m1_mean_py / (m1_mean_py.norm(dim=-1, keepdim=True) + small_value)
            m1_std_py = F.softplus(m1_std_py) + 1

            return m1_mean_qx, m1_std_qx, m1_mean_qx_normal, \
                   m1_mean_qy, m1_std_qy, m1_mean_qy_normal, \
                   m1_mean_px, m1_std_px, m1_mean_px_normal, \
                   m1_mean_py, m1_std_py, m1_mean_py_normal, \
                   attn_output_weight, attn_output_weight_velocity

        elif model_variant == 4:

              # batch_image is in the shape of (101, 3, 64, 64)
            batch_image = batch_image
            X_attn = batch_image.reshape(T, 3, d*d)
            input_attn = X_attn[:,1,:].unsqueeze(1)
            # Here, attn_output is of size [101, 1, 3]
            # Add the source mask to make sure it attends to the right position
            attn_output_m1, attn_output_weight = self.recog_q_net_state(input_attn,src_mask=src_mask_v.to(self.device))
            #attn_output_m1, attn_output_weight = self.recog_q_net_state(X_attn[:,1:])
            #attn_output_m1, attn_output_weight = self.recog_q_net_state(X_attn[:,1:])
            # Here, attn_output is of size [101, 3]
            attn_output_m1 = attn_output_m1.squeeze()
  
            m1_mean_pos, m1_std_pos,\
                m1_mean_vel, m1_std_vel\
                    = attn_output_m1.split([2, 1, 2, 1], dim=1)

            m1_mean_pos_normal = m1_mean_pos / (m1_mean_pos.norm(dim=-1, keepdim=True) + small_value)
            m1_std_pos = F.softplus(m1_std_pos) + 1
            m1_mean_vel_normal = m1_mean_vel / (m1_mean_vel.norm(dim=-1, keepdim=True) + small_value)
            m1_std_vel = F.softplus(m1_std_vel) + 1

            return m1_mean_pos, m1_std_pos, m1_mean_pos_normal, \
                   m1_mean_vel, m1_std_vel, m1_mean_vel_normal, \
                   attn_output_weight, attn_output_weight

    def forward(self, X, u, S, TIME_INDEX, src_mask_v, mass_full_length_list=None):




        # Freeze the parameter
        #for param in self.obs_net_m1.parameters():
        #    param.requires_grad = False
        #for param in self.obs_net_m2.parameters():
        #    param.requires_grad = False
        
        X = X / 255.0

        [T, self.bs, c, d, d] = X.shape
        #T = len(self.t_eval)

        # Get the content
        #self.content_m1 = self.get_content(X, S, d)
        #self.content_m2 = self.get_content(X, S, d)

        x_enc_list = []
        x_sim_list = []
        x_sim_Fourier_list = []
        x_enc_Fourier_list = []
        m1_list = []
        m2_list = []

        Attn_output_weight_list = []
        Attn_output_weight_velocity_list = []

        for batch_ii in tqdm(range(self.bs)):

        #for batch_ii in range(self.bs):

            u = torch.zeros((T,1)).to(self.device)

            # =======Encode=======
            # Get the mean and the variance of the distribution
            if enable_attn:
                self.m1_mean_pos, self.m1_std_pos, self.m1_mean_pos_normal, \
                    self.m1_mean_vel, self.m1_std_vel, self.m1_mean_vel_normal, \
                        self.attn_output_weight, self.attn_output_weight_velocity \
                            = self.encode_self_attention(X[:,batch_ii],src_mask_v,d,T)
                Attn_output_weight_velocity_list.append(self.attn_output_weight_velocity[0])
            else:
                self.m1_mean_pos, self.m1_std_pos, self.m1_mean_pos_normal = self.encode(X[:,batch_ii],d,T)


            #m1 qx: tensor(1.3453) ~ tensor(-1.3454)
            #m1 qy: tensor(1.3459) ~ tensor(-1.3433)
            #m1 px: tensor(2.0469) ~ tensor(-1.7322)
            #m1 py: tensor(1.5796) ~ tensor(-1.8818)
            #m2 qx: tensor(1.3454) ~ tensor(-1.3453)
            #m2 qy: tensor(1.3433) ~ tensor(-1.3459)
            #m2 px: tensor(1.7322) ~ tensor(-2.0469)
            #m2 py: tensor(1.8818) ~ tensor(-1.5796)

            # Sample mean and the variance for the position
            # The prior
            self.P_VM = HypersphericalUniform(1, device=self.device)
            # The likelihood
            self.Q_m1_pos = VonMisesFisher(self.m1_mean_pos_normal, self.m1_std_pos) 
            self.q_m1_q = self.Q_m1_pos.rsample().to(self.device)
            while torch.isnan(self.q_m1_q).any():#avoid nan
                print('isnan for self.q_m1_q')
                self.q_m1_q = self.Q_m1_pos.rsample().to(self.device)
            self.q_m2_q = - self.q_m1_q

            if enable_attn:
                # Using attention to estimate the velocity
                # Sample mean and the variance
                self.Q_m1_vel = VonMisesFisher(self.m1_mean_vel_normal, self.m1_std_vel) 
                # Need to resample to ensure the one sign of velocity is position, and the other is negative
                self.q_m1_p = self.Q_m1_vel.rsample().to(self.device)
                while torch.isnan(self.q_m1_p).any():#avoid nan
                    print('isnan for self.q_m1_p')
                    self.q_m1_p = self.Q_m1_vel.rsample().to(self.device)
                # Flip the sign
                self.q_m1_p = torch.flip(self.q_m1_p, [1])
                self.q_m1_p[:,0] *= -1
                self.q_m2_p = -1 * self.q_m1_p

                # Compute the velocity using finit element
                # This is achieved by comparing two frames
                self.q_dot0_compareFrame = []
            else:
                # Estimate velocity using finit element
                self.q_m1_p = self.angle_vel_est(self.q_m1_q[0:T-1], self.q_m1_q[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)
                self.q_m1_p = torch.flip(self.q_m1_p, [1])
                self.q_m1_p[:,0] *= -1
                self.q_m2_p = - self.q_m1_p
                self.q_dot0_compareFrame = self.angle_vel_est(self.q_m1_q[0:T-1], self.q_m1_q[1:T], self.t_eval[1]-self.t_eval[0]).to(self.device)

            # Estimate euler velocity
            #self.q_dot0 = self.angle_vel_est_euler(torch.atan2(self.q0[:,1],self.q0[:,0]) + np.pi, \
            #                                       self.t_eval[1]-self.t_eval[0]).to(self.device)
            # predict
            z0_u = torch.cat((self.q_m1_q[0:T-1,:],self.q_m1_p[0:T-1,:],\
                              self.q_m2_q[0:T-1,:],self.q_m2_p[0:T-1,:]),axis=1) #torch.Size([simulation_length, 4])

            # Do the simulation
            # The following ifelse ensures that the prediction is only made when given the full length of the data
            if TIME_INDEX == 0:
                self.ode.reset_prediction(prediction_enable=1)
                self.ode.predict_mass(z0_u.flatten())
                r_pred = self.ode.obtain_predict_mass()[0]
            else:
                self.ode.reset_prediction(prediction_enable=0,m1=mass_full_length_list[batch_ii],m2=mass_full_length_list[batch_ii])
            # Change the state
            if enable_physics:
                z0_u = torch.cat((self.q_m1_q[0:T-1,:]*r_pred,self.q_m1_p[0:T-1,:]/(2.0*r_pred**0.5),\
                                  self.q_m2_q[0:T-1,:]*r_pred,self.q_m2_p[0:T-1,:]/(2.0*r_pred**0.5)),axis=1) #torch.Size([simulation_length, 4])
            else:
                z0_u = torch.cat((self.q_m1_q[0:T-1,:]*1.5,self.q_m1_p[0:T-1,:]*0.7,\
                                  self.q_m2_q[0:T-1,:]*1.5,self.q_m2_p[0:T-1,:]*0.7),axis=1) #torch.Size([simulation_length, 4])

            x_enc_list.append(z0_u.reshape((-1,8)))

            z0_u = z0_u.reshape((1,-1))
            # Append with the inital state
            #s_init = torch.cat((S[0,batch_ii,0:5],S[0,batch_ii,5:10])).unsqueeze(0)
            #print('true s_init:',s_init)
            #s_init = torch.cat((mass_1,S[0,batch_ii,1:5],mass_2,S[0,batch_ii,6:10])).unsqueeze(0)

            #s_init = z0_u[0,0:8].unsqueeze(0)
            #s_init = torch.cat((mass_1,z0_u[0,0:4],mass_2,z0_u[0,4:8])).unsqueeze(0)
            
            r_pos = 2 * (r_pred ** 1.5)
            r_pos = r_pos.unsqueeze(0)

            mass_1 = S[0,batch_ii,0].unsqueeze(0)
            mass_2 = S[0,batch_ii,5].unsqueeze(0)
            if TEST_True_INIT_STATE == 1:
                s_init = torch.cat((mass_1,S[0,batch_ii,1:5],mass_2,S[0,batch_ii,6:10])).unsqueeze(0)
            elif TEST_True_INIT_STATE == 0:
                s_init = torch.cat((mass_1,z0_u[0,0:2],(-1*z0_u[0,1]/r_pos),(   z0_u[0,0]/r_pos),\
                                    mass_2,z0_u[0,4:6],( 1*z0_u[0,1]/r_pos),(-1*z0_u[0,0]/r_pos))).unsqueeze(0)
            elif TEST_True_INIT_STATE == 2:
                s_init = torch.cat((mass_1,z0_u[0,0:4],\
                                    mass_2,z0_u[0,4:8])).unsqueeze(0)
            self.t_eval_ = self.t_eval[TIME_INDEX:]

            if enable_physics:
                zT_u = odeint(self.ode, s_init[0,:], self.t_eval_, method=self.hparams.solver) # T,299
            else:
                zT_u = z0_u.reshape((-1,8))
                zT_u = torch.cat((zT_u,zT_u[-1].unsqueeze(0)),axis=0)

            # Append the number on the list
            # We only need to store the mu and L in the first place
            if TIME_INDEX == 0:
                r_pred = self.ode.obtain_predict_mass()
                m1_list.append(r_pred.detach().cpu().numpy()[0])

            # Get the state
            if enable_physics:
                zT_u = torch.cat((zT_u[:,1:5],zT_u[:,6:10]),axis=1)

            x_sim_list.append(zT_u) # T, bs, 4


            if TIME_INDEX == 0:
                FourierFeat = self.get_FourierFeat(zT_u.flatten())
                FourierFeat = torch.stack(FourierFeat, axis=0)
                FourierFeat.retain_grad()
                x_sim_Fourier_list.append(FourierFeat)

                FourierFeat = self.get_FourierFeat(z0_u.flatten())
                FourierFeat = torch.stack(FourierFeat, axis=0)
                FourierFeat.retain_grad()
                # Save the Fourier features
                x_enc_Fourier_list.append(FourierFeat)


        # Stack the data, and retain_grad() after using stack function
        x_sim_list = torch.stack(x_sim_list, axis=0)
        x_sim_list.retain_grad()
        x_sim_list = x_sim_list.permute(1,0,2)

        x_enc_list = torch.stack(x_enc_list, axis=0)
        x_enc_list.retain_grad()
        x_enc_list = x_enc_list.permute(1,0,2)

        self.x_sim_Fourier_list = torch.stack(x_sim_Fourier_list, axis=0)
        self.x_sim_Fourier_list.retain_grad()
        self.x_enc_Fourier_list = torch.stack(x_enc_Fourier_list, axis=0)
        self.x_enc_Fourier_list.retain_grad()

        if MODEL_TYPE == 0 or MODEL_TYPE == 1:
            if enable_attn == 1:
                self.Attn_output_weight_velocity_list = torch.stack(Attn_output_weight_velocity_list, axis=0)

        # We only need to store the mu and L in the first place
        if TIME_INDEX == 0:
            self.m1_list = m1_list
            self.m2_list = m2_list
        
        self.m1_qx, self.m1_qy, self.m1_px,self.m1_py,\
            self.m2_qx, self.m2_qy, self.m2_px,self.m2_py = x_sim_list.split([1, 1, 1, 1, 1, 1, 1, 1], dim=-1)

        self.m1_qx = self.m1_qx.contiguous()
        self.m1_qx = self.m1_qx.view(T*self.bs, 1)
        self.m1_qy = self.m1_qy.contiguous()
        self.m1_qy = self.m1_qy.view(T*self.bs, 1)

        self.m2_qx = self.m2_qx.contiguous()
        self.m2_qx = self.m2_qx.view(T*self.bs, 1)
        self.m2_qy = self.m2_qy.contiguous()
        self.m2_qy = self.m2_qy.view(T*self.bs, 1)

        # Force the position to be within small range
        #self.m1_qx = torch.clamp(self.m1_qx,min=-1.5, max=1.5)
        #self.m1_qy = torch.clamp(self.m1_qy,min=-1.5, max=1.5)
        #self.m2_qx = torch.clamp(self.m2_qx,min=-1.5, max=1.5)
        #self.m2_qy = torch.clamp(self.m2_qy,min=-1.5, max=1.5)
        self.m1_qx_enc, self.m1_qy_enc, self.m1_px_enc, self.m1_py_enc, \
            self.m2_qx_enc, self.m2_qy_enc, self.m2_px_enc, self.m2_py_enc = x_enc_list.split([1, 1, 1, 1, 1, 1, 1, 1], dim=-1)
        
        self.m1_qx_enc = self.m1_qx_enc.contiguous()
        self.m1_qx_enc = self.m1_qx_enc.view((T-1)*self.bs, 1)
        self.m1_qy_enc = self.m1_qy_enc.contiguous()
        self.m1_qy_enc = self.m1_qy_enc.view((T-1)*self.bs, 1)

        self.m2_qx_enc = self.m2_qx_enc.contiguous()
        self.m2_qx_enc = self.m2_qx_enc.view((T-1)*self.bs, 1)
        self.m2_qy_enc = self.m2_qy_enc.contiguous()
        self.m2_qy_enc = self.m2_qy_enc.view((T-1)*self.bs, 1)

        # =======Decode=======
        # Here we want to get the content of the pole
        ones = torch.ones_like(self.m1_qx[:,0:1])
        self.content_m1 = self.obs_net_m1(ones)
        self.content_m2 = self.obs_net_m2(ones)
        #self.content_m1 = self.content_m1.flatten().unsqueeze(0)
        #self.content_m2 = self.content_m2.flatten().unsqueeze(0)
        #self.content_m1 = self.content_m1.repeat(T*self.bs, 1)
        #self.content_m2 = self.content_m2.repeat(T*self.bs, 1)

        # Get the theta information to place the pole
        translate_matrix_m1 = self.get_position(x=self.m1_qx[:,0],y=self.m1_qy[:,0], scale=1, offset=0,\
                                                    ratio=2.2, bs=T*self.bs) 
        grid_m1 = F.affine_grid(translate_matrix_m1, torch.Size((T*self.bs, 1, d, d)))
        translate_matrix_m2 = self.get_position(x=self.m2_qx[:,0],y=self.m2_qy[:,0], scale=1, offset=0,\
                                                    ratio=2.2, bs=T*self.bs) 
        grid_m2 = F.affine_grid(translate_matrix_m2, torch.Size((T*self.bs, 1, d, d)))
        # Place the content
        content_m1 = F.grid_sample(self.content_m1.view(T*self.bs, 1, d, d), grid_m1)
        content_m2 = F.grid_sample(self.content_m2.view(T*self.bs, 1, d, d), grid_m2)

        # Get the reconstruction images
        self.Xrec = torch.cat([torch.zeros_like(content_m1), content_m1, content_m2], dim=1)
        self.Xrec = self.Xrec.view([T, self.bs, 3, d, d])

        # Plot something to track the performance
        if self.count % 25 == 0 and Plot_enable == True:
            for tt in range(T):
                fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
                gs = fig1.add_gridspec(1, 2, width_ratios=[1.0,1.0])
                ax = fig1.add_subplot(gs[0, 0])
                from torchvision import utils
                grid = utils.make_grid(X[tt, 0].view(-1, 3, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
                ax.imshow(X_)
                ax = fig1.add_subplot(gs[0, 1])

                grid = utils.make_grid(self.Xrec[tt, 0].view(-1, 3, 64, 64))
                X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
                ax.imshow(X_)
                


       
        if TIME_INDEX == 0:
            self.count += 1

        return None

    def get_content(self, x, s, d):

        #x -> torch.Size([101, 50, 3, 64, 64])
        #s -> torch.Size([101, 50, 11])
        content = x[0,0,1,:,:]
        sx      = s[0,0,1]
        sy      = s[0,0,2]
        # Get the theta information to place the pole
        translate_matrix = self.get_position(x=sx,y=-sy, scale=1, offset=0,\
                                               ratio=2.2, bs=1) 
        grid = F.affine_grid(translate_matrix, torch.Size((1, 1, d, d)))
        # Place the content
        content = F.grid_sample(content.view(1, 1, d, d), grid)

        '''
        fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
        gs = fig1.add_gridspec(1, 2, width_ratios=[1.0,1.0])
        ax = fig1.add_subplot(gs[0, 0])
        from torchvision import utils
        grid = utils.make_grid(content.view(-1, 1, 64, 64))
        X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
        ax.imshow(X_)
        ax = fig1.add_subplot(gs[0, 1])

        grid = utils.make_grid(content.view(-1, 1, 64, 64))
        X_ = np.array(grid.permute(1,2,0).detach().cpu().numpy())
        ax.imshow(X_)
        


      
        return content

    def print_max(self,x):

        print('m1 qx:',torch.max(x[:,:,1]),'~',torch.min(x[:,:,1]))
        print('m1 qy:',torch.max(x[:,:,2]),'~',torch.min(x[:,:,2]))
        print('m1 px:',torch.max(x[:,:,3]),'~',torch.min(x[:,:,3]))
        print('m1 py:',torch.max(x[:,:,4]),'~',torch.min(x[:,:,4]))
        print('m2 qx:',torch.max(x[:,:,6]),'~',torch.min(x[:,:,6]))
        print('m2 qy:',torch.max(x[:,:,7]),'~',torch.min(x[:,:,7]))
        print('m2 px:',torch.max(x[:,:,8]),'~',torch.min(x[:,:,8]))
        print('m2 py:',torch.max(x[:,:,9]),'~',torch.min(x[:,:,9]))

        sys.exit('Done')

    def training_step(self, train_batch, batch_idx):

        #self.print_parameter()


        X, u, State = train_batch
        #self.print_max(State)
        #X: torch.Size([101, 2, 64, 64, 3])
        State: torch.Size([101, 2, 11])
        X = X.permute(0,1,4,2,3)

        '''
        print('State:',State.shape)
        print('pos')
        r = (State[0:10,0,1] ** 2 + State[0:10,0,2] ** 2) ** 0.5
        print(r)
        vel_y = State[0:10,0,1] / (2*r**1.5)
        print('vel_y:',vel_y)
        vel_x = State[0:10,0,2] / (2*r**1.5)
        print('vel_x:',vel_x)

        print(State[0:10,0,1])
        print(State[0:10,0,2])
        print(State[0:10,0,6])
        print(State[0:10,0,7])



        print('vel')
        print(State[0:10,0,3])
        print(State[0:10,0,4])
        print(State[0:10,0,8])
        print(State[0:10,0,9])
        sys.exit()
        '''
           
        #max m1 q1: tensor(1.6997)
        #min m1 q1: tensor(-1.8963)
        #max m1 q2: tensor(1.8963)
        #min m1 q2: tensor(-1.8963)
        #max m2 p1: tensor(1.5176)
        #min m2 p1: tensor(-0.7995)
        #max m2 p2: tensor(0.7577)
        #min m2 p2: tensor(-1.5589)

        # plotting
        #fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
        #plt.plot(State.detach().cpu().numpy()[:,0,1])
        #plt.savefig('m1_q1' + '.png')
        #plt.close()
        #fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
        #plt.plot(State.detach().cpu().numpy()[:,0,2])
        #plt.savefig('m1_q2' + '.png')
        #plt.close()
        #fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
        #plt.plot(State.detach().cpu().numpy()[:,0,3])
        #plt.savefig('m1_p1' + '.png')
        #plt.close()
        #fig1 = plt.figure(constrained_layout=False, figsize=(10,4))
        #plt.plot(State.detach().cpu().numpy()[:,0,4])
        #plt.savefig('m1_p2' + '.png')
        #plt.close()
        #sys.exit()

        # X is in the shape of torch.Size([100, 256, 64, 64]) = [time, gray_scale,image_dim,image_dim]
        # T: simulation length: T = 100
        # size of X is (T+1, batch_size, 64, 64)
        # size of u is (64, 1), because of constant u
        # size of State is (T+1, batch_size, 7)

        lhood_list = []
        lhood_true_list = []
        kl_q_list = []
        penalty_list = []
        Time_loss_list = []
        Time_pos_loss_list = []
        Time_vel_loss_list = []
        l_sparse_list = []
        Freq_loss_list = []

        iii = 0
        for TIME_INDEX in np.arange(0,T_pred-cutoff_freq_input ,gap_interval): # Default: 20 is the gap interval
            X_ = X[TIME_INDEX:,:,:,:]

            State_ = State[TIME_INDEX:,:,:]
            if TIME_INDEX != 0:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii],self.mu_list)
            else:
                self.forward(X_, u, State_,TIME_INDEX,self.SRC_MAS_V[iii])

            iii += 1

            # Compute the system parameter loss
            if TIME_INDEX == 0:
                true_m1 = (State[0,:,1] ** 2 + State[0,:,2] ** 2) ** 0.5
                pred_m1 = torch.tensor(np.array(self.m1_list))
                mse_m1 = torch.mean(self.loss_fn(pred_m1.to(self.device),true_m1.to(self.device)))

                print('pred_r:',pred_m1)
                print('true_r:',true_m1)
                #print('R2:',R2_para)

            # Get the time loss to align the states between the encoder and the ode solver
            T_pred_ = X_.shape[0] - 1
            print('ode state pos x:', self.m1_qx.view(T_pred_+1,self.bs,-1)[0:T_pred_][0:20,0,:])
            print('ode state pos y:', self.m1_qy.view(T_pred_+1,self.bs,-1)[0:T_pred_][0:20,0,:])
            print('enc state pos x:', self.m1_qx_enc.view(T_pred_,self.bs,-1)[0:20,0,:])
            print('enc state pos y:', self.m1_qy_enc.view(T_pred_,self.bs,-1)[0:20,0,:])
            print('ode state vel:', self.m1_px.view(T_pred_+1,self.bs,-1)[0:T_pred_][0:10,0,:])
            print('enc state vel:', self.m1_px_enc.view(T_pred_,self.bs,-1)[0:10,0,:])

            # current version mean over everthing
            if mean_over_everything == 1:
                m1_qx_loss = self.loss_fn_mean(self.m1_qx.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_qx_enc.view(T_pred_,self.bs,-1))
                m1_qy_loss = self.loss_fn_mean(self.m1_qy.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_qy_enc.view(T_pred_,self.bs,-1))
                m1_px_loss = self.loss_fn_mean(self.m1_px.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_px_enc.view(T_pred_,self.bs,-1))
                m1_py_loss = self.loss_fn_mean(self.m1_py.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_py_enc.view(T_pred_,self.bs,-1))
                m2_qx_loss = self.loss_fn_mean(self.m2_qx.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_qx_enc.view(T_pred_,self.bs,-1))
                m2_qy_loss = self.loss_fn_mean(self.m2_qy.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_qy_enc.view(T_pred_,self.bs,-1))
                m2_px_loss = self.loss_fn_mean(self.m2_px.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_px_enc.view(T_pred_,self.bs,-1))
                m2_py_loss = self.loss_fn_mean(self.m2_py.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_py_enc.view(T_pred_,self.bs,-1))
            else:      
                m1_qx_loss = self.loss_fn(self.m1_qx.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_qx_enc.view(T_pred_,self.bs,-1))
                m1_qy_loss = self.loss_fn(self.m1_qy.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_qy_enc.view(T_pred_,self.bs,-1))
                m1_px_loss = self.loss_fn(self.m1_px.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_px_enc.view(T_pred_,self.bs,-1))
                m1_py_loss = self.loss_fn(self.m1_py.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m1_py_enc.view(T_pred_,self.bs,-1))
                m2_qx_loss = self.loss_fn(self.m2_qx.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_qx_enc.view(T_pred_,self.bs,-1))
                m2_qy_loss = self.loss_fn(self.m2_qy.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_qy_enc.view(T_pred_,self.bs,-1))
                m2_px_loss = self.loss_fn(self.m2_px.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_px_enc.view(T_pred_,self.bs,-1))
                m2_py_loss = self.loss_fn(self.m2_py.view(T_pred_+1,self.bs,-1)[0:T_pred_].detach(),self.m2_py_enc.view(T_pred_,self.bs,-1))

                # The size of Time_pos_loss is [Time_Steps, BatchSize, 2]
                m1_qx_loss = m1_qx_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m1_qy_loss = m1_qy_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m1_px_loss = m1_px_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m1_py_loss = m1_py_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m2_qx_loss = m2_qx_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m2_qy_loss = m2_qy_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m2_px_loss = m2_px_loss.sum([0, 2]).mean() # Per batchsize over steps and states
                m2_py_loss = m2_py_loss.sum([0, 2]).mean() # Per batchsize over steps and states

            #Time_pos_loss = Time_pos_loss.mean() # Per batchsize and steps over states
            #Time_vel_loss = Time_vel_loss.mean() # Per batchsize and steps over states
            Time_pos_loss = m1_qx_loss + m1_qy_loss + m2_qx_loss + m2_qy_loss
            Time_vel_loss = m1_px_loss + m1_py_loss + m2_px_loss + m2_py_loss
            Time_loss = Time_pos_loss + Time_vel_loss

            ######### Compute the loss #########
            # current version
            lhood = - self.loss_fn(self.Xrec[0:Loss_first_index][:,:,1:], X_[0:Loss_first_index][:,:,1:]/255.0*weight_recons_255)
            lhood = lhood.sum([0, 2, 3, 4]).mean()#(step, batch, channel, height, width)
            #(step, batch, height, width)

            # In the preious version, the channel is averaged together

            lhood_true = - self.loss_fn(self.Xrec[0:Loss_first_index][:,:,1:], X_[0:Loss_first_index][:,:,1:]/255.0)            
            lhood_true = lhood_true.sum([0, 2, 3, 4]).mean()#(step, batch, channel, height, width)

            l_sparse = (   (self.Xrec[0:Loss_first_index][:,:,1:]).sum([3, 4]).mean() \
                         - (X_[0:Loss_first_index][:,:,1:] / 255.0).sum([3, 4]).mean()   )**2

            lhood = lhood - enable_weight_sparse * l_sparse

            #l_sparse = abs(self.Xrec[0:Loss_first_index]).sum()
            #lhood = lhood.sum([2, 3]).mean()


            # Compute the Fourier features loss
            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                # The shape of self.x_sim_Fourier_list is torch.Size([32, 4, 15])
                Freq_loss = self.loss_fn(torch.exp(self.x_sim_Fourier_list.detach()),torch.exp(self.x_enc_Fourier_list))
                Freq_loss_0 = Freq_loss[:,0,:] * (1) #(batch, # of states, freq_bin)
                Freq_loss_1 = Freq_loss[:,1,:] * (1)
                Freq_loss_2 = Freq_loss[:,2,:] * (1)
                Freq_loss_3 = Freq_loss[:,3,:] * (1)
                Freq_loss = Freq_loss_0 + Freq_loss_1 + Freq_loss_2 + Freq_loss_3
                Freq_loss = Freq_loss.sum([1]).mean()

            if enable_attn == 1:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_m1_pos, self.P_VM).mean() \
                     + torch.distributions.kl.kl_divergence(self.Q_m1_vel, self.P_VM).mean()
            else:
                kl_q = torch.distributions.kl.kl_divergence(self.Q_m1_pos, self.P_VM).mean()

            if enable_attn == 1:    
                norm_penalty = (self.q_m1_p.norm(dim=-1).mean() - 1) ** 2\
                             + (self.q_m1_q.norm(dim=-1).mean() - 1) ** 2
            else:
                norm_penalty = (self.q_m1_p.norm(dim=-1).mean() - 1) ** 2

            lambda_ = self.current_epoch/8000 if self.hparams.annealing else 1/100

            lhood_list.append(lhood)
            lhood_true_list.append(lhood_true)
            kl_q_list.append(kl_q)
            penalty_list.append(lambda_ * norm_penalty)
            Time_loss_list.append(Time_loss)
            Time_pos_loss_list.append(Time_pos_loss)
            Time_vel_loss_list.append(Time_vel_loss)
            l_sparse_list.append(l_sparse)
            if MODEL_TYPE == 0 or MODEL_TYPE == 1:
                Freq_loss_list.append(Freq_loss)

        #### Final loss function ####
        # Reconstruction loss
        # KL loss
        # Regulization loss
        # State Alignment loss
        # Velocity Alignment Loss

        lhood_list = torch.stack(lhood_list, axis=0)
        lhood_list.retain_grad()
        kl_q_list = torch.stack(kl_q_list, axis=0)
        kl_q_list.retain_grad()
        penalty_list = torch.stack(penalty_list, axis=0)
        penalty_list.retain_grad()
        Time_loss_list = torch.stack(Time_loss_list, axis=0)
        Time_loss_list.retain_grad()
        Time_pos_loss_list = torch.stack(Time_pos_loss_list, axis=0)
        Time_pos_loss_list.retain_grad()
        Time_vel_loss_list = torch.stack(Time_vel_loss_list, axis=0)
        Time_vel_loss_list.retain_grad()
        l_sparse_list = torch.stack(l_sparse_list, axis=0)
        l_sparse_list.retain_grad()
        lhood_true_list = torch.stack(lhood_true_list, axis=0)
        lhood_true_list.retain_grad()
        if MODEL_TYPE == 0 or MODEL_TYPE == 1:
            Freq_loss_list = torch.stack(Freq_loss_list, axis=0)
            Freq_loss_list.retain_grad()

        # current version
        #loss = Time_loss_list.mean()
        if MODEL_TYPE == 0 or MODEL_TYPE == 1:
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\
                   + Freq_loss_weight * Freq_loss_list.mean() \
                   + Time_loss_weight * Time_loss_list.mean() #\
                   #+ weight_sparse * l_sparse_list.mean() 
        else:
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\
                   + Time_loss_weight * Time_loss_list.mean() #\    
        '''
        if self.current_epoch % training_loss_interval == 0:
            if self.training_loss_flag == 0:
                self.training_loss_flag = 1
            else:
                self.training_loss_flag = 0

        if self.training_loss_flag == 0:
            print('update f, freeze h & g')
        elif self.training_loss_flag == 1:
            print('update h & g, freeze f')
        '''
        # freeze the parameter 
        '''
        if self.training_loss_flag == 1:
            # Freeze f(), only update g and h
            for param in self.recog_q_net.parameters():
                param.requires_grad = True
            for param in self.recog_q_net_velocity.parameters():
                param.requires_grad = True
            for param in self.MLP_Spec_mu.parameters():
                param.requires_grad = False   
            for param in self.obs_net.parameters():
                param.requires_grad = True   
            # compute the loss
            loss = - weight_recons * lhood_list.mean() \
                   + lambda_ * penalty_list.mean() \
                   + 1.0 * kl_q_list.mean()\

        elif self.training_loss_flag == 0:
            # Freeze the h() that predicts the states
            # only update f() that predicts the parameters
            for param in self.recog_q_net.parameters():
                param.requires_grad = False
            for param in self.recog_q_net_velocity.parameters():
                param.requires_grad = False
            for param in self.MLP_Spec_mu.parameters():
                param.requires_grad = True       
            for param in self.obs_net.parameters():
                param.requires_grad = False   
            # compute the loss
            loss = Time_loss_weight * Time_loss_list.mean()
        '''
        if MODEL_TYPE == 0 or MODEL_TYPE == 1:
            logs = {'MSE_r': mse_m1, \
                    'Recons_Loss': -lhood_list.mean(), \
                    'Recons_True_Loss': -lhood_true_list.mean(), \
                    'State_Loss': Time_loss_list.mean(), \
                    'Freq_Loss': Freq_loss_list.mean(), \
                    'State_Pos_Loss': Time_pos_loss_list.mean(), \
                    'State_Vel_Loss': Time_vel_loss_list.mean(), \
                    'Image_Sparse_Loss': l_sparse_list.mean(), \
                    'KL_loss': kl_q_list.mean(), \
                    'Regulization_loss':  penalty_list.mean(), \
                    'Regulization_loss_lambda':  lambda_, \
                    'loss': loss, \
                    'monitor': loss}
        else:
            logs = {'MSE_r': mse_m1, \
                    'Recons_Loss': -lhood_list.mean(), \
                    'Recons_True_Loss': -lhood_true_list.mean(), \
                    'State_Loss': Time_loss_list.mean(), \
                    'State_Pos_Loss': Time_pos_loss_list.mean(), \
                    'State_Vel_Loss': Time_vel_loss_list.mean(), \
                    'Image_Sparse_Loss': l_sparse_list.mean(), \
                    'KL_loss': kl_q_list.mean(), \
                    'Regulization_loss':  penalty_list.mean(), \
                    'Regulization_loss_lambda':  lambda_, \
                    'loss': loss, \
                    'monitor': loss}
        # Log the running loss
        return {'loss':loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=lr, type=float)
        parser.add_argument('--batch_size', default=num_batch, type=int)
        
        return parser

def main(args):


    if dataset_type == 0:
        dataset_name = dataset_folder \
                     + 'twobody-gym-image-dataset-train_m11.0-1.0_m21.0-1.0_Hz6_sL101_nS200_DaTrue_R0.4-1.0.pkl'
    elif dataset_type == 1:
        dataset_name = dataset_folder \
                     + 'twobody-gym-image-dataset-train_m11.0-1.0_m21.0-1.0_Hz6_sL101_nS200_DaTrue_R0.4-1.0_smallBall.pkl'
    elif dataset_type == 2:
        dataset_name = dataset_folder \
                     + 'twobody-gym-image-dataset-train_m11.0-1.0_m21.0-1.0_Hz6_sL101_nS200_DaTrue_R0.4-1.0_smallBall_extreme.pkl'
    elif dataset_type == 3:
        dataset_name = dataset_folder \
                     + 'twobody-gym-image-dataset-train_m11.0-1.0_m21.0-1.0_Hz6_sL125_nS1000_DaFalse_R0.4-1.0_smallBall.pkl'
    else:                       
        sys.exit('No match dataset!!')

    model = Model(hparams=args, data_path=os.path.join(PARENT_DIR, 'datasets', dataset_name))

    # doc link for "ModelCheckpoint"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/model_checkpoint.py
    checkpoint_callback = ModelCheckpoint(monitor='monitor',
                                          dirpath=args.name + '/',
                                          filename='Model-{epoch:05d}-{loss:.2f}',
                                          save_top_k=5, 
                                          save_last=True)

    # doc link for "Trainer"
    # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/trainer.py
    trainer = Trainer.from_argparse_args(args, 
                                         limit_train_batches=1,
                                         max_epochs=10000,
                                         deterministic=True,
                                         terminate_on_nan=True,
                                         log_every_n_steps=1,
                                         default_root_dir=os.path.join(PARENT_DIR, 'logs', args.name),
                                         checkpoint_callback=checkpoint_callback,gradient_clip_val=gradient_clip,track_grad_norm=2)#, \



    main(args)


    # make gif
    #https://ezgif.com/maker/ezgif-6-69fc2d3f-gif