"""
This script launches a script in order to make AUC ROC curves
"""

import socket
import os
import numpy as np
from joblib import Parallel, delayed
from joblib import parallel_backend
from itertools import product
import time

from sgcl.solvers import solver, get_path
from sgcl.utils import get_alpha_max, get_sigma_min
from data.artificial import get_data_me
from expes.utils import check_and_create_dirs, get_path_expe
from data.semi_real import get_semi_real_data


if __name__ == '__main__':
    print("enter main", flush=True)

    """You can specify wich parameters you want:
    parameters from expe_1 or expe_2.
    For instance you can run:
    $ %run main_basic_nn.py --expe expe_2
    """
    import argparse
    import importlib
    parser = argparse.ArgumentParser('Main script for experiments on synthetic data')
    parser.add_argument('--expe', type=str, default='expe3D',
                        help='Choose the parameters for the experiement.')
    args = parser.parse_args()
    expe = importlib.import_module("expes.expe3.params_{}".format(args.expe))

    # parameters of the problem
    n_channels = expe.n_channels
    n_times = expe.n_times
    n_sources = expe.n_sources
    dictionary_type = expe.dictionary_type
    noise_type = expe.noise_type
    n_active = expe.n_active
    list_rho_X = expe.list_rho_X
    list_SNR = expe.list_SNR
    list_rho_noise = expe.list_rho_noise
    list_n_epochs = expe.list_n_epochs
    n_jobs = expe.n_jobs

    # n_samples for ROC curves
    n_repet_roc_curves = expe.n_repet_roc_curves

    # parameters of the solver
    n_iter = expe.n_iter
    list_gap_freq = expe.dict_gap_freq
    S_freq = expe.S_freq
    active_set_freq = expe.active_set_freq
    list_pb_name = expe.list_pb_name
    # tol = expe.tol
    dict_tol = expe.dict_tol
    # dict_use_accel = expe.dict_use_accel
    # parameters of the problem
    dict_list_p_alpha = expe.dict_list_p_alpha
    dict_gap_freq = expe.dict_gap_freq
    # parameters to store results
    name_dir_raw_res = expe.name_dir_raw_res
    name_dir_raw_res = name_dir_raw_res + "_" + args.expe

    name_expe = expe.name_expe
    path_expe = "sgcl/expes/%s/" % name_expe
    check_and_create_dirs(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)


def parallel_function(
        pb_name, rho_X, rho_noise, SNR, n_epochs, seed_number):
    params = (pb_name, rho_X, rho_noise, SNR, n_epochs, seed_number)
    path_dense_Bs = get_path_expe(
        name_expe, "", name_dir_raw_res, params,
        extension='npy', obj="dense_Bs")
    path_masks_Bs = get_path_expe(
        name_expe, "", name_dir_raw_res, params,
        extension='npy', obj="masks_Bs")
    if not os.path.isfile(path_masks_Bs) or not os.path.isfile(path_dense_Bs):
        # generate data
        X, all_epochs, B_star,  (multiplicativ_factor, S_star) = \
            get_data_me(
                dictionary_type=dictionary_type, rho=rho_X,
                noise_type=noise_type, n_channels=n_channels,
                n_times=n_times, n_sources=n_sources, n_epochs=n_epochs,
                n_active=n_active, rho_noise=rho_noise, SNR=SNR,
                seed=seed_number)
        Y = all_epochs.mean(axis=0)
        if pb_name == "CLaR" or pb_name == "MTLME" or \
                pb_name == "MLER" or pb_name == "MRCER":
                measurement = all_epochs
        else:
            measurement = Y
        try:
            sigma_min = get_sigma_min(Y)
            alpha_max = get_alpha_max(X, measurement, sigma_min, pb_name)
            print("alpha_max = %.2f" % alpha_max)
            print("sigma_min = %.2f" % sigma_min)

            # and save them
            path = name_dir_raw_res + ("/B_star_%i.npy" % seed_number)
            np.save(path, B_star)
            path = name_dir_raw_res + ("/S_star_%i.npy" % seed_number)
            np.save(path, S_star)
            path = name_dir_raw_res + ("/X_%i.npy" % seed_number)
            np.save(path, X)

            tol = dict_tol[pb_name]
            list_p_alpha = dict_list_p_alpha[pb_name]
            gap_freq = dict_gap_freq[pb_name]
            dict_masks, dict_dense_Bs = get_path(
                X, measurement, list_p_alpha, alpha_max,
                sigma_min, B0=None, n_iter=n_iter, tol=tol, gap_freq=gap_freq,
                active_set_freq=active_set_freq, S_freq=S_freq, pb_name=pb_name,
                use_accel=False, heur_stop=True)

            np.save(path_masks_Bs, dict_masks)
            # load_masks = np.load(path_masks_Bs).tak(0)
            np.save(path_dense_Bs, dict_dense_Bs)
        except:
            print("Computation failed")

if __name__ == '__main__':

    # import ipdb; ipdb.set_trace()
    if socket.gethostname().startswith(('node', 'margaret')):
        from dask_jobqueue import SLURMCluster
        from dask.distributed import Client
        cluster = SLURMCluster(182)  # choose your own to avoid conflicts
        # time.sleep(5)
        cluster.scale(10)
        client = Client(cluster)  # registers as the default "dask" backend
        parallel_backend('dask')
        time.sleep(20)
        print("cluster setup")

    print("enter parallel")
    n_jobs = np.minimum(n_jobs, 45)
    # Parallel(n_jobs=n_jobs, verbose=100)(
    Parallel(n_jobs=n_jobs, verbose=100, backend='multiprocessing')(
        delayed(parallel_function)(
            pb_name, rho_X, rho_noise, SNR, n_epochs, seed_number)
        for rho_X, rho_noise, SNR, seed_number, n_epochs, pb_name in
            product(list_rho_X, list_rho_noise, list_SNR,
            range(n_repet_roc_curves), list_n_epochs, list_pb_name))
    print('OK finished parallel')

    # print("enter sequential")
    # for rho_X, rho_noise, SNR, seed_number, n_epochs, (pb_name, gap_freq) in \
    #     product(list_rho_X, list_rho_noise, list_SNR, range(n_repet_roc_curves), list_n_epochs, zip(list_pb_name, list_gap_freq)):
    #         parallel_function(pb_name, gap_freq, rho_X, rho_noise, SNR, n_epochs, seed_number)
    # print("OK finished sequential")

    if socket.gethostname().startswith(('node', 'margaret')):
        client.close()
        cluster.stop_all_jobs()
        cluster.close()
