import random 
import numpy as np
from scipy.special import expit
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
import copy
import matplotlib.pyplot as plt

from global_functions import *


def sigmoid(x):
	return 1 / (1 + np.exp(-x))

def pi_star_X1(C1, C2): 
	prob_X1 = sigmoid(1*(C1+C2)-2)
	return prob_X1

def pi_X1(C1, C2): 
	prob_X1 = sigmoid(0.5*(C1+C2)-1)
	return prob_X1

def pi_star_X2(X1, W, C1, C2): 
	prob_X2 = sigmoid(0.5*(C1+C2) + 2*(2*X1-1) - 0.5*W+1)
	return prob_X2

def pi_X2(X1, W, C1, C2): 
	prob_X2 = sigmoid(1*(C1+C2) + 1*(2*X1-1) + 0.5*W-1)
	return prob_X2

def generate_samples(n0, n1, n2, seednum):
	np.random.seed(seednum)
	random.seed(123)
	
	def fU(n):
		U_X1W = np.random.normal(size = n)
		U_X1X2 = np.random.normal(size = n)
		U_X2W = np.random.normal(size = n)
		U_X2Y = np.random.normal(size = n)
		return U_X1W, U_X1X2, U_X2W, U_X2Y

	def fC1(n, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) + 0.1*S
		return np.random.normal(loc=0, scale=0.5, size=(n)) + S_noise

	def fC2(n, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) + 0.1*S
		return np.random.normal(loc=0, scale=0.5, size=(n)) + S_noise

	def fX1(C1, C2, S):
		if S == 0:
			pi_val = pi_star_X1(C1, C2)
		else:
			pi_val = pi_X1(C1, C2)
		X1 = (np.random.rand(len(pi_val)) < pi_val).astype(int)
		return X1 

	def fW(n, C1, C2, X1, U_X1W, S): 
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n))
		W = sigmoid(0.5*(C1+C2)-1 + 3*X1 + 0.5*U_X1W + S_noise )
		return W

	def fX2(X1, W, C1, C2, S):
		if S == 0: 
			pi_val = pi_star_X2(X1, W, C1, C2)
			# X2 = np.round(pi_star_X2(X1, W, C1, C2))
		else:
			pi_val = pi_X2(X1, W, C1, C2)
			# X2 = np.round(pi_X2(X1, W, C1, C2))
		X2 = (np.random.rand(len(pi_val)) < pi_val).astype(int)
		return X2 

	def fY(n, C1, C2, X1, X2, W, U_X2Y, S):
		S_noise = 0.25*S*np.random.normal(loc=0, scale=1, size=(n)) 
		Y = sigmoid( 0.5*(C1 + C2) + (2*X1-1) + (2*X2-1) - 0.5*W + 0.1*U_X2Y + S_noise)
		return Y 

	# data S=0 # Target 
	n_target = n0
	U_X1W_target, U_X1X2_target, U_X2W_target, U_X2Y_target = fU(n_target)
	C1_target = fC1(n_target, S=0)
	C2_target = fC2(n_target, S=0)
	X1_target = fX1(C1_target, C2_target, S=0)
	W_target = fW(n_target, C1_target, C2_target, X1_target, U_X1W_target, S=0)
	X2_target = fX2( X1_target, W_target, C1_target, C2_target, S=0)
	Y_target = fY(n_target, C1_target, C2_target, X1_target, X2_target, W_target, U_X2Y_target, S=0)
	data_S0 = pd.DataFrame(np.column_stack((C1_target, C2_target, X1_target, W_target, X2_target, Y_target)), columns=['C1', 'C2', 'X1', 'W', 'X2', 'Y'])


	# data S=1 # Source 1 
	U_X1W_1, U_X1X2_1, U_X2W_1, U_X2Y_1 = fU(n1)
	C1_1 = fC1(n1, S=1)
	C2_1 = fC2(n1, S=1)
	X1_1 = fX1(C1_1, C2_1, S=1)
	W_1 = fW(n1, C1_1, C2_1, X1_1, U_X1W_1, S=0) # W is invariant in S=0 and S=1
	X2_1 = fX2(X1_1, W_1, C1_1, C2_1, S=1)
	Y_1 = fY(n1, C1_1, C2_1, X1_1, X2_1, W_1, U_X2Y_1, S=1)
	data_S1 = pd.DataFrame(np.column_stack((C1_1, C2_1, X1_1, W_1, X2_1, Y_1)), columns=['C1', 'C2', 'X1', 'W', 'X2', 'Y'])


	# data S=2 # Source 2 
	U_X1W_2, U_X1X2_2, U_X2W_2, U_X2Y_2 = fU(n2)
	C1_2 = fC1(n2, S=2)
	C2_2 = fC2(n2, S=2)
	X1_2 = fX1(C1_2, C2_2, S=2)
	W_2 = fW(n2, C1_2, C2_2, X1_2, U_X1W_2, S=2) 
	X2_2 = fX2(X1_2, W_2, C1_2, C2_2, S=2)
	Y_2 = fY(n2, C1_2, C2_2, X1_2, X2_2, W_2, U_X2Y_2, S=0) # Y is invariant in S=0 and S=2
	data_S2 = pd.DataFrame(np.column_stack((C1_2, C2_2, X1_2, W_2, X2_2, Y_2)), columns=['C1', 'C2', 'X1', 'W', 'X2', 'Y'])

	return data_S0, data_S1, data_S2

def evaluate_DML(data_S0_cov, data_S1, data_S2, seednum, L=5, add_noise_TF = False):
	def compute_check_mu2(col_feature_mu2, mu2_model, data):
		''' 
		Compute sum_{x2} mu2(C1,C2,X1, W, x2) * pi_star_X2(X1, W, C1, C2) 
		'''
		# Evaluate mu2(C1, C2, X1, W, X2=0) at S=2 
		data_X2_x0 = data.copy()
		data_X2_x0['X2'] = 0
		matrix_data_X2_x0 = xgb.DMatrix(data_X2_x0[col_feature_mu2])
		mu2_X2_x0 = add_noise( mu2_model.predict(matrix_data_X2_x0), add_noise_TF )
		
		# Evaluate mu2(C1,C2,X1, W, X2=1)
		data_X2_x1 = data.copy()
		data_X2_x1['X2'] = 1
		matrix_data_X2_x1 = xgb.DMatrix(data_X2_x1[col_feature_mu2])
		mu2_X2_x1 = add_noise( mu2_model.predict(matrix_data_X2_x1), add_noise_TF )
		
		# Compute check_mu2(C1,C2,X1,W) := \sum_{x2} mu2(C1,C2,X1, W, x2) * pi_star_X2(X1, W, C1, C2)
		pi_star_X2_val = np.array( pi_star_X2(data['X1'], data['W'], data['C1'], data['C2']) )
		check_mu2_S2 = (mu2_X2_x1 * pi_star_X2_val) + (mu2_X2_x0 * (1-pi_star_X2_val))
		return check_mu2_S2

	def compute_check_mu1(col_feature_mu1, mu1_model, data):
		# Compute \sum_{x1} mu1(C1,C2,x1) * pi_star_X1(C1, C2)
		## Evaluate mu1(C1,C2,X1 = 0) 
		data_X1_x0 = data.copy()
		data_X1_x0['X1'] = 0
		matrix_data_X1_x0 = xgb.DMatrix(data_X1_x0[col_feature_mu1])
		mu1_X1_x0 = add_noise( mu1_model.predict(matrix_data_X1_x0), add_noise_TF )

		## Evaluate mu1(C1,C2,X1 = 1) 
		data_X1_x1 = data.copy()
		data_X1_x1['X1'] = 1
		matrix_data_X1_x1 = xgb.DMatrix(data_X1_x1[col_feature_mu1])
		mu1_X1_x1 = add_noise( mu1_model.predict(matrix_data_X1_x1), add_noise_TF )

		## Compute check_mu1(C1,C2) := \sum_{x1} mu1(C1,C2,x1) * pi_star_X1(C1, C2)
		pi_star_X1_val = pi_star_X1(data['C1'], data['C2'])		
		check_mu1_S1 = (mu1_X1_x1 * pi_star_X1_val) + mu1_X1_x0 * (1-pi_star_X1_val)
		return check_mu1_S1

	np.random.seed(seednum)
	random.seed(123)

	results_OM = []
	results_PW = []
	results_DML = []
	kf = KFold(n_splits=L, shuffle=True)

	for train_index, test_index in kf.split(data_S1):
		'''
		Estimate OM 
		'''
		# Split the samples data_S1 and data_S2
		data_S1_train, data_S1_test = data_S1.iloc[train_index], data_S1.iloc[test_index]
		data_S2_train, data_S2_test = data_S2.iloc[train_index], data_S2.iloc[test_index]

		# Learn mu2_model := mu2(C1,C2,X1,W,X2) := E_{P2_pi2}[Y | C1,C2,X1,W,X2] by regressing Y onto {C1,C2,X1,W,X2} using S=2
		col_feature_mu2 = ['C1','C2','X1','W','X2']
		col_label_mu2 = ['Y']
		mu2_params = {
			'booster': 'gbtree',
			'eta': 0.3,
			'gamma': 0,
			'max_depth': 10,
			'min_child_weight': 1,
			'subsample': 1.0,
			'colsample_bytree': 1,
			'lambda': 0.0,
			'alpha': 0.0,
			'objective': 'reg:squarederror',
			'eval_metric': 'rmse',
			'n_jobs': 4  # Assuming you have 4 cores
		}
		mu2_model = learn_mu(data_S2, col_feature_mu2, col_label_mu2, mu2_params)

		# Compute \sum_{x2} mu2(C1,C2,X1, W, x2) * pi_star_X2(X1, W, C1, C2), 
		# where mu2 and pi_star_X2 are evaluated from data_S1
		check_mu2_S1 = compute_check_mu2(col_feature_mu2, mu2_model, data_S1_train)

		# Learn mu1_model := mu1(C1,C2,X1) := E_{P2_pi1}[check_mu2 | C1,C2,X1] by regressing check_mu2 onto {C1,C2,X1} using S=1
		data_S1_train_mu1 = data_S1_train.copy()
		data_S1_train_mu1['check_mu2'] = check_mu2_S1
		col_feature_mu1 = ['C1','C2','X1']
		col_label_mu1 = ['check_mu2']
		mu1_params = {
			'booster': 'gbtree',
			'eta': 0.3,
			'gamma': 0,
			'max_depth': 10,
			'min_child_weight': 1,
			'subsample': 1.0,
			'colsample_bytree': 1,
			'lambda': 0.0,
			'alpha': 0.0,
			'objective': 'reg:squarederror',
			'eval_metric': 'rmse',
			'n_jobs': 4  # Assuming you have 4 cores
		}
		mu1_model = learn_mu(data_S1_train_mu1, col_feature_mu1, col_label_mu1, mu1_params)

		# Compute \sum_{x1} mu1(C1,C2,x1) * pi_star_X1(C1, C2) evaluated from data_S0
		check_mu1_S0 = compute_check_mu1(col_feature_mu1, mu1_model, data_S0_cov)

		# OM
		result_OM = np.clip( np.mean(check_mu1_S0), 0, 1)
		results_OM.append(result_OM)

		'''
		Estimate PW 
		'''
		# Compute omega_2(C1,C2,X1,W,X2) = (P0(C1,C2) / P2(C1,C2)) * (pi_star_X1(C1,C2) / pi_X1(C1,C2)) * {(P1(W,C1,C2,X1)/P2(W,C1,C2,X1)) * (P2(C1,C2,X1)/P1(C1,C2,X1))} * (pi_star_X2(X1,W,C1,C2) / pi_X2(X1,W,C1,C2))
		## Estimate (P0(C1,C2) / P2(C1,C2))
		### Model for P(S | C) whree S=0 means C1,C2 from P0, and S=1 means C1,C2 from P2 
		lambda_params = {
			'booster': 'gbtree',
			'eta': 0.05,
			'gamma': 0,
			'max_depth': 6,
			'min_child_weight': 1,
			'subsample': 0.0,
			'colsample_bytree': 1,
			'objective': 'binary:logistic',  # Change as per your objective
			'eval_metric': 'logloss',  # Change as per your needs
			'reg_lambda': 0.0,
			'reg_alpha': 0.0,
			'nthread': 4
		}
		model_ratio_P0_over_P2_C1C2 = estimate_odds_ratio(data_S0_cov, data_S2_train, ['C1','C2'], len(data_S2_train),lambda_params)
		matrix_ratio_P0_over_P2_C1C2 = xgb.DMatrix(data_S2_test[['C1','C2']])
		pred_ratio_P0_over_P2_C1C2 = model_ratio_P0_over_P2_C1C2.predict(matrix_ratio_P0_over_P2_C1C2) # P(S=1 | C1,C2)
		### Compute P(S=0|C)/P(S=1|C)
		ratio_P0_over_P2_C1C2 = (1-pred_ratio_P0_over_P2_C1C2)/(pred_ratio_P0_over_P2_C1C2)

		## Estimate P1(W,C1,C2,X1)/P2(W,C1,C2,X1)
		### Model for P(S | W,C1,C2,X1) whree S=0 means W,C1,C2,X1 from P1, and S=1 means W,C1,C2,X1 from P2 
		model_ratio_P1_over_P2_WC1C2X1 = estimate_odds_ratio(data_S1_train, data_S2_train, ['W','C1','C2','X1'], len(data_S2_train),lambda_params)
		matrix_ratio_P1_over_P2_WC1C2X1 = xgb.DMatrix(data_S2_test[['W','C1','C2','X1']])
		pred_ratio_P1_over_P2_WC1C2X1 = model_ratio_P1_over_P2_WC1C2X1.predict(matrix_ratio_P1_over_P2_WC1C2X1)
		### Compute P(S=0 | W,C1,C2,X1)/P(S=1| W,C1,C2,X1)
		ratio_P1_over_P2_WC1C2X1 = (1-pred_ratio_P1_over_P2_WC1C2X1)/(pred_ratio_P1_over_P2_WC1C2X1)

		## Estimate P2(C1,C2,X1)/P1(C1,C2,X1)
		### Model for P(S | C1,C2,X1) whree S=0 means C1,C2,X1 from P2, and S=1 means C1,C2,X1 from P1 
		model_ratio_P2_over_P1_C1C2X1 = estimate_odds_ratio(data_S2_train, data_S1_train, ['C1','C2','X1'], len(data_S2_train),lambda_params)
		matrix_ratio_P2_over_P1_C1C2X1 = xgb.DMatrix(data_S2_test[['C1','C2','X1']])
		pred_ratio_P2_over_P1_C1C2X1 = model_ratio_P2_over_P1_C1C2X1.predict(matrix_ratio_P2_over_P1_C1C2X1)
		ratio_P2_over_P1_C1C2X1 = (1-pred_ratio_P2_over_P1_C1C2X1)/(pred_ratio_P2_over_P1_C1C2X1)

		## Estmate (pi_star_X1(C1,C2) / pi_X1(C1,C2))
		pi_star_X1_over_pi_X1 = np.array( pi_star_X1(data_S2_test['C1'], data_S2_test['C2']) / pi_X1(data_S2_test['C1'], data_S2_test['C2']) )

		## Estmate (pi_star_X2(X1,W,C1,C2) / pi_X2(X1,W,C1,C2))
		pi_star_X2_over_pi_X2 = np.array( pi_star_X2(data_S2_test['X1'], data_S2_test['W'], data_S2_test['C1'], data_S2_test['C2']) / pi_X2(data_S2_test['X1'], data_S2_test['W'], data_S2_test['C1'], data_S2_test['C2']) )

		# PW
		omega_2 = add_noise( ratio_P0_over_P2_C1C2 * (ratio_P1_over_P2_WC1C2X1 * ratio_P2_over_P1_C1C2X1) * pi_star_X1_over_pi_X1 * pi_star_X2_over_pi_X2, add_noise_TF)
		result_PW = np.clip( np.mean(data_S2_test['Y'] * omega_2), 0, 1) 
		results_PW.append(result_PW)

		'''
		Estimate DML 
		- E_{S2}[omega2 check_mu2_S2] + E_{S1}[omega1 (check_mu2_S1 -  check_mu1_S1) ] 
		'''
		# Compute \sum_{x2} mu2(C1,C2,X1, W, x2) * pi_star_X2(X1, W, C1, C2) with S2 
		check_mu2_S2 = compute_check_mu2(col_feature_mu2, mu2_model, data_S2_train)

		# Compute omega1 := {P0(C) / P1(C) } * {pi_star_X1(C) / pi_1(C)}
		## Compute {P0(C) / P1(C) }
		model_ratio_P0_over_P1_C1C2 = estimate_odds_ratio(data_S0_cov, data_S1_train, ['C1','C2'], len(data_S1_train), lambda_params)
		matrix_ratio_P0_over_P1_C1C2 = xgb.DMatrix(data_S1_test[['C1','C2']])
		pred_ratio_P0_over_P1_C1C2 = model_ratio_P0_over_P1_C1C2.predict(matrix_ratio_P0_over_P1_C1C2)
		### Compute P(S=0 | C1,C2)/P(S=1| C1,C2)
		ratio_P0_over_P1_C1C2 = (1-pred_ratio_P0_over_P1_C1C2)/(pred_ratio_P0_over_P1_C1C2)

		## Compute {pi_star_X1(C) / pi_1(C)}
		pi_star_X1_over_pi_X1 = np.array( pi_star_X1(data_S1_test['C1'], data_S1_test['C2']) / pi_X1(data_S1_test['C1'], data_S1_test['C2']) )		

		# Omega1
		omega_1 = add_noise( ratio_P0_over_P1_C1C2 * pi_star_X1_over_pi_X1, add_noise_TF)

		# check_mu1_S1
		check_mu1_S1 = compute_check_mu1(col_feature_mu1, mu1_model, data_S1_train)

		result_DML = np.clip( (result_OM + result_PW) - np.mean(omega_2 * check_mu2_S2) + np.mean(omega_1 * (check_mu2_S1 - check_mu1_S1)), 0, 1)
		results_DML.append(result_DML)

		

	return np.mean(results_OM), np.mean(results_PW), np.mean(results_DML)

def performance(truth, est_OM, est_PW, est_DML):
	table_data = {
		'Truth': truth,
		'OM': est_OM,
		'PW': est_PW, 
		'DML': est_DML
	}

	error_data = {
		'OM': np.abs(truth - est_OM),
		'PW': np.abs(truth - est_PW),
		'DML': np.abs(truth - est_DML)
	}

	return table_data, error_data


if __name__ == "__main__":
	experiment_seed = 190602
	n_list = [2500, 5000, 10000, 20000]
	rounds_simulations = 20
	seednum_idx = 1 
	L = 2
	add_noise_TF = False 
	n0 = 1000000

	mu_params = {
		'booster': 'gbtree',
		'eta': 0.3,
		'gamma': 0,
		'max_depth': 10,
		'min_child_weight': 1,
		'subsample': 1.0,
		'colsample_bytree': 1,
		'lambda': 0.0,
		'alpha': 0.0,
		'objective': 'reg:squarederror',
		'eval_metric': 'rmse',
		'n_jobs': 4  # Assuming you have 4 cores
	}

	lambda_params = {
		'booster': 'gbtree',
		'eta': 0.05,
		'gamma': 0,
		'max_depth': 6,
		'min_child_weight': 1,
		'subsample': 0.0,
		'colsample_bytree': 1,
		'objective': 'binary:logistic',  # Change as per your objective
		'eval_metric': 'logloss',  # Change as per your needs
		'reg_lambda': 0.0,
		'reg_alpha': 0.0,
		'nthread': 4
	}

	seednum_list = np.random.randint(1000000, size=rounds_simulations)
	avg_acc = {"OM":[], "PW":[], "DML": []}
	ci_acc = {"OM":[], "PW":[], "DML": []}

	for n in n_list:
		n1 = n2 = n 
		avg_acc_at_n = {"OM":[], "PW":[], "DML": []}
		for seednum in seednum_list:
			data_S0, data_S1, data_S2 = generate_samples(n0, n1, n2, seednum)
			data_S0_cov = data_S0[['C1','C2']]
			
			truth = np.mean(data_S0['Y'])
			est_OM, est_PW, est_DML = evaluate_DML(data_S0_cov, data_S1, data_S2, seednum, L, add_noise_TF)

			table_data, error_data = performance(truth, est_OM, est_PW, est_DML)
			avg_acc_at_n["OM"].append( error_data["OM"] )
			avg_acc_at_n["PW"].append( error_data["PW"] )
			avg_acc_at_n["DML"].append( error_data["DML"] )

			print(("%.3f%% completed") % (seednum_idx / (len(seednum_list) * len(n_list)) * 100))
			seednum_idx += 1 

		for method in ['OM', 'PW', 'DML']:
			mean, margin_of_error = mean_confidence_interval(avg_acc_at_n[method])
			avg_acc[method].append(mean)
			ci_acc[method].append(margin_of_error)

	print(table_data)
	print(error_data)

	location_file = "experiments/pkl/"
	location_fig = "experiments/plot/"
	param_filename = "param_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_gTR"
	result_filename = "result_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_gTR"
	extension = ".pkl"
	image_name = "plot_seed" + str(experiment_seed) + "_r" + str(rounds_simulations) + "_noise" + str(add_noise_TF) + "_gTR"

	mean_OM, err_OM = avg_acc['OM'], ci_acc['OM']
	mean_PW, err_PW = avg_acc['PW'], ci_acc['PW']
	mean_DML, err_DML = avg_acc['DML'], ci_acc['DML']

	# Plotting with confidence intervals
	plt.figure(figsize=(12, 10))  # 10 inches wide, 8 inches tall
	plt.errorbar(n_list, mean_OM, yerr=err_OM, label='OM', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.errorbar(n_list, mean_PW, yerr=err_PW, label='PW', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.errorbar(n_list, mean_DML, yerr=err_DML, label='DML', marker='o', capsize=10, markersize=16, linewidth=5, elinewidth=3)
	plt.xticks(ticks=n_list, labels=n_list, size=35)
	# plt.ylabel("Error", fontsize=35)
	plt.yticks(size=45)
	plt.legend(prop={'size': 30})
	plt.grid(False)
	plt.savefig(location_fig + image_name + ".pdf")
	plt.show()









