import numpy as np
import tensorflow as tf
from experiments.utils.ou_noise import OUNoise
from experiments.algorithms.ddpg.ddpg_algo import DDPGAlgo

GAMMA = 0.999
TAU = 0.01
LR_ACTOR = 1e-4
LR_CRITIC = 2e-4
LR_F = 2e-4
OU_NOISE_THETA = 0.15
OU_NOISE_SIGMA = 0.5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX = 1.0
GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN = 1e-5
GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX = 0.2
GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE = 60000
ACTOR_GRADIENT_NORM_CLIP = 1.0
CRITIC_GRADIENT_NORM_CLIP = 1.0
F_GRADIENT_NORM_CLIP = 50

my_flatten = lambda x: [subitem for item in x for subitem in my_flatten(item)] if type(x) is list else [x]


class DDPGOprsV2Algo(DDPGAlgo):
    """
    DDPG with optimization of parameterized reward shaping (OPRS) v2
    which optimizes shaping weight function parameters \phi by computing
    the gradient of policy parameters \theta w.r.t phi, with the assumption
    that \nabla_{\phi} theta is only related to \Delta \theta
    """

    def __init__(self, sess, graph, state_dim, action_dim, algo_name="ddpg_oprs_v2", **kwargs):
        self.optimize_policy = True
        self.episode_traj = []
        super(DDPGOprsV2Algo, self).__init__(sess, graph, state_dim, action_dim, algo_name, **kwargs)

    def set_algo_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.tau = kwargs.get("tau", TAU)
        self.lr_actor = kwargs.get("lr_actor", LR_ACTOR)
        self.lr_critic = kwargs.get("lr_critic", LR_CRITIC)
        self.lr_f = kwargs.get("lr_f", LR_F)
        self.explo_method = kwargs.get("explo_method", "OU")
        if self.explo_method == "OU":
            ou_noise_theta = kwargs.get("ou_noise_theta", OU_NOISE_THETA)
            ou_noise_sigma = kwargs.get("ou_noise_sigma", OU_NOISE_SIGMA)
            self.ou_noise = OUNoise(self.action_dim, mu=0, theta=ou_noise_theta, sigma=ou_noise_sigma)
        elif self.explo_method == "GAUSSIAN_STATIC":
            self.gaussian_explo_ratio = kwargs.get("gaussian_explo_sigma_ratio_fix",
                                                   GAUSSIAN_EXPLORATION_SIGMA_RATIO_FIX)
        else:
            self.gaussian_explo_sigma_ratio_max = kwargs.get("gaussian_explo_sigma_ratio_max",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MAX)
            self.gaussian_explo_sigma_ratio_min = kwargs.get("gaussian_explo_sigma_ratio_min",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_MIN)
            gaussian_explo_sigma_ratio_decay_ep = kwargs.get("gaussian_explo_sigma_ratio_decay_ep",
                                                             GAUSSIAN_EXPLORATION_SIGMA_RATIO_DECAY_EPISODE)
            self.gaussian_explo_ratio = self.gaussian_explo_sigma_ratio_max
            self.gaussian_explo_decay_factor = pow(self.gaussian_explo_sigma_ratio_min /
                                                   self.gaussian_explo_sigma_ratio_max,
                                                   1.0 / gaussian_explo_sigma_ratio_decay_ep)

        self.actor_grad_clip = kwargs.get("actor_gradient_clip", True)
        self.critic_grad_clip = kwargs.get("critic_gradient_clip", True)
        self.f_grad_clip = kwargs.get("f_gradient_clip", True)
        self.actor_grad_norm_clip = kwargs.get("actor_gradient_norm_clip", ACTOR_GRADIENT_NORM_CLIP)
        self.critic_grad_norm_clip = kwargs.get("critic_gradient_norm_clip", CRITIC_GRADIENT_NORM_CLIP)
        self.f_grad_norm_clilp = kwargs.get("f_gradient_norm_clip", F_GRADIENT_NORM_CLIP)

        """
            network layer cell numbers
        """
        self.actor_net_layers = kwargs.get("actor_net_layers", [4, 4])
        self.critic_net_layers = kwargs.get("critic_net_layers", [32, 32])
        self.critic_act_in_ly_index = int(kwargs.get("critic_action_input_layer_index", 1))
        self.f_net_layers = kwargs.get("f_net_layers", [16, 8])

        self.net_add_one = kwargs.get("net_add_one", False)
        self.f_hidden_layer_act_func = kwargs.get("f_hidden_layer_act_func", tf.nn.relu)

    def init_networks(self):
        """
            firstly define place holders
        """
        self.define_place_holders()

        """
            then build networks, including
            actor network, shaping weight function, shaped critic
        """
        self.build_networks()

        """
            the next is to define trainers,
            including building target networks, loss, and trainers
        """
        self.define_trainers()

        with self.sess.as_default():
            with self.graph.as_default():
                self.saver = tf.train.Saver(max_to_keep=100)

    def define_place_holders(self):
        with self.sess.as_default():
            with self.graph.as_default():
                self.state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state')
                self.state_prime_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='state_prime')

                """
                    the original reward R(s,a)
                """
                self.reward_phd = tf.placeholder(tf.float32, [None, 1], name='reward')

                """
                    the additional reward, namely F(s,a)
                """
                self.add_reward_phd = tf.placeholder(tf.float32, [None, 1], name='additional_reward')

                self.done_phd = tf.placeholder(tf.float32, [None, 1], name='done')
                self.temp = tf.Variable(1.0, name='temperature')

                """
                    shaping weight value of the next state s'
                """
                self.f_phi_sp_phd = tf.placeholder(tf.float32, [None, 1], name='f_phi_sp')

                """
                    for computing \\nabla_{\phi} R_{tau}(s,a) for each (s,a)
                    we have compute the gradients of the state-action pairs along the trajectories
                    so we need the corresponding place holders
                """
                self.traj_state_phd = tf.placeholder(tf.float32, [None, self.state_dim], name='traj_state')
                self.traj_step_phd = tf.placeholder(tf.float32, [None, 1], name='traj_step')
                self.traj_F_phd = tf.placeholder(tf.float32, [None, 1], name='traj_F')

    def build_networks(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    build the shaping weight function f_phi(s)
                    also build the copy of f_phi along a trajectory
                """
                self.f_phi = self._build_weight_func(self.state_phd, )
                self.traj_f_phi = self._build_weight_func_reuse(self.traj_state_phd)

                """
                    build actor network, the input is state
                """
                self.actor_output = self._build_actor(self.state_phd,)

                """
                    build shaped critic network, which is for optimizing policy
                    the input is s, mu(s), f_phi(s)
                """
                self.shaped_critic = self._build_shaped_critic(self.state_phd, self.actor_output, self.f_phi)

                """
                    build true critic network, which is for optimization weight function f_phi
                    the input is state and action
                """
                self.true_critic = self._build_true_critic(self.state_phd, self.actor_output)

    def define_trainers(self):
        with self.sess.as_default():
            with self.graph.as_default():
                """
                    first build target network of actor, shaped critic, and true critic
                """
                a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Actor')
                shaped_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Shaped_Critic')
                true_c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='True_Critic')
                self.f_phi_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Weight_Func')
                ema = tf.train.ExponentialMovingAverage(decay=1 - self.tau)

                def ema_getter(getter, name, *args, **kwargs):
                    return ema.average(getter(name, *args, **kwargs))

                target_update_for_policy = [ema.apply(a_params), ema.apply(shaped_c_params)]
                target_update_for_f = [ema.apply(true_c_params)]

                """
                    target actor network, the input is next state s'
                """
                target_actor_output = self._build_actor(self.state_prime_phd, reuse=True, custom_getter=ema_getter)

                """
                    target shaped critic network, the input is s', mu'(s'), and f_phi(s')
                """
                target_shaped_critic = self._build_shaped_critic(self.state_prime_phd, target_actor_output,
                                                                 self.f_phi_sp_phd, reuse=True,
                                                                 custom_getter=ema_getter)

                """
                    build target true critic network, the input is s' and mu'(s')
                """
                target_true_critic = self._build_true_critic(self.state_prime_phd, target_actor_output,
                                                             reuse=True, custom_getter=ema_getter)

                """
                    define optimization of policy according to shaped rewards
                    we must record the gradient of actor w.r.t the parameter theta
                """
                a_loss = tf.reduce_mean(self.shaped_critic)
                if self.actor_grad_clip:
                    self.actor_param_gradients = tf.gradients(-a_loss, a_params)
                    self.actor_clipped_grad_op, _ = tf.clip_by_global_norm(self.actor_param_gradients,
                                                                           self.actor_grad_norm_clip)
                    self.actor_opt = tf.train.AdamOptimizer(self.lr_actor)
                    self.trainer_actor = self.actor_opt.apply_gradients(zip(self.actor_clipped_grad_op, a_params))
                else:
                    self.trainer_actor = tf.train.AdamOptimizer(self.lr_actor).minimize(-a_loss, var_list=a_params)

                with tf.name_scope('Actor_Loss'):
                    tf.summary.scalar('actor_exp_Q', a_loss)

                """
                    define optimization of shaped critic
                """
                with tf.control_dependencies(target_update_for_policy):
                    shaped_q_target = self.reward_phd + self.f_phi * self.add_reward_phd + \
                                      self.gamma * (1 - self.done_phd) * target_shaped_critic
                    shaped_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(shaped_q_target),
                                                                   predictions=self.shaped_critic)
                    clipped_shaped_td_error = tf.minimum(shaped_td_error, 100.0)
                    if self.critic_grad_clip:
                        # self.shp_c_param_gradients = tf.gradients(shaped_td_error, shaped_c_params)
                        self.shp_c_param_gradients = tf.gradients(clipped_shaped_td_error, shaped_c_params)

                        self.shp_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.shp_c_param_gradients,
                                                                               self.critic_grad_norm_clip)
                        self.shp_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_shaped_critic = self.shp_c_opt.apply_gradients(zip(self.shp_c_clipped_grad_op,
                                                                                        shaped_c_params))
                    else:
                        self.trainer_shaped_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_shaped_td_error, var_list=shaped_c_params)
                """
                    define optimization of f_phi
                    \\nabla_{\phi} J = \\nabla_{\\theta} \mu_{\\theta}(s) * 
                                        \\nabla_{\phi} \\theta * 
                                        Q_{True}(s, \mu(s))
                                     
                                     =  \\nabla_{\\theta} \mu_{\\theta}(s) * 
                                        \\alpha * \\nabla_{\\theta'} \mu_{\\theta'} * \\nabla_{\phi} R_{\tau}
                                        * Q_{True}(s, \mu(s))
                                    
                                     =  \\nabla_{\\theta} \mu_{\\theta}(s) * 
                                        \\alpha * \\nabla_{\\theta'} \mu_{\\theta'} * 
                                        \sum_{i=0}^{|\\tau|-1 } \gamma^i F(s_i,a_i) \\nabla_{\phi} f_{\phi}(s_i)
                                        * Q_{True}(s, \mu(s))
                """

                """
                    self.grad_return_wrt_phi = \sum_{i=0}^{|\\tau|-1 } \gamma^i F(s_i,a_i) \\nabla_{\phi} f_{\phi}(s_i)
                    
                    it should be noted that although the shape of self.traj_f_phi is [None, 1] (one output for each input state)
                    but the gradient value is the sum over all gradients of the input states
                    that is the say, self.grad_return_wrt_phi is like [array(), array(), ..., array()],
                    where each array corresponds to the gradients of each layer's parameters                    
                """
                self.grad_return_wrt_phi = tf.gradients(ys=self.traj_f_phi, xs=self.f_phi_params,
                                                        grad_ys=tf.multiply(tf.pow(self.gamma, self.traj_step_phd),
                                                                            self.traj_F_phd))
                self.f_phi_params_shapes = [None] * len(self.f_phi_params)
                self.theta_params_shapes = [None] * len(a_params)
                for i in range(len(self.f_phi_params)):
                    self.f_phi_params_shapes[i] = self.f_phi_params[i].shape
                    print("The f_phi param {} shape is {}".format(i, self.f_phi_params_shapes[i]))

                for i in range(len(a_params)):
                    self.theta_params_shapes[i] = a_params[i].shape
                    print("The actor param {} shape is {}".format(i, self.theta_params_shapes[i]))


                """
                    self.grad_mu_wrt_theta = \\nabla_{\\theta'} \mu_{\\theta'} and \\nabla_{\\theta} \mu_{\\theta}
                    the gradient of policy mu w.r.t to the parameter \\theta
                """
                self.grad_mu_wrt_theta = tf.gradients(ys=self.actor_output, xs=a_params)

                """
                    gradient of \\theta w.r.t. \\phi
                    which is aggragated every update of policy parameters
                    and will be averaged before the update of the shaping weight function
                """
                self.grad_theta_wrt_phi_aggr = None
                self.grad_aggr_num = 0


                """
                    the policy gradient is for optimizing weight function parameter \phi
                    \\nabla_{\\theta} \mu(s) * Q_{True}(s,\mu(s)) * (\\nabla_{\phi} \\theta)
                """
                true_a_loss = tf.reduce_mean(self.true_critic)
                self.grad_true_loss_wrt_theta = tf.gradients(ys=-true_a_loss, xs=a_params)
                self.grad_true_loss_wrt_theta, _ = tf.clip_by_global_norm(self.grad_true_loss_wrt_theta,
                                                                          self.actor_grad_norm_clip)

                self.grad_J_wrt_phi = [None] * len(self.f_phi_params)
                self.optimizer_f = tf.train.AdamOptimizer(self.lr_f)

                for ly in range(len(self.f_phi_params_shapes)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    fake_gradient = np.full(shape=ly_shape, fill_value=0.0, dtype=np.float32)
                    self.grad_J_wrt_phi[ly] = fake_gradient

                self.optimizer_f = tf.train.AdamOptimizer(self.lr_f)
                self.f_phi_params_grad_phds = [None] * len(self.f_phi_params)
                for ly in range(len(self.f_phi_params)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    self.f_phi_params_grad_phds[ly] = tf.placeholder(tf.float32, shape=ly_shape,
                                                                     name="f_phi_phd_{}".format(ly))
                if self.f_grad_clip:
                    self.f_clipped_grad_op, _ = tf.clip_by_global_norm(self.f_phi_params_grad_phds,
                                                                       self.f_grad_norm_clilp)
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_clipped_grad_op, self.f_phi_params))
                else:
                    self.trainer_f = self.optimizer_f.apply_gradients(zip(self.f_phi_params_grad_phds,
                                                                          self.f_phi_params))

                """
                    define optimization of true critic
                """
                with tf.control_dependencies(target_update_for_f):
                    true_q_target = self.reward_phd + self.gamma * (1 - self.done_phd) * target_true_critic
                    true_td_error = tf.losses.mean_squared_error(labels=tf.stop_gradient(true_q_target),
                                                                 predictions=self.true_critic)
                    clipped_true_td_error = tf.minimum(true_td_error, 100.0)
                    if self.critic_grad_clip:
                        # self.true_c_param_gradients = tf.gradients(true_td_error, true_c_params)
                        self.true_c_param_gradients = tf.gradients(clipped_true_td_error, true_c_params)

                        self.true_c_clipped_grad_op, _ = tf.clip_by_global_norm(self.true_c_param_gradients,
                                                                                self.critic_grad_norm_clip)
                        self.true_c_opt = tf.train.AdamOptimizer(self.lr_critic)
                        self.trainer_true_critic = self.true_c_opt.apply_gradients(zip(self.true_c_clipped_grad_op,
                                                                                       true_c_params))
                    else:
                        self.trainer_true_critic = tf.train.AdamOptimizer(self.lr_critic).minimize(
                            clipped_true_td_error, var_list=true_c_params)

                with tf.name_scope('Critic_loss'):
                    tf.summary.scalar('shaped_td_error', shaped_td_error)
                    tf.summary.scalar('true_td_error', true_td_error)

    def choose_action(self, s, is_test):
        """
            firstly compute the action
        """
        f_phi_s, action = self.sess.run([self.f_phi, self.actor_output], {self.state_phd: [s]})
        action = action[0]

        if len(f_phi_s.shape) == 2:
            f_phi_s = f_phi_s[0][0]

        # print("f_phi_s is {}".format(f_phi_s))

        """
            if currently it is test or optimizing the shaping weight function
        """
        if is_test or not self.optimize_policy:
            return action, {"f_phi_s": f_phi_s}
        else:
            # print("Computed action is {}".format(action))
            if self.explo_method == "OU":
                action = action + self.ou_noise.noise()
            else:
                action = np.random.normal(action, action * self.gaussian_explo_ratio)

            return action, {"f_phi_s": f_phi_s}

    def learn(self, bs, ba, br, bs_, bdone, **kwargs):
        with self.sess.graph.as_default():
            if self.optimize_policy:
                """
                    oprs-v2 learns with trajectory batch, which is organized in the trainer
                """
                """
                    optimize actor
                """
                # self.sess.run(self.trainer_actor, {self.state_phd: bs}) # why not use the exploration action??
                self.sess.run(self.trainer_actor, {self.state_phd: bs,
                                                   self.actor_output: ba})
                """
                    for each state in the state batch
                    compute \\nabla_{\\theta} mu_{\\theta}(s) and \\nabla_{\phi} R_{\\tau}(s)
                    
                    1. compute \\nabla_{\phi} R_{\\tau}(s) using the minibatch of the state
                    2. compute \\nabla_{\\theta} mu_{\\theta}(s)
                    3. conduct matrix multiplication of \\nabla_{\phi} R_{\\tau}(s) and \\nabla_{\\theta} mu_{\\theta}(s)
                    4. sum over the multiplication results of all states
                """
                mini_batches = kwargs.get("mini_batches")
                assert len(mini_batches) == len(bs)
                grad_theta_wrt_phi = None
                for i in range(len(bs)):
                    s = bs[i]

                    """
                        [state, step, F]
                    """
                    mini_batch_s = mini_batches[i]
                    mini_feed_dict = {}
                    mini_feed_dict.update({self.traj_state_phd: mini_batch_s[0]})
                    mini_feed_dict.update({self.traj_step_phd: mini_batch_s[1]})
                    mini_feed_dict.update({self.traj_F_phd: mini_batch_s[2]})

                    grad_return_wrt_phi_s = self.sess.run(self.grad_return_wrt_phi, feed_dict=mini_feed_dict)
                    grad_mu_wrt_theta_s = self.sess.run(self.grad_mu_wrt_theta, feed_dict={self.state_phd: [s]})

                    """
                        firstly transform the two list of nd-arrays into list of lists and record the corresponding shapes
                        flatten the two list of nd-arrays
                    """
                    # grad_return_wrt_phi_s_shapes = [None] * len(grad_return_wrt_phi_s)
                    # grad_mu_wrt_theta_s_shapes = [None] * len(grad_mu_wrt_theta_s)
                    for ly in range(len(grad_return_wrt_phi_s)):
                        grad_return_wrt_phi_s[ly] = grad_return_wrt_phi_s[ly].tolist()

                    for ly in range(len(grad_mu_wrt_theta_s)):
                        grad_mu_wrt_theta_s[ly] = grad_mu_wrt_theta_s[ly].tolist()

                    flatten_grad_return_wrt_phi_s = np.array(my_flatten(grad_return_wrt_phi_s))
                    flatten_grad_mu_wrt_theta_s = np.array(my_flatten(grad_mu_wrt_theta_s))
                    # print("The shape of gradient of return w.r.t. phi is {}".format(flatten_grad_return_wrt_phi_s.shape))
                    # print("The shape of gradient of mu w.r.t. theta is {}".format(flatten_grad_mu_wrt_theta_s.shape))

                    grad_theta_wrt_phi_s = np.matmul(flatten_grad_mu_wrt_theta_s.reshape([-1, 1]),
                                                     np.atleast_2d(flatten_grad_return_wrt_phi_s))
                    # print("The shape of matrix multiplication result is {}".format(grad_theta_wrt_phi_s.shape))

                    if grad_theta_wrt_phi is None:
                        grad_theta_wrt_phi = grad_theta_wrt_phi_s
                    else:
                        grad_theta_wrt_phi = grad_theta_wrt_phi + grad_theta_wrt_phi_s

                grad_theta_wrt_phi = np.multiply(grad_theta_wrt_phi, self.lr_actor / len(bs))
                # print("Gradient of theta w.r.t. phi of this update is {}".format(grad_theta_wrt_phi))
                # print("Gradient of theta w.r.t. phi of this update shape is {}".format(grad_theta_wrt_phi.shape))

                """
                    then aggregate the gradient of this update
                """
                self.grad_aggr_num += 1
                if self.grad_theta_wrt_phi_aggr is None:
                    self.grad_theta_wrt_phi_aggr = grad_theta_wrt_phi
                else:
                    self.grad_theta_wrt_phi_aggr = self.grad_theta_wrt_phi_aggr + grad_theta_wrt_phi


                """
                    optimize shaped critic
                    first we should compute f_phi(s')
                    target shaped critic is Q(s', mu'(s',f(s')), f(s'))
                """
                f_phi_s_batch = kwargs.get("f_phi_s")
                add_reward_batch = kwargs.get("F")
                f_phi_sp_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs_})

                self.sess.run(self.trainer_shaped_critic, feed_dict={self.state_phd: bs, self.actor_output: ba,
                                                                     self.reward_phd: br,
                                                                     self.add_reward_phd: add_reward_batch,
                                                                     self.f_phi: f_phi_s_batch,
                                                                     self.state_prime_phd: bs_,
                                                                     self.done_phd: bdone,
                                                                     self.f_phi_sp_phd: f_phi_sp_batch})
            else:
                assert self.grad_aggr_num != 0
                assert self.grad_theta_wrt_phi_aggr is not None

                """
                    firstly, get the value of \\nabla_{\phi} \\theta
                    this following value is 
                    \\nabla_{\phi} \\theta = \\alpha * \\nabla_{\\theta'} \mu_{\\theta'} * \\nabla_{\phi} R_{\\tau}
                    
                    the gradient of the objective w.r.t. \phi is:
                    \\nabla_{\\theta} \mu(s) * \\nabla_{\phi} \\theta * Q_{True}(s, \mu(s))
                """
                # self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)

                """
                    secondly, compute E_{s \sim \\rho} [\\nabla_{\\theta} \mu(s) * Q_{True}(s,\mu(s))]
                """
                # grad_true_loss_wrt_theta = self.sess.run(self.grad_true_loss_wrt_theta, {self.state_phd: bs})
                grad_true_loss_wrt_theta = self.sess.run(self.grad_true_loss_wrt_theta, {self.state_phd: bs,
                                                                                         self.actor_output: ba})
                for ly in range(len(grad_true_loss_wrt_theta)):
                    grad_true_loss_wrt_theta[ly] = grad_true_loss_wrt_theta[ly].tolist()

                flatten_grad_true_loss_wrt_theta = np.array(my_flatten(grad_true_loss_wrt_theta))
                # print("The shape of gradient of true loss w.r.t. theta is {}".format(flatten_grad_true_loss_wrt_theta.shape))
                # print("The shape of gradient of theta w.r.t phi is {}".format(self.grad_theta_wrt_phi_aggr.shape))

                """
                    then, conduct matrix multiplication and finally get \\nabla_{\phi} J_{True}
                    we also should reshape the value according to the shape of parameters \phi
                """
                grad_J_wrt_phi = np.matmul(np.atleast_2d(flatten_grad_true_loss_wrt_theta),
                                           self.grad_theta_wrt_phi_aggr)
                # print("The resulted matrix shape, gradient of J w.r.t. phi is {}".format(grad_J_wrt_phi.shape))
                if len(grad_J_wrt_phi.shape) == 2 and grad_J_wrt_phi.shape[0] == 1:
                    grad_J_wrt_phi = grad_J_wrt_phi[0]

                self.grad_J_wrt_phi = [None] * len(self.f_phi_params_shapes)
                element_start = 0
                element_end = 0  # exclusive
                for ly in range(len(self.f_phi_params_shapes)):
                    ly_shape = self.f_phi_params_shapes[ly]
                    element_num = 1
                    for j in range(len(ly_shape)):
                        element_num *= int(ly_shape[j])

                    element_end = element_start + element_num
                    param_grad_seg = grad_J_wrt_phi[element_start:element_end]
                    param_grad_seg = np.array(param_grad_seg, dtype=np.float32).reshape(ly_shape)

                    # self.grad_J_wrt_phi.append(param_grad_seg)
                    self.grad_J_wrt_phi[ly] = param_grad_seg

                    element_start = element_end

                # print("The recover gradient of J w.r.t. phi is {}".format(self.grad_J_wrt_phi))

                """
                    finally, we use \\nabla_{\phi} J_{True} to optimize f_phi
                """
                f_phi_dict = {}
                for ly in range(len(self.f_phi_params)):
                    f_phi_dict.update({self.f_phi_params_grad_phds[ly]: self.grad_J_wrt_phi[ly]})

                self.sess.run(self.trainer_f, feed_dict=f_phi_dict)

                """
                    optimize true critic, which is Q(s', mu'(s',f(s'))
                """
                f_phi_sp_batch = self.sess.run(self.f_phi, feed_dict={self.state_phd: bs_})

                self.sess.run(self.trainer_true_critic, feed_dict={self.state_phd: bs, self.actor_output: ba,
                                                                   self.reward_phd: br, self.state_prime_phd: bs_,
                                                                   self.done_phd: bdone,
                                                                   self.f_phi_sp_phd: f_phi_sp_batch})

    def _build_weight_func(self, s, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Weight_Func', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l'+str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=None,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=trainable)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_weight_func_reuse(self, s):
        with tf.variable_scope('Weight_Func', reuse=True):
            net = s
            for ly_index in range(len(self.f_net_layers)):
                ly_cell_num = self.f_net_layers[ly_index]
                net = tf.layers.dense(net, ly_cell_num,
                                      kernel_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      bias_initializer=tf.random_uniform_initializer(-1 / 8.0, 1 / 8.0),
                                      name='l' + str(ly_index), trainable=True)
                net = tf.contrib.layers.layer_norm(net)
                net = self.f_hidden_layer_act_func(net)

            net = tf.layers.dense(net, 1, activation=None,
                                  kernel_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  bias_initializer=tf.random_uniform_initializer(-1 / 1000.0, 1 / 1000.0),
                                  name='f_value', trainable=True)

            if self.net_add_one:
                net = tf.add(net, 1)

            return net

    def _build_true_critic(self, s, action, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('True_Critic', reuse=reuse, custom_getter=custom_getter):
            net = s
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def _build_shaped_critic(self, s, action, f_phi, reuse=None, custom_getter=None):
        trainable = True if reuse is None else False
        with tf.variable_scope('Shaped_Critic', reuse=reuse, custom_getter=custom_getter):
            net = tf.concat([s, f_phi], axis=1)
            for ly_index in range(len(self.critic_net_layers)):
                ly_cell_num = self.critic_net_layers[ly_index]
                if ly_index == self.critic_act_in_ly_index:
                    net = tf.concat([net, action], axis=1)

                net = tf.layers.dense(net, ly_cell_num,
                                      name='l' + str(ly_index), trainable=trainable)
                net = tf.contrib.layers.layer_norm(net)
                net = tf.nn.relu(net)

            return tf.layers.dense(net, 1, trainable=trainable)

    def experience(self, one_exp):
        assert one_exp is not None
        self.episode_traj.append(one_exp)

    def episode_done(self, is_test):
        if self.optimize_policy and not is_test:
            if self.explo_method == "GAUSSIAN_DYNAMIC":
                self.gaussian_explo_ratio = max(self.gaussian_explo_sigma_ratio_min,
                                                self.gaussian_explo_ratio * self.gaussian_explo_decay_factor)

        """
            reset the record of episode trajectory and return the last one
        """
        if not is_test:
            traj = np.array(self.episode_traj)
            self.episode_traj = []
            return traj
        else:
            return None

    def switch_optimization(self):
        self.optimize_policy = not self.optimize_policy
        """
            clean the gradient of theta w.r.t. phi 
            if now it is to optimize the policy again
        """
        if self.optimize_policy:
            self.grad_theta_wrt_phi_aggr = None
            self.grad_aggr_num = 0
        else:
            self.grad_theta_wrt_phi_aggr = np.divide(self.grad_theta_wrt_phi_aggr, self.grad_aggr_num)
