import numpy as np
import tensorflow as tf
from ..ppo_algo import PPOAlgo
from experiments.utils.mlp_policy import MlpPolicy
from experiments.utils.common import Dataset

RENDER = False

LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_PHI = 1e-3
GAMMA = 0.999
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
PHI_GRADIENT_NORM_CLIP = 50.0
ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024
TAU = 0.01

class PPODpbaAlgo(PPOAlgo):
    def __init__(self, sess, graph, state_space, action_space, algo_name="ppo_dpba",
                 **kwargs):
        super(PPODpbaAlgo, self).__init__(sess, graph, state_space, action_space,
                                          algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_phi = kwargs.get("lr_phi", LR_PHI)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.phi_gradient_clip = kwargs.get("phi_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.phi_grad_norm_clip = kwargs.get("phi_gradient_norm_clip", PHI_GRADIENT_NORM_CLIP)

        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.phi_net_layers = kwargs.get("phi_net_layers", [16, 8])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)
        self.tau = kwargs.get("tau", TAU)
        self.phi_hidden_layer_act_func = kwargs.get("phi_hidden_layer_act_func", tf.nn.relu)

    def policy_fn(self, name):
        return MlpPolicy(self.sess, self.graph, name=name, ob_space=self.state_space,
                         ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                         v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var)

    def init_networks(self):
        self.init_ppo_networks()

        """
            init phi network
        """
        self.init_phi_network()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def init_phi_network(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_prime')

                """
                    dealing with discrete action !!!
                """
                if self.is_discrete:
                    self.action_phd_for_phi = tf.placeholder(tf.int32, [None, ], name='action_for_phi')
                    self.action_prime_phd = tf.placeholder(tf.int32, [None, ], name='action_prime')
                    self.phi_sa = self.build_phi_net(self.state_phd, tf.one_hot(self.action_phd_for_phi,
                                                                                self.action_dim))
                else:
                    self.action_phd_for_phi = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                             name='action_for_phi')
                    self.action_prime_phd = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                           name='action_prime')

                    self.phi_sa = self.build_phi_net(self.state_phd, self.action_phd_for_phi)

                self.shaping_reward_phd = tf.placeholder(tf.float32, [None, 1], name="shaping_reward")
                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')

                phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="PhiNet")
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)  # soft replacement

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                phi_target_update = [ema.apply(phi_params)]

                """
                    target potential network
                """
                if self.is_discrete:
                    self.phi_spap = self.build_phi_net(self.state_prime_phd,
                                                       tf.one_hot(self.action_prime_phd, self.action_dim),
                                                       reuse=True, custom_getter=ema_getter)
                else:
                    self.phi_spap = self.build_phi_net(self.state_prime_phd, self.action_prime_phd,
                                                       reuse=True, custom_getter=ema_getter)

                with tf.control_dependencies(phi_target_update):
                    """
                        define the optimization of Phi-network
                    """
                    phi_target = -self.shaping_reward_phd + self.gamma * (1 - self.done_phd) * self.phi_spap
                    self.phi_target_op = tf.stop_gradient(phi_target)
                    self.phi_td_error_op = self.phi_target_op - self.phi_sa
                    self.phi_squared_error_op = tf.square(self.phi_td_error_op)
                    self.phi_loss = tf.reduce_mean(self.phi_squared_error_op, name="phi_loss")
                    if self.phi_gradient_clip:
                        self.phi_param_gradients = tf.gradients(self.phi_loss, phi_params)
                        self.phi_clipped_grad_op, _ = tf.clip_by_global_norm(self.phi_param_gradients,
                                                                             self.phi_grad_norm_clip)
                        self.phi_opt = tf.train.AdamOptimizer(self.lr_phi)
                        self.phi_optimizer = self.phi_opt.apply_gradients(zip(self.phi_clipped_grad_op, phi_params))
                    else:
                        self.phi_optimizer = tf.train.AdamOptimizer(self.lr_phi).minimize(self.phi_loss,
                                                                                          name="phi_adam_optimizer")

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            action, vpred = self.pi.act(stochastic=True, ob=s)

            if not is_test:
                """
                    get the phi value of the current state-action pair
                """
                phi_sa = self.sess.run(self.phi_sa, feed_dict={
                    self.state_phd: [s],
                    self.action_phd_for_phi: [action]
                })

                if len(phi_sa.shape) == 2:
                    phi_sa = phi_sa[0][0]

                return action, {"v_pred": vpred, "phi_sa": phi_sa}
            else:
                return action, {"v_pred": vpred}

    def learn(self, **kwargs):
        with self.graph.as_default():
            """
                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")
            bs_ = kwargs.get("bs_")
            ba_ = kwargs.get("ba_")
            bdone = kwargs.get("bdone")
            c_batch = kwargs.get("c_batch")

            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(ob=bs, ac=ba, atarg=batch_adv, vtarg=batch_td_lam_ret,
                             bs_=bs_, ba_=ba_, bdone=bdone, c_batch=c_batch),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            # param_values = self.sess.run(self.policy_params)
            # print("Before update the parameters are {}".format(param_values))

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):

                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                                   self.pol_surr, self.vf_loss, self.mean_kl],
                                                                  feed_dict={
                                                                      self.state_phd: batch["ob"],
                                                                      self.action_phd: batch["ac"],
                                                                      self.atarg: batch["atarg"],
                                                                      self.ret: batch["vtarg"]
                                                                  })

                    """
                        optimize the phi network
                    """
                    self.sess.run(self.phi_optimizer,
                                  feed_dict={self.state_phd: batch["ob"],
                                             self.action_phd: batch["ac"],
                                             self.state_prime_phd: batch["bs_"],
                                             self.action_prime_phd: batch["ba_"],
                                             self.done_phd: np.array(batch["bdone"]).reshape([-1, 1]),
                                             self.shaping_reward_phd: np.array(batch["c_batch"]).reshape([-1, 1]),
                                             self.action_phd_for_phi: batch["ac"],
                                             })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    # print("p_loss, v_loss, mean_kl are {}, {}, {}".format(policy_loss, v_loss, kl))
                    self.update_cnt += 1

    def build_phi_net(self, state_phd, action_phd, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('PhiNet', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([state_phd, action_phd], axis=1)
            for ly_index in range(len(self.phi_net_layers)):
                ly_cell_num = self.phi_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name="l"+str(ly_index),
                                      trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.phi_hidden_layer_act_func(net)

            # the output layer
            # note that the output is a vector which contains
            # Q-values of all actions in one state
            phisa = tf.layers.dense(net, 1, activation=None,
                                    kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                    bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                    name='phi_sa',
                                    trainable=trainable
                                    )
            return phisa
