import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split

# a = np.random.normal(0, 3, 1000)
# b = np.random.normal(2, 4, 900)
mpl.use('TkAgg')

############ shadow model result ###############
path1 = [
r"inter_output/CLID/Atk_Impt_mydenoise_M_coco_real_split1_DATA_val17_split1_TRTE_train_MAXsmp_3_T_0506_145909.txt"
][0]

path2 = [
r"inter_output/CLID/Atk_Impt_mydenoise_M_coco_real_split1_DATA_val17_split1__TRTE_test_MAXsmp_3_T_0506_145909.txt"
][0]


############ target model result ###############
path_test1 = [
r"inter_output/CLID/Atk_Impt_mydenoise_M_coco_real_ori_DATA_val17_TRTE_train_MAXsmp_3_T_0506_145842.txt"
][0]

path_test2 = [
r"inter_output/CLID/Atk_Impt_mydenoise_M_coco_real_ori_DATA_val17_TRTE_test_MAXsmp_3_T_0506_145842.txt"

][0]

MetricName = '---'
Dataname = '---'


def get_ori_data(path_train, path_test):
    global Dataname
    if '/' in path_train:
        Dataname = path_train.split('/')[-2]
    elif '\\' in path_train:
        Dataname = path_train.split('\\')[-2]
    print('dataname:', Dataname)
    # exit()

    with open(path_train, 'r', encoding='utf8') as f:
        train_list = [[float(e) for e in line.split('\t')] for line in f.readlines()[1:]]

    with open(path_test, 'r', encoding='utf8') as f:
        test_list = [[float(e) for e in line.split('\t')] for line in f.readlines()[1:]]


    train = np.array(train_list)
    test = np.array(test_list)
    # print("train.shape, test.shape:", train.shape, test.shape)

    max_v = max(train.max(), test.max())
    # max_v = max(.max(), sorted(test[2:-1],key= lambda x:x[0]).max())
    print("max_v", max_v)
    min_v = min(train.min(), test.min())
    print("min_v", min_v)

    return train, test, max_v, min_v


def deal_data_ratio(train, test):
    global MetricName
    MetricName = 'ratio1-4'
    train_ratio = [e[0] / e[-1] for e in train]
    test_ratio = [e[0] / e[-1] for e in test]

    return train_ratio, test_ratio


def deal_data_allconds(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [e for e in train]
    test = [e for e in test]

    return train, test


def deal_data_max_4conds(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [[e[0], e[1], e[2], e[-1]] for e in train]
    test = [[e[0], e[1], e[2], e[-1]] for e in test]

    return train, test


def deal_data_last_4conds(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [[e[1], e[2], e[3], e[4]] for e in train]
    test = [[e[1], e[2], e[3], e[4]] for e in test]

    return train, test


def deal_data_first_last(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [[e[0], e[-1]] for e in train]
    test = [[e[0], e[-1]] for e in test]

    return train, test


def deal_data_deltacondas(train, test):
    global MetricName
    MetricName = 'deltacondas'
    train = [[e[0], e[0] - e[1], e[0] - e[2], e[0] - e[3], e[0] - e[4]] for e in train]
    test = [[e[0], e[0] - e[1], e[0] - e[2], e[0] - e[3], e[0] - e[4]] for e in test]

    return train, test


def deal_data_first_lastavg(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [[e[0], sum(e[1:]) / len(e[1:])] for e in train]
    test = [[e[0], sum(e[1:]) / len(e[1:])] for e in test]

    return train, test


def deal_data_1_last(train, test):
    global MetricName
    MetricName = 'allconds'
    train = [[e[0], e[-1]] for e in train]
    test = [[e[0], e[-1]] for e in test]

    return train, test


def deal_data_2x1_4(train, test):
    global MetricName
    MetricName = 'cond2x1-4'
    train_ratio = [e[0] + 0.2 * (e[0] - e[-1]) for e in train]
    test_ratio = [e[0] + 0.2 * (e[0] - e[-1]) for e in test]

    return train_ratio, test_ratio


def deal_data_first(train, test):
    global MetricName
    MetricName = 'first_cond'
    train_ratio = [e[0] for e in train]
    test_ratio = [e[0] for e in test]

    return train_ratio, test_ratio


def deal_data_last(train, test):
    global MetricName
    MetricName = 'last_cond'
    train_ratio = [e[-1] for e in train]
    test_ratio = [e[-1] for e in test]

    return train_ratio, test_ratio


def metric_cond_minus_unc(losses):
    global tag
    global targetname
    if tag == 0:
        targetname = 'cond_minus_unc'
        print('\n--', targetname)
        tag = 1

    # np.mean(cross_entropy_loss(Target, normalization(train[i])))
    return losses[0] - losses[-1]


def get_cls_withTh(train, test, th):
    print('\n*** TEST Model ***')
    train_list = train
    test_list = test
    max_e = max(train_list + test_list)
    min_e = min(train_list + test_list)
    # n_points = 2000

    print("\ntrain_list[:3], test_list[:3]", train_list[:3], test_list[:3])
    print("\nmax_e, min_e:", max_e, min_e)

    TP = (train_list <= th).sum()
    TN = (test_list > th).sum()
    FP = (test_list <= th).sum()
    FN = (train_list > th).sum()
    TPR = TP / (TP + FN)
    FPR = FP / (FP + TN)
    ASR = (TP + TN) / (TP + TN + FP + FN)

    print('\n', 'TEST: ', 'ASR:', ASR, 'by the given threshold:', th)

    return ASR  # best_threshold, best_asr, auc, FPR_list, TPR_list, max_e, min_e


def get_xgb(train, test, n_estimators=50):
    # GET DATA
    # train_data = np.genfromtxt(path_train, delimiter='\t', skip_header=1)[:,:]
    print('---------Train XGB with n_estimators_{}:\n'.format(n_estimators))

    train = np.array(train)
    test = np.array(test)

    label0 = np.zeros((train.shape[0], 1))
    label1 = np.ones((test.shape[0], 1))

    datas = np.concatenate((train, test))
    print(datas.shape)

    labels = np.concatenate((label0, label1))
    print(labels.shape)

    data_with_label = np.hstack((datas, labels))

    np.random.shuffle(data_with_label)

    # data_with_label
    X = data_with_label[:, :-1]
    y = data_with_label[:, -1]

    # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0, random_state=88)
    X_train = X
    y_train = y

    clf = XGBClassifier(n_estimators=n_estimators,
                        gamma=0.7, max_depth=3, subsample=0.7, colsample_bytree=0.7, reg_alpha=1, reg_lambda=1)
    #

    clf.fit(X_train, y_train)

    ## TEST
    y_pred = clf.predict(X_train)
    print(y_pred.shape, y_pred[:5])

    ## CAL asr
    train_acc = np.mean(y_pred == y_train)
    print('Train acc: [{}]'.format(train_acc))

    ## TEST
    # y_pred = clf.predict(X_test)
    # print(y_pred.shape, y_pred[:5])
    #
    # ## CAL asr
    # test_acc = np.mean(y_pred == y_test)
    # print('test acc:', test_acc)
    proba_xgb = clf.predict_proba(X_train)
    print('pred_xgb.shape:', proba_xgb.shape, 'pred_xgb[:, 1].shape', proba_xgb[:, 1].shape)
    from sklearn.metrics import roc_auc_score, roc_curve
    auc = roc_auc_score(y_train, proba_xgb[:, 1])
    print('Train roc_auc: ', auc)

    fpr, tpr, _ = roc_curve(y_train, proba_xgb[:, 1])

    # print('---', y_pred)
    # print('---', proba_xgb)

    # exit()
    return clf, y_pred, y_train, proba_xgb


# @torch.no_grad()
def pre_xgb(train, test, cls):
    print('---------Test XGB on target model:\n')

    train = np.array(train)
    test = np.array(test)

    label0 = np.zeros((train.shape[0], 1))
    label1 = np.ones((test.shape[0], 1))

    datas = np.concatenate((train, test))
    print(datas.shape)

    labels = np.concatenate((label0, label1))
    print(labels.shape)

    data_with_label = np.hstack((datas, labels))

    # np.random.shuffle(data_with_label)

    # data_with_label
    X = data_with_label[:, :-1]
    y = data_with_label[:, -1]


    #
    # clf.fit(X_train, y_train)

    ## TEST
    y_pred = clf.predict(X)
    print(y_pred.shape, y_pred[:5])

    ## CAL asr
    test_acc = np.mean(y_pred == y)
    print('test acc: [{}]'.format(test_acc))

    ## TEST
    # y_pred = clf.predict(X_test)
    # print(y_pred.shape, y_pred[:5])
    #
    # ## CAL asr
    # test_acc = np.mean(y_pred == y_test)
    # print('test acc:', test_acc)
    proba_xgb = clf.predict_proba(X)
    print('pred_xgb.shape:', proba_xgb.shape, 'pred_xgb[:, 1].shape', proba_xgb[:, 1].shape)
    from sklearn.metrics import roc_auc_score, roc_curve
    auc = roc_auc_score(y, proba_xgb[:, 1])
    print('Test roc_auc: ', auc)

    fpr, tpr, _ = roc_curve(y, proba_xgb[:, 1])

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='blue', lw=2, label='ROC curve (AUC = %0.4f)' % auc)
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic (ROC) ASR:{test_acc}')
    plt.legend(loc="lower right")


    idx_1_percent_fpr = next(i for i, fpr_value in enumerate(fpr) if fpr_value >= 0.01)
    tpr_at_1_percent_fpr = tpr[idx_1_percent_fpr]

    plt.scatter(fpr[idx_1_percent_fpr], tpr[idx_1_percent_fpr], marker='o', color='red',
                label='1%% FPR (TPR = %0.4f)' % tpr_at_1_percent_fpr)
    plt.legend()

    plt.show()
    print('tpr_at_1_percent_fpr:', tpr_at_1_percent_fpr)

    # exit()
    return clf, y_pred, y, proba_xgb


def draw_distribute_auc(train, test, best_threshold, best_asr, auc, FPR_list, TPR_list, max_e, min_e, th_pred=None,
                        asr_pred=None):
    train_list = train
    test_list = test

    a = np.array(train_list)
    b = np.array(test_list)

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))

    bins = np.linspace(min(train_list + test_list), max(train_list + test_list), 100)

    # plt.figure()
    axs[0].hist(a, bins, alpha=0.5, label='Train data')
    axs[0].hist(b, bins, alpha=0.5, label='Test data')
    axs[0].legend(loc='upper left', )
    axs[0].axvline(x=best_threshold, color='r', linestyle='--')
    print('th_pred:', th_pred)
    if th_pred != None:
        axs[0].axvline(x=th_pred, color='blue', linestyle='--')
        print('th_pred 2:', th_pred)
        title_str = 'TrueAsr {:.4f}, PredAsr {:.4f}; TrueTh {:.3} Perc {:.3f}, PredTh {:.3}'.format(
            best_asr,
            asr_pred,
            best_threshold,
            (best_threshold - min_e) / (max_e - min_e),
            th_pred,
        )
        print('title_str', title_str)
    else:
        title_str = 'TrueAsr {:.3f}, TrueTh {:.3}, Perc {:.3f}'.format(
            best_asr,
            best_threshold,
            (best_threshold - min_e) / (max_e - min_e)
        )
    axs[0].set_title(title_str)


    axs[1].plot(FPR_list, TPR_list, 'k--', label='ROC = {0:.4f}'.format(auc), lw=2)
    axs[1].set_xlim([-0.05, 1.05])  #
    axs[1].set_ylim([-0.05, 1.05])
    axs[1].set_xlabel('False Positive Rate')
    axs[1].set_ylabel('True Positive Rate')  #
    axs[1].set_title('ROC Curve')
    axs[1].legend(loc="lower right")

    plt.tight_layout()

    print("\nDataName [{}], MetricName [{}]\n".format(Dataname, MetricName))
    plt.savefig(r"hists_conds/" + 'Dataname_{}_MetricName_{}.png'.format(Dataname, MetricName))
    print("save in", "hists_conds/" + 'Dataname_{}_MetricName_{}.png'.format(Dataname, MetricName))
    # plt.savefig(name)
    plt.show()


def custom_format(x):
    return "%.3f" % x


if __name__ == '__main__':
    print("\n**** Shadow: ****", )
    train_sd, test_sd, max_v_sd, min_v_sd = get_ori_data(path1,
                                                         path2)  # (path_test1, path_test2) #(path1, path2)  # (path1, path2)

    train_sd, test_sd = deal_data_allconds(train_sd, test_sd)

    clf, y_pred_sd, y_test_sd, proba_xgb_sd = get_xgb(
        train_sd, test_sd, 5)



    print("\n**** Target: ****", )
    train, test, max_v, min_v = get_ori_data(path_test1, path_test2)  # (path1, path2) # (path_test1, path_test2)
    train, test = deal_data_allconds(train, test)
    pre_xgb(train, test, clf)


