# coding=utf-8
import numpy as np
import tensorflow as tf
from ..ppo_trainer import PPOTrainer
from ..ppo_dpba.ppo_dpba_algo import PPODpbaAlgo


RENDER = False

"""
    save model per 1000 episodes
"""
MODEL_UPDATE_FREQ = 1000


class PPODpbaTrainer(PPOTrainer):
    def __init__(self, state_space, action_space, algo_name="ppo_dpba", **kwargs):
        super(PPODpbaTrainer, self).__init__(state_space, action_space, algo_name, **kwargs)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = PPODpbaAlgo(self.session, self.graph, self.state_space,
                                     self.action_space, algo_name=self.algo_name,
                                     **kwargs)

        """
            for restoring samples
            one update will be performed after 2048 exps are collected
        """
        self.exp_mini_buffer = [None] * self.truncation_size
        self.last_exp = None
        self.exp_cnt = 0
        self.update_cnt = 0

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)
        return a, action_info

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # ppo has no memory
        v_pred = kwargs.get("v_pred")

        """
            get the shaping reward of s
        """
        c = kwargs.get("c")
        phi_sa = kwargs.get("phi_sa")

        if self.last_exp is None:
            self.last_exp = (s, a, r, sp, terminal, v_pred, c, phi_sa)
        else:
            i = self.exp_cnt % self.truncation_size
            self.exp_mini_buffer[i] = self.last_exp
            self.last_exp = (s, a, r, sp, terminal, v_pred, c, phi_sa)

            self.exp_cnt += 1
            if self.exp_cnt % self.truncation_size == 0:
                """
                    update the policy using the current experiences in buffer
                """
                self.ppo_update(next_v_pred=v_pred, next_ac=a, phi_spap=phi_sa)

    def ppo_update(self, **kwargs):
        """
            conduct update of ppo
            first, we should transform experiences to samples
        """
        # print('update for', self.update_cnt)
        self.update_cnt += 1

        obs0 = self.exp_mini_buffer[0][0]
        act0 = self.exp_mini_buffer[0][1]
        # print("The initial state and action is {}, {}".format(obs0, act0))

        seg = {"ob": np.array([obs0 for _ in range(self.truncation_size)]),
               "sp": np.array([obs0 for _ in range(self.truncation_size)]),
               "ac": np.array([act0 for _ in range(self.truncation_size)]),
               # "prev_ac": np.array([act0 for _ in range(TRUNCATION_SIZE)]),
               "next_ac": np.array([act0 for _ in range(self.truncation_size)]),
               "rew": np.zeros(self.truncation_size, dtype=float),
               "v_pred": np.zeros(self.truncation_size, dtype=float),
               "done": np.zeros(self.truncation_size, dtype=int),
               "c": np.zeros(self.truncation_size, dtype=float),
               "phi_sa": np.zeros(self.truncation_size, dtype=float)}

        pre_act = act0
        for t in range(self.truncation_size):
            s, a, r, sp, done, v_pred, c, phi_sa = self.exp_mini_buffer[t]
            seg["ob"][t] = s
            seg["ac"][t] = a
            seg["rew"][t] = r
            seg["sp"][t] = sp
            seg["done"][t] = done
            seg["v_pred"][t] = v_pred
            seg["c"][t] = c
            seg["phi_sa"][t] = phi_sa
            # seg.get("prev_ac")[t] = pre_act
            # pre_act = a

            if t > 0:
                seg["next_ac"][t-1] = a

        seg["next_ac"][self.truncation_size - 1] = kwargs.get("next_ac")

        """
            add one more value to done and v_pred array
        """
        seg_done = seg["done"]
        vpred = np.append(seg["v_pred"], kwargs.get("next_v_pred"))
        potentials = np.append(seg["phi_sa"], kwargs.get("phi_spap"))

        """
            compute the advantage and GAE values
            for t = T-1, T-2, ..., 3, 2, 1
        """
        gae_lam = np.empty(self.truncation_size, dtype=float)
        seg_rewards = seg["rew"]
        last_gae_lam = 0
        for t in reversed(range(self.truncation_size)):
            non_terminal = 1 - seg_done[t]
            Fsaspap = self.gamma * potentials[t + 1] - potentials[t]
            delta = seg_rewards[t] + Fsaspap + self.gamma * vpred[t + 1] * non_terminal - vpred[t]
            gae_lam[t] = delta + self.gamma * self.lmda * non_terminal * last_gae_lam
            last_gae_lam = gae_lam[t]

        seg["adv"] = gae_lam
        seg["td_lam_ret"] = seg["adv"] + seg["v_pred"]

        # print("The reward batch is {}".format(seg["rew"]))
        # print("The action batch is {}".format(seg["ac"]))

        self.algorithm.learn(ob=seg["ob"], ac=seg["ac"], adv=seg["adv"],
                             td_lam_ret=seg["td_lam_ret"],
                             bs_=seg["sp"], ba_=seg["next_ac"],
                             bdone=seg["done"],
                             c_batch=seg["c"])

        # save param
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            self.save_params()

