# coding=utf-8
import numpy as np
import tensorflow as tf
from ...utils.memory import ReplayMemory
from .ddpg_algo import DDPGAlgo

RENDER = False

BATCH_SIZE = 1024
REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 4

FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 10000


class DDPGTrainer(object):
    def __init__(self, state_dim, action_dim, algo_name="ddpg", **kwargs):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.algo_name = algo_name

        """
            set trainer parameters
        """
        self.set_trainer_parameters(**kwargs)

        self.init_algo(**kwargs)

    def set_trainer_parameters(self, **kwargs):
        self.batch_size = kwargs.get("batch_size", BATCH_SIZE)
        self.replay_buffer_size = kwargs.get("replay_buffer_size", REPLAY_BUFFER_SIZE)
        self.update_freq = kwargs.get("update_freq", UPDATE_FREQ)

    def init_algo(self, **kwargs):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DDPGAlgo(self.session, self.graph, self.state_dim, self.action_dim,
                                  algo_name=self.algo_name, **kwargs)

        self.update_cnt = 0

        # the replay buffer
        self.memory = ReplayMemory(self.replay_buffer_size)

        # also create a tf file writer for writing other information
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, _ = self.algorithm.choose_action(state, test_model)
        return a, None

    def experience(self, s, a, r, sp, terminal, **kwargs):
        # put it to memory
        self.memory.add((s, a, r, sp, terminal))
        # self.algorithm.experience((s, a, r, sp, terminal))

    def update(self, t):

        if len(self.memory.store) > FIRST_UPDATE_SAMPLE_NUM:
            # update frequency
            if not t % self.update_freq == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1

            # get mini batch from replay buffer
            sample = self.memory.get_minibatch(self.batch_size)
            s_batch, a_batch, r_batch, sp_batch, done_batch = [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]))

            # save param
            self.save_params()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)

    def load_params(self, load_cnt):
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)
        print("load model for update %s " % load_cnt)

    def episode_done(self, test_model):
        self.algorithm.episode_done(test_model)

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
