import numpy as np
import copy

from Because.mpc.optimizers import RandomOptimizer, RandomOptimizer_Parallel

class MPC_Unlock(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args["horizon"]
        self.gamma = mpc_args["gamma"]
        self.popsize = mpc_args["popsize"]
        self.num_envs = mpc_args["num_envs"]
        print(self.num_envs)
        
        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.state_dim = mpc_args['env_params']['state_dim']
        
        self.goal_dim = mpc_args['env_params']['goal_dim']
        self.room_size = mpc_args['env_params']['room_size']
        self.move_dim = mpc_args['env_params']['move_dim']
        self.pick_key_dim = mpc_args['env_params']['pick_key_dim']
        
        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        
        self.map_size = self.room_size * self.room_size
        
        self.optimizer = RandomOptimizer(action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state

        best_solution = self.optimizer.obtain_solution_unlock()
        # task the first step as our action
        action = best_solution[:, 0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[:, None], self.popsize, axis=1)
        return state

    def rescale_state(self, state):
        # rescale the state to avoid accumulated error
        agent_pos = state[:, 0:self.map_size]
        key_pos = state[:, self.map_size:self.map_size*2]
        door_pos = state[:, self.map_size*2:self.map_size*3]
        have_key = state[:, self.map_size*3:]
        x_range = range(0, state.shape[0])

        agent_pos_b = np.zeros_like(agent_pos)
        idx = np.argmax(agent_pos, axis=1)
        agent_pos_b[x_range, idx] = 1

        key_pos_b = np.zeros_like(key_pos)
        idx = np.argmax(key_pos, axis=1)
        key_pos_b[x_range, idx] = 1

        door_pos_b = np.zeros_like(door_pos)
        idx = np.argmax(door_pos, axis=1)
        door_pos_b[x_range, idx] = 1

        have_key_b = np.zeros_like(have_key)
        idx = np.argmax(have_key, axis=1)
        have_key_b[x_range, idx] = 1

        state = np.concatenate([agent_pos_b, key_pos_b, door_pos_b, have_key_b], axis=1)
        return state

    def cost_function(self, actions):
        # the observation need to be processed since we use a common model
        state = self.preprocess(self.state)

        assert actions.shape == (self.num_envs, self.popsize, self.horizon, self.action_dim)
        assert state.shape == (self.num_envs, self.popsize, self.state_dim)

        costs = np.zeros(self.num_envs*self.popsize)
        
        actions = actions.reshape(-1, self.horizon, self.action_dim)    # (n_envs*popsize, horizon, action_dim)
        state = state.reshape(-1, self.state_dim)                       # (n_envs*popsize, state_dim)
        assert state.shape[0] == actions.shape[0] == costs.shape[0]
        finish_flag = np.zeros((self.popsize,))
        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)
            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state
            #state_next = self.rescale_state(state_next)

            '''
            move = action[:, 0:self.move_dim]
            pick_key = action[:, self.move_dim:self.move_dim+self.pick_key_dim]
            open_door = action[:, self.move_dim+self.pick_key_dim:]
            move_mapping = {0: 'up', 1: 'down', 2: 'left', 3: 'right'}

            agent_pos = state[:, 0:self.map_size].reshape(-1, self.room_size, self.room_size)
            key_pos = state[:, self.map_size:self.map_size*2].reshape(-1, self.room_size, self.room_size)
            door_pos = state[:, self.map_size*2:self.map_size*3].reshape(-1, self.room_size, self.room_size)
            have_key = state[:, self.map_size*3:]

            agent_pos_2 = state_next[:, 0:self.map_size].reshape(-1, self.room_size, self.room_size)
            key_pos_2 = state_next[:, self.map_size:self.map_size*2].reshape(-1, self.room_size, self.room_size)
            door_pos_2 = state_next[:, self.map_size*2:self.map_size*3].reshape(-1, self.room_size, self.room_size)
            have_key_2 = state_next[:, self.map_size*3:]
            '''

            #print(have_key[0])
            #print(have_key_2[0])
            #print('----')


            #cost = self.unlock_objective_logic(state, action)
            #cost = self.unlock_objective_state(state_next)
            cost = self.unlock_objective_state_multidoor(state_next) + self.energy_coef*self.energy_oracle(state_next)
            state = copy.deepcopy(state_next)

            #costs += cost * (1 - finish_flag) - finish_flag
            costs += cost

            # if cost is 1, we will stop this trajectory
            #finish_flag = np.logical_or(finish_flag, -cost)
        
        costs = costs.reshape((self.num_envs, self.popsize))
        # print(costs)
        return costs

    def unlock_objective_state(self, state_next):
        agent_pos = state_next[:, 0:self.map_size]
        key_pos = state_next[:, self.map_size:self.map_size*2]
        door_pos = state_next[:, self.map_size*2:self.map_size*3]

        have_door = door_pos > 0.5
        have_door = np.sum(have_door, axis=1)
        final_cost = np.zeros_like(have_door)

        # there is no door in the scene
        final_cost[have_door == 0] = -1
        return final_cost

    def unlock_objective_state_multidoor(self, state_next):
        agent_pos = state_next[:, 0:self.map_size]
        key_pos = state_next[:, self.map_size:self.map_size*2]
        door_pos = state_next[:, self.map_size*2:self.map_size*3]

        num_of_door = door_pos > 0.5
        num_of_door = np.sum(num_of_door, axis=1)

        # more door more cost, since we may have more than 1 door
        final_cost = num_of_door
        return final_cost

    def unlock_objective_logic(self, state, action):
        agent_pos = state[:, 0:self.map_size].reshape(-1, self.room_size, self.room_size)
        #key_pos = state[:, self.map_size:self.map_size*2].reshape(-1, self.room_size, self.room_size)
        door_pos = state[:, self.map_size*2:self.map_size*3].reshape(-1, self.room_size, self.room_size)
        have_key = state[:, self.map_size*3:]

        #move = action[:, 0:self.move_dim]
        #pick_key = action[:, self.move_dim:self.move_dim+self.pick_key_dim]
        open_door = action[:, self.move_dim+self.pick_key_dim:]

        # finish conditions
        # 1. have key
        have_key = np.argmax(have_key, axis=1)

        # 2. the action is open_door
        open_door = np.argmax(open_door, axis=1)


        # 3. agent is near the door
        agent_xy = np.argmax(agent_pos.reshape(-1, self.map_size), axis=1)[:, None]
        agent_xy = np.concatenate([agent_xy % self.room_size, agent_xy // self.room_size], axis=1)
        door_xy = np.argmax(door_pos.reshape(-1, self.map_size), axis=1)[:, None]
        door_xy = np.concatenate([door_xy % self.room_size, door_xy // self.room_size], axis=1)

        dist = (agent_xy[:, 0] - door_xy[:, 0])**2 + (agent_xy[:, 1] - door_xy[:, 1])**2
        dist = dist == 1

        reach_goal = have_key * open_door * dist
        final_cost = -1.0 * reach_goal

        return final_cost


class MPC_Unlock_Parallel(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args["horizon"]
        self.gamma = mpc_args["gamma"]
        self.popsize = mpc_args["popsize"]
        self.num_envs = mpc_args["num_envs"]
        print(self.num_envs)
        
        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.state_dim = mpc_args['env_params']['state_dim']
        
        self.goal_dim = mpc_args['env_params']['goal_dim']
        self.room_size = mpc_args['env_params']['room_size']
        self.move_dim = mpc_args['env_params']['move_dim']
        self.pick_key_dim = mpc_args['env_params']['pick_key_dim']
        
        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        
        self.map_size = self.room_size * self.room_size
        
        self.optimizer = RandomOptimizer_Parallel(num_envs=self.num_envs, action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state

        best_solution = self.optimizer.obtain_solution_unlock()
        # task the first step as our action
        action = best_solution[:, 0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[:, None], self.popsize, axis=1)
        return state

    def rescale_state(self, state):
        # rescale the state to avoid accumulated error
        agent_pos = state[:, 0:self.map_size]
        key_pos = state[:, self.map_size:self.map_size*2]
        door_pos = state[:, self.map_size*2:self.map_size*3]
        have_key = state[:, self.map_size*3:]
        x_range = range(0, state.shape[0])

        agent_pos_b = np.zeros_like(agent_pos)
        idx = np.argmax(agent_pos, axis=1)
        agent_pos_b[x_range, idx] = 1

        key_pos_b = np.zeros_like(key_pos)
        idx = np.argmax(key_pos, axis=1)
        key_pos_b[x_range, idx] = 1

        door_pos_b = np.zeros_like(door_pos)
        idx = np.argmax(door_pos, axis=1)
        door_pos_b[x_range, idx] = 1

        have_key_b = np.zeros_like(have_key)
        idx = np.argmax(have_key, axis=1)
        have_key_b[x_range, idx] = 1

        state = np.concatenate([agent_pos_b, key_pos_b, door_pos_b, have_key_b], axis=1)
        return state

    def cost_function(self, actions):
        # the observation need to be processed since we use a common model
        state = self.preprocess(self.state)

        assert actions.shape == (self.num_envs, self.popsize, self.horizon, self.action_dim)
        assert state.shape == (self.num_envs, self.popsize, self.state_dim)

        costs = np.zeros(self.num_envs*self.popsize)
        
        actions = actions.reshape(-1, self.horizon, self.action_dim)    # (n_envs*popsize, horizon, action_dim)
        state = state.reshape(-1, self.state_dim)                       # (n_envs*popsize, state_dim)
        assert state.shape[0] == actions.shape[0] == costs.shape[0]
        finish_flag = np.zeros((self.popsize,))
        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)
            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state

            cost = self.unlock_objective_state_multidoor(state_next)
            state = copy.deepcopy(state_next)

            costs += cost
            
        costs = costs.reshape((self.num_envs, self.popsize))
        # print(costs)
        return costs

    def unlock_objective_state(self, state_next):
        agent_pos = state_next[:, 0:self.map_size]
        key_pos = state_next[:, self.map_size:self.map_size*2]
        door_pos = state_next[:, self.map_size*2:self.map_size*3]

        have_door = door_pos > 0.5
        have_door = np.sum(have_door, axis=1)
        final_cost = np.zeros_like(have_door)

        # there is no door in the scene
        final_cost[have_door == 0] = -1
        return final_cost

    def unlock_objective_state_multidoor(self, state_next):
        agent_pos = state_next[:, 0:self.map_size]
        key_pos = state_next[:, self.map_size:self.map_size*2]
        door_pos = state_next[:, self.map_size*2:self.map_size*3]

        num_of_door = door_pos > 0.5
        num_of_door = np.sum(num_of_door, axis=1)

        # more door more cost, since we may have more than 1 door
        final_cost = num_of_door
        return final_cost

    def unlock_objective_logic(self, state, action):
        agent_pos = state[:, 0:self.map_size].reshape(-1, self.room_size, self.room_size)
        #key_pos = state[:, self.map_size:self.map_size*2].reshape(-1, self.room_size, self.room_size)
        door_pos = state[:, self.map_size*2:self.map_size*3].reshape(-1, self.room_size, self.room_size)
        have_key = state[:, self.map_size*3:]

        #move = action[:, 0:self.move_dim]
        #pick_key = action[:, self.move_dim:self.move_dim+self.pick_key_dim]
        open_door = action[:, self.move_dim+self.pick_key_dim:]

        # finish conditions
        # 1. have key
        have_key = np.argmax(have_key, axis=1)

        # 2. the action is open_door
        open_door = np.argmax(open_door, axis=1)


        # 3. agent is near the door
        agent_xy = np.argmax(agent_pos.reshape(-1, self.map_size), axis=1)[:, None]
        agent_xy = np.concatenate([agent_xy % self.room_size, agent_xy // self.room_size], axis=1)
        door_xy = np.argmax(door_pos.reshape(-1, self.map_size), axis=1)[:, None]
        door_xy = np.concatenate([door_xy % self.room_size, door_xy // self.room_size], axis=1)

        dist = (agent_xy[:, 0] - door_xy[:, 0])**2 + (agent_xy[:, 1] - door_xy[:, 1])**2
        dist = dist == 1

        reach_goal = have_key * open_door * dist
        final_cost = -1.0 * reach_goal

        return final_cost


class MPC_crash(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args["horizon"]
        self.popsize = mpc_args["popsize"]

        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.goal_dim = mpc_args['env_params']['goal_dim']
        self.agent_state_dim = mpc_args['env_params']['agent_state_dim']
        self.agent_action_dim = mpc_args['env_params']['agent_action_dim']
        self.map_scale = mpc_args['env_params']['map_scale']
        self.collision_threshold = mpc_args['env_params']['collision_threshold']

        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        
        self.optimizer = RandomOptimizer(action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state

        best_solution = self.optimizer.obtain_solution_crash(self.action_dim)
        # task the first step as our action
        action = best_solution[0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[None], self.popsize, axis=0)
        return state

    def rescale_state(self, state):
        # since the collision node is one-hot, we should not accumulate it
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        ego_collide_pedestrain[ego_collide_pedestrain >= 0.5] = 1
        ego_collide_pedestrain[ego_collide_pedestrain < 0.5] = 0
        collide_others[collide_others >= 0.5] = 1
        collide_others[collide_others < 0.5] = 0

        state = np.concatenate([state[:, 0:-2], ego_collide_pedestrain[:, None], collide_others[:, None]], axis=1)
        return state

    def cost_function(self, actions):
        # the observation need to be processed since we use a common model
        state = self.preprocess(self.state)

        assert actions.shape == (self.popsize, self.horizon, self.action_dim)
        # print('action state: ', actions.shape, state.shape)
        costs = np.zeros(self.popsize)
        costs_list = []
        finish_flag_list = []

        gamma = 0.99
        finish_flag = np.zeros((self.popsize,))
        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)

            # rescale state
            #state = self.rescale_state(state)

            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state
            state = copy.deepcopy(state_next)

            cost, one_flag = self.crash_objective_state_2(state_next)
            cost = cost * (1 - finish_flag) * gamma ** t_i + self.energy_coef*self.energy_oracle(state_next)
            costs += cost

            finish_flag = np.logical_or(finish_flag, one_flag)
            costs_list.append(cost)
            finish_flag_list.append(finish_flag)
            
        return costs, costs_list, finish_flag_list

    def crash_objective_rule(self, state):
        ped = state[:, 0:self.agent_state_dim]
        o1 = state[:, self.agent_state_dim:self.agent_state_dim*2]
        ego = state[:, self.agent_state_dim*2:self.agent_state_dim*3]

        ped_x = ped[:, 0] * self.map_scale[0]
        ped_y = ped[:, 1] * self.map_scale[1]
        o1_x = o1[:, 0] * self.map_scale[0]
        o1_y = o1[:, 1] * self.map_scale[1]
        ego_x = ego[:, 0] * self.map_scale[0]
        ego_y = ego[:, 1] * self.map_scale[1]

        dist_ego_ped = ((ped_x - ego_x)**2 + (ped_y - ego_y)**2)**0.5
        dist_o1_ped = ((ped_x - o1_x)**2 + (ped_y - o1_y)**2)**0.5

        collision_o1_ped = dist_o1_ped < self.collision_threshold + 8
        collision_ego_ped = dist_ego_ped < self.collision_threshold + 8

        #final_cost = -1 * collision_ego_ped + 100 * collision_o1_ped
        #final_cost = ped_v_negative * 10 * np.exp(-t_i)
        #reward =  collision_ego_ped * (1 - collision_o1_ped)
        #penalty = 10 * collision_o1_ped
        #final_cost = penalty - reward
        #final_flag = np.logical_or(penalty > 0, reward > 0)

        final_cost = np.zeros_like(collision_o1_ped)
        final_cost[collision_ego_ped] = -10
        final_cost[collision_o1_ped] = 10

        final_flag = None
        return final_cost, final_flag

    def crash_objective_state(self, state):
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        final_cost = np.zeros_like(ego_collide_pedestrain) 
        flag_1 = ego_collide_pedestrain > 0.5
        flag_2 = collide_others > 0.5
        final_cost[flag_1] = -10
        final_cost[flag_2] = 10

        final_flag = np.logical_or(flag_1, flag_2)
        return final_cost, final_flag

    def crash_objective_state_2(self, state):
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        final_cost = np.zeros_like(ego_collide_pedestrain) 
        flag_1 = ego_collide_pedestrain > 0.5
        flag_2 = collide_others < 0.5
        final_ped_ego = np.logical_and(flag_1, flag_2)
        #final_block_ped = collide_others > 0.5

        final_cost[final_ped_ego] = -10
        #final_cost[final_block_ped] = 10

        #final_flag = np.logical_or(final_ped_ego, final_block_ped)
        return final_cost, final_ped_ego

    def crash_objective_softmax(self, state):
        collision = state[:, -3:]
        collision_idx = np.argmax(collision, axis=1)

        #no_collision = collision_idx == 0
        ego_collide_pedestrain = collision_idx == 1
        collide_others = collision_idx == 2

        final_cost = np.zeros_like(collision_idx) 
        final_cost[ego_collide_pedestrain] = -10
        final_cost[collide_others] = 10

        final_flag = np.logical_or(ego_collide_pedestrain, collide_others)
        return final_cost, final_flag


class MPC_crash_Parallel(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args["horizon"]
        self.popsize = mpc_args["popsize"]
        self.num_envs = mpc_args["num_envs"]
        

        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.state_dim = mpc_args['env_params']['state_dim']
        
        self.goal_dim = mpc_args['env_params']['goal_dim']
        self.agent_state_dim = mpc_args['env_params']['agent_state_dim']
        self.agent_action_dim = mpc_args['env_params']['agent_action_dim']
        self.map_scale = mpc_args['env_params']['map_scale']
        self.collision_threshold = mpc_args['env_params']['collision_threshold']

        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        
        self.optimizer = RandomOptimizer_Parallel(num_envs=self.num_envs, action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state

        best_solution = self.optimizer.obtain_solution_crash(self.action_dim)
        # task the first step as our action
        action = best_solution[:, 0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[:, None], self.popsize, axis=1)        
        return state

    def rescale_state(self, state):
        # since the collision node is one-hot, we should not accumulate it
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        ego_collide_pedestrain[ego_collide_pedestrain >= 0.5] = 1
        ego_collide_pedestrain[ego_collide_pedestrain < 0.5] = 0
        collide_others[collide_others >= 0.5] = 1
        collide_others[collide_others < 0.5] = 0

        state = np.concatenate([state[:, 0:-2], ego_collide_pedestrain[:, None], collide_others[:, None]], axis=1)
        return state

    def cost_function(self, actions):
        
        # the observation need to be processed since we use a common model
        state = self.preprocess(self.state)
        
        assert actions.shape == (self.num_envs, self.popsize, self.horizon, self.action_dim)
        assert state.shape == (self.num_envs, self.popsize, self.state_dim)
        costs = np.zeros(self.num_envs*self.popsize)
        costs_list = []
        finish_flag_list = []

        gamma = 0.99
        finish_flag = np.zeros((self.num_envs*self.popsize,))
        actions = actions.reshape(-1, self.horizon, self.action_dim)    # (n_envs*popsize, horizon, action_dim)
        state = state.reshape(-1, self.state_dim)                       # (n_envs*popsize, state_dim)
        
        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)

            # rescale state
            #state = self.rescale_state(state)

            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state
            state = copy.deepcopy(state_next)

            cost, one_flag = self.crash_objective_state_2(state_next)
            cost = cost * (1 - finish_flag) * gamma ** t_i
            costs += cost
            
            finish_flag = np.logical_or(finish_flag, one_flag)
            costs_list.append(cost)
            finish_flag_list.append(finish_flag)
        
        costs = costs.reshape(self.num_envs, self.popsize)
                
        return costs, costs_list, finish_flag_list

    def crash_objective_rule(self, state):
        ped = state[:, 0:self.agent_state_dim]
        o1 = state[:, self.agent_state_dim:self.agent_state_dim*2]
        ego = state[:, self.agent_state_dim*2:self.agent_state_dim*3]

        ped_x = ped[:, 0] * self.map_scale[0]
        ped_y = ped[:, 1] * self.map_scale[1]
        o1_x = o1[:, 0] * self.map_scale[0]
        o1_y = o1[:, 1] * self.map_scale[1]
        ego_x = ego[:, 0] * self.map_scale[0]
        ego_y = ego[:, 1] * self.map_scale[1]

        dist_ego_ped = ((ped_x - ego_x)**2 + (ped_y - ego_y)**2)**0.5
        dist_o1_ped = ((ped_x - o1_x)**2 + (ped_y - o1_y)**2)**0.5

        collision_o1_ped = dist_o1_ped < self.collision_threshold + 8
        collision_ego_ped = dist_ego_ped < self.collision_threshold + 8

        #final_cost = -1 * collision_ego_ped + 100 * collision_o1_ped
        #final_cost = ped_v_negative * 10 * np.exp(-t_i)
        #reward =  collision_ego_ped * (1 - collision_o1_ped)
        #penalty = 10 * collision_o1_ped
        #final_cost = penalty - reward
        #final_flag = np.logical_or(penalty > 0, reward > 0)

        final_cost = np.zeros_like(collision_o1_ped)
        final_cost[collision_ego_ped] = -10
        final_cost[collision_o1_ped] = 10

        final_flag = None
        return final_cost, final_flag

    def crash_objective_state(self, state):
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        final_cost = np.zeros_like(ego_collide_pedestrain) 
        flag_1 = ego_collide_pedestrain > 0.5
        flag_2 = collide_others > 0.5
        final_cost[flag_1] = -10
        final_cost[flag_2] = 10

        final_flag = np.logical_or(flag_1, flag_2)
        return final_cost, final_flag

    def crash_objective_state_2(self, state):
        ego_collide_pedestrain = state[:, -2]
        collide_others = state[:, -1]

        final_cost = np.zeros_like(ego_collide_pedestrain) 
        flag_1 = ego_collide_pedestrain > 0.5
        flag_2 = collide_others < 0.5
        final_ped_ego = np.logical_and(flag_1, flag_2)
        #final_block_ped = collide_others > 0.5

        final_cost[final_ped_ego] = -10
        #final_cost[final_block_ped] = 10

        #final_flag = np.logical_or(final_ped_ego, final_block_ped)
        return final_cost, final_ped_ego

    def crash_objective_softmax(self, state):
        collision = state[:, -3:]
        collision_idx = np.argmax(collision, axis=1)

        #no_collision = collision_idx == 0
        ego_collide_pedestrain = collision_idx == 1
        collide_others = collision_idx == 2

        final_cost = np.zeros_like(collision_idx) 
        final_cost[ego_collide_pedestrain] = -10
        final_cost[collide_others] = 10

        final_flag = np.logical_or(ego_collide_pedestrain, collide_others)
        return final_cost, final_flag


class MPC_Lift(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args['horizon']
        self.gamma = mpc_args['gamma']
        self.popsize = mpc_args['popsize']

        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.state_dim = mpc_args['env_params']['state_dim']
        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        
        self.optimizer = RandomOptimizer(action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state

        best_solution = self.optimizer.obtain_solution_lift(low=-1., high=1.)  # Assuming the action range is [-1, 1]
        # take the first step as our action
        action = best_solution[0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[None], self.popsize, axis=0)
        return state

    def cost_function(self, actions):
        # the observation needs to be processed since we use a common model
        state = self.preprocess(self.state)
        assert actions.shape == (self.popsize, self.horizon, self.action_dim)
        costs = np.zeros(self.popsize)
        # print(state.shape, costs.shape, actions.shape)

        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)
            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state
            # print(state_next.shape, state.shape, action.shape)
            cost = self.lift_objective(state_next) + self.energy_coef*self.energy_oracle(state_next) # compute cost
            costs += cost
            state = copy.deepcopy(state_next)

        return costs

    def lift_objective(self, state):
        # Define the objective for the Lift environment. 
        # For simplicity, we'll use the Euclidean distance to the goal as the cost.
        # mse = np.sum((state - self.goal) ** 2, axis=1) ** 0.5
        cost = -np.array(np.max(state[:, 21]-0.82, 0), dtype=np.float32)
        return cost


class MPC_Lift_Parallel(object):
    def __init__(self, mpc_args):
        self.type = mpc_args['type']
        self.horizon = mpc_args['horizon']
        self.gamma = mpc_args['gamma']
        self.popsize = mpc_args['popsize']
        self.num_envs = mpc_args["num_envs"]
        self.energy_oracle = mpc_args['oracle']
        self.energy_coef = mpc_args['energy_coef']
        # parameters from the environment
        self.action_dim = mpc_args['env_params']['action_dim']
        self.state_dim = mpc_args['env_params']['state_dim']
        self.optimizer = RandomOptimizer_Parallel(num_envs=self.num_envs, action_dim=self.action_dim, horizon=self.horizon, popsize=self.popsize)
        self.optimizer.setup(self.cost_function)
        self.reset()

    def reset(self):
        self.optimizer.reset()

    def act(self, model, state):
        self.model = model
        self.state = state
        best_solution = self.optimizer.obtain_solution_lift(low=-1., high=1.)  # Assuming the action range is [-1, 1]
        # take the first step as our action
        action = best_solution[:, 0]
        return action

    def preprocess(self, state):
        state = np.repeat(self.state[:, None], self.popsize, axis=1)        
        return state
    
    def cost_function(self, actions):
        
        # the observation need to be processed since we use a common model
        state = self.preprocess(self.state)
        # print(actions.shape, state.shape)
        
        assert actions.shape == (self.num_envs, self.popsize, self.horizon, self.action_dim)
        assert state.shape == (self.num_envs, self.popsize, self.state_dim)
        costs = np.zeros(self.num_envs*self.popsize)
        costs_list = []
        # finish_flag_list = []
        gamma = 0.99
        # finish_flag = np.zeros((self.num_envs*self.popsize,))
        actions = actions.reshape(-1, self.horizon, self.action_dim)    # (n_envs*popsize, horizon, action_dim)
        state = state.reshape(-1, self.state_dim)                       # (n_envs*popsize, state_dim)
        
        for t_i in range(self.horizon):
            action = actions[:, t_i, :]  # (batch_size, timestep, action dim)
            # rescale state
            #state = self.rescale_state(state)

            # the output of the prediction model is [state_next - state]
            state_next = self.model.predict(state, action) + state
            state = copy.deepcopy(state_next)

            cost = self.lift_objective(state_next) + self.energy_coef*self.energy_oracle(state_next)
                                
            # if (cost < 0).any(): 
            #     print(action[np.where(cost < 0)])
            cost = cost * gamma ** t_i
            costs += cost
            
            # finish_flag = np.logical_or(finish_flag, one_flag)
            costs_list.append(cost)
            # finish_flag_list.append(finish_flag)
        
        costs = costs.reshape(self.num_envs, self.popsize)
                
        return costs

    def lift_objective(self, state):
        # Define the objective for the Lift environment. 
        # For simplicity, we'll use the Euclidean distance to the goal as the cost.
        # mse = np.sum((state - self.goal) ** 2, axis=1) ** 0.5
        cost = -np.array(np.max(state[:, 21]-0.82, 0), dtype=np.float32)
        return cost
