import matplotlib.pyplot as plt
from qpsolvers import solve_qp
from scipy import optimize
from copy import deepcopy
import numpy as np

EPS = 1e-8

def FDM(f, x, y, delta=1e-6):
    init_value = f(x, y)
    grad = []
    grad.append((f(x + delta, y) - init_value)/delta)
    grad.append((f(x, y + delta) - init_value)/delta)
    return np.array(grad)

class Solver:
    def __init__(self, num_costs, limit_values, H_mat, max_kl, slack_decay, ls_tol=1e-3, ls_decay=0.8, zeta=0.01) -> None:
        self.num_costs = num_costs
        self.limit_values = deepcopy(limit_values)
        self.H_mat = deepcopy(H_mat)
        self.max_kl = max_kl
        self.slack_decay = slack_decay
        self.slack_max_kl = 0.0
        self.ls_tol = ls_tol
        self.ls_decay = ls_decay
        self.zeta = zeta

        # for solver
        self.bounds = optimize.Bounds(np.zeros(num_costs + 1), np.ones(num_costs + 1)*np.inf)
        def dual(x, q_scalar, r_vector, S_mat, c_vector, max_kl):
            EPS = 1e-6
            lam_vector = x[:-1]
            nu_scalar = x[-1]
            objective = (q_scalar - 2.0*np.dot(r_vector, lam_vector) + np.dot(lam_vector, S_mat@lam_vector))/(2.0*nu_scalar + EPS) \
                            - np.dot(lam_vector, c_vector) + nu_scalar*max_kl
            return objective
        self.dual = dual

        # to visualize trust-region
        a, b, c, d = self.H_mat.flatten()
        e = a - (b + c)**2/(4*d)
        self.tr_xs = np.arange(-np.sqrt(2.0*self.max_kl/e), np.sqrt(2.0*self.max_kl/e) + 0.01, 0.01)
        self.tr_ys1 = (-((b + c)/(2*d))*self.tr_xs + np.sqrt(np.clip(2*self.max_kl - e*self.tr_xs**2, 0.0, np.inf)))/d
        self.tr_ys2 = (-((b + c)/(2*d))*self.tr_xs - np.sqrt(np.clip(2*self.max_kl - e*self.tr_xs**2, 0.0, np.inf)))/d

    def solve(self, state, objective, cost_functions):
        # for objectives
        g_vector = FDM(objective, *state)
        H_inv_g_vector = np.linalg.solve(self.H_mat, g_vector)
        g_H_inv_g_scalar = np.dot(g_vector, H_inv_g_vector)

        # for constraints
        b_vectors = []
        H_inv_b_vectors = []
        c_scalars = []
        max_c_scalars = []
        cost_scalars = []
        for cost_idx in range(self.num_costs):
            cost_scalar = cost_functions[cost_idx](*state)
            b_vector = FDM(cost_functions[cost_idx], *state)
            H_inv_b_vector = np.linalg.solve(self.H_mat, b_vector)
            c_scalar = cost_scalar - self.limit_values[cost_idx]
            max_c_scalar = np.sqrt(2.0*self.max_kl*np.dot(b_vector, H_inv_b_vector))
            cost_scalars.append(cost_scalar)
            b_vectors.append(b_vector)
            H_inv_b_vectors.append(H_inv_b_vector)
            c_scalars.append(c_scalar)
            max_c_scalars.append(max_c_scalar)
        B_mat = np.array(b_vectors).T
        H_inv_B_mat = np.array(H_inv_b_vectors).T
        S_mat = B_mat.T@H_inv_B_mat
        r_vector = g_vector@H_inv_B_mat
        c_vector = np.array(c_scalars)
        cost_vector = np.array(cost_scalars)

        # find scaling factor
        lam_vector = solve_qp(P=(S_mat + np.eye(self.num_costs)*EPS), q=-c_vector, lb=np.zeros(self.num_costs))
        approx_kl = 0.5*np.dot(lam_vector, S_mat@lam_vector)

        # for debugging! #
        temp_vector = H_inv_B_mat@lam_vector
        ##################

        # feasibility check
        if approx_kl/self.max_kl - 1.0 > -0.001:
            # try slack_max_kl
            for c_idx in range(len(c_vector)):
                max_c_scalar = max_c_scalars[c_idx]*np.sqrt(self.slack_max_kl/self.max_kl)
                c_vector[c_idx] = min(max_c_scalar, cost_vector[c_idx] - self.limit_values[c_idx])
            lam_vector = solve_qp(P=(S_mat + np.eye(self.num_costs)*EPS), q=-c_vector, lb=np.zeros(self.num_costs))
            approx_kl = 0.5*np.dot(lam_vector, S_mat@lam_vector)
            if approx_kl/self.max_kl - 1.0 > -0.001:
                recovery_mode = True
            else:
                recovery_mode = False
        else:
            recovery_mode = False

        if recovery_mode:
            for c_idx in range(len(c_vector)):
                max_c_scalar = max_c_scalars[c_idx]
                c_vector[c_idx] = min(max_c_scalar, cost_vector[c_idx] - self.limit_values[c_idx] + self.zeta)
            lam_vector = solve_qp(P=(S_mat + np.eye(self.num_costs)*EPS), q=-c_vector, lb=np.zeros(self.num_costs))
            approx_kl = 0.5*np.dot(lam_vector, S_mat@lam_vector)
            scaling = 1.0 if approx_kl <= self.max_kl else np.sqrt(self.max_kl/approx_kl)
            delta_state = -scaling*H_inv_B_mat@lam_vector
        else: 
            # feasible
            x0 = np.ones(self.num_costs + 1)
            scaling = 1.0 if approx_kl <= self.max_kl else np.sqrt(self.max_kl/approx_kl)
            res = optimize.minimize(\
                self.dual, x0, method='trust-constr', 
                args=(g_H_inv_g_scalar, r_vector, S_mat, c_vector, self.max_kl), 
                bounds=self.bounds, options={'disp': False, 'initial_tr_radius':0.1, 'xtol':1e-5, 'gtol':1e-5, 'barrier_tol':1e-5}
            )
            lam_vector, nu_scalar = res.x[:-1], res.x[-1]
            delta_state = (H_inv_g_vector - H_inv_B_mat@lam_vector)/(nu_scalar + EPS)
        print(recovery_mode, approx_kl)

        # for debugging! #
        temp_vector2 = deepcopy(delta_state)
        ##################

        # line search
        beta = 1.0
        init_obj = objective(*state)
        while True:
            new_state = state + beta*delta_state
            new_cost_vector = np.array([cost_f(*new_state) for cost_f in cost_functions])
            if np.sum(new_cost_vector - cost_vector <= np.maximum(-c_vector, self.ls_tol)) == self.num_costs:
                if recovery_mode:
                    break
                else:
                    new_obj = objective(*new_state)
                    if new_obj >= init_obj:
                        break
            beta *= self.ls_decay

        # decay slack_max_kl
        self.slack_max_kl = 1.0 - (1.0 - self.slack_max_kl)*self.slack_decay

        info = [g_vector, b_vectors, temp_vector, temp_vector2]
        return new_state, info
    
    def draw(self, ax, state):
        ax.plot(self.tr_xs + state[0], self.tr_ys1 + state[1], 'b')
        ax.plot(self.tr_xs + state[0], self.tr_ys2 + state[1], 'b')

