from apo_precond.utils import *
import numpy as np


class Preconditioner(nn.Module):
    _FC_DIAG_DIM_THRESHOLD = 4096
    _CONV_DIAG_DIM_THRESHOLD = 4096

    def __init__(self, var, parameterization, scale, debug=False):
        super(Preconditioner, self).__init__()
        self._shape = var.shape
        self._rank = len(self._shape)
        self.parameterization = parameterization
        self.params_dict = {}
        self.params_lst = []

        if self.parameterization == "ekfac_psd":
            if self._rank == 1:
                num_features = var.shape[0]

                if num_features < self._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["out"] = make_parameter((num_features, num_features), var.get_device(), debug)
                    nn.init.eye_(self.params_dict["out"])
                    self.out = self.params_dict["out"]
                    self.params_dict["out"].data = self.params_dict["out"].data * np.sqrt(scale)
                    self.params_lst.append(self.params_dict["out"])
                else:
                    self.params_dict["d_out"] = make_parameter((num_features,), var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_out"])
                    self.out = self.params_dict["d_out"]
                    self.params_dict["d_out"].data = self.params_dict["d_out"].data * np.sqrt(scale)
                    self.params_lst.append(self.params_dict["d_out"])

            elif self._rank == 2:
                out_features = var.shape[0]
                in_features = var.shape[1]

                if out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_out"] = make_parameter((out_features,), var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_out"])
                    self.d_out = self.params_dict["d_out"]
                    self.params_lst.append(self.params_dict["d_out"])
                else:
                    self.params_dict["out"] = make_parameter((out_features, out_features), var.get_device(), debug)
                    nn.init.eye_(self.params_dict["out"])
                    self.out = self.params_dict["out"]
                    self.params_lst.append(self.params_dict["out"])

                if in_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_in"] = make_parameter((in_features,), var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_in"])
                    self.d_in = self.params_dict["d_in"]
                    self.params_lst.append(self.params_dict["d_in"])
                else:
                    self.params_dict["in"] = make_parameter((in_features, in_features), var.get_device(), debug)
                    nn.init.eye_(self.params_dict["in"])
                    self.inp = self.params_dict["in"]
                    self.params_lst.append(self.params_dict["in"])

                self.params_dict["diag"] = make_parameter((out_features, in_features), var.get_device(), debug)
                nn.init.constant_(self.params_dict["diag"], np.sqrt(scale))
                self.diag = self.params_dict["diag"]
                self.params_lst.append(self.params_dict["diag"])

            elif self._rank == 4:
                flatten_var = var.view(var.data.size(0), -1)
                out_features = flatten_var.shape[0]
                in_features = flatten_var.shape[1]

                if out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_out"] = make_parameter((out_features,), flatten_var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_out"])
                    self.d_out = self.params_dict["d_out"]
                    self.params_lst.append(self.params_dict["d_out"])
                else:
                    self.params_dict["out"] = make_parameter((out_features, out_features), flatten_var.get_device(), debug)
                    nn.init.eye_(self.params_dict["out"])
                    self.out = self.params_dict["out"]
                    self.params_lst.append(self.params_dict["out"])

                if in_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_in"] = make_parameter((in_features,), flatten_var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_in"])
                    self.d_in = self.params_dict["d_in"]
                    self.params_lst.append(self.params_dict["d_in"])
                else:
                    self.params_dict["in"] = make_parameter((in_features, in_features), flatten_var.get_device(), debug)
                    nn.init.eye_(self.params_dict["in"])
                    self.inp = self.params_dict["in"]
                    self.params_lst.append(self.params_dict["in"])

                self.params_dict["diag"] = make_parameter((out_features, in_features), flatten_var.get_device(), debug)
                nn.init.constant_(self.params_dict["diag"], np.sqrt(scale))
                self.diag = self.params_dict["diag"]
                self.params_lst.append(self.params_dict["diag"])

        elif self.parameterization == "scale_ekfac_psd":
            if self._rank == 1:
                num_features = var.shape[0]
                self.params_dict["out"] = make_parameter((num_features, num_features), var.get_device(), debug)
                nn.init.eye_(self.params_dict["out"])
                self.out = self.params_dict["out"]
                self.params_dict["out"].data = self.params_dict["out"].data * np.sqrt(scale)
                self.params_lst.append(self.params_dict["out"])

                self.params_dict["scale"] = make_parameter((1,), var.get_device(), debug)
                nn.init.ones_(self.params_dict["scale"])
                self.params_lst.append(self.params_dict["scale"])

            elif self._rank == 2:
                out_features = var.shape[0]
                in_features = var.shape[1]

                if out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_out"] = make_parameter((out_features,), var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_out"])
                    self.d_out = self.params_dict["d_out"]
                    self.params_lst.append(self.params_dict["d_out"])
                else:
                    self.params_dict["out"] = make_parameter((out_features, out_features), var.get_device(), debug)
                    nn.init.eye_(self.params_dict["out"])
                    self.out = self.params_dict["out"]
                    self.params_lst.append(self.params_dict["out"])

                if in_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    self.params_dict["d_in"] = make_parameter((in_features,), var.get_device(), debug)
                    nn.init.ones_(self.params_dict["d_in"])
                    self.d_in = self.params_dict["d_in"]
                    self.params_lst.append(self.params_dict["d_in"])
                else:
                    self.params_dict["in"] = make_parameter((in_features, in_features), var.get_device(), debug)
                    nn.init.eye_(self.params_dict["in"])
                    self.inp = self.params_dict["in"]
                    self.params_lst.append(self.params_dict["in"])

                self.params_dict["diag"] = make_parameter((out_features, in_features), var.get_device(), debug)
                nn.init.constant_(self.params_dict["diag"], np.sqrt(scale))
                self.diag = self.params_dict["diag"]
                self.params_lst.append(self.params_dict["diag"])

                self.params_dict["scale"] = make_parameter((1,), var.get_device(), debug)
                nn.init.ones_(self.params_dict["scale"])
                self.params_lst.append(self.params_dict["scale"])

        else:
            raise NotImplementedError()

    def precondition_gradient(self, grad):
        if self.parameterization == "ekfac_psd":
            if self._rank == 1:
                num_features = grad.shape[0]
                if num_features < self._FC_DIAG_DIM_THRESHOLD:
                    out = self.params_dict["out"] @ self.params_dict["out"].T
                    precond_grad = out @ grad
                else:
                    out = self.params_dict["d_out"] ** 2.
                    precond_grad = out * grad

            elif self._rank == 2:
                out_features = self._shape[0]
                in_features = self._shape[1]

                if out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD and \
                        in_features < Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    precond_grad = self.params_dict["out"].T @ grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad @ self.params_dict["in"].T

                elif out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD > in_features:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad =self.params_dict["d_out"].unsqueeze(-1) * precond_grad @ self.params_dict["in"].T

                elif out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD <= in_features:
                    precond_grad = self.params_dict["out"].T @ grad * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad * self.params_dict["d_in"].unsqueeze(0)

                else:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad \
                                   * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * precond_grad \
                                   * self.params_dict["d_in"].unsqueeze(0)

            elif self._rank == 4:
                grad_shape = grad.shape
                grad = grad.view(grad.data.size(0), -1)

                out_features = grad.shape[0]
                in_features = grad.shape[1]

                if out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD and \
                        in_features < Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    precond_grad = self.params_dict["out"].T @ grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad @ self.params_dict["in"].T

                elif out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD > in_features:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * precond_grad @ self.params_dict["in"].T

                elif out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD <= in_features:
                    precond_grad = self.params_dict["out"].T @ grad * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad * self.params_dict["d_in"].unsqueeze(0)

                else:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad \
                                   * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * precond_grad \
                                   * self.params_dict["d_in"].unsqueeze(0)

                precond_grad = precond_grad.view(grad_shape)
            else:
                raise Exception()

        elif self.parameterization == "scale_ekfac_psd":
            if self._rank == 1:
                out = self.params_dict["out"] @ self.params_dict["out"].T
                precond_grad = (out @ grad) * self.params_dict["scale"]

            elif self._rank == 2:
                out_features = self._shape[0]
                in_features = self._shape[1]

                if out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD and \
                        in_features < Preconditioner._FC_DIAG_DIM_THRESHOLD:
                    precond_grad = self.params_dict["out"].T @ grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad @ self.params_dict["in"].T
                    precond_grad = precond_grad * self.params_dict["scale"]

                elif out_features >= Preconditioner._FC_DIAG_DIM_THRESHOLD > in_features:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad @ self.params_dict["in"]
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * precond_grad @ self.params_dict["in"].T
                    precond_grad = precond_grad * self.params_dict["scale"]

                elif out_features < Preconditioner._FC_DIAG_DIM_THRESHOLD <= in_features:
                    precond_grad = self.params_dict["out"].T @ grad * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["out"] @ precond_grad * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = precond_grad * self.params_dict["scale"]

                else:
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * grad \
                                   * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = (self.params_dict["diag"] ** 2.) * precond_grad
                    precond_grad = self.params_dict["d_out"].unsqueeze(-1) * precond_grad \
                                   * self.params_dict["d_in"].unsqueeze(0)
                    precond_grad = precond_grad * self.params_dict["scale"]
            else:
                raise Exception()
        else:
            raise Exception()
        return precond_grad
