import torch
from torch.optim import Optimizer
from pdb import set_trace as bp
from dowel import logger, tabular


class SCRNVROptimizer(Optimizer):
    def __init__(self, params, inner_itr=10, ro=0.1, l=0.5, epsilon=1e-3, c_prime=0.1, K=10,
                 C=10, S=10):

        self.ro = ro
        self.l = l
        self.S = S
        self.epsilon = epsilon
        self.c_prime = c_prime
        self.inner_itr = inner_itr
        self.step_size = 1 / (20 * l)
        self.iteration = -1
        defaults = dict()
        self.sqr_grads_norms = 0
        self.last_grad_norm = 0
        self.C = C
        self.power = 1.0 / 3.0
        self.K = 1.0 * K
        self.eta = 1.0 / self.C
        self.alpha = 2 * self.K * self.eta ** 2

        self.iteration = 0
        defaults = dict()
        super(SCRNVROptimizer, self).__init__(params, defaults)
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['displacement'] = torch.zeros_like(p)
                state['last_point'] = torch.zeros_like(p)
                state['current_point'] = torch.zeros_like(p)

    def compute_norm_of_list_var(self, array_):
        """
        Args:
        param array_: list of tensors
        return:
        norm of the flattened list
        """
        norm_square = 0
        for i in range(len(array_)):
            norm_square += array_[i].norm(2).item() ** 2
        return norm_square ** 0.5

    def inner_product_of_list_var(self, array1_, array2_):

        """
        Args:
        param array1_: list of tensors
        param array2_: list of tensors
        return:
        The inner product of the flattened list
        """

        sum_list = 0
        for i in range(len(array1_)):
            sum_list += torch.sum(array1_[i] * array2_[i])
        return sum_list

    def update_model_to_last_point(self, ):
        """
        update the parameter based on the displacement
        """

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    state['current_point'].copy_(p)
                    p.copy_(state['last_point'])

    def update_model_to_current_point(self, ):
        """
        update the parameter based on the displacement
        """

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    p.copy_(state['current_point'])

    def cubic_subsolver(self, g, grads, param, g_norm: float, epsilon: float, ro: float, l: float):
        """
        solve the sub problem with gradient decent
        """
        deltas = [0] * len(grads)
        g_tildas = [0] * len(grads)

        with torch.no_grad():
            if g_norm >= l ** 2 / self.ro:
                # compute hessian vector with respect to grads
                hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                          grad_outputs=g, retain_graph=True)
                g_t_dot_bg_t = self.inner_product_of_list_var(g, hvp) / (ro * (g_norm ** 2))
                R_c = -g_t_dot_bg_t + (g_t_dot_bg_t ** 2 + 2 * g_norm / ro) ** 0.5
                for i in range(len(g)):
                    deltas[i] = -R_c * g[i].clone() / g_norm

            else:
                sigma = self.c_prime * (epsilon * ro) ** 0.5 / l
                for i in range(len(g)):
                    deltas[i] = torch.zeros(g[i].shape)
                    khi = torch.rand(g[i].shape)
                    g_tildas[i] = g[i].clone() + sigma * khi
                for t in range(self.inner_itr):
                    # compute hessian vector with respect to delta
                    hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                              grad_outputs=deltas, retain_graph=True)
                    deltas_norm = self.compute_norm_of_list_var(deltas)
                    if self.compute_norm_of_list_var(hvp) > 150:
                        break

                    for i in range(len(g)):
                        deltas[i] = deltas[i] - self.step_size * (
                                g_tildas[i] + hvp[i].clone() + ro / 2 * deltas_norm * deltas[i])

        # compute hessian vector with respect to delta
        hvp = torch.autograd.grad(outputs=grads, inputs=param,
                                  grad_outputs=deltas, retain_graph=True)
        deltas_norm = self.compute_norm_of_list_var(deltas)
        delta_m = 0
        for i in range(len(grads)):
            delta_m += torch.sum(grads[i] * deltas[i]) + 0.5 * torch.sum(
                deltas[i] * hvp[i].clone()) + ro / 6 * deltas_norm ** 3

        deltas_norm = 0
        # update the displacement
        for group in self.param_groups:
            i = 0
            for p in group['params']:
                state = self.state[p]
                deltas_norm += deltas[i].norm(2).item() ** 2
                state['displacement'] = deltas[i]
                i += 1

        return delta_m.item(), deltas_norm

    def update_parameters(self, ):

        for group in self.param_groups:
            with torch.no_grad():
                for p in group['params']:
                    state = self.state[p]
                    displacement = state['displacement']
                    state['last_point'].copy_(p)
                    p.add_(displacement.clone())

    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        torch.autograd.set_detect_anomaly(True)
        self.iteration += 1

        g_square_norm = 0
        grad_square_norm = 0

        # update eta, alpha
        if self.iteration > 1:
            new_eta = 1. / (self.C * self.iteration ** self.power)
            self.alpha = 2 * self.K * self.eta * new_eta
            self.eta = new_eta

        vector = []
        grads = []
        g = []
        param = []
        modified_grads = []
        for group in self.param_groups:

            self.update_model_to_last_point()
            with torch.enable_grad():
                closure()

            for p in group['params']:
                if p.grad is None:
                    continue
                modified_grads.append(p.grad.detach().clone())

            self.update_model_to_current_point()
            with torch.enable_grad():
                closure(current_point=True)

            for p in group['params']:
                if p.grad is None:
                    continue
                vector.append(self.state[p]['displacement'])
                grads.append(p.grad)
                param.append(p)
            # compute gradiant vector
            i = 0
            for p in group['params']:
                state = self.state[p]
                d_p = grads[i]
                grad_square_norm += d_p.norm(2).item() ** 2

                if self.alpha != 1:
                    if 'momentum_buffer' not in state or self.iteration % self.S == 0:
                        buf = state['momentum_buffer'] = d_p
                    else:
                        buf = state['momentum_buffer'].detach()
                        buf.add_(d_p - modified_grads[i]).mul_(0.1).add_(d_p, alpha=1)

                else:
                    buf = state['momentum_buffer'] = d_p

                d_p = buf
                g.append(d_p)
                g_square_norm += d_p.norm(2).item() ** 2
                i += 1

        # store square of grad norm
        self.sqr_grads_norms += self.last_grad_norm
        self.last_grad_norm = grad_square_norm
        delta_m, deltas_norm = self.cubic_subsolver(g, grads, param, g_square_norm ** 0.5, self.epsilon, self.ro,
                                                    self.l)
        self.update_parameters()

        with tabular.prefix("SCRN" + '/'):
            tabular.record('delta of m', delta_m)
            tabular.record('norm of gradient', grad_square_norm ** (1. / 2))
            tabular.record('norm of deltas', deltas_norm)
            # tabular.record('landa min', lambda_min)
            logger.log(tabular)
        return None
