# coding=utf-8
import numpy as np
import tensorflow as tf
from experiments.utils.memory import ReplayMemory
from experiments.algorithms.ddpg.ddpg_trainer import DDPGTrainer

RENDER = False

GAMMA = 0.999
BATCH_SIZE = 1024
REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 4


class DDPGPbrsTrainer(DDPGTrainer):
    def __init__(self, state_dim, action_dim, algo_name="ddpg_pbrs", **kwargs):
        self.last_exp = None
        super(DDPGPbrsTrainer, self).__init__(state_dim, action_dim, algo_name, **kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.gamma = kwargs.get("gamma", GAMMA)
        self.batch_size = kwargs.get("batch_size", BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)

    def experience(self, s, a, r, sp, terminal, **kwargs):
        """
            get the potential of s and s_n
            and compute the shaping reward
        """
        phi_s = kwargs.get("phi_s")
        phi_sp = kwargs.get("phi_sp")
        if phi_sp == "mujoco":
            if self.last_exp is None:
                self.last_exp = [s, a, r - phi_s, sp, terminal]
            else:
                self.last_exp[2] += self.gamma * phi_s
                self.memory.add(self.last_exp)
                self.last_exp = [s, a, r - phi_s, sp, terminal]
                """
                    if this is the last state
                    then phi_sp = 0
                """
                if terminal:
                    self.memory.add(self.last_exp)
        else:
            f_ssp = self.gamma * phi_sp - phi_s
            self.memory.add((s, a, r + f_ssp, sp, terminal))

    def episode_done(self, test_model):
        self.algorithm.episode_done(test_model)
        self.last_exp = None



