from sswimlib.utils.enums import SEQUENCE, SCRAMBLING, QMC_KWARG, STANDARD_KERNEL
from sswimlib.features import samplers, transformers
from sswimlib.utils.sequences import Sequence
from sswimlib.utils.opt import Param


class Linear():
    def __init__(self,
                 D,
                 has_bias_term=True,
                 transformer=transformers.linear):
        super().__init__()
        self.D = D,
        self.has_bias_term = has_bias_term
        self.__transformer = transformer
        if self.has_bias_term is True:
            self.M = self.D + 1
        else:
            self.M = self.D

    def transform(self, X):
        """
        Applies the instance's sample weight expansion function to some data X
        h is the constant magnitude scalar
        :param X: The input data we want to transform
        :return: The fourier represented feature map for this particular kernel
        """
        return self.__transformer(X, self.has_bias_term)


class LengthscaleKernel():
    """
    Fourier Features for standard shift invariant kernels
      - RBF
      - M12
      - M32
      - M52
    """

    def __init__(self,
                 M,
                 D,
                 ls=Param(init=1.0,
                          forward_fn=None,
                          gmin=0.02,
                          gmax=2.00),
                 ns_type=None,
                 meanshift=None,
                 kernel_type=STANDARD_KERNEL.RBF,
                 sequence_type=SEQUENCE.HALTON,
                 scramble_type=SCRAMBLING.OWEN17,
                 kwargs={QMC_KWARG.PERM: None}):
        """
        :param M:   int
                    The dimensionality of our features
                    M = 2m because we are using the [cos(wx),sin(wx)]
        :param D:   int
                    The dimensionality of our input data
        :param ls: Param()  [Optimizable]
        :param sampler: The sampling function
        :param transformer: The transformer function
        """
        super().__init__()
        if M % 2 != 0:
            raise ValueError("M must be an even number")
        self.M = M
        self.m = M // 2  # This is half the number of features
        self.D = D
        self.ns_type = ns_type
        self.kernel_type = kernel_type
        self.sequencer_type = sequence_type
        self.scramble_type = scramble_type
        self.kwargs = kwargs
        self.sequence = Sequence(N=self.m, D=self.D,
                                 sequence_type=self.sequencer_type,
                                 scramble_type=self.scramble_type,
                                 kwargs=self.kwargs)

        self.S = None
        self.ls = ls
        self.meanshift = meanshift
        self.__sampler = samplers.standard_kernel
        if self.ns_type == "lebesgue_stieltjes":
            self.__transformer = transformers.cos_sin_ns
        else:
            self.__transformer = transformers.cos_sin_ui
        self.sample_frequencies()

    def get_params(self):
        if self.ns_type == "lebesgue_stieltjes":
            params = [*self.ls]
            if self.meanshift:
                params += [*self.meanshift]
            return params
        elif self.ns_type is None:
            return [self.ls]

    def sample_frequencies(self):
        """
        Allows one to resamples the internal spectral weights
        This would typically occur after an optimisation step
        The lengthscale can be optimized separately
        Assumes self.params order is known apriori
        I.e. the sampler and transformer should match each other
        :note: During optimisation, ensure that this method is called
               during the fitting process otherwise the parameters
               won't be updated!
        """

        if self.ns_type == "lebesgue_stieltjes":
            self.S = self.__sampler(sequence=self.sequence,
                                    kernel_type=self.kernel_type,
                                    ls=[self.ls[0].forward(), self.ls[1].forward()],
                                    meanshift=[self.meanshift[0].forward(), self.meanshift[1].forward()],
                                    ns_type=self.ns_type)
        else:
            self.S = self.__sampler(sequence=self.sequence,
                                    kernel_type=self.kernel_type,
                                    ls=self.ls.forward(),
                                    ns_type=self.ns_type)

    def transform(self, X, X_var=None):
        """
        Applies the instance's sample weight expansion function to some data X
        h is the constant magnitude scalar
        :param X:       The input data we want to transform
        :param X_var:   The covariance of the X
                        NOTE: For this paper, we are dealing only with diagonal covariance
        :return: The fourier represented feature map for this particular kernel
        """
        return self.__transformer(X, X_var=X_var, S=self.S)
