import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
from estimator import Estimator, Slope, LossEstimator,LossTestEstimator


class Generator():

    def __init__(self, K, d, d_star):

        self.K = K
        self.d = d
        self.d_star = d_star
        Sigmas = []
        means = []
        for a in range(K):
            A = np.random.normal(0, 1, (d, d))
            Sigma = A.T.dot(A)
            Sigmas.append(Sigma)
            means.append(np.random.normal(0, 1, d))

        self.distr = np.random.dirichlet(np.ones(self.K))

        self.Sigmas = Sigmas
        self.means = means

        self.theta = np.random.normal(0, 1, d)
        self.theta[d_star:] = 0.0
        self.theta = self.theta / np.linalg.norm(self.theta)
        self.sigma = 1.0
        self.n_eval = 500
        self.n_train = 200


        self.phis_train, self.Rs_train, _  = self.generate_mu(self.n_train)
        self.samples_test, _, self.rewards_test = self.generate_full_info(self.n_eval)

    def eval_opt(self):
        rewards = self.rewards_test
        max_rewards = np.max(rewards, axis=0)
        return np.mean(max_rewards)

    def eval(self, learner):
        rewards = self.rewards_test
        all_phis = self.samples_test
        learner_rewards = []
        values_tmp = []
        for i in range(rewards.shape[1]):
            phis = all_phis[:, i, :]
            a = learner.action(phis)
            learner_rewards.append(rewards[a, i])
            values_tmp.append(self.theta.dot(phis[a]) )

        return np.mean(learner_rewards)

    def eval2(self, learner, all_phis_valid):
        values, values_true = [], []
        for i in range(all_phis_valid.shape[1]):
            phis = all_phis_valid[:, i, :]
            a = learner.action(phis)
            values.append(learner.theta.dot(phis[a]))
            values_true.append(self.theta.dot(phis[a]))

        return np.mean(values)


    def generate_full_info(self, n):
        samples = []
        for a in range(self.K):
            covariates = np.random.multivariate_normal(self.means[a], self.Sigmas[a], n) 
            samples.append(covariates)

        samples = np.array(samples) # note that this is over dimension (a, n, d)
        noiseless_rewards = samples.dot(self.theta) # note that this is of dimension (a, n)

        eta = np.random.normal(0, self.sigma, noiseless_rewards.shape)
        rewards = noiseless_rewards.copy() + eta

        return samples, rewards, noiseless_rewards

    def generate_mu(self, n):
        samples, rewards, noiseless_rewards = self.generate_full_info(n)

        print "Dirichlet min: " + str(np.min(self.distr))
        actions = np.random.choice(list(range(self.K)), n, p=self.distr)
        indices = np.arange(n)

        noiseless_rs = noiseless_rewards[actions, indices]
        rs = rewards[actions, indices]
        xs = samples[actions, indices]

        return xs, rs, noiseless_rs



class Learner():
    def __init__(self, K, d):
        self.K = K
        self.d = d
        self.theta = np.zeros(self.d)

    def fit(self, Phi, R):
        if len(Phi) == 0:
            self.theta = np.zeros(self.d)
            return self.theta

        Phi = np.array(Phi)
        R = np.array(R)
        Phi = Phi[:, :self.d]
        A = Phi.T.dot(Phi) + self.d  * np.eye(self.d)
        b = Phi.T.dot(R)
        x = np.linalg.solve(A, b)
        self.theta = x
        self.Phi = Phi
        return self.theta

    def action(self, phis):
        phis = phis[:, :self.d]
        pred_rewards = phis.dot(self.theta)
        return np.argmax(pred_rewards)

TRIALS = 10
data_opt = []
data_slope = []
data_all = []
data_holdout = []
data_test = []

for T in range(TRIALS):

    K, d, d_star = 10, 200, 30
    gen = Generator(K, d, d_star)
    Phi_all, R_all = gen.phis_train.copy(), gen.Rs_train.copy()
    all_phis_valid, _, _ = gen.generate_full_info(gen.n_train)
    phis_valid_mu, r_valid_mu, _ = gen.generate_mu(gen.n_train)

    ds = [15, 20, 25, 28, 29, 30, 50, 75, 100, 200]
    N = 500
    ns = np.arange(20, gen.n_train + 1, 10)
    ns_losses = []

    learners = []
    all_vs = []
    slope_vs = []
    holdout_vs = []
    comp_vs = []
    test_vs = []

    for d_model in ds:
        learner = Learner(K, d_model)
        learners.append(learner)
        all_vs.append([])


    for n in ns:
        print "n: " + str(n)

        # FITTING AND EVALUATING ALL MODELS
        for k in range(len(learners)):
            learners[k].fit(Phi_all[:n], R_all[:n])
            v = gen.eval(learners[k])
            all_vs[k].append(v)
        


        # BUILDING SLOPE ESTIMATOR AND SELECTING K
        slope = Slope(learners, all_phis_valid)
        estimated_values = []
        real_values = []
        for i, learner in enumerate(learners):
            est_value = slope.estimate(learner)
            estimated_values.append(est_value)
            real_values.append(gen.eval(learner))

        k_hat = np.argmax(estimated_values)

        # EVALUATE SELECTED K
        v_slope = gen.eval(learners[k_hat])
        slope_vs.append(v_slope)


        # ESTIMATING LOSSES
        estimated_losses = []
        partial = max( int(.2 * n  / .8), 1)
        print "partial: " + str(partial)
        ns_losses.append(partial + n)
        loss_est = LossEstimator(phis_valid_mu[:partial], r_valid_mu[:partial])
        for learner in learners:
            est_loss = loss_est.estimate(learner)
            estimated_losses.append(est_loss)

        k_hat_loss = np.argmin(estimated_losses)

        # EVALUATE SELECTED K
        v_loss = gen.eval(learners[k_hat_loss])
        holdout_vs.append(v_loss)




        # ESTIMATING LOSSES WITH TEST
        estimated_losses = []
        loss_est = LossTestEstimator(n, learners, phis_valid_mu[:partial], r_valid_mu[:partial])
        k_hat_test = loss_est.select()

        v_losstest = gen.eval(learners[k_hat_test])
        test_vs.append(v_losstest)

        print "length: " + str(len(test_vs))




        # ESTIMATING LOSSES WITH COMPLEXITY REGULARIZATION
        estimated_losses = []
        loss_est = LossEstimator(Phi_all[:n], R_all[:n])
        for learner in learners:
            bonus =  np.sqrt(float(learner.d) / float(n))
            est_loss = loss_est.estimate(learner) +  bonus
            estimated_losses.append(est_loss)
       
        # EVALUATE SELECTED K
        k_hat_comp = np.argmin(estimated_losses)
        v_comp = gen.eval(learners[k_hat_comp])
        comp_vs.append(v_comp)



    v_opt = gen.eval_opt()
    vs_opt = np.ones(len(all_vs[0])) * v_opt










    # PLOTTING CODE
    for k in range(len(ds)):
        d = ds[k]
        plt.plot(ns, v_opt -  all_vs[k], label='Lnr d=' + str(d), linewidth=4.0)

    plt.plot(ns, v_opt - slope_vs, label='Slope', linestyle='--', linewidth=3.0)

    ns_losses, holdout_vs, test_vs = np.array(ns_losses), np.array(holdout_vs), np.array(test_vs)
    indices = ns_losses <= gen.n_train
    ns_losses = ns_losses[indices]
    holdout_vs = holdout_vs[indices]
    test_vs = test_vs[indices]

    plt.plot(ns_losses, v_opt - holdout_vs, label='Hold-out', linestyle='--', linewidth=2.0)
    plt.plot(ns_losses, v_opt - test_vs, label='Test', linestyle=':', linewidth=2.0)

    plt.legend()
    plt.ylabel('Regret')
    plt.xlabel('Dataset Size')
    plt.savefig('figures/tmp.png')
    plt.cla()
    plt.clf()
    plt.close()


    all_reg = []
    for k in range(len(ds)):
        all_reg.append(v_opt - all_vs[k])
    reg_slope = v_opt - slope_vs
    holdout_reg = v_opt - holdout_vs
    test_reg = v_opt - test_vs

    data_slope.append(reg_slope)
    data_holdout.append(holdout_reg)
    data_all.append(all_reg)
    data_test.append(test_reg)

    np.save('data/ns_holdout.npy', ns_losses)
    np.save('data/ns.npy', ns)
    np.save('data/ds.npy', ds)
    np.save('data/data_slope.npy', data_slope)
    np.save('data/data_holdout.npy', data_holdout)
    np.save('data/data_test.npy', data_test)
    np.save('data/data_all.npy', data_all)












