from matplotlib import rc
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from keras.models import load_model


def plot_shap_vals(
	shapvals_filepath, 
	data_filepath, 
	data_indices_filepath,
	indirect_output_filepath, 
	n_instances):

	rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
	## for Palatino and other serif fonts use:
	rc('font',**{'family':'serif','serif':['Palatino']})
	rc('text', usetex=True)
	
	shap.initjs()



	zip_package = np.load(data_filepath, mmap_mode="r", encoding="bytes")
	#images = zip_package["imgs"]
	#latents_classes = zip_package["latents_classes"][:,1:]
	latents_values = zip_package["latents_values"][:,1:]
	
	idxs = pd.read_csv(data_indices_filepath)["data_indices"]

	# match indices with training
	#images = images[idxs]
	latents_values = latents_values[idxs]
	shapeIsHeart = (latents_values[:,0] == 1)
	shapeIsHeart = shapeIsHeart.astype(int)
	latents_values[:,0] = shapeIsHeart  # replace shape with (shape == Heart)
	#latents_classes = latents_classes[idxs]
	#flat_images = images.reshape(images.shape[0], x_dim * y_dim)
		
	shap_vals_df = pd.read_csv(shapvals_filepath)
	feat_names = ["shape", "scale", "orient.", "x pos.", "y pos."]
	
	feats = latents_values[0:n_instances] ## REMOVE [:,[0] once using all features
	shap_vals = shap_vals_df[0:n_instances].values
	print(feats.shape)
	print(shap_vals.shape)
	
	shap.summary_plot(shap_vals, feats, show=False, plot_type="dot", sort=False, feature_names= feat_names)
	
	plt.savefig(indirect_output_filepath)
	plt.clf()



plot_shap_vals(shapvals_filepath="results_2/shap_values.csv", 
	data_filepath="../../data/dsprites/reduced_dsprites_16.npz",
	data_indices_filepath="results_2/data_indices.csv",
	indirect_output_filepath="results_2/indirect_influence_distributions2.png",
	n_instances=3000)




