from BaseAgent import BaseAgent
import numpy as np
import copy

class Qtarget(BaseAgent):
    def __init__(self, env, config):

        super().__init__(env)

        self.env = env

        self.gamma = config["gamma"]
        self.alpha = config["alpha"]
        self.beta = config["beta"]
        self.eta = config['eta']
        self.updates = 0
        self.weights = None
        self.target_weights = None

        self.init_weight(self.env.env_name)

    def primal_weight(self):
        return self.weights

    def dual_weight(self):
        return self.target_weights

    def init_weight(self, weight_initializer):
        init_type = {"Baird": self.baird_weight, "ThetaTwoTheta": self.theta_two_theta}
        init_type[weight_initializer]()

    def theta_two_theta(self):
        self.weights = np.ones(self.num_features)
        self.target_weights = np.ones(self.num_features)

    def baird_weight(self):

        self.weights = np.ones(self.num_features)
        self.weights[self.env.SEVENTH_STATE] = 10

        self.target_weights = copy.deepcopy(self.weights)

    def td_error(self, state, action, next_state, reward, done_mask):

        phi = self.features[action, state]
        q_sa_theta = self.action_value(state, action, self.weights)
        next_action = self.greedy_policy(next_state, self.target_weights)
        next_q_sa_theta = self.action_value(next_state, next_action, self.target_weights)

        td_error = reward + done_mask * self.gamma * next_q_sa_theta - q_sa_theta

        return td_error

    def update_weight(self, state, action, next_state, reward, done_mask):

        td_error = self.td_error(state, action, next_state, reward, done_mask)
        phi = self.features[action, state]

        gradient = td_error * phi

        # self.writer.add_scalar("abs(gradient)", max(abs(gradient)), self.updates)
        # self.writer.add_scalar("lr", self.lr_alpha.lr, self.updates)

        self.weights = self.weights + self.alpha * (gradient - self.eta * self.weights)

    def update_target(self):
        self.target_weights = self.target_weights + self.beta * (self.weights - self.target_weights)

    def update(self, state, next_state, action, reward, done_mask):

        self.update_weight(state, action, next_state, reward, done_mask)
        self.update_target()

        self.updates += 1