# coding=utf-8
import random
import tensorflow as tf
from .dqn_algo import DqnAlgo

ALPHA = 1e-4 #0.001

GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
# BATCH_SIZE = 1024
# EPSILON = 0.1
EPSILON = 0.0001
EPSILON_MAX = 1.0
# EPSILON_MIN = 0.05
# EPSILON_MIN = 0.1
EPSILON_MIN = 0.1 #0.02
EPSILON_DECAY_EPISODE = 6000 #60000
ETA_MAX = 1.0 #0.2 #0.7
ETA_MIN = 0.9 #1e-4 #0.1 #0.9
ETA_DECAY_EPISODE = 10000 #60000

# best setting for 1-agent DQN
# epsilon decay: exponential, 1.0 to 0.02 (0.2 for many agents such as 6,7,8), 60000 episodes
# soft target update using ema, tau = 0.01
# learning rate alpha = 1e-3, adam optimizer

class DqnDpsAlgo(DqnAlgo):
    def __init__(self, sess, graph, state_dim, action_dim, epsilon_decay=None,
                 is_test=False, algo_name="dqn_dps"):
        super(DqnDpsAlgo, self).__init__(sess, graph, state_dim, action_dim,
                                         epsilon_decay, is_test, algo_name)

        self.eta = ETA_MAX
        self.eta_decay_factor = pow(ETA_MIN / ETA_MAX, 1.0 / ETA_DECAY_EPISODE)

    def init_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name="s_")
                self.action_phd = tf.placeholder(tf.int32, [None, ], name="a_")
                self.action_prime_phd = tf.placeholder(tf.int32, [None, ], name="ap_")
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim],
                                                      name="s_prime_")
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name="reward_")
                self.shp_reward_phd = tf.placeholder(tf.float32, [None, 1], name="shp_reward_")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name="done_")

                """
                    build Q-network and the shaping Q-network
                """
                self.netQ = self.build_q_net(self.state_phd)
                self.shp_netQ = self.build_shaping_q_net(self.state_phd)

                """
                    define Q(s,a), Q'(s,a), argmax Q(s,a), and argmax Q'(s,a)
                    Q' is the shaping Q-network
                """
                self.q_sa_op = tf.reduce_sum(self.netQ * tf.one_hot(self.action_phd, self.action_dim),
                                             axis=1, keepdims=True, name="q_sa_")
                self.shp_q_sa_op = tf.reduce_sum(self.shp_netQ * tf.one_hot(self.action_phd, self.action_dim),
                                                 axis=1, keepdims=True, name="shp_q_sa_")
                self.argmax_qsa = tf.argmax(self.netQ, axis=1, name="argmax_qsa_")
                self.argmax_shp_qsa = tf.argmax(self.shp_netQ, axis=1, name="argmax_shp_qsa_")

                """
                    build the target Q-network and the shaping target Q-network
                """
                q_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Qnet_")
                shp_q_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Shaping_Qnet_")
                ema = tf.train.ExponentialMovingAverage(decay=1 - TAU)

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update = [ema.apply(q_params)]
                shp_target_update = [ema.apply(shp_q_params)]

                with tf.control_dependencies(target_update):
                    self.netQ_target = self.build_q_net(self.state_prime_phd, reuse=True, custom_getter=ema_getter)
                    # self.q_max_sp_op = tf.reduce_max(self.netQ_target,
                    #                                  name="q_max_sp_" + str(self.agent_index),
                    #                                  axis=1,
                    #                                  keepdims=True)
                    # q_target = self.reward_phd + GAMMA * (1 - self.done_phd) * self.q_max_sp_op
                    self.q_spap = tf.reduce_sum(self.netQ_target * tf.one_hot(self.action_prime_phd, self.action_dim),
                                                     name="q_max_sp_",
                                                     axis=1,
                                                     keepdims=True)
                    q_target = self.reward_phd + GAMMA * (1 - self.done_phd) * self.q_spap
                    self.target_op = tf.stop_gradient(q_target)
                    self.td_error_op = self.target_op - self.q_sa_op

                    self.squared_error_op = tf.square(self.td_error_op)
                    self.loss = tf.reduce_mean(self.squared_error_op, name="loss_")

                    # define the optimizer
                    self.optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.loss,
                                                                            name="adam_optimizer_")

                with tf.control_dependencies(shp_target_update):
                    self.shp_netQ_target = self.build_shaping_q_net(self.state_prime_phd, reuse=True,
                                                                    custom_getter=ema_getter)
                    # self.shp_q_max_sp_op = tf.reduce_max(self.shp_netQ_target,
                    #                                      name="shp_q_max_sp_" + str(self.agent_index),
                    #                                      axis=1, keepdims=True)
                    # shp_q_target = self.reward_phd + self.shp_reward_phd + \
                    #                GAMMA * (1 - self.done_phd) * self.shp_q_max_sp_op
                    self.shp_q_spap = tf.reduce_sum(self.shp_netQ_target *
                                                    tf.one_hot(self.action_prime_phd, self.action_dim),
                                                    name="shp_q_max_sp_",
                                                    axis=1, keepdims=True
                                                    )
                    shp_q_target = self.reward_phd + self.shp_reward_phd + \
                                   GAMMA * (1 - self.done_phd) * self.shp_q_spap
                    self.shp_target_op = tf.stop_gradient(shp_q_target)
                    self.shp_td_error_op = self.shp_target_op - self.shp_q_sa_op
                    self.shp_squared_error_op = tf.square(self.shp_td_error_op)
                    self.shp_loss = tf.reduce_mean(self.shp_squared_error_op, name="shp_loss_")
                    self.shp_optimizer = tf.train.AdamOptimizer(ALPHA).minimize(self.shp_loss,
                                                                                name="shp_adam_optimizer_")

                with tf.name_scope("q_loss"):
                    tf.summary.scalar("q_loss_", self.loss)
                    tf.summary.scalar("shaping_q_loss_", self.shp_loss)

                self.merged = tf.summary.merge_all()
                self.saver = tf.train.Saver(max_to_keep=100)

    def choose_action(self, s, test_model):
        if not test_model and random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            if random.uniform(0, 1.0) < self.eta:
                max_act = self.sess.run(self.argmax_shp_qsa, feed_dict={self.state_phd: [s]})
            else:
                max_act = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: [s]})
            # if not test_model:
            #     # print("Eta is {}".format(self.eta))
            #     if random.uniform(0, 1.0) < self.eta:
            #         max_act = self.sess.run(self.argmax_shp_qsa, feed_dict={self.state_phd: [s]})
            #     else:
            #         max_act = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: [s]})
            # else:
            #     max_act = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: [s]})

            return max_act[0]

    def learn(self, state_batch, action_batch, reward_batch,
              state_prime_batch, done_batch, **kwargs):
        if self.is_test:
            return

        """
            get the shaping reward batches of the current state
        """
        c_batch = kwargs.get("c_batch")

        with self.sess.graph.as_default():
            if random.uniform(0, 1.0) < self.eta:
                max_act_sp_batch = self.sess.run(self.argmax_shp_qsa, feed_dict={self.state_phd: state_prime_batch})
            else:
                max_act_sp_batch = self.sess.run(self.argmax_qsa, feed_dict={self.state_phd: state_prime_batch})

            _, _, summary = self.sess.run([self.optimizer, self.shp_optimizer, self.merged],
                                          feed_dict={self.state_phd: state_batch,
                                                     self.action_phd: action_batch,
                                                     self.state_prime_phd: state_prime_batch,
                                                     self.reward_phd: reward_batch,
                                                     self.done_phd: done_batch,
                                                     self.shp_reward_phd: c_batch,
                                                     self.action_prime_phd: max_act_sp_batch
                                                     })

        self.train_writer.add_summary(summary, self.update_cnt)
        self.update_cnt += 1

    def episode_done(self, test_model):
        if not test_model:
            if self.epsilon_decay == "exponential":
                self.epsilon = max(EPSILON_MIN, self.epsilon * self.epsilon_decay_factor)
            elif self.epsilon_decay == "linear":
                self.epsilon = max(EPSILON_MIN, self.epsilon - self.epsilon_inc)

            self.eta = max(ETA_MIN, self.eta * self.eta_decay_factor)
            self.episode_cnt += 1
            self.episode_reward = []

    def build_shaping_q_net(self, state_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Shaping_Qnet_', reuse=reuse, custom_getter=custom_getter):

            x = 1.0

            # the first hidden layer
            net = tf.layers.dense(state_phd, 64,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l1",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the second hidden layer
            net = tf.layers.dense(net, 32,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 8.0, x / 8.0),
                                  name="l2",
                                  trainable=trainable)
            net = tf.contrib.layers.layer_norm(net)
            net = tf.nn.relu(net)

            # the output layer
            # note that the output is a vector which contains
            # Q-values of all actions in one state
            qsa = tf.layers.dense(net, self.action_dim, activation=None,
                                  # kernel_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                  # bias_initializer=tf.random_uniform_initializer(-x / 1000.0, x / 1000.0),
                                  name='qs',
                                  trainable=trainable)
            return qsa


