import numpy as np
from typing import Callable, Union
import sys
sys.path.append('./')
from est.models.Model import Model
from est.fit.Minimizer import ScipyMinimizer


class EstimatedResult(object):
    def __init__(self,
                 params: np.ndarray,
                 log_like: float,
                 status: int,
                 success: bool,
                 message: str):
        """
        Container for the result of estimation
        :param params: array, the estimated (optimal) params
        :param log_like: float, the final log-likelihood value (at optimum)
        :param status: int, termination status
        for 'trust-constr' method, 0 : The maximum number of function evaluations is exceeded.
        1 : gtol termination condition is satisfied. 2 : xtol termination condition is satisfied. 
        3 : callback function requested termination.
        for 'BFGS' method:
        0: The optimization terminated successfully.
        1: The maximum number of function evaluations was exceeded.
        2: The algorithm did not converge to a solution within the specified tolerance.
        3: The line search algorithm could not find a better solution.
        4: The optimization was terminated by the user.
        5: The optimization encountered an error.
        """
        self.params = params
        self.log_like = log_like
        self.status = status
        self.success = success
        self.message = message

    @property
    def likelihood(self) -> float:
        """ The likelihood with estimated params """
        return np.exp(self.log_like)

    def __str__(self):
        """ String representation of the class (for pretty printing the results) """
        return f'\nparams    | {self.params} \n' \
               f'likelihood  | {self.log_like} '


class LikelihoodEstimator:
    def __init__(self,
                 sample: np.ndarray,
                 dt: float,
                 model: Model,
                 minimizer: ScipyMinimizer = ScipyMinimizer(),
                 t0: float = 0):
        """
        Abstract base class for Diffusion Estimator
        :param sample: np.ndarray, a univariate time series sample from the diffusion (ascending order of time)
        :param dt: float, time step (time between diffusion steps)
            Either supply a constant dt for all time steps, or supply a set of dt's equal in length to the sample
        :param model: the diffusion model. This defines the parametric family/model,
            the parameters of which will be fitted during estimation
        :param minimizer: Minimizer, the minimizer that is used to maximize the likelihood function. If none is
            supplied, then ScipyMinimizer is used by default
        :param t0: Union[float, np.ndarray], optional parameter, if you are working with a time-homogenous model,
            then this doesnt matter. 
        """
        self._sample = sample.squeeze()
        self._dt = dt
        self._model = model
        self._minimizer = minimizer
        self._t0 = t0        
        self._min_prob = 1e-30  # used to floor probabilities when evaluating the log

    def estimate_params(self, params0: np.ndarray) -> EstimatedResult:
        """
        Main estimation function
        :param params0: array, the initial guess params
        :return: (array, float), the estimated params and final likelihood
        """
        return self._estimate_params(params0=params0, likelihood=self.log_likelihood_negative)


    def _estimate_params(self, params0: np.ndarray, likelihood: Callable) -> EstimatedResult:
        """
        Main estimation function
        :param params0: array, the initial guess params
        :return: array, the estimated params
        """

        res = self._minimizer.minimize(function=likelihood, guess=params0)
        params = res.params
        final_like = -res.value
        status = res.status
        message = res.message
        success = res.success
        return EstimatedResult(params=params, log_like=final_like, status=status, 
                               success = success, message=message)