import seaborn as sns
import numpy as np
import time
import matplotlib.pyplot as plt

from numpy.linalg import norm
from sklearn.utils import check_random_state
from sklearn.metrics import auc

from data.artificial import (get_data_me)
from clar.solvers import solver
from clar.utils import get_sigma_min, get_alpha_max, get_S_Sinv
from expes.utils import (
    get_rates_from_listes, get_roc_auc_scores,
    configure_plt, get_precision_from_array)

dictionary_type = "Toeplitz"
rho_X = 0.6
noise_type = "Gaussian_multivariate"
rho_noise = 0.9

n_channels = 30
n_sources = 100
n_times = 20
n_active = 12
n_epochs = 20
SNR = 0.1
seed = 0

# Parameters of the solver
tol = 1e-3
n_iter = 10 ** 3
gap_freq_mtl = 30
gap_freq_me = 50
S_freq = 1
active_set_freq = 10

# Parameters for MRCE
alpha_Sigma_inv = 1

X, all_epochs, B_star, _ = get_data_me(
        dictionary_type=dictionary_type, noise_type=noise_type, n_channels=n_channels, n_sources=n_sources, n_times=n_times,
        n_epochs=n_epochs, n_active=n_active, rho=rho_X, SNR=SNR, seed=seed)

Y = all_epochs.mean(axis=0)
sigma_min = get_sigma_min(Y)

Y = all_epochs.mean(axis=0)
n_points_roc = 20

dict_measurements = {}
dict_measurements["MTL"] = Y
dict_measurements["SGCL"] = Y
dict_measurements["MRCE"] = Y
dict_measurements["MLER"] = all_epochs
dict_measurements["MLE"] = Y
dict_measurements["CLaR"] = all_epochs
dict_measurements["NNCVX"] = all_epochs
dict_measurements["MRCER"] = all_epochs
dict_measurements["levina"] = Y

dict_p_alpha = {}
dict_p_alpha["MTL"] = np.geomspace(1, 0.4, n_points_roc)
dict_p_alpha["MRCE"] = np.geomspace(1, 0.85, n_points_roc)
dict_p_alpha["CLaR"] = np.geomspace(1, 0.5, n_points_roc)
dict_p_alpha["MLER"] = np.geomspace(1, 0.5, n_points_roc)
dict_p_alpha["MLE"] = np.geomspace(1, 0.8, n_points_roc)
dict_p_alpha["SGCL"] = np.geomspace(1, 0.9, n_points_roc)
dict_p_alpha["MRCER"] = np.geomspace(1, 0.3, n_points_roc)
dict_p_alpha["NNCVX"] = np.geomspace(1, 0.7, n_points_roc)
# dict_p_alpha["levina"] = np.geomspace(1, 0.5, n_points_roc)


dict_raw_B = {}
dict_time = {}

########################################################################
list_pb_name = ["CLaR", "SGCL", "MRCER", "MTL", "MLER", "MLE"]

dict_alpha_max = {}
for pb_name in list_pb_name:
    dict_alpha_max[pb_name] = \
        get_alpha_max(
            X, dict_measurements[pb_name],
            sigma_min=sigma_min, pb_name=pb_name,
            alpha_Sigma_inv=alpha_Sigma_inv)

for pb_name in list_pb_name:
    time_start = time.time()
    measurements = dict_measurements[pb_name]
    B_algo = None
    raw_B = np.empty((n_points_roc, n_sources, n_times))
    start = time.time()
    alpha_max = dict_alpha_max[pb_name]
    for i in range(n_points_roc):
        print("-----------------------------------------------------")
        print("%i-th lambda / %i " % (i, n_points_roc))
        liste_per_alpha = dict_p_alpha[pb_name]
        per = liste_per_alpha[i]
        alpha = per * alpha_max
        B_algo = solver(
            X, measurements, alpha, alpha_max,
            sigma_min, B0=B_algo, n_iter=n_iter,
            gap_freq=gap_freq_mtl, active_set_freq=active_set_freq,
            S_freq=S_freq, pb_name=pb_name,
            heur_stop=True, tol=tol, alpha_Sigma_inv=alpha_Sigma_inv)[0]
        raw_B[i, :, :] = B_algo
    time_end = time.time()
    dict_time[pb_name] = time_end - time_start
    dict_raw_B[pb_name] = raw_B

dict_fp_tp = {}
for pb_name in list_pb_name:
    dict_fp_tp[pb_name] = get_rates_from_listes(dict_raw_B[pb_name], B_star)
    fp, tp = dict_fp_tp[pb_name]
    one = np.ones(1)
    fp = np.concatenate((fp, one), axis=0)
    tp = np.concatenate((tp, one), axis=0)

configure_plt()
c_list = sns.color_palette()

plt.close('all')
for i, pb_name in enumerate(list_pb_name):
    false_neg_rates, true_pos_rates = dict_fp_tp[pb_name]
    plt.plot(false_neg_rates, true_pos_rates,
             color=c_list[i], label=pb_name, marker='o')
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.ylim([-0.05, 1.05])
plt.legend()
plt.grid()
plt.tight_layout()
plt.show(block=False)
