from scipy.stats import gmean
import numpy as np
from magnipy import Magnipy
from distances import get_dist
from vendi_score import vendi
import pandas as pd
import pickle
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import numpy as np
from sklearn.model_selection import RepeatedKFold, cross_validate
from sklearn.ensemble import RandomForestRegressor
from sklearn.isotonic import IsotonicRegression
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import pairwise_distances
from n_gram_sim import *

def std_emb_vect_div(X):
    stds=np.std(X, axis=0)
    stds=stds[np.nonzero(stds)]
    return gmean(stds)

def std_emb_vect_div_zero(X):
    stds=np.std(X, axis=0)
    return gmean(stds)

def std_emb_vect_div_error(X):
    stds=np.std(X, axis=0)
    stds[stds==0]=0.0001
    return gmean(stds)

def similarity2negmeansimilarity(X):
    similarity_scores_list = [] # note: len() assertion are done in get_similarity_scores method
    for sample_i in range(X.shape[0]):
        for sample_j in range(sample_i):
            similarity_scores_list.append(X[sample_i, sample_j])
            similarity_scores_list.append(X[sample_i, sample_j])
    return -np.mean(similarity_scores_list)

def dist2similarity(X):
    return (-X)+1

def vendi_dot(X, q=1, normalize=False):
    n, d = X.shape
    if n < d:
        return vendi.score_X(X, q=q, normalize=normalize)
    return vendi.score_dual(X, q=q, normalize=normalize)

def calc_metrics_from_embedding(X, n_ts=65, metrics=["cosine", "L2", "L1"], target_scale=None):
    stds_div = std_emb_vect_div(X)
    stds_div_zero = std_emb_vect_div_zero(X)
    stds_div_error = std_emb_vect_div_error(X)
    vendi_d = vendi_dot(X, q=1, normalize=True)
    vendi_d_nn = vendi_dot(X, q=1, normalize=False)

    if target_scale is not None:
        D=get_dist(X, metric="euclidean")
        n=D.shape[0]
        target_value=n*target_scale
    else:
        target_value=None

    if "cosine" in metrics:
        Mag = Magnipy(X=X, target_value=target_value, n_ts=n_ts, log_scale = False, metric="cosine")
        sim = dist2similarity(get_dist(X, metric="cosine", check_for_duplicates=False))
        neg_means_cosine=similarity2negmeansimilarity(sim)
        vendi_cosine=vendi.score_K(sim)
        vendi_cosine_n=vendi.score_K(sim, normalize=True)
    else:
        Mag = None
        neg_means_cosine=None
        vendi_cosine=None
        vendi_cosine_n=None

    if "L2" in metrics:
        MagL2 = Magnipy(X=X, target_value=target_value, n_ts=n_ts, log_scale = False, metric="euclidean", p=2)
        D=get_dist(X, metric="euclidean", p=2, check_for_duplicates=False)#MagL2.get_dist()
        #neg_means_l2=similarity2negmeansimilarity(dist2similarity(D))
        sim = np.exp(-D)
        neg_means_exp=similarity2negmeansimilarity(sim)
        vendi_exp=vendi.score_K(sim)
        vendi_exp_n=vendi.score_K(sim, normalize=True)
    else:
        #print("ups")
        MagL2 = None
        #neg_means_l2=None
        neg_means_exp=None
        vendi_exp=None
        vendi_exp_n=None

    if "L1" in metrics:
        MagL1 = Magnipy(X=X, target_value=target_value, n_ts=n_ts, log_scale = False, metric="cityblock", p=1)
        D=get_dist(X, metric="cityblock", p=1, check_for_duplicates=False)#MagL1.get_dist()
        #neg_means_l1=similarity2negmeansimilarity(dist2similarity(D))
        sim=np.exp(-D)
        neg_means_expl1=similarity2negmeansimilarity(sim)
        vendi_expl1=vendi.score_K(sim)
        vendi_exp1_n=vendi.score_K(sim, normalize=True)
    else:
        MagL1 = None
        #neg_means_l1=None
        neg_means_expl1=None
        vendi_expl1=None
        vendi_exp1_n=None
    #print(MagL2)
    return Mag, MagL2, MagL1, [neg_means_cosine, vendi_cosine, vendi_cosine_n, 
                               vendi_d,vendi_d_nn, stds_div, stds_div_zero, stds_div_error,
                               #neg_means_l2, 
                               neg_means_exp, vendi_exp, vendi_exp_n, #neg_means_l1, 
                               neg_means_expl1, vendi_expl1, vendi_exp1_n], ["neg_mean_cosine", "vendi_cosine", "vendi_cosine_norm", 
                                                               "vendi_dot","vendi_dot_no_norm",
                                                               "stds_div", "stds_div_zero", "stds_div_error",
                                                               #"neg_means_l2", 
                                                               "neg_mean_exp", "vendi_exp","vendi_exp_norm", #"neg_means_l1", 
                                                               "neg_mean_expl1", "vendi_expl1", "vendi_expl1_norm"]

def is_singular(matrix):
    determinant = np.linalg.det(matrix)
    return determinant == 0

def add_pertubations(matrix, epsilon=1e-4):
    n = matrix.shape[0]
    perturbation = epsilon * np.random.randn(n, n)
    perturbed_matrix = matrix + perturbation #+ np.transpose(perturbation)
    #np.fill_diagonal(perturbed_matrix, 0)
    return perturbed_matrix

def remove_zero_rows_and_columns(matrix):
    """
    Removes rows and columns consisting of only zeros from a matrix.
    
    Parameters:
        matrix (numpy.ndarray): The input matrix.
        
    Returns:
        numpy.ndarray: The matrix with zero rows and columns removed.
    """
    # Find rows and columns with all zeros
    zero_rows = np.all(matrix == 0, axis=1)
    zero_columns = np.all(matrix == 0, axis=0)
    
    # Remove zero rows and columns
    matrix_without_zeros = matrix[~zero_rows, :]
    matrix_without_zeros = matrix_without_zeros[:, ~zero_columns]
    
    return matrix_without_zeros

def calc_metrics_from_text(resp_set, n_ts=65, metrics=["cosine_ngrams"], n_ngrams=[1,2,3]):
    names=[]
    values=[]
    mags=[]
    for grams in n_ngrams:
        n_string=str(round(grams))
        if "cosine_ngrams" in metrics:
            unique_n_grams = unique_ngrams(resp_set, n=grams)
            D=ngram_distance_matrix(resp_set, n=grams, drop_duplicates=True)
            D_mag = D
            D_mag = remove_zero_rows_and_columns(D_mag)
            if D_mag.diagonal().sum()!=0:
                raise ValueError("There are non-zero diagonal elements in the n-gram similarity matrix.")
            if ((is_singular(D_mag)) & (D_mag.shape[0]>1)):
                print("singular matrix")
                #print(D_mag)
            try:
                Mag = Magnipy(X=None, D=D_mag, target_value=None, n_ts=n_ts, log_scale = False, metric="cosine", method="scipy")
                if D_mag.shape[0] >1:
                    Mag.get_t_conv()
            except Exception as e:
                #print(e)
                D_mag = add_pertubations(D_mag) 
                Mag = Magnipy(X=None, D=D_mag, target_value=None, n_ts=n_ts, log_scale = False, metric="cosine", method="scipy")
                if D_mag.shape[0] >1:
                    Mag.get_t_conv()
            mags.append(Mag)
            D=ngram_distance_matrix(resp_set, n=grams, drop_duplicates=False)
            sim = dist2similarity(D)
            neg_means_cosine=similarity2negmeansimilarity(sim)
            vendi_cosine=vendi.score_K(sim)
            values.extend([unique_n_grams, neg_means_cosine, vendi_cosine])
            names.extend(["unique_ngrams"+n_string, "neg_mean_cosine_ngrams"+n_string, "vendi_cosine_ngrams"+n_string])

    return mags, values, names


from sklearn.datasets import make_blobs
def load_blob_data(idx, seed=0):
    np.random.seed(seed)
    X, y, centers = make_blobs(n_samples=[600, 300], n_features=5, return_centers=True)
    B1=X[y==0]
    B2=X[y==1]

    B3=np.concatenate((B1[:(600-idx),:], B2[:idx,:]), axis=0)
    #make_blobs(n_samples=[100-idx, idx], n_features=5, return_centers=True)
    return B3 #make_blobs(n_samples=[100-idx, idx], n_features=5)

def calc_metrics_from_responses(load_data, n_samples, n_ts=65, metrics=["cosine_ngrams"], 
                                 #reference_summaries = False,
                                 reference_scale=0.5, scale=True, absolute_area=True):
    mags=[]
    mags1=[]
    mags2=[]
    Xs=[]
    rows={}
    for idx in range(n_samples):
        print(idx)
        X = load_data(idx)
        #print(X)
        magss, scores, scores_names = calc_metrics_from_text(X, n_ts=n_ts, metrics=metrics, n_ngrams=[1,2,3])
        mags.append(magss[0])
        mags1.append(magss[1])
        mags2.append(magss[2])
        rows[idx]=scores
        Xs.append(X)
    
    #ref_scores = pd.DataFrame()

    df_scores = pd.DataFrame.from_dict(rows, orient='index', columns=scores_names)
    #df_scores=df_scores.loc[:,df_scores.isnull().sum()<10]
    mag_diffs={}
    mag_dfs={}

    for met in [#"cosine_ngrams1", 
                #"cosine_ngrams2", 
                "cosine_ngrams3"]:
        t_cut=0
        print(met)
        if met == "cosine_ngrams1":
            funs = mags.copy()
        elif met == "cosine_ngrams2":
            funs = mags1.copy()
        elif met == "cosine_ngrams3":
            funs = mags2.copy()
        else:
            raise ValueError("metric not implemented")
        #print(len(funs))
        if reference_scale == "reference":
            t_cut = funs[0].get_t_conv()
        else:
            t_cut = find_common_scale(funs, n_samples=n_samples, q=reference_scale)
            
        mag_area=[]
        for idx, Mag in enumerate(funs):
            Mag.change_scales(t_cut=t_cut)
            mag_area.append(Mag.get_magnitude_area(integration="trapz",
                absolute_area=absolute_area, scale=scale, plot=False))
        df_scores["mag_area_"+met]=mag_area
        df_scores["t_cut_"+met]=t_cut

        mag_diffs[met], mag_dfs[met]=get_mag_diffs_and_dfs(funs, n_samples, n_ts, scale=scale, absolute_area=absolute_area)
        if met == "cosine_ngrams1":
            mags = funs.copy()
        elif met == "cosine_ngrams2":
            mags1 = funs.copy()
        elif met == "cosine_ngrams3":
            mags2 = funs.copy()
        
    return {"summary_statistics": df_scores, #"reference_scores": ref_scores, 
            "magnitude_differences": mag_diffs, 
            "magnitude_function_dfs": mag_dfs, "mag_cosine_ngrams": mags}

def find_common_scale(funs, n_samples=None, q=0.5):
    t_convs=[]
    if n_samples is None:
        n_samples = len(funs)

    if q is None:
        t_cut=funs[0].get_t_conv()
        return t_cut
    else:
        for idx in range(n_samples):
            if funs[idx].get_dist().shape[0] > 1:
                t_convs.append(funs[idx].get_t_conv())
        t_cut=np.quantile(t_convs, q=q)
        return t_cut

def calc_metrics_from_embeddings(load_data, n_samples, n_ts=65, metrics=["cosine", "L2", "L1"], 
                                 reference_summaries = False,
                                 reference_scale=0.5, scale=True, absolute_area=True, nearest_k=10, target_scale=None):
    mags=[]
    magsl1=[]
    magsl2=[]
    rows={}
    Xs=[]
    for idx in range(n_samples):
        X = load_data(idx)
        #print(X)
        Mag, MagL2, MagL1, scores, scores_names = calc_metrics_from_embedding(X, n_ts=n_ts, metrics=metrics, target_scale=target_scale)
        mags.append(Mag)
        magsl2.append(MagL2)
        magsl1.append(MagL1)
        rows[idx]=scores
        Xs.append(X)
    
    if reference_summaries:
        ref_scores = compute_ref_metrics(Xs, Xs[0], nearest_k=nearest_k)
    else:
        ref_scores = pd.DataFrame()

    df_scores = pd.DataFrame.from_dict(rows, orient='index', columns=scores_names)
    #df_scores=df_scores.loc[:,df_scores.isnull().sum()<10]
    mag_diffs={}
    mag_dfs={}

    for met in metrics:
        t_cut=0
        print(met)
        if met == "cosine":
            funs = mags.copy()
        elif met == "L2":
            funs = magsl2.copy()
        elif met == "L1":
            funs = magsl1.copy()
        else:
            raise ValueError("metric not implemented")

        if reference_scale == "reference":
            t_cut = funs[0].get_t_conv()
        else:
            t_cut = find_common_scale(funs, n_samples=n_samples, q=reference_scale)
            t_min = find_common_scale(funs, n_samples=n_samples, q=0)
            t_med = find_common_scale(funs, n_samples=n_samples, q=0.5)
            t_max = find_common_scale(funs, n_samples=n_samples, q=1)
            
        mag_area=[]
        mag_t_min=[]
        mag_t_med=[]
        mag_t_max=[]
        for idx, Mag in enumerate(funs):
            Mag.change_scales(t_cut=t_cut)
            mag_area.append(Mag.get_magnitude_area(integration="trapz",
                absolute_area=absolute_area, scale=scale, plot=False))
            if reference_scale != "reference":
                mag_t_min.append(Mag._eval_at_scales([t_min], get_weights=False)[0][0])
                mag_t_med.append(Mag._eval_at_scales([t_med], get_weights=False)[0][0])
                mag_t_max.append(Mag._eval_at_scales([t_max], get_weights=False)[0][0])
            else:
                mag_t_min.append(Mag._eval_at_scales([t_cut*0.33], get_weights=False)[0][0])
                mag_t_med.append(Mag._eval_at_scales([t_cut*0.5], get_weights=False)[0][0])
                mag_t_max.append(Mag._eval_at_scales([t_cut], get_weights=False)[0][0])
        #if reference_scale != "reference":
        df_scores["mag_t_min_"+met]=mag_t_min
        df_scores["mag_t_med_"+met]=mag_t_med
        df_scores["mag_t_max_"+met]=mag_t_max
        df_scores["mag_area_"+met]=mag_area
        df_scores["t_cut_"+met]=t_cut

        mag_diffs[met], mag_dfs[met]=get_mag_diffs_and_dfs(funs, n_samples, n_ts, scale=scale, absolute_area=absolute_area)
        ref_scores["mag_diff_"+met]=mag_diffs[met][0,:]/mag_area[0]
        if met == "cosine":
            mags = funs.copy()
        elif met == "L2":
            magsl2 = funs.copy()
        elif met == "L1":
            magsl1 = funs.copy()

    return {"summary_statistics": df_scores, "reference_scores": ref_scores, "magnitude_differences": mag_diffs, 
            "magnitude_function_dfs": mag_dfs, "mag_cosine": mags, "mag_l2": magsl2, "mag_l1":magsl1}

def get_mag_diffs_and_dfs(funs, n_samples, n_ts, scale=True, absolute_area=True):
    mat=np.zeros((n_samples, n_samples)) 
    dff=np.zeros((n_samples, n_ts))       
    for j, nn in enumerate(funs):
        dff[j,:]=nn.get_magnitude()[0]
        ### for each point against each other point
        for j2 in range(j+1, len(funs)):
            n2 = funs[j2]
            mag_diff = n2.get_magnitude_difference(nn, scale=scale, plot=False, absolute_area=absolute_area)
            #print(mag_diff)
            mat[j, j2] = mag_diff#/n_samples
            mat[j2, j] = mag_diff#/n_samples
    return mat, dff # magnitude differences, magnitude functions as array

def save_magnitude_results(mag_results, output_path, reference_summaries = False):
    mag_results["summary_statistics"].to_csv(output_path+"_scores.csv")
    if reference_summaries:
        mag_results["reference_scores"].to_csv(output_path+"_reference_scores.csv")

    with open(output_path+"_magnitude_results"+'.pkl', 'wb') as fp:
        pickle.dump(mag_results, fp)
        print('magnitude saved successfully to file')

def compute_pairwise_distance(data_x, data_y=None):
    """
    Args:
        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
    Returns:
        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
    """
    if data_y is None:
        data_y = data_x
    dists = pairwise_distances(
        data_x, data_y, metric='euclidean', n_jobs=8)
    return dists


def get_kth_value(unsorted, k, axis=-1):
    """
    Args:
        unsorted: numpy.ndarray of any dimensionality.
        k: int
    Returns:
        kth values along the designated axis.
    """
    indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
    k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
    kth_values = k_smallests.max(axis=axis)
    return kth_values


def compute_nearest_neighbour_distances(input_features, nearest_k):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    distances = compute_pairwise_distance(input_features)
    radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    return radii


def compute_ref_metrics(Xs, real_features, nearest_k=10):
    real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        real_features, nearest_k)
    this_row={}
    for ind, fake_features in enumerate(Xs):
        fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        fake_features, nearest_k)
        distance_real_fake = compute_pairwise_distance(
            real_features, fake_features)
    
        precision = (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
        ).any(axis=0).mean()

        recall = (
                distance_real_fake <
                np.expand_dims(fake_nearest_neighbour_distances, axis=0)
        ).any(axis=1).mean()

        density = (1. / float(nearest_k)) * (
                distance_real_fake <
                np.expand_dims(real_nearest_neighbour_distances, axis=1)
        ).sum(axis=0).mean()

        coverage = (
                distance_real_fake.min(axis=1) <
                real_nearest_neighbour_distances
        ).mean()
        this_row[ind]=[precision, recall, density, coverage]
    df = pd.DataFrame.from_dict(this_row, orient='index', columns=["precision", "recall", "density", "coverage"])
    return df

def compute_prdc(real_features, fake_features, nearest_k=10):

    real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        real_features, nearest_k)
    fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        fake_features, nearest_k)
    distance_real_fake = compute_pairwise_distance(
        real_features, fake_features)

    
    precision = (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).any(axis=0).mean()

    recall = (
            distance_real_fake <
            np.expand_dims(fake_nearest_neighbour_distances, axis=0)
    ).any(axis=1).mean()

    density = (1. / float(nearest_k)) * (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).sum(axis=0).mean()

    coverage = (
            distance_real_fake.min(axis=1) <
            real_nearest_neighbour_distances
    ).mean()

    return precision, recall, density, coverage