from typing import Any, List, Optional, Union

import torch
from botorch.models.model import Model
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.likelihoods.noise_models import FixedGaussianNoise, Noise
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import Prior
from linear_operator.operators import DiagLinearOperator, LinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky
from torch import Tensor
from torch.nn.parameter import Parameter

from ..constraints import NonTransformedInterval
from ..relevance_pursuit import RelevancePursuitMixin


class SparseOutlierGaussianLikelihood(_GaussianLikelihoodBase):
    def __init__(
        self,
        base_noise: Union[Noise, FixedGaussianNoise],
        dim: int,
        outlier_indices: Optional[List[int]] = None,
        rho_prior: Optional[Prior] = None,
        rho_constraint: Optional[NonTransformedInterval] = None,
        batch_shape: Optional[torch.Size] = None,
        convex_parameterization: bool = True,
        loo: bool = True,
    ) -> None:
        """A likelihood that models the noise of a GP with a sparse outlier model that
        permits additional "robust" variance for a small set of outlier data points.
        Notably, the indices of the outlier data points can be inferred during the
        optimization of the associated log marginal likelihood.

        Args:
            base_noise: The base noise model.
            dim: The number of training observations on which to apply the noise model.
                We could also get this from the forward pass when the model is in
                training mode and cache it, but it's better to be explicit.
            outlier_indices: The indices of the outliers.
            rho_prior: Prior for the rho parameter.
            rho_constraint: Constraint for the rho parameter. Needs to be a
                NonTransformedInterval because exact sparsity cannot be represented
                using smooth transforms like a softplus or sigmoid.
            batch_shape: The batch shape of the learned noise parameter (default: []).
            convex_parameterization: Whether to use a convex parameterization of rho.
            loo: Whether to use leave-one-out (LOO) update equations that can compute
                the optimal values of each individual rho, keeping all else equal.
        """
        noise_covar = SparseOutlierNoise(
            base_noise=base_noise,
            dim=dim,
            outlier_indices=outlier_indices,
            rho_prior=rho_prior,
            rho_constraint=rho_constraint,
            batch_shape=batch_shape,
            convex_parameterization=convex_parameterization,
            loo=loo,
        )
        super().__init__(noise_covar=noise_covar)

    def marginal(
        self,
        function_dist: MultivariateNormal,
        *params: Any,
    ) -> MultivariateNormal:
        mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
        diag_K = covar.diagonal() if self.noise_covar.convex_parameterization else None
        noise_covar = self.noise_covar.forward(*params, shape=mean.shape, diag_K=diag_K)
        full_covar = covar + noise_covar
        return function_dist.__class__(mean, full_covar)

    def expected_log_prob(
        self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any
    ) -> Tensor:
        raise NotImplementedError


# Tangential to this, could introduce mixed variable / fixed noise
# class that would allow us to exactly condition the process on certain
# pseudo-observations corresponding to prior knowledge (i.e. concrete strength
# is zero at time zero).
class SparseOutlierNoise(Noise, RelevancePursuitMixin):
    def __init__(
        self,
        base_noise: Union[Noise, FixedGaussianNoise],
        dim: int,
        outlier_indices: Optional[List[int]] = None,
        rho_prior: Optional[Prior] = None,
        rho_constraint: Optional[NonTransformedInterval] = None,
        batch_shape: Optional[torch.Size] = None,
        convex_parameterization: bool = True,
        loo: bool = True,
    ):
        """A noise model that permits additional "robust" variance for a small set of
        outlier data points. See also SparseOutlierGaussianLikelihood, which leverages
        this noise model.

        NOTE: Let base_noise also use the non-transformed constraints, which is
        probably more stable but orthogonal implementation-wise.

        Args:
            base_noise: The base noise model.
            dim: The number of training observations on which to apply the noise model.
                We could also get this from the forward pass when the model is in
                training mode and cache it, but it's better to be explicit.
            outlier_indices: The indices of the outliers.
            rho_prior: Prior for the rho parameter.
            rho_constraint: Constraint for the rho parameter. Needs to be a
                NonTransformedInterval because exact sparsity cannot be represented
                using smooth transforms like a softplus or sigmoid.
            batch_shape: The batch shape of the learned noise parameter (default: []).
            convex_parameterization: Whether to use a convex parameterization of rho.
            loo: Whether to use leave-one-out (LOO) update equations that can compute
                the optimal values of each individual rho, keeping all else equal.
        """
        super().__init__()
        RelevancePursuitMixin.__init__(self, dim=dim, support=outlier_indices)

        if batch_shape is None:
            batch_shape = base_noise.noise.shape[:-1]

        self.base_noise = base_noise
        device = base_noise.noise.device
        if rho_constraint is None:
            cvx_upper_bound = 1 - 1e-3
            rho_constraint = NonTransformedInterval(
                lower_bound=0.0,
                upper_bound=cvx_upper_bound if convex_parameterization else torch.inf,
                initial_value=0.0,
            )
        else:
            if not isinstance(rho_constraint, NonTransformedInterval):
                raise ValueError("No non-NonTransformedIntervals.")

            if rho_constraint.lower_bound < 0:
                raise ValueError(
                    "SparseOutlierNoise requires rho_constraint.lower_bound >= 0."
                )

            if convex_parameterization and rho_constraint.upper_bound > 1:
                raise ValueError(
                    "Convex parameterization requires rho_constraint.upper_bound <= 1."
                )

        # NOTE: Prefer to keep the initialization of the sparse_parameter in the
        # derived classes of the Mixin, because it might require additional logic
        # that we don't want to put into RelevancePursuitMixin.
        num_outliers = len(self.support)
        self.register_parameter(
            "raw_rho",
            parameter=Parameter(
                torch.zeros(
                    *batch_shape,
                    num_outliers,
                    dtype=base_noise.noise.dtype,
                    device=device,
                )
            ),
        )

        if rho_prior is not None:

            def _rho_param(m):
                return m.rho

            def _rho_closure(m, v):
                return m._set_rho(v)

            self.register_prior("rho_prior", rho_prior, _rho_param, _rho_closure)

        self.register_constraint("raw_rho", rho_constraint)
        self.convex_parameterization = convex_parameterization
        self.loo = loo

    @property
    def sparse_parameter(self) -> Parameter:
        return self.raw_rho

    def set_sparse_parameter(self, value: Parameter) -> None:
        """Sets the sparse parameter.

        NOTE: We can't use the property setter @sparse_parameter.setter because of
        the special way PyTorch treats Parameter types, including custom setters.
        """
        self.raw_rho = value.to(self.raw_rho)

    @staticmethod
    def _from_model(model: Model) -> RelevancePursuitMixin:
        sparse_module = model.likelihood.noise_covar
        if not isinstance(sparse_module, SparseOutlierNoise):
            raise ValueError(
                "The model's likelihood does not have a SparseOutlierNoise noise "
                "as its noise_covar module."
            )
        return sparse_module

    @property
    def _convex_rho(self) -> Parameter:
        """Transforms the raw_rho parameter such that `rho ~= 1 / (1 - raw_rho) - 1`,
        which is a diffeomorphism from [0, 1] to [0, inf] whose derivative is nowhere
        zero. This transforms the marginal log likelihood to be a convex function of
        the raw_rho parameter, whenever the rest of the covariance matrix is well
        conditioned.

        Old idea, probably subsumed with this: inverse_parameterization? This would
        change the meaning of sparsity from an inlier to default to a complete outlier.
        """
        # TODO: want to scale this by the diagonal of the covariance matrix.
        eps = 1e-12
        return 1 / (1 - self.raw_rho + eps) - 1

    # these two don't need to be methods, can pass these as local closures
    @property
    def rho(self) -> Tensor:
        """Dense representation of the potentially sparsely represented raw_rho values,
        so that the last dimension is equal to the number of training points self.dim.

        NOTE: In this case the getter needs to be different than the sparse_parameter
        getter, because the latter must be able to return the parameter in its sparse
        representation. The rho property embeds the sparse representation in a dense
        tensor in order only to propagate gradients to the sparse rhos in the support.

        Returns:
            A `batch_shape x self.dim`-dim Tensor of robustness variances.
        """
        # TODO: don't need to do transform / untransform if we
        # enforce having NonTransformedIntervals.
        rho_outlier = self._convex_rho if self.convex_parameterization else self.raw_rho
        if not self.is_sparse:  # in the dense representation, we're done.
            return rho_outlier

        # If rho_outlier is in the sparse representation, we need to pad the
        # rho values with zeros at the correct positions. The difference
        # between this and calling RelevancePursuit's `to_dense` is that
        # the latter will propagate gradients through all rhos, whereas
        # the path here only propagates gradients to the sparse set of
        # outliers, which is important for the optimization of the support.
        rho_inlier = torch.zeros(
            1, dtype=rho_outlier.dtype, device=rho_outlier.device
        ).expand(rho_outlier.shape[:-1] + (1,))
        rho = torch.cat(
            [rho_outlier, rho_inlier], dim=-1
        )  # batch_shape x (num_outliers + 1)

        return rho[..., self._rho_selection_indices]

    @property
    def _rho_selection_indices(self) -> Tensor:
        # num_train is cached in the forward pass in training mode
        # if an index is not in the outlier indices, we get the zeros from the
        # last index of "rho"
        # is this related to a sparse to dense mapping used in RP?
        rho_selection_indices = torch.full(
            self.raw_rho.shape[:-1] + (self.dim,),
            -1,
            dtype=torch.long,
            device=self.raw_rho.device,
        )
        for i, j in enumerate(self.support):
            rho_selection_indices[j] = i

        return rho_selection_indices

    def forward(
        self,
        *params: Any,
        diag_K: Optional[Tensor] = None,
        shape: Optional[torch.Size] = None,
    ) -> Union[LinearOperator, Tensor]:
        """ """
        noise_covar = self.base_noise(*params, shape=shape)
        # rho should always be applied to the training set, irrespective of whether or
        # not we are in training mode.  TODO: figure out where we call the predictive
        # noise variance to make adjustments to noisy predictive distribution.
        rho = self.rho
        if noise_covar.shape[-1] == rho.shape[-1]:
            if diag_K is not None:
                rho = (diag_K + noise_covar.diagonal()) * rho
            noise_covar = noise_covar + DiagLinearOperator(rho)
        else:
            print(
                "WARNING: Robust rho not applied because the shape of the base noise"
                "covariance is not compatible with the shape of rho. This usually "
                "happens when the model posterior is evaluated on test data."
            )
        return noise_covar

    # relevance pursuit method expansion and contraction related methods
    def expansion_objective(self, mll: ExactMarginalLogLikelihood) -> Tensor:
        # TODO: check if the biggest change in rho coincides with the largest
        # change in likelihood, if not, adjust the objective here.
        f = self._optimal_rhos if self.loo else self._sparse_parameter_gradient
        return f(mll)

    # def contraction_objective(self, mll: ExactMarginalLogLikelihood) -> Tensor:
    # using the magnitude by default.
    #     return self._optimal_rhos(mll)

    def _optimal_rhos(self, mll: ExactMarginalLogLikelihood) -> Tensor:
        """Computes the optimal rho deltas for the given model.

        Args:
            mll: The marginal likelihood, containing the model to optimize.

        Returns:
            A `batch_shape x self.dim`-dim Tensor of optimal rho deltas.
        """
        # train() is important, since we want to evaluate the prior with mll.model(X),
        # but in eval(), __call__ gives the posterior.
        mll.train()  # NOTE: this changes model.train_inputs to be unnormalized.
        X, Y = mll.model.train_inputs[0], mll.model.train_targets
        F = mll.model(X)
        L = mll.likelihood(F)
        S = L.covariance_matrix  # (Kernel Matrix + Noise Matrix)

        # NOTE: The following computation is mathematically equivalent to the formula
        # in this comment, but leverages the positive-definiteness of S via its
        # Cholesky factorization.
        # S_inv = S.inverse()
        # diag_S_inv = S_inv.diagonal(dim1=-1, dim2=-2)
        # loo_var = 1 / S_inv.diagonal(dim1=-1, dim2=-2)
        # loo_mean = Y - (S_inv @ Y) / diag_S_inv

        chol = psd_safe_cholesky(S, upper=True)
        eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype)
        inv_root = torch.linalg.solve_triangular(chol, eye, upper=True)

        # test: inv_root.square().sum(dim=-1) - S.inverse().diag()
        diag_S_inv = inv_root.square().sum(dim=-1)
        loo_var = 1 / diag_S_inv
        S_inv_Y = torch.cholesky_solve(Y.unsqueeze(-1), chol, upper=True).squeeze(-1)
        loo_mean = Y - S_inv_Y / diag_S_inv

        loo_error = loo_mean - Y
        optimal_rho_deltas = loo_error.square() - loo_var
        return (optimal_rho_deltas - self.rho).clamp(0)[~self.is_active]
