import GPy
from bayesian_optimization_bts_red_mean_var import BTS_RED_mean_var
import pickle
import numpy as np

from multiprocessing.dummy import Pool as ThreadPool

np.random.seed(0)

max_iter = 70

ls, ls_noise_var = 0.04, 0.15
noise_var_min, noise_var_max = 1e-4, 0.2

log_file_name = "obj_funcs/synth_func.pkl"
all_func_info = pickle.load(open(log_file_name, "rb"))
domain = all_func_info["domain"]
f = all_func_info["f"]
f_noise_var = all_func_info["f_noise_var"]

def synth_func(param, n_t):
    x = param[0]
    ind = np.argmin(np.abs(domain - x))
    samples = np.random.normal(f[ind], np.sqrt(f_noise_var[ind]), n_t)

    empirical_mean = np.mean(samples)
    empirical_var = np.sum((samples - samples.mean())**2) / (n_t - 1)
    return empirical_mean, - empirical_var, f[ind], f_noise_var[ind]

batch_size = 50

R2 = 0.02

ratio = (np.sqrt(batch_size) + 1) / (batch_size - 1) * 0.3
# ratio = (np.sqrt(batch_size) + 1) / (batch_size - 1) * 0.2

beta_t = np.ones(5000)
beta_t_var = np.ones(5000)

n_min, n_max = 2, 50

#### we use a fixed n_t during initialization for every queried initial input
fix_nt_init = 10
init_size = 10 # number of initial input


#### whether we estimate the max noise variance in order to use our theory-inspired choice of R2, if this is True, then the value of R2 set above will have no effect
estimate_sigma_max = True

gp_opt_schedule = 5

M_TS = 50

run_list = np.arange(50)

for itr in run_list:
    log_file_name = "results_bts_red_unknown/res_ls_" + str(ls) + "_ls_noise_var_" + str(ls_noise_var) + \
        "_noise_range_" + str(noise_var_min) + "_" + str(noise_var_max) + "_iter_" + str(itr) + \
        "_batch_size_" + str(batch_size) + "_R2_" + str(R2) + \
        "_R2var_" + str(0) + "_n_min_" + str(n_min) + "_n_max_" + str(n_max) + "_init_" + str(init_size) + \
        "_fix_nt_init_" + str(fix_nt_init) + ".pkl"

    if estimate_sigma_max:
        log_file_name = log_file_name[:-4] + "_ratio_" + str(ratio) + ".pkl"
    
    bo_ts = BTS_RED_mean_var(f=synth_func, pbounds={'x1':(0, 1)}, gp_opt_schedule=gp_opt_schedule, \
               log_file=log_file_name, M_TS=M_TS, \
               n_min=n_min, n_max=n_max, noise_var_func=f_noise_var, domain=domain, \
               batch_size=batch_size, R2=R2, beta_t=beta_t, \
               use_init="inits/init_itr_" + str(itr) + "_init_" + str(init_size) + ".p", save_init=False, save_init_file=None, \
               T=max_iter, beta_t_var=beta_t_var, mean_var_obj=False, \
               estimate_sigma_max=estimate_sigma_max, ratio=ratio, fix_nt_init=fix_nt_init)
    bo_ts.maximize(n_iter=max_iter, init_points=init_size)
