import numpy as np
import random
import math
from algorithms.AutoTuning import *

class LinTS:
    def __init__(self, class_context, T):
        self.data = class_context
        self.T = T
        self.d = self.data.d
        
    def lints_theoretical_explore(self, lamda=1, delta=0.1, explore = -1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        explore_flag = explore # indicator of whether we should use fixed explore
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            
            # when explore = -1, which is impossible, use theoretical value
            # otherwise, it means I have specify a fixed value of explore in the code
            # specifying a fixed value for explore is only for grid serach
            if explore_flag == -1:
                explore = self.data.sigma*math.sqrt( d*math.log((t*self.data.max_norm**2/lamda+1)/delta) ) + math.sqrt(lamda)
            else:
                explore = explore_flag
            theta_ts = np.random.multivariate_normal(theta_hat, explore**2*B_inv)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_ts)
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            
            # update linucb components
            B += np.outer(feature[pull], feature[pull])
            tmp = B_inv.dot(feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def lints_tl(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        # initialization for exp3 algo
        # "explore_rates" is the tuning set
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial explore parameters
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            
            theta_ts = np.random.multivariate_normal(theta_hat, explore**2*B_inv)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_ts)
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            
            # update linucb
            tmp = B_inv.dot(feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update exploration parameters by auto_tuning (our proposed two-layer bandit algo)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
        return regret
    
    def lints_op(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        B = np.identity(d) * lamda
        B_inv = np.identity(d) / lamda
        theta_hat = np.zeros(d)
        
        # initialization for selecting hyper-paras
        Kexp = len(explore_rates)
        s = np.ones(Kexp)
        f = np.ones(Kexp)
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            theta_ts = np.random.multivariate_normal(theta_hat, explore**2*B_inv)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_ts)
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)
            
            # update linucb
            tmp = B_inv.dot(feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update exploration parameters by op
            s, f, index = op_tuning(s, f, observe_r, index)
            explore = explore_rates[index]
        return regret
    
    def lints_syndicated(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        theta_hat = np.zeros(d)
        
        # initialization for exp3 layers in syndicated bandits
        # exp3 for exploration parameter
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial explore rate
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        # exp3 for lambda
        Klam = len(lamdas)
        loglamw = np.zeros(Klam)
        plam = np.ones(Klam) / Klam
        gamma_lam = min(1, math.sqrt( Klam*math.log(Klam) / ( (np.exp(1)-1) * T ) ) )
        # random initial lambda
        index_lam = np.random.choice(Klam)
        lamda = lamdas[index_lam]
        
        xxt = np.zeros((d,d))
        B_inv = np.identity(d) / lamda
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            theta_ts = np.random.multivariate_normal(theta_hat, explore**2*B_inv)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_ts)
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)

            # update hyper-paras by auto_tuning (syndicated)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
            loglamw, plam, index_lam = auto_tuning(loglamw, plam, observe_r, index_lam, gamma_lam)
            lamda = lamdas[index_lam]
            
            # update linucb
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def lints_tl_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        xr = np.zeros(d)
        theta_hat = np.zeros(d)
        
        # initialization for exp3 layer in TL   
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2) # combination set
        Kexp = len(explore_lamda)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial hyper-paras
        index = np.random.choice(Kexp)
        explore, lamda = explore_lamda[index]
        
        xxt = np.zeros((d,d))
        B_inv = np.identity(d) / lamda
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            theta_ts = np.random.multivariate_normal(theta_hat, explore**2*B_inv)
            ucb_idx = [0]*K
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta_ts)
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t,pull)

            # update explore rates by auto_tuning (tl)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore, lamda = explore_lamda[index]
            
            # update linucb
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            xr += feature[pull] * observe_r
            theta_hat = B_inv.dot(xr)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def lints_corral(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        
        M = len(explore_rates)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        
        xr = [np.zeros(d) for _ in range(M)]
        B = [np.identity(d) * lamda for _ in range(M)]
        B_inv = [np.identity(d) / lamda for _ in range(M)]
        theta_hat = [np.zeros(d) for _ in range(M)]
        
        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore = explore_rates[base]
                theta_ts = np.random.multivariate_normal(theta_hat[base], explore**2*B_inv[base])
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta_ts)
                pull += [np.argmax(ucb_idx)]
                
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
            
            # update linucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                tmp = B_inv[base].dot(feature[pull[base]])
                B_inv[base] -= np.outer(tmp, tmp)/ (1+feature[pull[base]].dot(tmp))
                xr[base] += feature[pull[base]] * rew
                theta_hat[base] = B_inv[base].dot(xr[base])
            
            passl = np.zeros(M)
            passl[chosen_base] = -observe_r
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta     
        return regret    
    
    def lints_corral_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2) # combination set
        M = len(explore_lamda)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        
        xr = [np.zeros(d) for _ in range(M)]
        B = []
        B_inv = []
        for base in range(M):
            _, lamda = explore_lamda[base]
            B += [np.identity(d) * lamda]
            B_inv += [np.identity(d)/lamda]
        theta_hat = [np.zeros(d) for _ in range(M)]

        for t in range(T):
            feature = self.data.fv[t]
            K = len(feature)
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore, _ = explore_lamda[base]
                theta_ts = np.random.multivariate_normal(theta_hat[base], explore**2*B_inv[base])
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta_ts)
                pull += [np.argmax(ucb_idx)]
            
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
            
            # update linucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                tmp = B_inv[base].dot(feature[pull[base]])
                B_inv[base] -= np.outer(tmp, tmp)/ (1+feature[pull[base]].dot(tmp))
                xr[base] += feature[pull[base]] * rew
                theta_hat[base] = B_inv[base].dot(xr[base])
            
            passl = np.zeros(M)
            passl[chosen_base] = -observe_r
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta     
        return regret