import random
import numpy as np


class RecurrentMemory(object):
    def __init__(self, agent_num, state_dim, action_dim, episode_capacity, trace_length, trace_batch_size):

        self.agent_num = agent_num
        self.state_dim = state_dim
        self.action_dim = action_dim
        # how many episode samples can be stored in the memory pool
        self.D_episode_capacity = episode_capacity
        # the current episode sample position (for data inserting)
        self.D_epi_pos = 0
        # the real episode size of the memory pool
        self.D_epi_size = 0
        # trace length, which means we get how many consecutive
        # samples from one episode
        self.tau = trace_length
        # how many traces we fetch each time
        self.trace_batch_size = trace_batch_size
        # the data list
        self.D = [None] * episode_capacity
        # the current episode
        self.current_episode = []

    def mem_size(self):
        return self.D_epi_size

    def add(self, transition):
        self.current_episode.append(transition)
        # whether this is the terminal state
        terminal = transition[4]
        if terminal:
            traj_trans = [None] * len(self.current_episode)
            for index in range(len(self.current_episode)):
                sample = self.current_episode[index]
                traj_trans[index] = [np.array(x) if type(x) == list else x for x in sample]

            # put the data in the current position, this is an order insert
            self.D[self.D_epi_pos] = traj_trans
            # let the position increase
            self.D_epi_pos += 1
            # let the size increase
            if self.D_epi_size < len(self.D):
                self.D_epi_size += 1
            # FIFO order replacement
            if self.D_epi_pos >= len(self.D):
                self.D_epi_pos = 0
            # reset the current episode
            self.current_episode = []

    def get_minibatch(self):

        # first pick up the episode
        sample_trace_num = min(self.D_epi_size, self.trace_batch_size)

        # initialize the trace batch
        # trace[0]:
        # [s1, s2, ...,sn], [a1, a2, ..., an], [r1, r2, ..., rn], [s1', s2', ..., sn']
        # trace[1]
        # [s1, s2, ...,sn], [a1, a2, ..., an], [r1, r2, ..., rn], [s1', s2', ..., sn']
        # ...
        # trace[tau-1]
        # [s1, s2, ...,sn], [a1, a2, ..., an], [r1, r2, ..., rn], [s1', s2', ..., sn']
        state_batch = [None] * self.tau
        action_batch = [None] * self.tau
        reward_batch = [None] * self.tau
        next_state_batch = [None] * self.tau
        done_batch = [None] * self.tau
        for x in range(self.tau):
            state_batch[x] = [None] * sample_trace_num
            action_batch[x] = [None] * sample_trace_num
            reward_batch[x] = [None] * sample_trace_num
            next_state_batch[x] = [None] * sample_trace_num
            done_batch[x] = [None] * sample_trace_num

        for index in range(sample_trace_num):
            # randomly choose episode
            sample_epi_pos = int(random.randint(0, self.D_epi_size-1))
            sampled_episode = self.D[sample_epi_pos]

            # randomly choose starting pos
            # the sample interval is [-tau+1, epi_len]
            # which ensures that the final experience of each episode has the same probability of being sampled
            # as non-terminal samples
            epi_len = len(sampled_episode)
            start_pos = int(random.randint(1-self.tau, epi_len-1))

            # then organize the sample
            for t in range(self.tau):
                fetch_pos = start_pos + t
                if fetch_pos < 0 or fetch_pos >= epi_len:
                    state, reward, next_state, terminal = [0] * self.state_dim,\
                                                          0,\
                                                          [0] * self.state_dim,\
                                                          0
                    action_onehot_n = [None] * self.agent_num
                    for ag in range(self.agent_num):
                        action_onehot_n[ag] = [0] * self.action_dim
                else:
                    state, actions, reward, next_state, terminal = sampled_episode[fetch_pos][0],\
                                                                    sampled_episode[fetch_pos][1],\
                                                                    sampled_episode[fetch_pos][2],\
                                                                    sampled_episode[fetch_pos][3],\
                                                                    sampled_episode[fetch_pos][4]
                    action_onehot_n = [None] * self.agent_num
                    for ag in range(self.agent_num):
                        action_onehot_n[ag] = self.compute_one_hot(actions[ag])

                state_batch[t][index] = state
                action_batch[t][index] = action_onehot_n
                reward_batch[t][index] = [reward]
                next_state_batch[t][index] = next_state
                done_batch[t][index] = [terminal]

        return state_batch, action_batch, reward_batch, next_state_batch, done_batch

    def compute_one_hot(self, index):
        one_hot_code = [0] * self.action_dim
        one_hot_code[index] = 1.0
        return one_hot_code
