#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
import time
from .BaseAgent import *
from .DQN_agent import *


class CategoricalDQNActor(DQNActor):
    def __init__(self, config):
        super().__init__(config)

    def _set_up(self):
        self.config.atoms = tensor(self.config.atoms)

    def compute_q(self, prediction):
        q_values = (prediction['prob'] * self.config.atoms).sum(-1)
        return to_np(q_values)


class CategoricalDQNAgent(DQNAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()
        config.atoms = np.linspace(config.categorical_v_min, config.categorical_v_max, config.categorical_n_atoms)

        self.replay = config.replay_fn()
        self.actor = CategoricalDQNActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        self.actor.set_network(self.network)

        self.total_steps = 0
        self.batch_indices = range_tensor(config.batch_size)
        self.atoms = tensor(config.atoms)
        self.delta_atom = (config.categorical_v_max - config.categorical_v_min) / float(config.categorical_n_atoms - 1)

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        prediction, feature = self.network(state)
        q = (prediction['prob'] * self.atoms).sum(-1)
        action = to_np(q.argmax(-1))
        self.config.state_normalizer.unset_read_only()
        return action, feature

    def TSNE(self):
        n_labels = self.config.action_dim
        feature_dict = {}
        state_dict = {}
        data_num = 512
        for i in range(n_labels):
            feature_dict[i] = np.ones((1, 512))
        for i in range(n_labels):
            state_dict[i] = torch.load('figure/'+self.config.game_file+str(i)+'.pt', map_location='cpu')
            for j in range(state_dict[i].shape[0]):
                # print('state.shape', state_dict[i].shape)
                if torch.cuda.is_available():
                    device = torch.device('cuda')
                    state = state_dict[i][j,:].unsqueeze(0).float().to(device) # [4,84,84] -> [1,4,84,84]
                else:
                    state = state_dict[i][j,:].unsqueeze(0).float() # [4,84,84] -> [1,4,84,84]
                action, feature_vectors = self.eval_step(state)
                # print(action)
                feature_dict[i] = np.vstack((feature_dict[i], feature_vectors.detach().cpu().numpy()))
                # feature_dict[int(action[0])] = np.vstack((feature_dict[int(action[0])], feature_vectors.detach().cpu().numpy()))
                # feature_dict[int(action)] = np.vstack((feature_dict[int(action)], feature_vectors.detach().cpu().numpy()))

        for i in range(n_labels):
            feature_dict[i] = feature_dict[i][1:,:]
            print(feature_dict[i].shape)

        # t-sne
        COLOR = ['red', 'blue', 'lime', 'yellow']

        def plot_embedding(result, label, title, filename):  # [1083,2] [1083]
            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            plt.figure()
            for i in range(data.shape[0]):  # 1083
                # plt.scatter(data[i, 0], data[i, 1], marker='o', color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                plt.scatter(data[i, 0], data[i, 1], marker='o', color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        def plot_embedding_3D(result, label, title, filename):
            x_min, x_max = np.min(result, 0), np.max(result, 0)
            data = (result - x_min) / (x_max - x_min)  # [0-1] scale
            fig = plt.figure()
            ax = Axes3D(fig)
            for i in range(data.shape[0]):
                # ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=plt.cm.Set1(label[i] / (1.0 * n_labels)))
                ax.scatter(data[i, 0], data[i, 1], data[i, 2], color=COLOR[int(labels[i])], alpha=0.5)
            plt.title(title)
            plt.savefig('figure/' + filename + '.png', bbox_inches='tight')
            plt.close()

        if not os.path.exists('figure'):
            os.makedirs('figure')

        labels = []
        data = np.ones((1, 512))
        samples = 200
        CLASSES = 2 # 4
        for i in range(CLASSES):  # only the first two labels
            NUMS_index = np.random.choice(data_num, samples, replace=False)
            data0 = feature_dict[i][NUMS_index, :]
            # data0 = np.array(feature_dict[i])[NUMS_index]
            # print(data0.shape)
            data = np.vstack((data, data0))
            labels.extend(np.ones(samples) * i)
        data = data[1:, :]
        print('Final data shape: ', data.shape)
        tsne2 = TSNE(n_components=2, init='pca', random_state=0)  # n_components: 64 -> 2；
        result2 = tsne2.fit_transform(data)
        tsne3 = TSNE(n_components=3, init='pca', random_state=0)  # n_components: 64 -> 2；
        result3 = tsne3.fit_transform(data)
        plot_embedding(result2, labels, 't-SNE on C51 Features', 'C51-tSNE-2D')
        plot_embedding_3D(result3, labels, 't-SNE on C51 Features', 'C51-tSNE-3D')

    def compute_loss(self, transitions):
        config = self.config
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)
        with torch.no_grad():
            prob_next, _ = self.target_network(next_states)
            prob_next = prob_next['prob']
            q_next = (prob_next * self.atoms).sum(-1)
            if config.double_q:
                prob_ = self.network(next_states)
                a_next = torch.argmax((prob_['prob'] * self.atoms).sum(-1), dim=-1)
            else:
                a_next = torch.argmax(q_next, dim=-1)
            prob_next = prob_next[self.batch_indices, a_next, :]

        rewards = tensor(transitions.reward).unsqueeze(-1)
        masks = tensor(transitions.mask).unsqueeze(-1)
        atoms_target = rewards + self.config.discount ** config.n_step * masks * self.atoms.view(1, -1)
        atoms_target.clamp_(self.config.categorical_v_min, self.config.categorical_v_max)
        atoms_target = atoms_target.unsqueeze(1)
        target_prob = (1 - (atoms_target - self.atoms.view(1, -1, 1)).abs() / self.delta_atom).clamp(0, 1) * \
                      prob_next.unsqueeze(1)
        target_prob = target_prob.sum(-1)

        log_prob, _ = self.network(states)
        log_prob = log_prob['log_prob']
        actions = tensor(transitions.action).long()
        log_prob = log_prob[self.batch_indices, actions, :]
        KL = (target_prob * target_prob.add(1e-5).log() - target_prob * log_prob).sum(-1)
        return KL

    def reduce_loss(self, loss):
        return loss.mean()
