#from magnitude_embedding_experiment import *
from magnitude import magnitude_from_distances
import pickle
from timeit import default_timer as timer
import matplotlib.pylab as pl
from scipy.stats.stats import pearsonr
from sklearn import datasets, linear_model
#from magnitude_difference import *
from sklearn.linear_model import LinearRegression
from matplotlib import cm
from scipy.spatial import distance_matrix
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import balanced_accuracy_score, f1_score, mean_absolute_error, r2_score, mean_squared_error
from sklearn.linear_model import *
import pylab as pl
from sklearn.linear_model import QuantileRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import KFold

from lineartree import LinearTreeRegressor, LinearTreeClassifier
import piecewise_regression
from sklearn.model_selection import StratifiedKFold
from sklearn.svm import SVR


from sklearn.model_selection import GridSearchCV, cross_val_score, KFold
from sklearn.svm import SVC
import pandas as pd
from sklearn.dummy import DummyClassifier
from sklearn.model_selection import cross_validate

from sklearn.model_selection import cross_val_predict

from sklearn.metrics import precision_recall_fscore_support as score

from sklearn.linear_model import LogisticRegression
from sklearn import metrics
from sklearn.metrics import accuracy_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedShuffleSplit
#import scikitplot as skplt

from sklearn import neighbors, datasets, preprocessing

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
import warnings
from sklearn.metrics import make_scorer
import numpy as np, scipy.stats as st

from sklearn.model_selection import RepeatedStratifiedKFold

from sklearn.model_selection import GridSearchCV, cross_val_score, KFold
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer
from sklearn.ensemble import ExtraTreesClassifier

from sklearn.svm import SVC
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.preprocessing import StandardScaler


import ph
import ml
import model
import ph_ml
#import nn_shallow
#import nn_deep

from matplotlib import pyplot as plt


### Reference: Turkes, Renata, Guido F. Montufar, and Nina Otter. 
###"On the effectiveness of persistent homology." Advances in Neural Information Processing Systems 35 (2022): 35432-35448.


plt.rcParams.update({'mathtext.default':  'regular' })
#
plt.rc('font', size=12)  # controls default text sizes
plt.rc('axes', titlesize=24)  # 26     # fontsize of the axes title
plt.rc('axes', labelsize=24)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=22)  # fontsize of the tick labels
plt.rc('ytick', labelsize=22)  # fontsize of the tick labels
#plt.rc('legend', fontsize=24)  # legend fontsize -- original
plt.rc('legend', fontsize=16)  # legend fontsize
plt.rc('figure', titlesize=12)  # fontsize of the figure title\

plt.rcParams.update({
    "text.usetex": True
})

from sklearn.utils.fixes import parse_version, sp_version

# This is line is to avoid incompatibility if older SciPy version.
# You should use `solver="highs"` with recent version of SciPy.
solver = "highs" if sp_version >= parse_version("1.6.0") else "interior-point"

def get_accuracy(X, y):
    df = pd.DataFrame(data=X)
    labels = pd.DataFrame(data=y)

    df_try = pd.DataFrame(data=mag_diffs_same_scales)
    poly = PolynomialFeatures(degree=2, include_bias=False)
    poly_features = poly.fit_transform(df_try)
    np.random.seed(0)
    numsplits = 5
    mae = []
    r2 = []
    mses = []
    # X_new = df.flatten().tolist()
    # labels_new = labels.flatten().tolist()
    for _ in range(1000):
        X_train, X_test, y_train, y_test = train_test_split(poly_features, y, test_size=0.2)
        # too basic:
        # regr = LinearRegression().fit(X_train, y_train)
        # quantiles = [0.05, 0.5, 0.95]
        quantiles = [0.5]
        predictions = {}
        # out_bounds_predictions = np.zeros_like(y_true_mean, dtype=np.bool_)
        for quantile in quantiles:
            print('Current quantile')
            qr = QuantileRegressor(quantile=quantile, alpha=0, solver=solver)
            y_pred = qr.fit(X_train, y_train).predict(X_test)
            predictions[quantile] = y_pred
        # try this:
        #regr = LinearTreeRegressor(base_estimator=LinearRegression(), criterion='mse')
        # try this with 2 break points
        #regr = piecewise_regression.Fit(X_train, y_train, n_breakpoints=3)
        # regr.summary()
        # Get predictions on the test set
        # y_pred = regr.predict(X_test)

        # regr.plot_model()
            m = mean_absolute_error(y_test, y_pred)
            r = r2_score(y_test, y_pred)
            mse = mean_squared_error(y_test, y_pred)
            mae.append(m)
            r2.append(r)
            mses.append(mse)
            print(m, r)
            print('The mse is,', mse)
            plt.title('Curvature')
            plt.xlabel('True curvature')
            plt.ylabel('Predicted curvature')
            plt.scatter(y_test, y_pred)
            plt.show()
        # import matplotlib.pyplot as plt
        # ms = piecewise_regression.ModelSelection(X_train, y_train, max_breakpoints=6)
        # print('MS is,', ms)
        # Plot the data, fit, breakpoints and confidence intervals
        # plot LinerTree data
        # regr.plot_data(color="grey", s=20)
        # # Pass in standard matplotlib keywords to control any of the plots
        # regr.plot_fit(color="red", linewidth=4)
        # regr.plot_breakpoints()
        # regr.plot_breakpoint_confidence_intervals()
            plt.xlabel("x")
            plt.ylabel("y")
            plt.show()
            plt.close()

    print('The MSEs with Piecewise regression are,', mses)
    print("mean MAE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mae), np.std(mae)))
    print("mean R2:", "%.2f%% (+/- %.2f%%)" % (np.mean(r2), np.std(r2)))
    print("mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mses), np.std(mses)))


def import_labels(name = None):
    print("Importing labels...")
    if name == "holes" or name == "curvature" or name == "convexity" or name == "flavia":
        with open("data/" + name + "/labels.pkl", "rb") as f:
            labels = pickle.load(f)
    else:
        print("Error: The data to import can only be one of the saved datasets: holes, curvature or convexity!")
    # print("type(labels) = ", type(labels))
    # print("len(labels) = ", len(labels))
    #print("Number of point clouds with each label value: ", collections.Counter(np.round(labels, 2)))
    return labels

def import_point_clouds(name = None):
    print("Importing point clouds...")
    if name == "holes" or name == "curvature" or name == "convexity" or name == "flavia":
        #with open("point_clouds.pkl", "rb") as f:
        with open("data/" + name + "/point_clouds.pkl", "rb") as f:
            point_clouds = pickle.load(f)
    else:
        print("Error: The data to import can only be one of the saved datasets: holes, curvature or convexity!")
    print("type(point_clouds) = ", type(point_clouds))
    print("len(point_clouds) = ", len(point_clouds))
    print("point_clouds[0].shape = ", point_clouds[0].shape)
    return point_clouds

def calculate_distance_matrices_flat(point_clouds):
    # print("Calculating distance matrices (above the diagonal, and flattened) from the given point clouds...")
    # If point clouds have more than 1000 points, calculating distance matrices because computationally too demanding, but also redundant.
    point_clouds_sparse = []
    num_samples = len(point_clouds)
    for s in range(num_samples):
        point_clouds_sparse.append(point_clouds[s][0:100, :])
    num_samples = len(point_clouds)
    num_features = point_clouds[0].shape[0]
    distance_matrices = np.zeros((num_samples, 100, 100))
    for s in range(num_samples):
        distance_matrices[s] = euclidean_distances(point_clouds_sparse[s])
        distance_matrices[s] = np.around(distance_matrices[s], 3)
    # distance_matrices_flat = flatten_symmetric_matrices(distance_matrices)
    # distance_matrices_flat = preprocessing.StandardScaler().fit_transform(distance_matrices_flat)
    print("distance_matrices_flat.shape = ", distance_matrices.shape)
    return distance_matrices

def flatten_symmetric_matrices(matrices):
    num_matrices = matrices.shape[0]
    matrix_size = matrices.shape[1]
    num_above_diag = int( (matrix_size-1)*matrix_size / 2 )
    matrices_flat = np.zeros((num_matrices, num_above_diag))
    for m in range(num_matrices):
        matrices_flat[m] = matrices[m][np.triu_indices(matrix_size, k=1)]
    return matrices_flat

if __name__ == '__main__':
    # Hyperparameters.

    # Objective #1: compute the features for Simple PH

    # Data.
    num_train_curvatures = 101  # 101
    num_point_clouds_train_curvature = 10  # 100
    num_test_point_clouds = 100  # 100
    n = 500  # 1000

    # PH parameters.

    # DL parameters.
    NUM_EPOCHS = 25  # 25
    BATCH_SIZE = 32

    target_proportion=0.95
    num_intervals=33
    guess=10
    p = 2

    point_clouds = import_point_clouds("curvature")
    labels = import_labels("curvature")
    nice_curvatures = [-2, -1.6, -1.2, -0.8, -0.4, 0, 0.4, 0.8, 1.2, 1.6, 2]
    # MAGNITUDE COMPUTATION
    # compute magnitude
    magnitudes = []
    magnitude_curves = []
    current_labels = []

    # classes = [0, 1, 2]
    # plot_classes = []

    unique_classes_labels = []

    # for label in labels:
    #     if label not in unique_classes_labels:
    #         unique_classes_labels.append(label)
    #         if label < 0:
    #             plot_classes.append(0)
    #         elif label == 0:
    #             plot_classes.append(1)
    #         else:
    #             plot_classes.append(2)

    tss = []
    unique_labels = []
    mag_diffs_scaled = []
    mag_diffs = []
    mag_diffs_same_scales = []
    indices_point_clouds = []
    for i in range(len(point_clouds)):
        print('The labels are,', i)
        # if labels[i] < 0 and \
        # if labels[i] not in unique_labels:
        if labels[i] not in unique_labels:
            indices_point_clouds.append(i)
            unique_labels.append(labels[i])
            convs = []
            target_value = point_clouds[i].shape[0] * target_proportion
            start = timer()

            if False:
                conv = compute_target_scale(point_clouds[i], target_value, conv=None, delta=None, guess=100,
                                            magnitude_computations="cholesky", metric="Lp", p=p, input_type="points",
                                            normalise_by_diameter=False)
                convs.append(conv)
                ts = np.linspace(0, conv, num=num_intervals)
                tss.append(ts)
                end = timer()
                time = end - start
                # print("finding convergence took " + str(time))
                start = timer()
                # a = compute_magnitude_weights(point_cloud, p=p, ts=ts, normalise_by_diameter=True)

                a = compute_magnitude(point_clouds[i], p=2, ts=ts, normalise_by_diameter=False)
                current_labels.append(labels[i])
                print(a.sum())
                magnitude_curves.append(a)
            # print('The magnitude is,', a)
            # print('The ts is,', ts)
                magnitudes.append(a.sum())
                # compute mag difference
                # without scaling:
                mag_diff = np.trapz(y=ts, x=a)
                mag_diffs.append(mag_diff)
            current_labels.append(labels[i])
            #ts_cut_off = 73
            ts_cut_off = 73
            num_intervals = 30
            # num_intervals = 40
            tss_new_values = np.linspace(0, ts_cut_off, num=num_intervals)
            dist_matrix_ref = distance_matrix(point_clouds[i], point_clouds[i], p=2)

            magnitude_cutoff = magnitude_from_distances(dist_matrix_ref, ts=tss_new_values)

            #magnitude_cutoff = compute_magnitude_from_distances_cholesky_inversion(dist_matrix_ref, tss_values,
            #                                                                       method="smart")

            mag_diff_same_scale = np.trapz(y=magnitude_cutoff, x=tss_new_values)

            mag_diffs_same_scales.append(mag_diff_same_scale)
    pd.DataFrame(data=mag_diffs_same_scales).to_csv('mag_diffs_curvature.csv')
    pd.DataFrame(data=current_labels).to_csv('labels_curvature.csv')
    # ---------- COMPLETE MAGNITUDE COMPUTATION -----------------------------------
    point_clouds_for_consideration = []
    for index_pc in indices_point_clouds:
        point_clouds_for_consideration.append(point_clouds[index_pc])


    distance_matrices = calculate_distance_matrices_flat(point_clouds_for_consideration)
    # flatten matrics
    data_dis_mat_flat = flatten_symmetric_matrices(distance_matrices)
    print(distance_matrices.shape)
    data_pd0, data_pd1 = ph.calculate_pds_distance_matrices_ripser(distance_matrices)
    data_ph0 = ph.sorted_lifespans_pds(data_pd0, size=n)
    data_ph1 = ph.sorted_lifespans_pds(data_pd1, size=n)
    print(data_pd0)
    print(data_ph0)

    # data_ph0_longest = data_ph0[:, int(n/2):]
    # data_ph1_longest = data_ph1[:, int(n/2):]
    data_ph0_longest = data_ph0[:, :10]
    data_ph1_longest = data_ph1[:, :10]
    # data_ph0_longest_train = data_ph0_longest[train_indices]
    # data_ph0_longest_test = data_ph0_longest[test_indices]
    # data_ph1_longest_train = data_ph1_longest[train_indices]
    # data_ph1_longest_test = data_ph1_longest[test_indices]

    # for i in range(len(data_pd0)):
    #     print(data_pd0[i])
    #     ph_ml.tune_hyperparameters(data_pd0[i], [1])
    print('we are here!!!')

    df_ph_pd = pd.DataFrame.from_records(data_pd0[:, None])
    #df_ph_pd = pd.DataFrame(data=data_pd0)
    df_labels_pd = pd.DataFrame(data=current_labels)
    print('we are here')
    df_ph = pd.DataFrame(data=data_ph0)
    df_labels_ph = pd.DataFrame(data=current_labels)

    df_ph_10_longest = pd.DataFrame(data=data_ph0_longest)
    df_ph_10_longest_labels = pd.DataFrame(data=current_labels)

    df_ml = pd.DataFrame(data=data_dis_mat_flat)
    df_ml_labels = pd.DataFrame(data=current_labels)

    # df_mag = pd.DataFrame(data=mag_diffs_same_scales)
    df_labels_mag = pd.DataFrame(data=current_labels)

    if False:
        df_mag = pd.DataFrame(data=mag_diffs_same_scales)
    #if False:
    if False:
        df_try = pd.DataFrame(data=mag_diffs_same_scales)
        #poly = PolynomialFeatures(degree=2, include_bias=True)
        poly = PolynomialFeatures(degree=2, include_bias=False)
        poly_features = poly.fit_transform(df_try)
        df_mag = pd.DataFrame(data=poly_features)
        print('DF Mag is', df_mag)
    # poly = PolynomialFeatures(degree=2, include_bias=False)
    # poly_features = poly.fit_transform(df_mag)

    # double-check that df_ph and df_mag have the same number of labels
# --------- now that we have all the features, split the dataset and compare the performance
    mse_ph = []
    mse_ph_10_longest = []
    mse_mag = []
    mse_ph_non_simple = []
    mse_ml_all = []
    mse_nn_shallow_all = []
    mse_nn_deep_all = []

    kf = KFold(n_splits=5)
    #kf = KFold(n_splits=3)
    for i, (train_index, test_index) in enumerate(kf.split(df_ph)):
        print(f"Fold {i}:")
        print(f"  Train: index={train_index}")
        print(f"  Test:  index={test_index}")

        X_train_ph_pd_raw = []
        X_test_ph_pd_raw = []
        for i in train_index:
            X_train_ph_pd_raw.append(data_pd0[i])
        for i in test_index:
            X_test_ph_pd_raw.append(data_pd0[i])

        X_train_ph_pd = df_ph_pd.iloc[train_index]
        X_test_ph_pd = df_ph_pd.iloc[test_index]
        y_train_ph_pd = df_labels_pd.iloc[train_index].to_numpy().ravel()
        y_test_ph_pd = df_labels_pd.iloc[test_index].to_numpy().ravel()

        X_train_ph = df_ph.iloc[train_index]
        X_test_ph = df_ph.iloc[test_index]
        y_train_ph = df_labels_ph.iloc[train_index].to_numpy().ravel()
        y_test_ph = df_labels_ph.iloc[test_index].to_numpy().ravel()

        X_train_ph_10_longest = df_ph_10_longest.iloc[train_index]
        X_test_ph_10_longest = df_ph_10_longest.iloc[test_index]
        y_train_ph_10_longest = df_ph_10_longest_labels.iloc[train_index].to_numpy().ravel()
        y_test_ph_10_longest = df_ph_10_longest_labels.iloc[test_index].to_numpy().ravel()

        X_train_ml = df_ml.iloc[train_index]
        X_test_ml = df_ml.iloc[test_index]
        y_train_ml = df_ml_labels.iloc[train_index].to_numpy().ravel()
        y_test_ml = df_ml_labels.iloc[test_index].to_numpy().ravel()

        df_try = pd.DataFrame(data=mag_diffs_same_scales).iloc[train_index]
        df_test= pd.DataFrame(data=mag_diffs_same_scales).iloc[test_index]
        #poly = PolynomialFeatures(degree=2, include_bias=True)
        poly = PolynomialFeatures(degree=2, include_bias=False)
        poly_features = poly.fit_transform(df_try)
        X_train_mag = pd.DataFrame(data=poly_features)
        X_test_mag = poly.transform(df_test)
        #print('DF Mag is', df_mag)

        if False:
            X_train_mag = df_mag.iloc[train_index]
            X_test_mag = df_mag.iloc[test_index]
        y_train_mag = df_labels_mag.iloc[train_index].to_numpy().ravel()
        y_test_mag = df_labels_mag.iloc[test_index].to_numpy().ravel()

        # Evaluate PH performance
        print("\n\nTuning the hyperparameters of ML on simple 0-dim PH...")
        model_ml_on_ph0 = ml.tune_hyperparameters(X_train_ph, y_train_ph)

        print("\n\nTraining ML on simple 0-dim PH...")
        model_trained_ml_on_ph0, _ = model.fit(X_train_ph, y_train_ph, model_ml_on_ph0)
        mse_ml_on_ph0 = model.get_score(X_test_ph, y_test_ph, model_trained_ml_on_ph0)

        print('The MSE PH_0 is,', mse_ml_on_ph0)
        mse_ph.append(mse_ml_on_ph0)

        # Evaluate PH 0 performance of PH (non-simple)
        print("\n\nTuning the hyperparameters of 0-dim PH...")
        model_ph0_ml = ph_ml.tune_hyperparameters(X_train_ph_pd_raw, y_train_ph_pd)
        model_trained_ph0_ml, _ = model.fit(X_train_ph_pd_raw, y_train_ph_pd, model_ph0_ml)
        mse_ph0_ml = model.get_score(X_test_ph_pd_raw, y_test_ph_pd, model_trained_ph0_ml)

        print('The MSE PH_0 is,', mse_ml_on_ph0)
        mse_ph_non_simple.append(mse_ph0_ml)

        # Tune, train and evalute ML
        model_ml_on_dis_mat = ml.tune_hyperparameters(X_train_ml, y_train_ml)
        model_trained_ml_on_dis_mat, _ = model.fit(X_train_ml, y_train_ml, model_ml_on_dis_mat)
        mse_ml = model.get_score(X_test_ml, y_test_ml, model_trained_ml_on_dis_mat)
        mse_ml_all.append(mse_ml)

        # Tune, train and evaluate NN shallow
        # model_nn_shallow = nn_shallow.tune_hyperparameters(X_train_ml, y_train_ml)
        # model_trained_nn_shallow, _ = model.fit(X_train_ml, y_train_ml, model_nn_shallow,
        #                                         num_train_iter=NUM_EPOCHS)
        # mse_nn_shallow = model.get_score(X_test_ml, y_test_ml, model_trained_nn_shallow)
        # mse_nn_shallow_all.append(mse_nn_shallow)
        #
        # # Tune, train and evlauate NN deep
        # model_nn_deep = nn_deep.tune_hyperparameters(X_train_ml, y_train_ml)
        # model_trained_nn_deep, _ = model.fit(X_train_ml, y_train_ml, model_nn_deep,
        #                                         num_train_iter=NUM_EPOCHS)
        # mse_nn_deep = model.get_score(X_test_ml, y_test_ml, model_trained_nn_deep)
        # mse_nn_deep_all.append(mse_nn_deep)

        # Evaluate PH 0 performance on the 10 longest intervals
        model_ml_on_ph0_longest = ml.tune_hyperparameters(X_train_ph_10_longest, y_train_ph_10_longest)
        model_trained_ml_on_ph0_longest, _ = model.fit(X_train_ph_10_longest, y_train_ph_10_longest, model_ml_on_ph0_longest)
        mse_ml_on_ph0_longest = model.get_score(X_test_ph_10_longest, y_test_ph_10_longest, model_trained_ml_on_ph0_longest)
        mse_ph_10_longest.append(mse_ml_on_ph0_longest)


        # Evaluate Magnitude performance
        #pr = piecewise_regression(n_breakpoints=2)
        if False:
            #y_pred = piecewise_regression.Fit([x for x in X_train_mag["0"]], [y for y in y_train_mag["0"]], n_breakpoints=2).predict([k for k in X_test_mag["0"]])
            y_pred = piecewise_regression.Fit([d[0] for d in df_try.values], [y for y in y_train_mag], n_breakpoints=2).predict([k[0] for k in df_test.values])
            m = mean_absolute_error(y_test_mag, y_pred)
            r = r2_score(y_test_mag, y_pred)
            mse_mag_fold = mean_squared_error(y_test_mag, y_pred)
            print('The MSE Mag is,', mse_mag_fold)
            mse_mag.append(mse_mag_fold)

        if True:
            qr = QuantileRegressor(quantile=0.5, alpha=0, solver="revised simplex")#, solver=solver)
            y_pred = qr.fit(X_train_mag, y_train_mag).predict(X_test_mag)
            m = mean_absolute_error(y_test_mag, y_pred)
            r = r2_score(y_test_mag, y_pred)
            mse_mag_fold = mean_squared_error(y_test_mag, y_pred)
            print('The MSE Mag is,', mse_mag_fold)
            mse_mag.append(mse_mag_fold)


    # Final evaluation
    print('The MSEs for PH are are,', mse_ph)
    print("PH mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_ph), np.std(mse_ph)))

    print('The MSEs for PH 10 longest are are,', mse_ph_10_longest)
    print("PH mean MSE 10 longest:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_ph_10_longest), np.std(mse_ph_10_longest)))

    print('The MSEs for PH Non-simple are are,', mse_ph_non_simple)
    print("PH mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_ph_non_simple), np.std(mse_ph_non_simple)))

    print('The MSEs for ML are,', mse_ml_all)
    print("ML mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_ml_all), np.std(mse_ml_all)))

    if False:
        print('The MSEs for NN shallow are,', mse_nn_shallow_all)
        print("NN shallow mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_nn_shallow_all), np.std(mse_nn_shallow_all)))

        print('The MSEs for NN deep are,', mse_nn_deep_all)
        print("NN deep mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_nn_deep_all), np.std(mse_nn_deep_all)))

    print('The MSEs for Mag are are,', mse_mag)
    print("Mag mean MSE:", "%.2f%% (+/- %.2f%%)" % (np.mean(mse_mag), np.std(mse_mag)))


