import random
import numpy as np
import tensorflow as tf
from ....utils.mlp_policy import MlpPolicy
import experiments.utils.tf_util as U
from experiments.utils.common import Dataset, zipsame
from gym import spaces
RENDER = False

"""
    default algorithm parameters
"""
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_LAGRANGE_MULTI_INIT = 1e-4
LR_LAGRANGE_MULTI_INIT_DECAY_FACTOR = 0.999999
GAMMA = 0.999
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
ENTROPY_COEFF = 0.0
RATIO_CLIP_PARAM = 0.2
ADAM_EPSILON = 1e-5
OPTIM_EPOCHS = 50
OPTIM_BATCH_SIZE = 1024

LAGRANGE_MULTI_MIN = 0
LAGRANGE_MULTI_MAX = 10000
CONSTRAINT = 0.25

L_GRADIENT_NORM_CLIP = 1.0

"""
    Reward Constrained Policy Optimization based on the PPO algorithm
    But the Lagrange multiplier now is a function
"""

class RcpoPpoLagFuncAlgo(object):
    def __init__(self, sess, graph, state_space, action_space, algo_name="rcpo_ppo_lag_func",
                 **kwargs):
        # print(a_dim, s_dim)
        self.pointer = 0
        self.sess = sess
        self.graph = graph
        self.algo_name = algo_name
        self.action_space, self.state_space = action_space, state_space

        if isinstance(action_space, spaces.Discrete):
            self.is_discrete = True
            self.action_dim = self.action_space.n
            print("Discrete Action Space, action num is {}".format(self.action_dim))
        else:
            self.is_discrete = False
            self.action_dim = self.action_space.shape[0]
            print("Continuous Action Space, action num is {}".format(self.action_dim))

        """
            initialize algorithm parameters
        """
        self.set_algo_parameters(**kwargs)

        print("Now the sess is {}".format(self.sess))
        print("Now the graph is {}".format(self.graph))

        self.init_networks()

        self.update_cnt = 0
        self.train_writer = tf.summary.FileWriter("./data/" + self.algo_name + "/summary/", self.sess.graph)

        """
            try the initialization of tf_util
        """
        # U.initialize()
        with self.sess.as_default():
            with self.graph.as_default():
                tf.global_variables_initializer().run()

        """
            sync of mpi optimizer
        """
        # self.adam.sync()

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.entropy_coeff = kwargs.get("entropy_coeff", ENTROPY_COEFF)
        self.ratio_clip_param = kwargs.get("ratio_clip_param", RATIO_CLIP_PARAM)
        self.adam_epsilon = kwargs.get("adam_epsilon", ADAM_EPSILON)
        self.optim_epochs = kwargs.get("optim_epochs", OPTIM_EPOCHS)
        self.optim_batch_size = kwargs.get("optim_batch_size", OPTIM_BATCH_SIZE)
        self.policy_net_layers = kwargs.get("policy_net_layers", [8, 8])
        self.v_net_layers = kwargs.get("v_net_layers", [32, 32])
        self.gaussian_fixed_var = kwargs.get("gaussian_fixed_var", False)

        """
            The learning rate of the lagrange multiplier:
            initial value and decay factor
        """
        self.lr_lagrange_multi = kwargs.get("lr_lagrange_mul_init", LR_LAGRANGE_MULTI_INIT)

        """
            the max and min value of lagrange multiplier
        """
        self.lagrange_multiplier_min = kwargs.get("lagrange_multiplier_min", LAGRANGE_MULTI_MIN)
        self.lagrange_multiplier_max = kwargs.get("lagrange_multiplier_max", LAGRANGE_MULTI_MAX)

        self.lag_net_layers = kwargs.get("lag_net_layers", [16, 8])
        self.lag_grad_clip = kwargs.get("lag_gradient_clip", True)
        self.lag_grad_norm_clip = kwargs.get("lag_gradient_norm_clip", L_GRADIENT_NORM_CLIP)


    def policy_fn(self, name):
        return MlpPolicy(self.sess, self.graph, name=name, ob_space=self.state_space,
                         ac_space=self.action_space, policy_net_layers=self.policy_net_layers,
                         v_net_layers=self.v_net_layers, gaussian_fixed_var=self.gaussian_fixed_var)

    def init_networks(self):
        """
            init Lagrange multiplier function network
        """
        self.init_lag_network()

        self.init_rcpo_ppo_networks()

        """
            define the optimizer of the Lagrange multiplier function
        """
        with self.sess.as_default():
            with self.graph.as_default():
                self.lr_lag_phd = tf.placeholder(tf.float32, shape=[None, ], name="lr_lag_phd")
                self.lag_func_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Lag_Func')
                self.grad_cons_return_wrt_lag_param = tf.gradients(ys=-self.traj_lag_func, xs=self.lag_func_params,
                                                                   grad_ys=tf.multiply(tf.pow(self.gamma, self.traj_step_phd),
                                                                                       self.traj_cons_phd))
                self.flatten_grad_cons_return_wrt_lag_param = U.flatten_tensors(self.grad_cons_return_wrt_lag_param)

                self.lag_func_params_shapes = [None] * len(self.lag_func_params)
                self.lag_func_params_grad_phds = [None] * len(self.lag_func_params)
                for ly in range(len(self.lag_func_params)):
                    ly_shape = self.lag_func_params[ly].shape
                    self.lag_func_params_shapes[ly] = ly_shape
                    self.lag_func_params_grad_phds[ly] = tf.placeholder(tf.float32, shape=ly_shape,
                                                                        name="lag_func_phd_{}".format(ly))

                self.lag_opt = tf.train.AdamOptimizer(self.lr_lagrange_multi)
                if self.lag_grad_clip:
                    self.lag_clipped_grad_op, _ = tf.clip_by_global_norm(self.lag_func_params_grad_phds,
                                                                       self.lag_grad_norm_clip)
                    self.lag_trainer = self.lag_opt.apply_gradients(zip(self.lag_clipped_grad_op,
                                                                    self.lag_func_params))
                else:
                    self.lag_trainer = self.lag_opt.apply_gradients(zip(self.lag_func_params_grad_phds,
                                                                    self.lag_func_params))

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def init_lag_network(self):
        with self.sess.as_default():
            with self.graph.as_default():
                # self.state_phd_of_f = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='state_of_f')
                self.state_phd = U.get_placeholder_with_graph(name="ob", dtype=tf.float32,
                                                              shape=[None] + list(self.state_space.shape),
                                                              graph=self.graph)

                """
                    build the shaping weight function f_phi(s, a)
                    and we should deal with both discrete and continuous actions
                """
                if self.is_discrete:
                    self.action_phd_for_l = tf.placeholder(tf.int32, [None, ], name='action_for_l')
                    self.lag_func = self._build_lagrange_multiplier_func(self.state_phd,
                                                                         tf.one_hot(self.action_phd_for_l,
                                                                                    self.action_dim))
                else:
                    self.action_phd_for_l = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                           name='action_for_l')
                    self.lag_func = self._build_lagrange_multiplier_func(self.state_phd, self.action_phd_for_l)

                """
                    for computing nabla_{phi} R_{tau}(s,a) for each (s,a) 
                    we have compute the gradients of the state-action pairs along the trajectories
                    so we need the corresponding place holders
                """
                self.traj_state_phd = tf.placeholder(tf.float32, [None, self.state_space.shape[0]], name='traj_state')
                self.traj_step_phd = tf.placeholder(tf.float32, [None, 1], name='traj_step')
                self.traj_lag_phd = tf.placeholder(tf.float32, [None, 1], name='traj_lag_phd')
                self.traj_cons_phd = tf.placeholder(tf.float32, [None, 1], name='traj_cons_phd')
                if self.is_discrete:
                    self.traj_action_phd = tf.placeholder(tf.int32, [None, ], name='traj_action')
                    self.traj_lag_func = self._build_lagrange_multiplier_func_reuse(self.traj_state_phd,
                                                                                    tf.one_hot(self.traj_action_phd,
                                                                                               self.action_dim))
                else:
                    self.traj_action_phd = tf.placeholder(tf.float32, [None, self.action_space.shape[0]],
                                                          name='traj_action')
                    self.traj_lag_func = self._build_lagrange_multiplier_func_reuse(self.traj_state_phd,
                                                                                    self.traj_action_phd)


    def init_rcpo_ppo_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.pi = self.policy_fn("pi")  # Construct network for new policy
                self.pi_old = self.policy_fn("oldpi")  # Network for old policy

                self.atarg = tf.placeholder(dtype=tf.float32, shape=[None])  # Target advantage function (if applicable)
                self.ret = tf.placeholder(dtype=tf.float32, shape=[None])  # Empirical return
                self.lrmult = tf.placeholder(name='lrmult', dtype=tf.float32,
                                             shape=[])  # learning rate multiplier, updated with schedule

                self.state_phd = U.get_placeholder_cached(name="ob")  # ob
                self.action_phd = self.pi.pdtype.sample_placeholder([None])  # ac

                """
                    I thinks these belong to optimization part
                """
                self.kl_old_new = self.pi_old.pd.kl(self.pi.pd)
                self.ent = self.pi.pd.entropy()
                self.mean_kl = tf.reduce_mean(self.kl_old_new)
                self.mean_ent = tf.reduce_mean(self.ent)
                self.pol_ent_pen = (-self.entropy_coeff) * self.mean_ent
                self.ratio = tf.exp(
                    self.pi.pd.logp(self.action_phd) - self.pi_old.pd.logp(self.action_phd))  # pnew / pold
                self.surr1 = self.ratio * self.atarg  # surrogate from conservative policy iteration
                self.surr2 = tf.clip_by_value(self.ratio, 1.0 - self.ratio_clip_param,
                                              1.0 + self.ratio_clip_param) * self.atarg  #
                self.pol_surr = - tf.reduce_mean(
                    tf.minimum(self.surr1, self.surr2))  # PPO's pessimistic surrogate (L^CLIP)
                self.vf_loss = tf.reduce_mean(tf.square(self.pi.vpred - self.ret))
                self.total_loss = self.pol_surr + self.pol_ent_pen + self.vf_loss
                self.losses = [self.pol_surr, self.pol_ent_pen, self.vf_loss, self.mean_kl, self.mean_ent]
                self.loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]

                # self.all_params = self.pi.get_trainable_variables()
                self.policy_params = self.pi.get_policy_variables()
                self.critic_params = self.pi.get_critic_variables()

                """
                    define parameter assignment operation
                """
                # print("Pi variables are {}".format(self.pi.get_variables()))
                self.assign_old_eq_new = U.function([], [],
                                                    updates=[tf.assign(oldv, newv) for (oldv, newv) in
                                                             zipsame(self.pi_old.get_variables(),
                                                                     self.pi.get_variables())],
                                                    sess=self.sess,
                                                    graph=self.graph)

                """
                    define policy optimizer and policy trainer
                """
                self.policy_opt = tf.train.AdamOptimizer(self.lr_actor)
                if self.actor_grad_clip:
                    self.policy_gradients = tf.gradients(self.total_loss, self.policy_params)
                    self.policy_clipped_grad_op, _ = tf.clip_by_global_norm(self.policy_gradients,
                                                                            self.actor_grad_norm_clip)
                    self.policy_trainer = self.policy_opt.apply_gradients(zip(self.policy_clipped_grad_op,
                                                                              self.policy_params))
                else:
                    self.policy_trainer = self.policy_opt.minimize(loss=self.total_loss, var_list=self.policy_params)

                """
                    define critic optimizer and critic trainer
                """
                self.critic_opt = tf.train.AdamOptimizer(self.lr_critic)
                if self.critic_grad_clip:
                    self.critic_gradients = tf.gradients(self.vf_loss, self.critic_params)
                    self.critic_clipped_grad_op, _ = tf.clip_by_global_norm(self.critic_gradients,
                                                                            self.critic_grad_norm_clip)
                    self.critic_trainer = self.critic_opt.apply_gradients(zip(self.critic_clipped_grad_op,
                                                                              self.critic_params))
                else:
                    self.critic_trainer = self.critic_opt.minimize(self.vf_loss, var_list=self.critic_params)

                self.compute_losses = U.function([self.state_phd, self.action_phd,
                                                  self.atarg, self.ret, self.lrmult],
                                                 self.losses,
                                                 sess=self.sess,
                                                 graph=self.graph)

    def choose_action(self, s, is_test):
        with self.graph.as_default():
            action, vpred = self.pi.act(stochastic=True, ob=s)

            """
                get the Lagrange multiplier
            """
            lagrange_multi_value = self.sess.run(self.lag_func, feed_dict={
                self.state_phd: [s],
                self.action_phd_for_l: [action]
            })
            lagrange_multi_value = lagrange_multi_value[0][0]
            lagrange_multi_value = min(max(lagrange_multi_value, self.lagrange_multiplier_min),
                                       self.lagrange_multiplier_max)

            return action, vpred, lagrange_multi_value

    def learn(self, **kwargs):
        with self.graph.as_default():
            """
                get:
                advantage values
                td-lambda returns
                state value predictions
            """
            bs = kwargs.get("ob")
            ba = kwargs.get("ac")
            batch_adv = kwargs.get("adv")
            batch_td_lam_ret = kwargs.get("td_lam_ret")


            """
                standardized advantage function estimate
            """
            batch_adv = (batch_adv - batch_adv.mean()) / batch_adv.std()

            """
                note that ppo has no replay buffer
                so construct data set immediately
            """
            d = Dataset(dict(ob=bs, ac=ba, atarg=batch_adv, vtarg=batch_td_lam_ret),
                        deterministic=self.pi.recurrent)

            batch_size = self.optim_batch_size or bs.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(bs)

            """
                set old parameter values to new parameter values
            """
            self.assign_old_eq_new()

            """
                Here we do a bunch of optimization epochs over the data
            """
            for _ in range(self.optim_epochs):
                losses = []  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(batch_size):
                    _, _, policy_loss, v_loss, kl = self.sess.run([self.policy_trainer, self.critic_trainer,
                                                               self.pol_surr, self.vf_loss, self.mean_kl],
                                  feed_dict={
                                      self.state_phd: batch["ob"],
                                      self.action_phd: batch["ac"],
                                      self.atarg: batch["atarg"],
                                      self.ret: batch["vtarg"]
                                  })

                    """
                        write summary
                    """
                    self.write_summary_scalar(self.update_cnt, "policy_loss", policy_loss)
                    self.write_summary_scalar(self.update_cnt, "v_loss", v_loss)
                    self.write_summary_scalar(self.update_cnt, "mean_kl", kl)
                    # print("p_loss, v_loss, mean_kl are {}, {}, {}".format(policy_loss, v_loss, kl))
                    self.update_cnt += 1

            """
                update the Lagrange multiplier
            """
            mini_batches = kwargs.get("mini_batches")
            flatten_grad_cons_return_sum = None
            for sample_index in range(len(mini_batches)):

                """
                    [state, action, step, Lagrange_multiplier, cons]
                """
                mini_batch = mini_batches[sample_index]
                mini_feed_dict = {}
                mini_feed_dict.update({self.traj_state_phd: mini_batch[0]})
                mini_feed_dict.update({self.traj_action_phd: mini_batch[1]})
                mini_feed_dict.update({self.traj_step_phd: mini_batch[2]})
                mini_feed_dict.update({self.traj_lag_func: mini_batch[3]})
                mini_feed_dict.update({self.traj_cons_phd: mini_batch[4]})
                # mini_feed_dict.update({self.lr_lag_phd: self.lr_lagrange_multi})

                flatten_grad_cons_return = self.sess.run(self.flatten_grad_cons_return_wrt_lag_param,
                                                         feed_dict=mini_feed_dict)
                if flatten_grad_cons_return_sum is None:
                    flatten_grad_cons_return_sum = flatten_grad_cons_return
                else:
                    flatten_grad_cons_return_sum = flatten_grad_cons_return_sum + flatten_grad_cons_return

            flatten_grad_cons_return_sum /= len(mini_batches)
            print("The flatten gradient is {}".format(flatten_grad_cons_return_sum))

            element_start = 0
            grad_cons_return_values = [None] * len(self.lag_func_params)
            for ly in range(len(self.lag_func_params_shapes)):
                ly_shape = self.lag_func_params_shapes[ly]
                element_num = 1
                for j in range(len(ly_shape)):
                    element_num *= int(ly_shape[j])

                element_end = element_start + element_num
                param_grad_seg = flatten_grad_cons_return_sum[element_start:element_end]
                param_grad_seg = np.array(param_grad_seg, dtype=np.float32).reshape(ly_shape)
                grad_cons_return_values[ly] = param_grad_seg
                element_start = element_end

            # print("The recover gradient of J w.r.t. phi is {}".format(self.grad_J_wrt_phi))

            lag_func_grad_dict = {}
            for ly in range(len(self.lag_func_params)):
                lag_func_grad_dict.update({self.lag_func_params_grad_phds[ly]: grad_cons_return_values[ly]})

            """
                Finally, optimize the Lagrange multiplier function
            """
            self.sess.run(self.lag_trainer, feed_dict=lag_func_grad_dict)

    def _build_lagrange_multiplier_func(self, s, a, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Lag_Func', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.lag_net_layers)):
                ly_cell_num = self.lag_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            net = tf.layers.dense(net, 1, activation=tf.nn.relu,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='lag_value', trainable=trainable)
            net = tf.add(net, 1)
            return net

    def _build_lagrange_multiplier_func_reuse(self, s, a):
        with tf.variable_scope('Lag_Func', reuse=True):
            net = tf.concat([s, a], axis=1)
            for ly_index in range(len(self.lag_net_layers)):
                ly_cell_num = self.lag_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=True)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            net = tf.layers.dense(net, 1, activation=tf.nn.relu,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='lag_value', trainable=True)
            net = tf.add(net, 1)
            return net

    def write_summary_scalar(self, iteration, tag, value):
        self.train_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
