'''
Author: 
Email: 
Date: 2021-12-21 11:57:44
LastEditTime: 2022-12-28 13:31:19
Description: 
'''

from copy import deepcopy
import numpy as np
import os

import torch
import torch.nn as nn
import networkx as nx

from Because.mpc.mpc import MPC_Unlock, MPC_Unlock_Parallel, MPC_crash, MPC_crash_Parallel, MPC_Lift, MPC_Lift_Parallel
from Because.dynamics_modules import RGCN, MLP, GRU_SCM, TICSA_GRU, TICSA
from Because.Because_utils import CUDA, kaiming_init


class WorldModel(object):
    def __init__(self, args):
        self.state_dim = args['env_params']['state_dim']
        self.action_dim = args['env_params']['action_dim']
        self.goal_dim = args['env_params']['goal_dim']
        self.env_name = args['env_params']['env_name']
        self.Because_model = args['Because_model']
        self.use_discover = args['use_discover']
        self.use_gt = args['use_gt']
        self.num_envs = args['num_envs']
        
        assert self.Because_model in ['causal', 'full', 'mopo', 'offline', 'gnn', 'cdl']

        self.n_epochs = args['n_epochs']
        self.lr = args['lr']
        self.weight_decay = args['weight_decay']
        self.batch_size = args['batch_size']

        self.validation_flag = args['validation_flag']
        self.validate_freq = args['validation_freq']
        self.validation_ratio = args['validation_ratio']
        if torch.cuda.is_available():
            self.device = 'cuda:0'
        else:
            self.device = 'cpu'

        # process things that are different in environments
        
        if self.env_name == 'unlock':
            self.build_node_and_edge = self.build_node_and_edge_unlock
            self.organize_nodes = self.organize_nodes_unlock_crash
            self.room_size = args['env_params']['room_size']
            self.max_key_num = args['env_params']['max_key_num']
            self.move_dim = args['env_params']['move_dim']
            self.pick_key_dim = args['env_params']['pick_key_dim']
            self.open_door_dim = args['env_params']['open_door_dim']
            self.map_size = self.room_size * self.room_size
            self.state_dim_list = [self.map_size, self.map_size, self.map_size, self.max_key_num]  # [agent_pos, key_pos, door_pos, have_key]
            self.action_dim_list = [self.move_dim, self.pick_key_dim, self.open_door_dim]  # [move, pick_key, open_door]

        elif args['env_params']['env_name'] == 'crash':
            self.build_node_and_edge = self.build_node_and_edge_crash_large
            self.organize_nodes = self.organize_nodes_unlock_crash
            self.agent_state_dim = args['env_params']['agent_state_dim']
            self.agent_action_dim = args['env_params']['agent_action_dim']
            self.n_agents = args['env_params']['n_agents']
            self.collision_dim = args['env_params']['collision_dim']
            self.state_dim_list = [self.agent_state_dim] * self.n_agents + [self.collision_dim] # the last one is the collision node
            self.action_dim_list = [self.agent_action_dim] * (self.n_agents - 1)
        
        elif args['env_params']['env_name'] == 'lift':
            self.build_node_and_edge = self.build_node_and_edge_lift
            self.organize_nodes = self.organize_nodes_unlock_crash
            self.state_dim_list = [1]*3 + [4, 6, 6] + [1]*3 + [4] + [1]*3 + [3, 1] # the last one is the collision node
            self.action_dim_list = [1]*4
            
        else: 
            raise ValueError('Unknown environment name')

        self.use_full = False
        self.use_mlp = False
        self.use_ticsa = False
        
        if self.Because_model == 'mopo':
            self.model_name = 'mopo'
            self.use_mlp = True
        elif self.Because_model == 'causal':
            self.model_name = 'gru'
        elif self.Because_model == 'full':
            self.model_name = 'gru'
            self.use_full = True
        elif self.Because_model == 'gnn':
            self.model_name = 'gnn'
            self.use_full = True
        elif self.Because_model == 'cdl': 
            self.model_name = 'cdl'
            self.use_mlp = True
            
        random = False
        if self.model_name == 'mopo':
            input_dim = self.state_dim - self.goal_dim + self.action_dim
            output_dim = self.state_dim - self.goal_dim
            self.model = CUDA(MLP(input_dim, output_dim, 1, args["hidden_size"], dropout_p=0.0))
            hidden_dim = args["hidden_size"]
        elif self.model_name == 'cdl': 
            input_dim = self.state_dim - self.goal_dim + self.action_dim
            output_dim = self.state_dim - self.goal_dim
            self.model = CUDA(TICSA(input_dim, output_dim, 3, args["hidden_size"], dropout_p=0.1))
            hidden_dim = args["hidden_size"]
            
        elif self.model_name == 'gru' or self.model_name == 'gnn':
            edge_dim = 1
            hidden_num = 1
            
            
            hidden_dim = args["hidden_dim"]
            self.node_num = len(self.action_dim_list) + len(self.state_dim_list)
            self.node_dim = int(np.max(self.state_dim_list+self.action_dim_list))
            if self.model_name == 'gnn':
                self.model = CUDA(RGCN(self.node_dim, self.node_num, 'mean', args["hidden_dim"], self.node_dim, edge_dim, hidden_num))
            else:
                self.model = CUDA(GRU_SCM(self.action_dim_list, self.state_dim_list, self.node_num, 'mean', args["hidden_dim"], edge_dim, hidden_num, dropout=0.0, random=random))
        print(self.model)
        print('----------------------------')
        print('Env:', self.env_name)
        print('Because model:', self.Because_model)
        print('Model_name:', self.model_name)
        print('Full:', self.use_full)
        print('SCM noise:', random)
        print('Hidden dim:', hidden_dim)
        print('----------------------------')

        self.model.apply(kaiming_init)
        self.mse_loss = nn.MSELoss()
        self.ce_loss = nn.CrossEntropyLoss()
        self.buffer_length = 0
        self.criterion = self.mse_loss

        if self.model_name == 'cdl':
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        else: 
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        self.data = None
        self.label = None
        self.reward = None
        
        self.eps = 1e-30
        
        if self.Because_model == 'causal':
            # the initial graph is a lower triangular graph
            self.causal_graph = np.zeros((self.node_num, self.node_num))
            for i in range(self.causal_graph.shape[0]):
                for j in range(self.causal_graph.shape[1]):
                    if i >= j:
                        self.causal_graph[i, j] = 1
        self.best_test_loss = np.inf

    def fused_mse_loss_crash(self, predict, groundtruth):
        collision = groundtruth[:, -self.collision_dim:]
        collision_ = predict[:, -self.collision_dim:]
        collision_loss = self.mse_loss(collision_, collision)

        pose = groundtruth[:, :-self.collision_dim]
        pose_ = predict[:, :-self.collision_dim]
        pose_loss = self.mse_loss(pose, pose_)

        return collision_loss + pose_loss * 10

    def fused_loss_crash(self, predict, groundtruth):
        collision = groundtruth[:, -self.collision_dim:]
        assert collision.shape[1] == self.collision_dim
        collision = torch.argmax(collision, dim=1)
        collision_ = predict[:, -self.collision_dim:]
        collision_loss = self.ce_loss(collision_, collision)

        pose = groundtruth[:, :-self.collision_dim]
        assert pose.shape[1] == np.sum(self.state_dim_list) - self.collision_dim
        pose_ = predict[:, :-self.collision_dim]
        pose_loss = self.mse_loss(pose, pose_)

        #print(pose_loss, collision_loss)
        return collision_loss + pose_loss
        
    def onehot_loss_unlock(self, predict, groundtruth):
        agent_pos = groundtruth[:, 0:self.map_size]
        agent_pos = torch.argmax(agent_pos, dim=1)
        key_pos = groundtruth[:, self.map_size:self.map_size*2]
        key_pos = torch.argmax(key_pos, dim=1)
        door_pos = groundtruth[:, self.map_size*2:self.map_size*3]
        door_pos = torch.argmax(door_pos, dim=1)
        have_key = groundtruth[:, self.map_size*3:]
        have_key = torch.argmax(have_key, dim=1)

        agent_pos_ = predict[:, 0:self.map_size]
        key_pos_ = predict[:, self.map_size:self.map_size*2]
        door_pos_ = predict[:, self.map_size*2:self.map_size*3]
        have_key_ = predict[:, self.map_size*3:]

        loss_1 = self.ce_loss(agent_pos_, agent_pos)
        loss_2 = self.ce_loss(key_pos_, key_pos)
        loss_3 = self.ce_loss(door_pos_, door_pos)
        loss_4 = self.ce_loss(have_key_, have_key)

        return loss_1 + loss_2 + loss_3 + loss_4
    
    def build_node_and_edge_unlock(self, data):
        """
            Note that the order of GNN is [A, S], which is different from the order [S, A] in data.
        """
        # create the node matrix. the last node is the output node therefore should always be 0.
        batch_size = data.shape[0]
        x = torch.zeros((batch_size, self.node_num, self.node_dim), device=torch.device(self.device))

        # build the nodes of action
        action = data[:, sum(self.state_dim_list):]
        start_ = 0
        for a_i in range(len(self.action_dim_list)):
            end_ = self.action_dim_list[a_i] + start_
            x[:, a_i, 0:end_-start_] = action[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # build the nodes of state
        state = data[:, 0:sum(self.state_dim_list)]
        start_ = 0
        for s_i in range(len(self.state_dim_list)):
            end_ = self.state_dim_list[s_i] + start_
            x[:, s_i+len(self.action_dim_list), 0:end_-start_] = state[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # using GT causal graph
        # nodes - [move, pick_up, open_door, agent_pos, key_door_obs, have_key] - [M, P, O, A, K, D, H]
        adj = [
            [1, 0, 0, 0, 0, 0, 0], # M
            [0, 1, 0, 0, 0, 0, 0], # P
            [0, 0, 1, 0, 0, 0, 0], # O
            [1, 0, 0, 1, 0, 0, 0], # A
            [0, 0, 0, 0, 1, 0, 0], # K
            [0, 0, 1, 1, 0, 1, 1], # D
            [0, 1, 0, 1, 1, 0, 1]  # H
        ]

        adj_offline = [
            [1, 0, 0, 0, 0, 0, 0], # M
            [0, 1, 0, 0, 0, 0, 0], # P
            [0, 0, 1, 0, 0, 0, 0], # O
            [1, 0, 0, 1, 0, 1, 0], # A
            [0, 0, 0, 0, 1, 0, 0], # K
            [0, 0, 0, 1, 0, 1, 1], # D
            [0, 1, 0, 1, 0, 0, 1]  # H
        ]

        # for non-causal cases, the output delta_state should be 0
        adj_anti = [
            [1, 0, 0, 0, 0, 0, 0], # M
            [0, 1, 0, 0, 0, 0, 0], # P
            [0, 0, 1, 0, 0, 0, 0], # O
            [0, 0, 0, 1, 0, 0, 0], # A
            [0, 0, 0, 0, 1, 0, 0], # K
            [0, 0, 0, 0, 0, 1, 0], # D
            [0, 0, 0, 0, 0, 0, 1]  # H
        ]

        # full graph
        full = [
            [1, 0, 0, 0, 0, 0, 0], # M
            [0, 1, 0, 0, 0, 0, 0], # P
            [0, 0, 1, 0, 0, 0, 0], # O
            [1, 1, 1, 1, 1, 1, 1], # A
            [1, 1, 1, 1, 1, 1, 1], # K
            [1, 1, 1, 1, 1, 1, 1], # D
            [1, 1, 1, 1, 1, 1, 1]  # H
        ]

        if self.use_full:
            adj = full
        # if self.use_offline:
        #     adj = adj_offline

        adj = np.array(adj)[None, None, :, :]
        adj = CUDA(torch.Tensor(adj)).repeat(batch_size, 1, 1, 1)
        adj_anti = np.array(adj_anti)[None, None, :, :]
        adj_anti = CUDA(torch.Tensor(adj_anti)).repeat(batch_size, 1, 1, 1)
        label_anti = torch.zeros_like(state, device=self.device)

        return x, adj

    def build_node_and_edge_crash_large(self, data):
        """
            Note that the order of GNN is [A, S], which is different from the order [S, A] in data.
        """
        # create the node matrix.
        batch_size = data.shape[0]
        x = torch.zeros((batch_size, self.node_num, self.node_dim), device=torch.device(self.device))


        # build the nodes of action
        action = data[:, sum(self.state_dim_list):]
        # print('action: ', action.shape)
        
        start_ = 0
        for a_i in range(len(self.action_dim_list)):
            end_ = self.action_dim_list[a_i] + start_
            x[:, a_i, 0:end_-start_] = action[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # build the nodes of state
        state = data[:, 0:sum(self.state_dim_list)]
        start_ = 0
        for s_i in range(len(self.state_dim_list)):
            end_ = self.state_dim_list[s_i] + start_
            x[:, s_i+len(self.action_dim_list), 0:end_-start_] = state[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # using GT causal graph
        # nodes - [A_ped, A_o1, A_o2, A_o3, S_ped, S_o1, S_o2, S_o3, S_ego, C]
        adj = [
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
            [1, 0, 0, 0, 1, 0, 0, 0, 0, 0], # S_ped
            [0, 1, 0, 0, 1, 1, 0, 0, 0, 0], # S_o1
            [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_o2
            [0, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_o3
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 0], # S_ego
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 1], # C
        ]

        # for non-causal cases, the output delta_state should be 0
        adj_offline = [
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
            [1, 0, 0, 0, 1, 1, 0, 0, 1, 0], # S_ped
            [0, 1, 0, 0, 1, 1, 0, 1, 1, 0], # S_o1
            [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_o2
            [0, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_o3
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 0], # S_ego
            [0, 0, 0, 0, 0, 1, 0, 0, 0, 1], # C
        ]

        # full graph
        full = [
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ped
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o1
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o2
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o3
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ego
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # C
        ]

        # anti-causal graph
        adj_anti = [
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ped
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o1
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o2
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o3
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ego
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # C
        ]

        if self.use_full:
            adj = full
        # if self.use_offline:
        #     adj = adj_offline

        adj = np.array(adj)[None, None, :, :]
        adj = CUDA(torch.Tensor(adj)).repeat(batch_size, 1, 1, 1)
        adj_anti = np.array(adj_anti)[None, None, :, :]
        adj_anti = CUDA(torch.Tensor(adj_anti)).repeat(batch_size, 1, 1, 1)
        label_anti = torch.zeros_like(state, device=self.device)

        return x, adj
    
    def build_node_and_edge_lift(self, data):
        """
            Note that the order of GNN is [A, S], which is different from the order [S, A] in data.
        """
        # create the node matrix.
        batch_size = data.shape[0]
        x = torch.zeros((batch_size, self.node_num, self.node_dim), device=torch.device(self.device))


        # build the nodes of action
        action = data[:, sum(self.state_dim_list):]
        # print('action: ', action.shape)
        
        start_ = 0
        for a_i in range(len(self.action_dim_list)):
            end_ = self.action_dim_list[a_i] + start_
            x[:, a_i, 0:end_-start_] = action[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # build the nodes of state
        state = data[:, 0:sum(self.state_dim_list)]
        start_ = 0
        for s_i in range(len(self.state_dim_list)):
            end_ = self.state_dim_list[s_i] + start_
            x[:, s_i+len(self.action_dim_list), 0:end_-start_] = state[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # using GT causal graph
        # nodes - [A_ped, A_o1, A_o2, A_o3, S_ped, S_o1, S_o2, S_o3, S_ego, C]
        # adj = [
        #     [1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0], # A_x
        #     [0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], # A_y
        #     [0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], # A_z
        #     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], # A_gripper
        #     [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0], # S_eef_pos_x
        #     [0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_eef_pos_y
        #     [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_eef_pos_z
        #     [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # S_eef_quat (4)
        #     [0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0], # S_gripper_qpos (6)
        #     [0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], # S_gripper_qvel (6)
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], # S_cube_pos_x
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_cube_pos_y
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_cube_pos_z
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # S_cube_quat (4)
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], # S_gripper_to_cube_pos_x
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], # S_gripper_to_cube_pos_y
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], # S_gripper_to_cube_pos_z
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # S_cube_color (3)
        #     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1], # S_contact
        # ]
        adj = [
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_x
            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_y
            [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_z
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_gripper
            [1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # S_eef_pos_x
            [0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # S_eef_pos_y
            [0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # S_eef_pos_z
            [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # s_eef_quat (4)
            [1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # S_gripper_qpos (6)
            [1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0 ,0, 0, 0, 0, 0], # S_gripper_qvel (6)
            [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], # s_cube_pos_x
            [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1], # s_cube_pos_y
            [0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], # s_cube_pos_z
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # s_cube_quat (4)
            [1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0], # S_gripper_to_cube_pos_x
            [0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_gripper_to_cube_pos_y
            [0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_gripper_to_cube_pos_z
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # S_cube_color (3)
            [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1], # S_contact
        ]
        # for non-causal cases, the output delta_state should be 0
        # adj_offline = [
        #     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
        #     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
        #     [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
        #     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
        #     [1, 0, 0, 0, 1, 1, 0, 0, 1, 0], # S_ped
        #     [0, 1, 0, 0, 1, 1, 0, 1, 1, 0], # S_o1
        #     [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # S_o2
        #     [0, 0, 0, 1, 0, 0, 0, 1, 0, 0], # S_o3
        #     [0, 0, 0, 0, 1, 1, 0, 0, 1, 0], # S_ego
        #     [0, 0, 0, 0, 0, 1, 0, 0, 0, 1], # C
        # ]

        # # full graph
        # full = [
        #     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
        #     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
        #     [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
        #     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ped
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o1
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o2
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o3
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ego
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # C
        # ]

        # # anti-causal graph
        # adj_anti = [
        #     [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # A_ped
        #     [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # A_o1
        #     [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # A_o2
        #     [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # A_o3
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ped
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o1
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o2
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_o3
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # S_ego
        #     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], # C
        # ]

        # if self.use_full:
        #     adj = full
        # if self.use_offline:
        #     adj = adj_offline

        adj = np.array(adj)[None, None, :, :]
        adj = CUDA(torch.Tensor(adj)).repeat(batch_size, 1, 1, 1)
        
        return x, adj
    
    def build_node_and_edge_crash_small(self, data):
        """
            Note that the order of GNN is [A, S], which is different from the order [S, A] in data.
        """
        # create the node matrix.
        batch_size = data.shape[0]
        x = torch.zeros((batch_size, self.node_num, self.node_dim), device=torch.device(self.device))

        # build the nodes of action
        action = data[:, sum(self.state_dim_list):]
        start_ = 0
        for a_i in range(len(self.action_dim_list)):
            end_ = self.action_dim_list[a_i] + start_
            x[:, a_i, 0:end_-start_] = action[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # build the nodes of state
        state = data[:, 0:sum(self.state_dim_list)]
        start_ = 0
        for s_i in range(len(self.state_dim_list)):
            end_ = self.state_dim_list[s_i] + start_
            x[:, s_i+len(self.action_dim_list), 0:end_-start_] = state[:, start_:end_] # pad 0 for remaining places
            start_ = end_

        # using GT causal graph
        # nodes - [A_ped, A_o1, S_ped, S_o1, S_ego, C]
        adj = [
            [1, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0], # A_o1
            [1, 0, 1, 0, 0, 0], # S_ped
            [0, 1, 0, 1, 0, 0], # S_o1
            [0, 0, 1, 1, 1, 0], # S_ego
            [0, 0, 1, 1, 1, 1], # C
        ]

        # for non-causal cases, the output delta_state should be 0
        adj_anti = [
            [1, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0], # A_o1
            [0, 0, 1, 0, 0, 0], # S_ped
            [0, 0, 0, 1, 0, 0], # S_o1
            [0, 0, 0, 0, 1, 0], # S_ego
            [0, 0, 0, 0, 0, 1], # C
        ]

        # full graph
        full = [
            [1, 0, 0, 0, 0, 0], # A_ped
            [0, 1, 0, 0, 0, 0], # A_o1
            [1, 1, 1, 1, 1, 1], # S_ped
            [1, 1, 1, 1, 1, 1], # S_o1
            [1, 1, 1, 1, 1, 1], # S_ego
            [1, 1, 1, 1, 1, 1], # C
        ]

        adj = adj

        adj = np.array(adj)[None, None, :, :]
        adj = CUDA(torch.Tensor(adj)).repeat(batch_size, 1, 1, 1)
        adj_anti = np.array(adj_anti)[None, None, :, :]
        adj_anti = CUDA(torch.Tensor(adj_anti)).repeat(batch_size, 1, 1, 1)
        label_anti = torch.zeros_like(state, device=self.device)

        return x, adj

    def organize_nodes_unlock_crash(self, x):
        # x - [B, node_num, node_dim], the nodes of next_state are in the end
        delta_state_node = x[:, -len(self.state_dim_list):, :]
        delta_state = []
        for s_i in range(len(self.state_dim_list)):
            state_i = delta_state_node[:, s_i, 0:self.state_dim_list[s_i]] 
            delta_state.append(state_i)

        delta_state = torch.cat(delta_state, dim=1)
        return delta_state

    def data_process(self, data, max_buffer_size):
        x = data[0][None]
        label = data[1][None]
        # reward = data[2][None]
        self.buffer_length += 1
                
        # add new data point to data buffer
        if self.data is None:
            self.data = CUDA(torch.from_numpy(x.astype(np.float32)))
            self.label = CUDA(torch.from_numpy(label.astype(np.float32)))
            # self.reward = CUDA(torch.from_numpy(reward.astype(np.float32)))
            
        else:
            if self.data.shape[0] < max_buffer_size:
                self.data = torch.cat((self.data, CUDA(torch.from_numpy(x.astype(np.float32)))), dim=0)
                self.label = torch.cat((self.label, CUDA(torch.from_numpy(label.astype(np.float32)))), dim=0)
                # self.reward = torch.cat((self.reward, CUDA(torch.from_numpy(reward.astype(np.float32)))), dim=0)
            else:
                # replace the old buffer
                #index = self.buffer_length % max_buffer_size # sequentially replace buffer
                index = np.random.randint(0, max_buffer_size) # randomly replace buffer
                self.data[index] = CUDA(torch.from_numpy(x.astype(np.float32)))
                self.label[index] = CUDA(torch.from_numpy(label.astype(np.float32)))
                
        # print(self.data.shape, self.label.shape)

    def split_train_validation(self):
        num_data = len(self.data)

        # use validation
        if self.validation_flag:
            indices = list(range(num_data))
            split = int(np.floor(self.validation_ratio * num_data))
            np.random.shuffle(indices)
            train_idx, test_idx = indices[split:], indices[:split]

            train_set = [[self.data[idx], self.label[idx]] for idx in train_idx]
            test_set = [[self.data[idx], self.label[idx]] for idx in test_idx]

            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=self.batch_size)
            test_loader = torch.utils.data.DataLoader(test_set, shuffle=True, batch_size=self.batch_size)
        else:
            train_set = [[self.data[idx], self.label[idx]] for idx in range(num_data)]
            train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=self.batch_size)
            test_loader = None
        return train_loader, test_loader

    def fit(self):
        self.model.train()
        train_loader, test_loader = self.split_train_validation()

        self.best_test_loss = np.inf
        for epoch in range(self.n_epochs):
            for datas, labels in train_loader:
                self.optimizer.zero_grad()
                if self.use_mlp or self.use_ticsa:
                    delta = self.model(datas)
                    loss = self.criterion(delta, labels)
                else:
                    x, adj = self.build_node_and_edge(datas)
                    x = self.model(x, adj)
                    delta = self.organize_nodes(x)
                    loss = self.criterion(delta, labels)
                loss.backward()
                self.optimizer.step()
                
            if self.validation_flag and (epoch+1) % self.validate_freq == 0:
                with torch.no_grad():
                    loss_test = self.validate_model(test_loader)
                if loss_test < self.best_test_loss:
                    self.best_test_loss = loss_test
                    self.best_model = deepcopy(self.model.state_dict())
        
        # load the best model if we use validation
        if self.validation_flag:
            self.model.load_state_dict(self.best_model)
        return self.best_test_loss


    def fit_offline(self, dataloader):
        self.model.train()
        # train_loader, test_loader = self.split_train_validation()

        self.best_test_loss = np.inf
        for epoch in range(self.n_epochs):
            loss_train = 0.
            for datas in dataloader: 
                actions = CUDA(datas['actions'])
                states = CUDA(datas['states'])
                rewards = CUDA(datas['rewards'])
                # print(actions.shape, states.shape)
                datas_1 = torch.cat([actions[:, -2], states[:, -2]], dim=1)
                datas_2 = torch.cat([actions[:, -3], states[:, -3]], dim=1)
                labels_1 = states[:, -1] - states[:, -2]
                labels_2 = states[:, -2] - states[:, -3]
                datas = torch.cat([datas_1, datas_2])
                labels = torch.cat([labels_1, labels_2])
                
                # print(datas.shape, labels.shape)
                self.optimizer.zero_grad()
                
                if self.use_mlp or self.use_ticsa:
                    delta = self.model(datas)
                    loss = self.criterion(delta, labels)
                else:
                    x, adj = self.build_node_and_edge(datas)
                    x = self.model(x, adj)
                    delta = self.organize_nodes(x)
                    loss = self.criterion(delta, labels)
                loss.backward()
                self.optimizer.step()
                loss_train += loss.item()
            loss_train /= len(dataloader)
            if self.validation_flag and (epoch+1) % self.validate_freq == 0:
                # with torch.no_grad():
                #     loss_test = self.validate_model(dataloader)
                if loss_train < self.best_test_loss:
                    self.best_test_loss = loss_train
                    # self.best_model = deepcopy(self.model.state_dict())
        
        # load the best model if we use validation
        # if self.validation_flag:
        #     self.model.load_state_dict(self.best_model)
        return self.best_test_loss
    
    def validate_model(self, testloader):
        self.model.eval()
        loss_list = []
        for datas, labels in testloader:
            if self.use_mlp:
                delta = self.model(datas)
                loss = self.criterion(delta, labels)
            else:
                x, adj = self.build_node_and_edge(datas)
                x = self.model(x, adj)
                delta = self.organize_nodes(x)
                loss = self.criterion(delta, labels)

            loss_list.append(loss.item())
        self.model.train()
        return np.mean(loss_list)

    def predict(self, s, a):
        self.model.eval()
        # convert to torch format
        if isinstance(s, np.ndarray):
            s = CUDA(torch.from_numpy(s.astype(np.float32)))
        if isinstance(a, np.ndarray):
            a = CUDA(torch.from_numpy(a.astype(np.float32)))

        inputs = torch.cat((s, a), axis=1)

        with torch.no_grad():
            if self.use_mlp:
                delta = self.model(inputs)
            else:
                x, adj = self.build_node_and_edge(inputs)
                x = self.model(x, adj)
                delta = self.organize_nodes(x)

            delta = delta.cpu().detach().numpy()
        return delta

    def save_model(self, model_path, model_id):
        states = {'model': self.model.state_dict()}
        filepath = os.path.join(model_path, 'grade.'+str(model_id)+'.torch')
        with open(filepath, 'wb') as f:
            torch.save(states, f)

    def load_model(self, model_path, model_id):
        filepath = os.path.join(model_path, 'grade.'+str(model_id)+'.torch')
        if os.path.isfile(filepath):
            with open(filepath, 'rb') as f:
                checkpoint = torch.load(f)
            self.model.load_state_dict(checkpoint['model'])
        else:
            raise Exception('No GRADE model found!')


class Planner(object):
    def __init__(self, args):
        self.pretrain_buffer_size = args['pretrain_buffer_size']
        self.max_buffer_size = args['max_buffer_size']
        self.epsilon = args['epsilon']
        self.goal_dim = args['env_params']['goal_dim']
        self.num_envs = args['num_envs']
        
        args['mpc']['env_params'] = args['env_params']
        args['mpc']['num_envs'] = args['num_envs']
        # print('env name: ', args['env_params'])
        
        if args['env_params']['env_name'] == 'unlock': 
            self.mpc_controller = MPC_Unlock(args['mpc'])
            self.mpc_controller_parallel = MPC_Unlock_Parallel(args['mpc'])
        elif args['env_params']['env_name'] == 'crash':
            self.mpc_controller = MPC_crash(args['mpc'])
            self.mpc_controller_parallel = MPC_crash_Parallel(args['mpc'])
        elif args['env_params']['env_name'] == 'lift':
            self.mpc_controller = MPC_Lift(args['mpc'])
            self.mpc_controller_parallel = MPC_Lift_Parallel(args['mpc'])
            # self.mpc_controller_parallel = MPC_Lift_Parallel(args['mpc'])
        else: 
            raise Exception('No MPC model found!')
            
        self.mpc_controller.reset()

        self.model = WorldModel(args)

    def select_action(self, env, state, deterministic):
        # print('select action inputs: ', state.shape)
        if self.model.data is None or self.model.data.shape[0] < self.pretrain_buffer_size:
            action = env.random_action()
        else:
            if np.random.uniform(0, 1) > self.epsilon or deterministic:
                action = self.mpc_controller.act(model=self.model, state=state)
            else:
                action = env.random_action()
            
            # if deterministic:
            #     action = self.mpc_controller.act(model=self.model, state=state)            
            # else:
            #     action = self.mpc_controller.act(model=self.model, state=state)
            #     action_random = env.random_action()
            #     mask = np.random.uniform(0, 1, size=(self.num_envs, )) > self.epsilon
            #     # print('mask: ', action.shape, action_random.shape, mask.shape)
            #     action = mask * action + (1-mask) * action_random
        return action
    
    def select_action_parallel(self, env, state, deterministic):
        # print('select action inputs: ', state.shape)
        if self.model.data is None or self.model.data.shape[0] < self.pretrain_buffer_size:
            action = env.random_action()
        else:
            if deterministic:
                action = self.mpc_controller_parallel.act(model=self.model, state=state)            
            else:
                action = self.mpc_controller_parallel.act(model=self.model, state=state)
                action_random = env.random_action()
                mask = np.random.uniform(0, 1, size=(self.num_envs, )) > self.epsilon
                # print('mask: ', action.shape, action_random.shape, mask.shape)
                action = mask * action + (1-mask) * action_random
        return action

    def store_transition(self, data):
        # [state, action, next_state]
        # we should remove the goal infomation from x and label
        # pure_state = data[0][:, :len(data[0][0])-self.goal_dim]
        # action = data[1]
        # pure_next_state = data[2][:, :len(data[0][0])-self.goal_dim]
        pure_state = data[0][:len(data[0])-self.goal_dim]
        action = data[1]
        pure_next_state = data[2][:len(data[0])-self.goal_dim]
        # print("store trans: ", pure_state.shape, pure_next_state.shape, action.shape)
        x = np.concatenate([pure_state, action])
        label = pure_next_state - pure_state 
        self.model.data_process([x, label], self.max_buffer_size)
    
    def train(self):
        # when data has been collected enough, train model
        if self.model.data.shape[0] < self.pretrain_buffer_size:
            self.best_test_loss = 0
        else:
            self.best_test_loss = self.model.fit()
    
    def train_offline(self, dataloader):
        # when data has been collected enough, train model
        
        self.best_test_loss = self.model.fit_offline(dataloader)


    def set_causal_graph(self, causal_graph):
        self.model.causal_graph = causal_graph

    def save_model(self, model_path, model_id):
        self.model.save_model(model_path, model_id)

    def load_model(self, model_path, model_id):
        self.model.load_model(model_path, model_id)
