import pandas as pd
import numpy as np
from predict_sum import sum_predictor
from DisentanglingInfluence.Influence.influence import DR_influence
from DisentanglingInfluence.Disentangling.disentangle import disentangle
from DisentanglingInfluence.utils import DataGenerator

# The number of data points to explain with SHAP values (the full dataset takes a while)
N=1000

output_dir = "../../outputs/synthetic_test/"

# read and parse the input data
data = pd.read_csv("../../data/synthetic/sum_synthetic.csv")
feature_names = ["x","x2","xSquared","y","y2","ySquared","z","z2","zSquared"]
label_name = "xPy_Label"

# train a regressor for the label (x+y) from the features
clf = sum_predictor()
clf.save("./models/sum_predictor.h5")


# calculates the mean squared error
mse = lambda x, y: np.mean((x-y)**2)

protected_feats = ["x","y"]
print("Calculating Influence for [x,y] jointly...")
# unprotected_names = [feat for feat in feature_names if feat is not protected_feat]
features = data[feature_names].values
protected = data[protected_feats].values
labels = data["xPy_Label"].values
n_instances, n_feats = features.shape
gen = DataGenerator([features, protected], batch_size=16)

# train the models to disentangle the data
FullModel, Enc, Dec, Disc, AutoEncoder = disentangle(data_generator=gen,
											latent_dim=1, 
											disc_weight=0.5,
											n_feats=n_feats, n_protected=len(protected_feats),
											train_steps=8000, enc_layer_sizes=[],
											dec_layer_sizes=[3], disc_layer_sizes=[5,5],
											dec_final_activ=None, 
											disc_final_activ=None,
											output_dir=output_dir)

# generate all the various representations
unprotected = Enc.predict(features)  # latent representation
dis_rep = [unprotected, protected] # full disentangled representation
autoencoded = Dec.predict(dis_rep) # reconstructed original features
phat = Disc.predict(unprotected) # revealed protected information
preds = clf.predict(features) # model predictions on original data
reconstructed_preds = clf.predict(autoencoded) # model predictions on reconstructed data
# error metrics
reconstruction_error = mse(features, autoencoded)
prediction_error = mse(preds, reconstructed_preds)
discriminator_mse = mse(phat, protected)

# choose a subset of the data to explain the influence on
explain = [rep[0:N] for rep in dis_rep]

# calculate the influence!
influence = DR_influence(decoder=Dec, 
	black_box=clf, 
	disentangled_reps=dis_rep, 
	labels=labels, 
	explain=explain)
protected_influence = influence[0][1]
x_joint_influence = protected_influence[:,0].flatten()
y_joint_influence = protected_influence[:,1].flatten()


print("Reconstruction Error:", reconstruction_error)
print("Prediction Error:", prediction_error)
print("discriminator MSE:", discriminator_mse)

# report results (influence and error metrics)
influence_dict = {"x_joint_influence":x_joint_influence,"y_joint_influence":y_joint_influence}
influences = pd.DataFrame.from_dict(influence_dict)

influences.to_csv("results/XY_shap_values.csv", index=False)


