from matplotlib import rc
import pandas as pd
import shap
import matplotlib.pyplot as plt
from shap.plots import colors

FEATURES = ["x", "x2", "xSquared", "y", "y2", "ySquared", "z", "z2", "zSquared"]
RECONSTRUCTION_ERROR_FILENAME_STUB = "trained/csv_files/reconstruction_error_"

def get_reconstruction_filename(stub, featurename):
    return stub + featurename + ".csv"

def plot_shap_vals(
    shapvals_filepath, 
    data_filepath, 
    indirect_output_filepath, 
    n_instances):

    rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
    ## for Palatino and other serif fonts use:
    rc('font',**{'family':'serif','serif':['Palatino']})
    rc('text', usetex=True)

    shap.initjs()

    shap_vals_df = pd.read_csv(shapvals_filepath)
    feats_df = pd.read_csv(data_filepath)
    make_plot(shap_vals_df, feats_df, indirect_output_filepath, n_instances)

def make_plot(shap_df, feats_df, outfile, n_instances):
    feat_names = ["x","x2","xSquared","y","y2","ySquared","z","z2","zSquared"]

    feats_df = feats_df[feat_names]
    shap_vals_df = shap_df[feat_names]

    feats = feats_df[0:n_instances].abs().values
    shap_vals = shap_vals_df[0:n_instances].values

    # color = colors.gray_lch ?
    shap.summary_plot(shap_vals, feats, show=False, plot_type="dot", sort=False, feature_names= [r"$x$",r"$2x$",\
                                                        r"$x^2$",r"$y$",r"$2y$", r"$y^2$",r"$c$",r"$2c$",r"$c^2$"])

    #plt.show()
    plt.savefig(outfile)
    plt.clf()

def combine_error_percol(filename_stub, features):
    combined_df = pd.DataFrame()
    frames = []
    for feature in features:
        filename = get_reconstruction_filename(filename_stub, feature)
        df = pd.read_csv(filename)
        frames.append(df)
    return pd.concat(frames)

def copy_shap_df(shapfile, features):
    combined_df = pd.DataFrame()
    frames = []
    for feature in features:
        df = pd.read_csv(shapfile)
        frames.append(df)
    return pd.concat(frames)


#plot_shap_vals(shapvals_filepath="results/shap_values.csv", 
#	data_filepath="../../data/synthetic/sum_synthetic.csv",
#	indirect_output_filepath="results/indirect_influence_distributions.png",
#	n_instances=3000) # @sorelle, you probably don't want to change the n_instances


plot_shap_vals(shapvals_filepath="trained/csv_files/shap_values.csv", 
    data_filepath="trained/csv_files/prediction_error.csv",
    indirect_output_filepath="trained/figures/predictionerror_influence_distributions.png",
    n_instances=3000) # @sorelle, you probably don't want to change the n_instances

#plot_shap_vals(shapvals_filepath="trained/csv_files/shap_values.csv", 
#	data_filepath="trained/csv_files/discriminator_error.csv",
#	indirect_output_filepath="trained/figures/discriminatorerror_influence_distributions.png",
#	n_instances=3000) # @sorelle, you probably don't want to change the n_instances

#plot_shap_vals(shapvals_filepath="trained/csv_files/shap_values.csv",
#    data_filepath=filename,
#    indirect_output_filepath="trained/figures/reconstructionerror_influence_distributions.png",
#    n_instances=3000) # @sorelle, you probably don't want to change the n_instances

feats_df = combine_error_percol(RECONSTRUCTION_ERROR_FILENAME_STUB, FEATURES)
shap_df = copy_shap_df("trained/csv_files/shap_values.csv", FEATURES)
make_plot(shap_df, feats_df, "trained/figures/reconstructionerror_influence_distributions.png", 3000)

