# -*- coding: utf-8 -*-
# cd conferences/NeuRIPS2020/sup_mat

import os
import os.path as op
import numpy as np
import mne
from mne.datasets import sample
from mne.datasets import somato
from mne.minimum_norm import make_inverse_operator
from mne.minimum_norm import apply_inverse

from mne.inverse_sparse.mxne_inverse import _prepare_gain, _make_sparse_stc

from apply_hd_inference_function import (
    apply_solver,
    clustered_group_inference_solver,
    ensemble_clustered_group_inference_solver)

cur_dir = os.getcwd().split('/')[-1]

MNE_DATA = None

# Apply your custom solver
solver_type = 'clustered_group_inference'
# solver_type = 'ensemble_clustered_group_inference'
# solver_type = 'sLORETA'

#  Choose condition
cond = 'audio'
# cond = 'visual'
# cond = 'somato'

plot = True
save = True
n_jobs = 1

if cond == 'somato':
    data_path = somato.data_path(MNE_DATA)
    subject = '01'
    subjects_dir = data_path + '/derivatives/freesurfer/subjects'
    task = 'somato'
    raw_fname = op.join(data_path, 'sub-{}'.format(subject), 'meg',
                        'sub-{}_task-{}_meg.fif'.format(subject, task))
    fwd_fname = op.join(data_path, 'derivatives', 'sub-{}'.format(subject),
                        'sub-{}_task-{}-fwd.fif'.format(subject, task))
    condition = 'Unknown'

    # Read evoked
    raw = mne.io.read_raw_fif(raw_fname)
    events = mne.find_events(raw, stim_channel='STI 014')
    reject = dict(grad=4000e-13, eog=350e-6)
    picks = mne.pick_types(raw.info, meg=True, eog=True)

    event_id, tmin, tmax = 1, -.2, .25
    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                        reject=reject, preload=True)
    evoked = epochs.average()
    evoked = evoked.pick_types(meg=True)

    # Compute noise covariance matrix
    noise_cov = mne.compute_covariance(epochs, rank='info', tmax=0.)
    t_min, t_max = 0.03, 0.04
    # t_min, t_max = -0.06, -0.05
    t_step = 1.0 / 300
    rank = 'info'
    pca = True

else:
    subject = 'sample'
    data_path = sample.data_path(MNE_DATA)
    fwd_fname = data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif'
    ave_fname = data_path + '/MEG/sample/sample_audvis-ave.fif'
    raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
    cov_fname = data_path + '/MEG/sample/sample_audvis-shrunk-cov.fif'
    subjects_dir = data_path + '/subjects'
    # Read noise covariance matrix
    noise_cov = mne.read_cov(cov_fname)
    # Handling average file
    condition = 'Left visual' if cond == 'visual' else 'Left Auditory'
    evoked = mne.read_evokeds(ave_fname, condition=condition,
                              baseline=(None, 0))
    t_min, t_max = 0.05, 0.1
    t_step = 0.01
    rank = None
    pca = False

evoked.pick_types('grad')
evoked.crop(-0.05, 0.2).plot()
mne.viz.set_3d_backend("pyvista")

# evoked.pick_types('grad')  # why only 'grad' test with more ?

# Handling forward solution
forward = mne.read_forward_solution(fwd_fname)
connectivity = mne.source_estimate.spatial_src_connectivity(forward['src'])

###############################################################################
###############################################################################
# Apply your custom solver
# solver_type = 'clustered_group_inference'
# solver_type = 'ensemble_clustered_group_inference'
# solver_type = 'sLORETA'
n_clusters = 1000
step = int(t_step * evoked.info['sfreq'])
use_baseline = False
noise_method = 'AR'
order = 1
c = 0.005
train_size = 0.1
celer = True
lambda2 = 1. / 9
kwargs = {}
stat = 'F'
n_rand = 100
aggregate = 'quantiles'
gamma_min = 0.25

if solver_type == 'clustered_group_inference':
    solver_function = clustered_group_inference_solver
    loose, depth = 0., 0.
    kwargs['connectivity'] = connectivity
    kwargs['n_clusters'] = n_clusters
    kwargs['reapply_source_weighting'] = False
    kwargs['celer'] = celer
    kwargs['n_jobs'] = n_jobs
    kwargs['c'] = c
    kwargs['step'] = 1
    kwargs['noise_method'] = noise_method
    kwargs['order'] = order
    kwargs['stat'] = stat
elif solver_type == 'ensemble_clustered_group_inference':
    solver_function = ensemble_clustered_group_inference_solver
    loose, depth = 0., 0.
    kwargs['connectivity'] = connectivity
    kwargs['n_clusters'] = n_clusters
    kwargs['train_size'] = train_size
    kwargs['reapply_source_weighting'] = False
    kwargs['celer'] = celer
    kwargs['c'] = c
    kwargs['n_jobs'] = n_jobs
    kwargs['step'] = 1
    kwargs['noise_method'] = noise_method
    kwargs['order'] = order
    kwargs['stat'] = stat
    kwargs['aggregate'] = aggregate
    kwargs['gamma_min'] = gamma_min
    kwargs['n_rand'] = n_rand
elif solver_type in ['sLORETA', 'dSPM', 'MNE']:
    kwargs['method'] = solver_type

evoked.crop(tmin=t_min, tmax=t_max)
evoked.decimate(step)

if solver_type in ['sLORETA', 'dSPM', 'MNE']:
    inv = make_inverse_operator(evoked.info, forward, noise_cov)
    stc_full = apply_inverse(evoked, inv, lambda2=lambda2,
                             method=solver_type, pick_ori='normal')
    stc_full = stc_full.mean()
    X_full_1D = np.copy(stc_full.data)
    active_set = np.copy((np.abs(X_full_1D) > 3))
    active_set = active_set.flatten()

    forward, gain, gain_info, whitener, source_weighting, mask = \
        _prepare_gain(forward, evoked.info, noise_cov, pca=False, depth=0.,
                      loose=0., weights=None, weights_min=None, rank=None)
    X = np.copy(np.atleast_2d(X_full_1D[active_set]))
    X = X.flatten()

    stc = _make_sparse_stc(X, active_set, forward, stc_full.tmin,
                           tstep=stc_full.tstep)
    stc.subject = 'sample'

else:
    stc, stc_full = apply_solver(solver_function, evoked, forward,
                                 noise_cov, loose=loose, depth=depth,
                                 rank=rank, pca=pca, **kwargs)

if plot:
    fwer_target = 0.2
    if solver_type in ['original', 'hd_inference', 'sLORETA', 'dSPM', 'MNE',
                       'group_inference']:
        correction = 1. / stc_full.data.shape[0]
    else:
        correction = 1. / n_clusters
    from scipy.stats import norm
    zscore_target = norm.isf((fwer_target / 2) * correction)

    test = False
    if test:
        evoked.crop(0.0, 0.2)
        evoked.plot()
    max_stc = np.maximum(np.max(np.abs(stc._data)), zscore_target)
    clim = dict(pos_lims=(zscore_target, zscore_target, 6), kind='value')
    brain = stc.plot(subject=subject, hemi='both', subjects_dir=subjects_dir,
                     clim=clim)
