import numpy as np
from numpy.linalg import norm
from numpy.linalg import multi_dot
from joblib import Parallel, delayed
from sklearn.utils import resample
import scipy.stats as st
from scipy.linalg import inv, toeplitz, solve
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_memory
from sklearn.linear_model import Lasso, LassoCV, MultiTaskLassoCV
from sklearn.model_selection import KFold
from celer import MultiTaskLassoCV as MTLCV_celer


def ensemble_clustered_inference(X_init, y, ward, n_clusters=1000,
                                 method='DGL', aggregate='quantiles',
                                 gamma_min=0.25, train_size=0.1,
                                 condition_mask=None, groups=None, seed=0,
                                 n_rand=100, predict=False, n_jobs=1,
                                 memory=None, verbose=0, **kwargs):
    '''ecd-MTL'''

    results = Parallel(n_jobs=n_jobs)(
        delayed(clustered_inference)(X_init, y, ward, n_clusters,
                                     method=method, train_size=train_size,
                                     condition_mask=condition_mask,
                                     groups=groups, rand=rand,
                                     predict=predict, memory=memory,
                                     verbose=verbose, **kwargs)
        for rand in np.arange(seed, seed + n_rand))

    if verbose > 0:
        print('Ensembling Step')

    results = np.asarray(results)

    list_sf = results[:, 0, :]
    list_sf_corr = results[:, 1, :]
    list_cdf = results[:, 2, :]
    list_cdf_corr = results[:, 3, :]

    if aggregate == 'quantiles':
        sf = aggregate_quantiles(list_sf, gamma_min)
        sf_corr = aggregate_quantiles(list_sf_corr, gamma_min)
        cdf = aggregate_quantiles(list_cdf, gamma_min)
        cdf_corr = aggregate_quantiles(list_cdf_corr, gamma_min)
    else:
        raise ValueError('Use aggregate = "quantiles"')

    if predict:
        raise ValueError('Use predict = False')

    return sf, sf_corr, cdf, cdf_corr


def ward_clustering(X_init, ward, train_index, condition_mask=None):

    ward = ward.fit(X_init[train_index, :])
    X_reduced = ward.transform(X_init)

    if condition_mask is None:
        X = np.asarray(X_reduced)

    else:
        X = np.asarray(X_reduced[condition_mask])

    return X, ward


def clustered_inference(X_init, y, ward, n_clusters=1000, method='DGL',
                        train_size=1.0, condition_mask=None,
                        groups=None, rand=0, predict=False,
                        n_jobs=1, memory=None, verbose=0, **kwargs):
    '''cd-MTL'''

    memory = check_memory(memory)

    print('n_clusters = ', n_clusters, 'method = ', method, 'rand = ', rand)
    n_trials, n_voxels = X_init.shape

    if verbose > 0:
        print('Clustering step')

    # Sampling
    if groups is None:

        train_index = resample(np.arange(n_trials),
                               n_samples=int(n_trials * train_size),
                               replace=False, random_state=rand)

    else:

        unique_groups = np.unique(groups)
        n_groups = unique_groups.size
        train_group = resample(unique_groups,
                               n_samples=int(n_groups * train_size),
                               replace=False, random_state=rand)
        train_index = np.arange(n_trials)[np.isin(groups, train_group)]

    X, ward = memory.cache(ward_clustering)(
        X_init, ward, train_index, condition_mask=None)

    if verbose > 0:
        print('Inference step')

    # Experiment
    X = StandardScaler().fit_transform(X)
    y = y - np.mean(y)

    if predict:
        raise ValueError('Use predict = False')
    else:
        sf, sf_corr, cdf, cdf_corr = hd_inference(X, y, method,
                                                  n_jobs=n_jobs,
                                                  memory=memory,
                                                  **kwargs)

    sf_compressed = ward.inverse_transform(sf)
    sf_corr_compressed = ward.inverse_transform(sf_corr)
    cdf_compressed = ward.inverse_transform(cdf)
    cdf_corr_compressed = ward.inverse_transform(cdf_corr)

    return (sf_compressed, sf_corr_compressed, cdf_compressed,
            cdf_corr_compressed)


def hd_inference(X, y, method='DGL', predict=False, n_jobs=1, memory=None,
                 verbose=0, **kwargs):
    '''d-MTL'''

    if method == 'DGL':

        Y = y
        sf, sf_corr, cdf, cdf_corr = \
            desparsified_group_lasso(X, Y, n_jobs=n_jobs, memory=memory,
                                     verbose=verbose, **kwargs)

    else:
        raise ValueError('Unknow method')

    return sf, sf_corr, cdf, cdf_corr


def desparsified_group_lasso(X, Y, Sigma=None, max_iter=5000,
                             tol=1e-3, stat='F', method="lasso", c=0.005,
                             celer=True, noise_method='AR', order=1,
                             n_jobs=1,  memory=None, verbose=0):
    """Desparsified MTLasso

    Parameters
    -----------
        X : ndarray or scipy.sparse matrix, (n_samples, n_features)
            Data
        Y : ndarray, shape (n_samples,) or (n_samples, n_targets)
            Target. Will be cast to X's dtype if necessary
        confidence : float, optional
            Confidence level used to compute the confidence intervals.
            Each value should be in the range [0, 1].
        tol : float, optional
            The tolerance for the optimization: if the updates are
            smaller than ``tol``, the optimization code checks the
            dual gap for optimality and continues until it is smaller
            than ``tol``.
        method : string, optional
            The method for the nodewise lasso: "lasso", "lasso_cv"
        memory : str or object with the joblib.Memory interface, default=None
            Used to cache the estimated residuals.
        c : float, optional
            Only used if method="lasso". Then alpha = c * alpha_max.
        """

    X = np.asarray(X)

    n_samples, n_features = X.shape
    n_targets = Y.shape[1]

    Z = np.zeros((n_samples, n_features))
    omega_diag = np.zeros(n_features)

    memory = check_memory(memory)

    if method == "lasso":

        Gram = np.dot(X.T, X)

        k = c * (1. / n_samples)
        alphas = k * np.max(np.abs(Gram - np.diag(np.diag(Gram))), axis=0)

    else:

        Gram = None
        alphas = None

    # Calculating Omega Matrix
    Z, omega_diag = memory.cache(compute_all_residuals, ignore=['n_jobs'])(
        X, alphas, Gram=Gram, max_iter=max_iter,
        tol=tol, method=method, c=c, n_jobs=n_jobs,
        verbose=verbose)

    # Lasso regression
    Sigma_hat, Beta_mtl = group_reid(X, Y, return_Beta=True, celer=celer,
                                     method=noise_method, order=order,
                                     n_jobs=n_jobs)

    if Sigma is not None:
        if Sigma.shape == Sigma_hat.shape:
            Sigma_hat = Sigma
        else:
            raise ValueError('Sigma has not the good shape')

    Theta_hat = n_samples * inv(Sigma_hat)

    # Estimating the coefficient vector
    Beta_bias = Y.T.dot(Z) / np.sum(X * Z, axis=0)

    Beta_mtl = Beta_mtl.T
    Beta_bias = Beta_bias.T

    P = ((Z.T.dot(X)).T / np.sum(X * Z, axis=0)).T
    P_nodiag = P - np.diag(np.diag(P))

    Beta_hat = Beta_bias - P_nodiag.dot(Beta_mtl)

    if stat == 'chi2':
        chi2_scores = \
            np.diag(multi_dot([Beta_hat, Theta_hat, Beta_hat.T])) / omega_diag
        pval = np.minimum(st.chi2.sf(chi2_scores, df=n_targets) * 2, 1.0)
    if stat == 'F':
        f_scores = (np.diag(multi_dot([Beta_hat, Theta_hat, Beta_hat.T])) /
                    omega_diag / n_targets)
        pval = np.minimum(st.f.sf(f_scores, dfd=n_samples, dfn=n_targets) * 2,
                          1.0)
    sign_Beta = np.sign(np.sum(Beta_hat, axis=1))
    sf, sf_corr, cdf, cdf_corr = sf_and_cdf_from_pval_and_sign(pval, sign_Beta)

    return sf, sf_corr, cdf, cdf_corr


def compute_all_residuals(X, alphas, Gram=None, max_iter=5000, tol=1e-3,
                          method='lasso', c=0.005, n_jobs=1, verbose=0):

    n_samples, n_features = X.shape

    results = \
        Parallel(n_jobs=n_jobs, verbose=verbose)(
            delayed(compute_residuals)
                (X=X,
                 column_index=i,
                 alpha=alphas[i],
                 Gram=Gram,
                 max_iter=max_iter,
                 tol=tol,
                 method=method,
                 c=c)
            for i in range(n_features))

    results = np.asarray(results)
    Z = np.stack(results[:, 0], axis=1)
    omega_diag = np.stack(results[:, 1])

    return Z, omega_diag


def compute_residuals(X, column_index, alpha=None, Gram=None, max_iter=5000,
                      tol=1e-3, method='lasso', c=0.005):

    n_samples, n_features = X.shape
    i = column_index

    X_new = np.delete(X, i, axis=1)
    y = np.copy(X[:, i])

    if method == 'lasso' and Gram is None:
        Gram_loc = np.dot(X_new.T, X_new)
    elif method == 'lasso' and Gram is not None:
        Gram_loc = np.delete(np.delete(Gram, i, axis=0), i, axis=1)

    if method == 'lasso' and alpha is None:
        k = c * (1. / n_samples)
        alpha = k * np.max(np.abs(np.dot(X_new, y)))

    if method == 'lasso':
        clf_lasso_loc = Lasso(alpha=alpha, precompute=Gram_loc,
                              max_iter=max_iter, tol=tol)

    elif method == 'lasso_cv':
        clf_lasso_loc = LassoCV(max_iter=max_iter, tol=tol, cv=3)

    clf_lasso_loc.fit(X_new, y)
    z = y - clf_lasso_loc.predict(X_new)

    omega_diag_i = n_samples * np.sum(z ** 2) / np.dot(y, z) ** 2

    return z, omega_diag_i


def group_reid(X, Y, eps=1e-2, fit_Y=True, return_Beta=False,
               celer=True, adjust_celer_tol=True,
               tol=1e-4, max_iter=1e+4, max_iter_celer=50,
               max_epochs_celer=5000, stationary=True,
               method='simple', order=1,
               n_jobs=1, random_state=0):
    """Estimation of covariance matrix

    Parameters
    -----------
        X : ndarray or scipy.sparse matrix, (n_samples, n_features)
            Data
        y : ndarray, shape (n_samples,) or (n_samples, n_targets)
            Target. Will be cast to X's dtype if necessary
        """

    X = np.asarray(X)
    n_samples, n_features = X.shape
    n_targets = Y.shape[1]

    print('Group reid: ' + method + str(order))

    if (max_iter // 5) <= n_features:
        max_iter = n_features * 5

    cv = KFold(n_splits=5, shuffle=True, random_state=random_state)

    if fit_Y:
        if celer:
            if adjust_celer_tol:
                tol = tol * 1e-2

            clf_mtlcv = \
                MTLCV_celer(eps=eps, normalize=False, fit_intercept=False,
                            cv=cv, tol=tol, max_iter=max_iter_celer,
                            max_epochs=max_epochs_celer, n_jobs=n_jobs,
                            verbose=0)
        else:
            clf_mtlcv = \
                MultiTaskLassoCV(eps=eps, normalize=False, fit_intercept=False,
                                 cv=cv, tol=tol, max_iter=max_iter,
                                 n_jobs=n_jobs)

        clf_mtlcv.fit(X, Y)
        Beta_hat = clf_mtlcv.coef_
        Error = clf_mtlcv.predict(X) - Y
        coef_max = np.max(np.abs(Beta_hat))

        if coef_max == 0:
            support = 0
        else:
            support = int(np.sum(np.abs(Beta_hat) > tol * coef_max) /
                          clf_mtlcv.coef_.shape[0])

    else:
        Beta_hat = np.zeros((n_features, n_targets))
        Error = np.copy(Y)
        support = 0

    sigma_hat_raw = \
        np.sqrt((1. / (n_samples - support)) * norm(Error, axis=0) ** 2)

    if stationary:
        sigma_hat = np.median(sigma_hat_raw) * np.ones(n_targets)
        Corr_emp = np.corrcoef(Error.T)

    else:
        sigma_hat = sigma_hat_raw
        Error_resc = Error / sigma_hat
        Corr_emp = np.corrcoef(Error_resc.T)

    if not stationary or method == 'simple':

        rho_hat = np.median(np.diag(Corr_emp, 1))
        Corr_hat = \
            toeplitz(np.geomspace(1, rho_hat ** (n_targets - 1), n_targets))
        Cov_hat = np.outer(sigma_hat, sigma_hat) * Corr_hat

    if stationary and method == 'AR':

        if order > n_targets - 1:
            raise ValueError('Impossible to estimate AR model')

        rho_ar = np.zeros(order + 1)
        # cov_ar = np.zeros(order + 1)
        rho_ar[0] = 1

        for i in range(1, order + 1):
            rho_ar[i] = np.median(np.diag(Corr_emp, i))

        A = toeplitz(rho_ar[:-1])
        coef_ar = solve(A, rho_ar[1:])

        Error_estimate = np.zeros((n_samples, n_targets - order))

        for i in range(order):
            Error_estimate += coef_ar[i] * Error[:, order-i-1:-i-1]

        epsilon = Error[:, order:] - Error_estimate
        sigma_epsilon = \
            np.median(np.sqrt(norm(epsilon, axis=0) ** 2 / n_samples))

        rho_ar_completed = np.zeros(n_targets)
        rho_ar_completed[:order+1] = rho_ar

        for i in range(order + 1, n_targets):
            rho_ar_completed[i] = \
                np.dot(coef_ar[::-1], rho_ar_completed[i-order:i])

        Corr_hat = toeplitz(rho_ar_completed)
        sigma_hat[:] = \
            np.sqrt(sigma_epsilon ** 2 / (1 - np.dot(coef_ar, rho_ar[1:])))
        Cov_hat = np.outer(sigma_hat, sigma_hat) * Corr_hat

    if return_Beta:
        return Cov_hat, Beta_hat

    return Cov_hat


def aggregate_quantiles(list_sf, gamma_min=0.25):
    """Aggregation of survival function values with the Meinshausen procedure

    Parameters
    -----------
        list_sf : ndarray or scipy.sparse matrix, (n_iter, n_features)
            List of survival fuction values
    """

    n_iter, n_features = list_sf.shape
    sf = 0.5 * np.ones(n_features)

    m = n_iter + 1
    k = np.maximum(1, int(np.floor(gamma_min * n_iter)))
    r = 1 - np.log(gamma_min)
    seq = range(k, n_iter)

    asc_sf = np.sort(list_sf, axis=0)
    dsc_sf = asc_sf[::-1]

    for i in np.arange(n_features):

        sf_neg = min([asc_sf[j, i] * m / (j + 1) for j in seq])
        sf_neg = min(0.5, sf_neg)

        sf_pos = max([1 - (1 - dsc_sf[j, i]) * m / (j + 1) for j in seq])
        sf_pos = max(0.5, sf_pos)

        if (1 - sf_pos) < sf_neg:

            sf[i] = np.maximum(0.5, 1 - (1 - sf_pos) * r)

        else:

            sf[i] = np.minimum(0.5, sf_neg * r)

    return sf


def sf_and_cdf_from_pval_and_sign(pval, sign):

    sf = sf_from_pval_and_sign(pval, sign)
    cdf = cdf_from_pval_and_sign(pval, sign)
    sf_corr = sf_corr_from_sf(sf)
    cdf_corr = cdf_corr_from_cdf(cdf)

    return sf, sf_corr, cdf, cdf_corr


def sf_from_pval_and_sign(pval, sign, eps=1e-14):

    n_features = pval.size
    sf = 0.5 * np.ones(n_features)

    sf[sign > 0] = pval[sign > 0] / 2
    sf[sign < 0] = 1 - pval[sign < 0] / 2
    sf[sf > 1 - eps] = 1 - eps

    return sf


def cdf_from_pval_and_sign(pval, sign, eps=1e-14):

    n_features = pval.size
    cdf = 0.5 * np.ones(n_features)

    cdf[sign > 0] = 1 - pval[sign > 0] / 2
    cdf[sign < 0] = pval[sign < 0] / 2
    cdf[cdf > 1 - eps] = 1 - eps

    return cdf


def sf_corr_from_sf(sf):

    n_features = sf.size

    sf_corr = np.zeros(n_features) + 0.5

    sf_corr[sf < 0.5] = np.minimum(0.5, sf[sf < 0.5] * n_features)
    sf_corr[sf > 0.5] = np.maximum(0.5, 1 - (1 - sf[sf > 0.5]) * n_features)

    return sf_corr


def cdf_corr_from_cdf(cdf):

    n_features = cdf.size

    cdf_corr = np.zeros(n_features) + 0.5

    cdf_corr[cdf < 0.5] = np.minimum(0.5, cdf[cdf < 0.5] * n_features)
    cdf_corr[cdf > 0.5] = \
        np.maximum(0.5, 1 - (1 - cdf[cdf > 0.5]) * n_features)

    return cdf_corr


def zscore_from_sf(sf, distrib='Norm'):
    """z-scores from survival function values

    Parameters
    -----------
        sf : float
            Survival function values
    """
    if distrib == 'Norm':
        t_stat = st.norm.isf(sf)

    return t_stat


def zscore_from_cdf(cdf, distrib='Norm'):
    """z-scores from cumulative distribution function values

    Parameters
    -----------
        sf : float
            Survival function values
    """
    if distrib == 'Norm':
        t_stat = st.norm.ppf(cdf)

    return t_stat


def zscore_from_sf_and_cdf(sf, cdf, distrib='Norm'):
    """z-scores from survival function and cumulative distribution function
     values

    Parameters
    -----------
        sf : float
            Survival function values
    """
    if distrib == 'Norm':
        t_stat_sf = zscore_from_sf(sf)
        t_stat_cdf = zscore_from_cdf(cdf)

    t_stat = np.zeros(sf.size)
    t_stat[sf < 0.5] = t_stat_sf[sf < 0.5]
    t_stat[sf > 0.5] = t_stat_cdf[sf > 0.5]

    t_stat = replace_infinity(t_stat, replace_val=40, method='plus-one')

    return t_stat


def replace_infinity(x, replace_val=None, method='times-two'):

    largest_non_inf = np.max(np.abs(x)[np.abs(x) != np.inf])

    if method == 'times-two':
        replace_val_min = largest_non_inf * 2
    elif method == 'plus-one':
        replace_val_min = largest_non_inf + 1

    if (replace_val is not None) and (replace_val < largest_non_inf):
        replace_val = replace_val_min
    elif replace_val is None:
        replace_val = replace_val_min

    x_new = np.copy(x)
    x_new[x_new == np.inf] = replace_val
    x_new[x_new == -np.inf] = -replace_val

    return x_new
