import numpy as np
# from mne.viz import plot_sparse_source_estimates
# import matplotlib.pyplot as plt

# cd conferences/NeuRIPS2020/sup_mat

from sklearn.cluster import FeatureAgglomeration
from sklearn.preprocessing import StandardScaler

from hd_inference_function import zscore_from_sf_and_cdf
from hd_inference_function import group_reid
from hd_inference_function import desparsified_group_lasso
from hd_inference_function import ensemble_clustered_inference


def apply_solver(solver, evoked, forward, noise_cov, evoked_baseline=None,
                 loose=0.2, depth=0.8, reapply_source_weighting=True,
                 rank=None, pca=False, **kwargs):

    """Call a custom solver on evoked data.

    This function does all the necessary computation:

    - to select the channels in the forward given the available ones in
      the data
    - to take into account the noise covariance and do the spatial whitening
    - to apply loose orientation constraint as MNE solvers
    - to apply a weigthing of the columns of the forward operator as in the
      weighted Minimum Norm formulation in order to limit the problem
      of depth bias.

    Parameters
    ----------
    solver : callable
        The solver takes 3 parameters: data M, gain matrix G, number of
        dipoles orientations per location (1 or 3). A solver shall return
        2 variables: X which contains the time series of the active dipoles
        and an active set which is a boolean mask to specify what dipoles are
        present in X.
    evoked : instance of mne.Evoked
        The evoked data
    forward : instance of Forward
        The forward solution.
    noise_cov : instance of Covariance
        The noise covariance.
    loose : float in [0, 1] | 'auto'
        Value that weights the source variances of the dipole components
        that are parallel (tangential) to the cortical surface. If loose
        is 0 then the solution is computed with fixed orientation.
        If loose is 1, it corresponds to free orientations.
        The default value ('auto') is set to 0.2 for surface-oriented source
        space and set to 1.0 for volumic or discrete source space.
    depth : None | float in [0, 1]
        Depth weighting coefficients. If None, no depth weighting is performed.

    Returns
    -------
    stc : instance of SourceEstimate
        The source estimates.
    """
    # Import the necessary private functions
    from mne.inverse_sparse.mxne_inverse import \
        (_prepare_gain, is_fixed_orient,
         _reapply_source_weighting, _make_sparse_stc)

    all_ch_names = evoked.ch_names

    # Handle depth weighting and whitening (here is no weights)
    forward, gain, gain_info, whitener, source_weighting, mask = _prepare_gain(
        forward, evoked.info, noise_cov, pca=pca, depth=depth,
        loose=loose, weights=None, weights_min=None, rank=rank)

    # Select channels of interest
    sel = [all_ch_names.index(name) for name in gain_info['ch_names']]

    M = evoked.data[sel]
    M = np.dot(whitener, M)

    if evoked_baseline is not None:
        M_baseline = evoked_baseline.data[sel]
        M_baseline = np.dot(whitener, M_baseline)
        kwargs['M_baseline'] = M_baseline

    n_orient = 1 if is_fixed_orient(forward) else 3

    X, active_set, X_full = solver(M, gain, n_orient, **kwargs)
    active_set_full = np.copy(active_set)
    active_set_full[:] = True

    X = np.atleast_2d(X).T
    X_full = np.atleast_2d(X_full).T

    if reapply_source_weighting:
        X = _reapply_source_weighting(X, source_weighting, active_set)
        X_full = _reapply_source_weighting(X_full, source_weighting,
                                           active_set_full)

    stc = _make_sparse_stc(X, active_set, forward, tmin=evoked.times[0],
                           tstep=1. / evoked.info['sfreq'])
    stc_full = _make_sparse_stc(X_full, active_set_full, forward,
                                tmin=evoked.times[0],
                                tstep=1. / evoked.info['sfreq'])
    return stc, stc_full


def clustered_group_inference_solver(M, G, n_orient, connectivity,
                                     M_baseline=None,
                                     step=1,
                                     ward=None,
                                     n_clusters=1000,
                                     noise_method='simple',
                                     order=None,
                                     n_jobs=1,
                                     fwer_target=1,
                                     **kwargs):

    if n_orient != 1:
        raise ValueError('n_orient must be equal to 1.')

    G = StandardScaler().fit_transform(G)
    M = M - np.mean(M)

    X = np.copy(G)
    Y = np.copy(M[:, ::step])

    if ward is None:

        ward = FeatureAgglomeration(n_clusters=n_clusters,
                                    connectivity=connectivity,
                                    linkage='ward',
                                    memory=None)

    ward.fit(G)
    X = StandardScaler().fit_transform(ward.transform(G))
    Y = np.copy(M[:, ::step])

    if M_baseline is not None:
        assert M_baseline.shape == M.shape

        if 'celer' in kwargs:
            celer = kwargs['celer']
        else:
            celer = None

        Y_baseline = np.copy(M_baseline[:, ::step])
        Sigma_hat = group_reid(X, Y_baseline,
                               fit_Y=False,
                               return_Beta=False,
                               n_jobs=n_jobs,
                               method=noise_method,
                               order=order,
                               celer=celer)
    else:
        Sigma_hat = None

    sf, sf_corr, cdf, cdf_corr = \
        desparsified_group_lasso(X, Y, Sigma=Sigma_hat, n_jobs=n_jobs,
                                 memory='.', verbose=1,
                                 noise_method=noise_method,
                                 order=order, **kwargs)

    sf_compressed = ward.inverse_transform(sf)
    sf_corr_compressed = ward.inverse_transform(sf_corr)
    cdf_compressed = ward.inverse_transform(cdf)
    cdf_corr_compressed = ward.inverse_transform(cdf_corr)

    active_set = np.logical_or(sf_corr_compressed < fwer_target / 2,
                               cdf_corr_compressed < fwer_target / 2)
    zscore = zscore_from_sf_and_cdf(sf_compressed, cdf_compressed)
    zscore_active_set = zscore[active_set]

    return zscore_active_set, active_set, zscore


def ensemble_clustered_group_inference_solver(M, G, n_orient, connectivity,
                                              M_baseline=None,
                                              step=1,
                                              ward=None,
                                              n_clusters=1000,
                                              n_rand=100,
                                              n_jobs=1,
                                              fwer_target=1,
                                              **kwargs):

    if n_orient != 1:
        raise ValueError('n_orient must be equal to 1.')

    G = StandardScaler().fit_transform(G)
    M = M - np.mean(M)

    X = np.copy(G)
    Y = np.copy(M[:, ::step])

    if ward is None:

        ward = FeatureAgglomeration(n_clusters=n_clusters,
                                    connectivity=connectivity,
                                    linkage='ward',
                                    memory=None)

    Y = np.copy(M[:, ::step])

    if M_baseline is not None:
        assert M_baseline.shape == M.shape

        if 'celer' in kwargs:
            celer = kwargs['celer']
        else:
            celer = False

        if 'noise_method' in kwargs:
            noise_method = kwargs['noise_method']
        else:
            noise_method = 'simple'

        if 'order' in kwargs:
            order = kwargs['order']
        else:
            order = 1

        ward.fit(X)
        X_compressed = StandardScaler().fit_transform(ward.transform(X))
        Y_baseline = np.copy(M_baseline[:, ::step])
        Sigma_hat = group_reid(X_compressed, Y_baseline,
                               fit_Y=False,
                               return_Beta=False,
                               celer=celer,
                               n_jobs=n_jobs,
                               method=noise_method,
                               order=order)
        kwargs['Sigma_hat'] = Sigma_hat

    sf, sf_corr, cdf, cdf_corr = \
        ensemble_clustered_inference(X, Y, ward,
                                     n_clusters=n_clusters,
                                     method='DGL',
                                     n_rand=n_rand,
                                     n_jobs=n_jobs,
                                     memory='.',
                                     verbose=1,
                                     **kwargs)

    active_set = np.logical_or(sf_corr < fwer_target / 2,
                               cdf_corr < fwer_target / 2)
    zscore = zscore_from_sf_and_cdf(sf, cdf)
    zscore_active_set = zscore[active_set]

    return zscore_active_set, active_set, zscore
