# coding=utf-8
from .dqn_trainer import DqnTrainer
import numpy as np
import tensorflow as tf
from ...utils.memory import ReplayMemory
from .dqn_myp_pbrs_algo import DqnMypPbrsAlgo

GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
BATCH_SIZE = 1024

REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 100
FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 1000

# high frequency target soft update is better for DQN
TARGET_UPDATE_FREQ = 1


class DqnMypPbrsTrainer(DqnTrainer):
    def __init__(self, state_dim, action_num, algo_name="dqn_myp_pbrs"):
        super(DqnMypPbrsTrainer, self).__init__(state_dim, action_num, algo_name)

    def init_algo(self):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DqnMypPbrsAlgo(self.session, self.graph, self.state_dim, self.action_num,
                                        epsilon_decay="exponential", is_test=False, algo_name=self.algo_name)

        self.update_cnt = 0

        # the replay buffer
        self.memory = ReplayMemory(REPLAY_BUFFER_SIZE)

        # also create a tf file writer for writing other information
        self.writer_graph = tf.Graph()
        self.my_writer = self.algorithm.train_writer

    def experience(self, s, a_n, r_n, s_n, terminal, **kwargs):
        """
            get the potential of s and s_n
            and compute the shaping reward
        """
        c = kwargs.get("c")
        c_sp = kwargs.get("c_sp")
        self.memory.add((s, a_n, r_n, s_n, terminal, c, c_sp))

    def update(self, t):

        if len(self.memory.store) > FIRST_UPDATE_SAMPLE_NUM:
            # update frequency
            if not t % UPDATE_FREQ == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1

            # get mini batch from replay buffer
            sample = self.memory.get_minibatch(BATCH_SIZE)
            s_batch, a_batch, r_batch, sp_batch, done_batch, c_batch, c_sp_batch = [], [], [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])
                c_batch.append(sample[i][5])
                c_sp_batch.append(sample[i][6])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]),
                                 c_batch=np.array(c_batch).reshape([-1, 1]),
                                 c_sp_batch=np.array(c_sp_batch))

            # update target network
            if self.update_cnt % TARGET_UPDATE_FREQ == 0:
                self.algorithm.update_target_soft(tau=0.01)

            # save param
            self.save_params()
