import numpy
import numpy as np
from algorithms import *
from examples import *
import matplotlib.pyplot as plt

def compute_error_bars(x,low,high):
    n = len(x)
    error_bars = numpy.zeros((2,n))
    for i in range(n):
        error_bars[0][i] = numpy.median(x[i]) - numpy.percentile(x[i],low)
        error_bars[1][i] = numpy.percentile(x[i],high) - numpy.median(x[i])
    return error_bars
    
def test_HE_eigenbound(reps = 10):
    eigenbound = numpy.logspace(-2,2,9)
    error = []
    for L in eigenbound:
        print(L)
        total_error = []
        for i in range(reps):
            X,Y,Z,theta = heterogeneous_effects_data(10000, 20, 0.1)
            moments = IV_Moments(X,Y,Z)
            wnaive = moments.optimize()
            w0 = repeated_gmm_sever(moments,L,20,L,10)
            total_error.append(numpy.linalg.norm(w0 - theta))
        error.append(total_error)
    error_bars = compute_error_bars(error,25,75)
    median_error = [numpy.median(l) for l in error]
    plt.errorbar(corruption,median_error,yerr=error_bars,label="GMM-Sever error",capsize=3)
    plt.legend()
    plt.ylim(bottom=0)
    plt.xlabel("Choice of hyperparameter L")
    plt.ylabel("l2 recovery error")
    plt.show()
    

def test_HE_corruption(d = 20, reps = 10):
    corruption = numpy.linspace(0.01, 0.1, 2)
    error = []
    iv_error = []
    huber_error = []
    ctr=-1
    for epsilon in corruption:
        ctr+=1
        print(epsilon)
        total_error = []
        total_naive_error = []
        total_huber_error = []
        for i in range(reps):
            X,Y,Z,theta = heterogeneous_effects_data(10000, d, epsilon)
            moments = IV_Moments(X,Y,Z)
            wnaive = moments.optimize()
            w0 = repeated_gmm_sever(moments,0.25,d,0.25,10)
            total_error.append(numpy.linalg.norm(w0 - theta))
            total_iv_error.append(numpy.linalg.norm(wnaive - theta))
            w_huber = two_stage_robust_sls(X,Y,Z)
            total_huber_error.append(numpy.linalg.norm(w_huber - theta))
        error.append(total_error)
        iv_error.append(total_iv_error)
        huber_error.append(total_huber_error)
    error_bars = compute_error_bars(error,25,75)
    iv_error_bars = compute_error_bars(iv_error,25,75)
    huber_error_bars = compute_error_bars(huber_error,25,75)
    median_error = [numpy.median(l) for l in error]
    median_iv = [numpy.median(l) for l in iv_error]
    median_huber = [numpy.median(l) for l in huber_error]
    plt.errorbar(corruption,median_error,yerr=error_bars,label="GMM-Sever error",capsize=3)
    plt.errorbar(corruption,median_iv,yerr=iv_error_bars,label="IV error",capsize=3)
    plt.errorbar(corruption,median_huber,yerr=huber_error_bars,label="2S Huber error",capsize=3)
    plt.legend()
    plt.ylim(bottom=0)
    plt.xlabel("Fraction of corrupted samples")
    plt.ylabel("l2 recovery error")
    plt.show()

def test_weak_instruments(d=20,reps=10):
    alpha_list = numpy.logspace(-1, 1, 10)
    error = []
    iv_error = []
    huber_error = []
    orig_error = []
    for alpha in alpha_list:
        sub_gmm = []
        sub_iv = []
        sub_huber = []
        sub_orig = []
        for i in range(reps):
            X,Y,Z,theta,oe = corrupted_weak_instruments_data(10000,d,alpha,0.01)
            moments = IV_Moments(X,Y,Z)
            wnaive = moments.optimize()
            w0 = repeated_gmm_sever(moments, 0.1, d, 0.1, 10)
            sub_gmm.append(numpy.linalg.norm(w0-theta))
            sub_iv.append(numpy.linalg.norm(wnaive-theta))
            w_huber = two_stage_robust_sls(X,Y,Z)
            sub_huber.append(numpy.linalg.norm(w_huber-theta))
            sub_orig.append(oe)
        error.append(sub_gmm)
        iv_error.append(sub_iv)
        huber_error.append(sub_huber)
        orig_error.append(sub_orig)
    error_bars = compute_error_bars(error,25,75)
    iv_error_bars = compute_error_bars(iv_error,25,75)
    huber_error_bars = compute_error_bars(huber_error,25,75)
    orig_error_bars = compute_error_bars(orig_error,25,75)
    median_error = [numpy.median(l) for l in error]
    median_iv = [numpy.median(l) for l in iv_error]
    median_huber = [numpy.median(l) for l in huber_error]
    median_orig = [numpy.median(l) for l in orig_error]
    plt.errorbar(alpha_list,median_error,yerr=error_bars,label="GMM-Sever error",capsize=3)
    plt.errorbar(alpha_list,median_iv,yerr=iv_error_bars,label="IV error",capsize=3)
    plt.errorbar(alpha_list,median_huber,yerr=huber_error_bars,label="Huber 2SLS error",capsize=3)
    plt.errorbar(alpha_list,median_orig,yerr=orig_error_bars,label="Clean IV error",capsize=3)
    plt.legend()
    plt.xlabel("Instrument strength (alpha)")
    plt.ylabel("l2 recovery error")
    plt.yscale("log")
    plt.xscale("log")
    plt.show()

def test_uncorrupted_HE(n=1000, d=20, trials=100):
    error = []
    iv_error = []
    huber_error = []
    zero_error = []
    for i in range(trials):
        X,Y,Z,theta = heterogeneous_effects_data(n, d, 0)
        moments = IV_Moments(X,Y,Z)
        wnaive = moments.optimize()
        w0 = repeated_gmm_sever(moments, 0.25, d, 0.25, 10)
        error.append(numpy.linalg.norm(w0 - theta))
        iv_error.append(numpy.linalg.norm(wnaive - theta))
        w_huber = two_stage_robust_sls(X,Y,Z)
        huber_error.append(numpy.linalg.norm(w_huber-theta))
        zero_error.append(numpy.linalg.norm(theta))
    print("Iterated GMM Sever error (median, 25th percentile, 75th percentile):", numpy.median(error),numpy.percentile(error,25),numpy.percentile(error,75))
    print("IV error (median, 25th percentile, 75th percentile):", numpy.median(iv_error),numpy.percentile(iv_error,25),numpy.percentile(iv_error,75))
    print("2S Huber error (median, 25th percentile, 75th percentile):", numpy.median(huber_error),numpy.percentile(huber_error,25),numpy.percentile(huber_error,75))
    print("Zero estimator error (median, 25th percentile, 75th percentile):", numpy.median(zero_error), numpy.percentile(zero_error,25),numpy.percentile(zero_error,75))

def uncorrupted_nlsym_eigenbound(trials=50):
    X, y, Z, T = nlsym_data()
    X0 = np.hstack([X * np.ones((X.shape[0], 1)), np.zeros(X.shape)])
    XT = np.hstack([X * T[:, None], X])
    XZ = np.hstack([X * Z[:, None], X])
    eigenbound_list = numpy.logspace(-3,2,30)
    ate_list = []
    for L in eigenbound_list:
        ate_sub_list = []
        for trial in range(trials):
            moments = IV_Moments(XT, y, XZ)
            w = repeated_gmm_sever(moments, L, 20, L, 10)
            ate = np.mean(X0 @ w)
            ate_sub_list.append(ate)
        ate_list.append(ate_sub_list)
    error_bars = compute_error_bars(ate_list, 25, 75)
    median_ate = [np.median(l) for l in ate_list]
    plt.errorbar(eigenbound_list, median_ate, yerr=error_bars, label="GMM-Sever ATE", capsize=3)
    plt.legend()
    plt.xlabel("Choice of hyperparameter L")
    plt.ylabel("Estimated average treatment effect")
    plt.xscale("log")
    plt.show()

def corrupted_nlsym_eigenbound(epsilon=0.1, trials=50):
    eigenbound_list = numpy.logspace(-3,2,30)
    ate_list = []
    iv_list = []
    X,y,Z,T = nlsym_data()
    X0 = np.hstack([X * np.ones((X.shape[0], 1)), np.zeros(X.shape)])
    XT = np.hstack([X * T[:, None], X])
    XZ = np.hstack([X * Z[:, None], X])
    moments = IV_Moments(XT, y, XZ)
    w = moments.optimize()
    truth = np.mean(X0 @ w)
    for L in eigenbound_list:
        ate_sub_list = []
        iv_sub_list = []
        for trial in range(trials):
            X, y, Z, T = corrupted_nlsym_data(epsilon)
            X0 = np.hstack([X * np.ones((X.shape[0], 1)), np.zeros(X.shape)])
            XT = np.hstack([X * T[:, None], X])
            XZ = np.hstack([X * Z[:, None], X])
            moments = IV_Moments(XT, y, XZ)
            wnaive = moments.optimize()
            w = repeated_gmm_sever(moments, L, 20, L, 10)
            ate_sub_list.append(abs(truth-np.mean(X0@w)))
            iv_sub_list.append(abs(truth-np.mean(X0@wnaive)))
        ate_list.append(ate_sub_list)
        iv_list.append(iv_sub_list)
    error_bars = compute_error_bars(ate_list, 25, 75)
    iv_error_bars = compute_error_bars(iv_list, 25, 75)
    median_ate = [np.median(l) for l in ate_list]
    median_iv = [np.median(l) for l in iv_list]
    plt.errorbar(eigenbound_list, median_ate, yerr=error_bars, label="GMM-Sever error", capsize=3)
    plt.errorbar(eigenbound_list, median_iv, yerr=iv_error_bars, label="IV error", capsize=3)
    plt.legend()
    plt.xlabel("Choice of hyperparameter L")
    plt.ylabel("Median (absolute) error in ATE")
    plt.xscale("log")
    plt.yscale("log")
    plt.show()

def corrupted_nlsym_corruption(trials=10, gmm_sever_reps = 50):
    corruption_list = numpy.linspace(0.01, 0.2, 10)
    ate_list = []
    iv_list = []
    huber_list = []
    X,y,Z,T = nlsym_data()
    X0 = np.hstack([X * np.ones((X.shape[0], 1)), np.zeros(X.shape)])
    XT = np.hstack([X * T[:, None], X])
    XZ = np.hstack([X * Z[:, None], X])
    moments = IV_Moments(XT, y, XZ)
    w = moments.optimize()
    truth = np.mean(X0 @ w)
    for epsilon in corruption_list:
        ate_sub_list = []
        iv_sub_list = []
        huber_sub_list = []
        for trial in range(trials):
            X, y, Z, T = corrupted_nlsym_data(epsilon)
            X0 = np.hstack([X * np.ones((X.shape[0], 1)), np.zeros(X.shape)])
            XT = np.hstack([X * T[:, None], X])
            XZ = np.hstack([X * Z[:, None], X])
            moments = IV_Moments(XT, y, XZ)
            wnaive = moments.optimize()
            l = []
            for j in range(gmm_sever_reps):
                try:
                    w = repeated_gmm_sever(moments, 0.01, 20, 0.01, 10)
                except:
                    continue
                l.append(np.mean(X0@w))
            ate_sub_list.append(abs(truth-np.median(l)))
            iv_sub_list.append(abs(truth-np.mean(X0@wnaive)))
            whuber = two_stage_robust_sls(XT,y,XZ)
            huber_sub_list.append(abs(truth-np.mean(X0@whuber)))
        ate_list.append(ate_sub_list)
        iv_list.append(iv_sub_list)
        huber_list.append(huber_sub_list)
    error_bars = compute_error_bars(ate_list, 25, 75)
    iv_error_bars = compute_error_bars(iv_list, 25, 75)
    huber_error_bars = compute_error_bars(huber_list, 25, 75)
    median_ate = [np.median(l) for l in ate_list]
    median_iv = [np.median(l) for l in iv_list]
    median_huber = [np.median(l) for l in huber_list]
    plt.errorbar(corruption_list, median_ate, yerr=error_bars, label="GMM-Sever error", capsize=3)
    plt.errorbar(corruption_list, median_iv, yerr=iv_error_bars, label="IV error", capsize=3)
    plt.errorbar(corruption_list, median_huber, yerr=huber_error_bars, label = "Huber 2SLS error", capsize=3)
    plt.legend()
    plt.xlabel("Fraction of corrupted samples")
    plt.ylabel("Median (absolute) error in ATE")
    plt.yscale("log")
    plt.show()
    
                        
            
    
