import os

NUM_THREADS = "1"
os.environ["OMP_NUM_THREADS"] = NUM_THREADS  # export OMP_NUM_THREADS=1
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS  # export OPENBLAS_NUM_THREADS=1
os.environ["MKL_NUM_THREADS"] = NUM_THREADS  # export MKL_NUM_THREADS=1
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS  # export VECLIB_MAXIMUM_THREADS=1
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS  # export NUMEXPR_NUM_THREADS=1
os.environ["WANDB_API_KEY"] = "cc4872dc17a000e23a9bf51bd333642f0fa2b2c5"
os.environ["WANDB_MODE"] = "offline"

import re
import sys
import time
import pickle
import argparse
import numpy as np
from functools import partial
from tqdm import tqdm, trange
import wandb
import datetime
import xgboost as xgb

sys.path.insert(0, '.')
from tlbo.facade.notl import NoTL
from tlbo.facade.rgpe import RGPE
from tlbo.facade.tst import TST
from tlbo.facade.rgpe_space import RGPESPACE, RGPESPACE_BO, RGPESPACE_RS, RGPESPACE_TST
from tlbo.facade.tst_space import TSTSPACE
from tlbo.facade.random_surrogate import RandomSearch
from tlbo.framework.smbo_offline import SMBO_OFFLINE
from tlbo.framework.smbo_baseline import SMBO_SEARCH_SPACE_Enlarge
from tlbo.framework.smbo_continous import SMBO_SEARCH_SPACE_Enlarge_Continuous
from tlbo.config_space.space_instance import get_configspace_instance
from tools.utils import seeds, convert_method_name
from functions.bbob import *
from functools import partial
import json
from utils import convert_data
from functions.lunar_lander import lunar
from functions.push_function import push
from functions.rover_function import rover
import functions

parser = argparse.ArgumentParser()
parser.add_argument('--search-space-id', type=str, default='Sphere')
parser.add_argument('--dataset-id', type=str, default='72')   
parser.add_argument('--dims', type=int, default=None)
parser.add_argument('--methods', type=str, default='rs')
parser.add_argument('--surrogate_type', type=str, default='gp')
parser.add_argument('--trial_num', type=int, default=100)
parser.add_argument('--init_num', type=int, default=0)
parser.add_argument('--num_source_trial', type=int, default=100)
parser.add_argument('--num_source_problem', type=int, default=-1)
parser.add_argument('--task_set', type=str, default='full')
parser.add_argument('--target_set', type=str, default='full')
parser.add_argument('--rep', type=int, default=3)
parser.add_argument('--start_id', type=int, default=0)
parser.add_argument('--discrete', action='store_true')
parser.add_argument('--similar', type=str, choices=["similar", "unsimilar", "combine", "mix-similar", 'mix-unsimilar', 'mix-both'], required=True)
parser.add_argument('--mode', type=str, default='hpob')

default_pmin, default_pmax = 0.0, 1.0
parser.add_argument('--pmin', type=int, default=default_pmin)
parser.add_argument('--pmax', type=int, default=default_pmax)
args = parser.parse_args()

SERACH_SPACE_ID = args.search_space_id
DATASET_ID = args.dataset_id
task_set = args.task_set
targets = args.target_set
surrogate_type = args.surrogate_type
n_src_trial = args.num_source_trial
num_source_problem = args.num_source_problem
trial_num = args.trial_num
init_num = args.init_num
test_mode = 'random'
baselines = args.methods.split(',')
rep = args.rep
start_id = args.start_id
discrete = args.discrete
similar = args.similar
mode = args.mode

wandb_flag = True

pmin = args.pmin
pmax = args.pmax

is_minimize = 1 if mode in ["Sphere2D", "hpob"] else -1

dim_task={
    2: ["5860", "5970"],
    3: ["4796"],
    6: ["5859", "5889"],
    8: ["5891"],
    9: ["7607", "7609"],
    16: ["5906", "5971"]
    }

if mode == 'hpob':
    data_dir = "data/hpob-data/"
    from data.hpob_handler import HPOBHandler
    hpob_hdlr = HPOBHandler(root_dir=data_dir, mode="v3", surrogates_dir='data/saved-surrogates/')
    dims = np.array(hpob_hdlr.meta_test_data[SERACH_SPACE_ID][DATASET_ID]["X"]).shape[1]
elif mode == 'bbob':
    data_dir = "data/bbob/"
    dims = args.dims
elif mode == "real":
    data_dir = "data/real/"
    dims_functions = {
        "LunarLander": functions.lunar_lander.get_dim(),
        "RobotPush": functions.push_function.get_dim(),
        "Rover": functions.rover_function.get_dim()
    }
    args.dims = dims_functions[args.search_space_id]
    dims = args.dims
elif mode == "Sphere2D":
    data_dir = "data/Sphere2D/"
    args.dims = 2
    dims = args.dims
else:
    raise ValueError

assert test_mode in ['random']
if init_num > 0:
    enable_init_design = True
else:
    enable_init_design = False
    # Default number of random configurations.
    init_num = 3

def should_process_file(filename, dims, search_space_id, similar, search_spaces):
    if not filename.endswith(f'{dims}-train.pkl'):
        return False
    if similar=="similar":
        return filename.startswith(search_space_id)
    if mode == "Sphere2D" and similar == "unsimilar":
        if filename != "Sphere2D-(5.0, 5.0)-2-train.pkl":
            return True
        else:
            return False
    return any(filename.startswith(sid) for sid in search_spaces)

def get_search_spaces_unsimilar():
    if mode=="hpob":
        search_spaces = dim_task[dims][:]
    elif mode=="bbob":
        data = get_bbob_data()
        search_spaces = list(data.keys())
    elif mode=="real":
        data = get_real_data()
        search_spaces = list(data.keys())
    elif mode=="Sphere2D":
        data = get_Sphere2D_data()
        search_spaces = list(data.keys())
    if similar != "combine":
        search_spaces.remove(SERACH_SPACE_ID)
    return search_spaces

def load_hpo_history():
    source_hpo_ids, source_hpo_data = list(), list()
    search_spaces_to_check = []
    if similar in ["unsimilar", "combine"]:
        search_spaces_to_check = get_search_spaces_unsimilar()
    elif similar == "mix-similar":
        search_spaces_to_check = ['similar', SERACH_SPACE_ID]
    elif similar == 'mix-unsimilar':
        search_spaces_to_check = ['unsimilar', SERACH_SPACE_ID]
    elif similar == 'mix-both':
        search_spaces_to_check = ['similar', 'unsimilar', SERACH_SPACE_ID]
    for _file in tqdm(sorted(os.listdir(data_dir))):
        if not should_process_file(_file, dims, SERACH_SPACE_ID, similar, search_spaces_to_check):
            continue
        basename = os.path.splitext(_file)[0]
        result = basename.split('-')
        if result is None:
            continue
        search_space_id, dataset_id = result[0], result[1]
        with open(os.path.join(data_dir, _file), 'rb') as f:
            data = pickle.load(f)
            perfs = np.array(list(data.values()))
        p_max, p_min = np.max(perfs), np.min(perfs)
        if p_max == p_min:
            continue
        if (perfs == perfs[0]).all():
            continue
        source_hpo_ids.append(dataset_id)
        if perfs.ndim == 2:
            assert perfs.shape[1] == 2
            _data = {k: v[0] for k, v in data.items()}
        else:
            _data = data
        source_hpo_data.append(_data)

    random_hpo_data = None

    assert len(source_hpo_ids) == len(source_hpo_data)
    print('Load %s source hpo problems for search space %s.' % (len(source_hpo_ids), SERACH_SPACE_ID))
    return source_hpo_ids, source_hpo_data, random_hpo_data


def get_bbob_data():
    path = os.path.join(data_dir, f"meta_dataset.json")
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def get_real_data():
    path = os.path.join(data_dir, f"meta_dataset.json")
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def get_Sphere2D_data():
    path = os.path.join(data_dir, f"meta_dataset.json")
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def dump_data(data, sid, did, config_space, state='train', X=None, Y=None):
    filepath = os.path.join(data_dir, f"{sid}-{did}-{dims}-{state}.pkl")
    if os.path.exists(filepath):
        return
    if X is None and Y is None:
        X = data[sid][did]["X"]
        Y = data[sid][did]["y"]
    data_converted = convert_data(X, Y, config_space, mode=mode, is_minimize = is_minimize)
    with open(filepath, 'wb') as f:
        pickle.dump(data_converted, f)

def read_data_from_json(path):
    with open(path, 'r') as f:
        json_data = json.load(f)
    X = np.array(json_data['X'])
    y = np.array(json_data['y'])
    return X, y

def extract_data(mode):
    # 先将数据转化成对应的格式
    config_space = get_configspace_instance(algo_id=SERACH_SPACE_ID, dims = dims, mode=mode)
    optimal = None
    if mode == 'hpob':
        data = hpob_hdlr.meta_train_data
        if similar=="similar":
            for dataset_id in data[SERACH_SPACE_ID].keys():
                dump_data(data, SERACH_SPACE_ID, dataset_id, config_space, state='train')
        else:
            assert SERACH_SPACE_ID in dim_task[dims]
            search_spaces = get_search_spaces_unsimilar()
            for sid in search_spaces:
                for did in data[sid].keys():
                    dump_data(data, sid, did, config_space, state='train')
        assert DATASET_ID in hpob_hdlr.meta_test_data[SERACH_SPACE_ID].keys()
        assert dims == np.array(hpob_hdlr.meta_test_data[SERACH_SPACE_ID][DATASET_ID]["X"]).shape[1]       
        optimal = np.min(hpob_hdlr.meta_test_data[SERACH_SPACE_ID][DATASET_ID]["y"])    
        for dataset_id in hpob_hdlr.meta_test_data[SERACH_SPACE_ID].keys():
            dump_data(hpob_hdlr.meta_test_data, SERACH_SPACE_ID, dataset_id, config_space, state='test')
    elif mode == 'real':
        data = get_real_data()
        for did in data[SERACH_SPACE_ID].keys():
            dump_data(data, SERACH_SPACE_ID, did, config_space, state='train')
        
        if 'mix' in args.similar:
            mix_data_dir = f"data/generated_data/{SERACH_SPACE_ID}"
            if similar in ["mix-similar", "mix-both"]:
                dir = os.path.join(mix_data_dir,"similar/")
                for root, dirs, files in os.walk(dir):
                    for file in files:
                        file_path = os.path.join(root, file)
                        X, Y = read_data_from_json(file_path)
                        dump_data(data, 'similar', os.path.splitext(file)[0], config_space, state='train', X=X, Y=Y)
            if similar in ["mix-unsimilar", "mix-both"]:
                dir = os.path.join(mix_data_dir,"unsimilar/")
                for root, dirs, files in os.walk(dir):
                    for file in files:
                        file_path = os.path.join(root, file)
                        X, Y = read_data_from_json(file_path)
                        dump_data(data, 'unsimilar', os.path.splitext(file)[0], config_space, state='train', X=X, Y=Y)
        assert dims == np.array(data[SERACH_SPACE_ID][list(data[SERACH_SPACE_ID].keys())[0]]["X"]).shape[1]
        optimal = 0
    elif mode == 'bbob':
        data = get_bbob_data()
        if similar=="similar":
            for dataset_id in data[SERACH_SPACE_ID].keys():
                dump_data(data, SERACH_SPACE_ID, dataset_id, config_space, state='train')
        else:
            search_spaces = get_search_spaces_unsimilar()
            for sid in search_spaces:
                for did in data[sid].keys():
                    dump_data(data, sid, did, config_space, state='train')
        assert dims == np.array(data[SERACH_SPACE_ID][list(data[SERACH_SPACE_ID].keys())[0]]["X"]).shape[1]
        optimal = 0
    elif mode == "Sphere2D":
        data = get_Sphere2D_data()
        for did in data[SERACH_SPACE_ID].keys():
            dump_data(data, SERACH_SPACE_ID, did, config_space, state='train')
        assert dims == np.array(data[SERACH_SPACE_ID][list(data[SERACH_SPACE_ID].keys())[0]]["X"]).shape[1]
        optimal = 0
    else:
        raise ValueError
    hpo_ids, hpo_data, random_test_data = load_hpo_history()
    if mode == "bbob":
        num_datasets = len(hpo_ids)
        if similar == "similar":
            assert num_datasets == 20
        elif similar == "combine":
            assert num_datasets == 100
        elif similar == "unsimilar":
            assert num_datasets == 80
        else:
            assert 0
    if mode == "real":
        num_datasets = len(hpo_ids)
        if similar == "similar":
            assert num_datasets == 20
        elif similar in ['mix-similar', 'mix-unsimilar']:
            assert num_datasets == 20+7
        elif similar == 'mix-both':
            assert num_datasets == 20+7+7
    if mode == "Sphere2D":
        if similar == "similar":
            assert len(hpo_ids) == 3
        elif similar == "unsimilar":
            assert len(hpo_ids) == 2
    return hpo_ids, hpo_data, random_test_data, optimal

def load_surrogate(search_space_id, dataset_id):
    surrogate_name='surrogate-'+search_space_id+'-'+dataset_id
    bst_surrogate = xgb.Booster()
    bst_surrogate.load_model('data/saved-surrogates/'+surrogate_name+'.json')
    return bst_surrogate

def get_func(search_space_id, dataset_id, mode, dims):
    if mode == "hpob":
        surrogate = load_surrogate(search_space_id, dataset_id)
        func = lambda x: surrogate.predict(xgb.DMatrix(x.reshape(-1, dims)))[0]
    elif mode == "real":
        func_dict = {
            "LunarLander": lunar,
            "RobotPush": push,
            "Rover": rover,
        }   
        func = func_dict[search_space_id]
    elif mode == "Sphere2D":
        def Sphere2D(arr: np.ndarray, center: tuple):
            # 解包中心点坐标
            a, b = center
            return float(np.sum((arr - np.array([a, b]))**2))
        return partial(Sphere2D, center=eval(dataset_id))
    else:
        func_dict = {
            "Sphere": Sphere,
            "Rastrigin":Rastrigin,
            "BuecheRastrigin":BuecheRastrigin,
            "LinearSlope":LinearSlope,
            "AttractiveSector":AttractiveSector,
            "StepEllipsoidal":StepEllipsoidal,
            "RosenbrockRotated":RosenbrockRotated,
            "Ellipsoidal":Ellipsoidal,
            "Discus":Discus,
            "BentCigar":BentCigar,
            "SharpRidge":SharpRidge, 
            "DifferentPowers":DifferentPowers, 
            "Weierstrass":Weierstrass, 
            "SchaffersF7":SchaffersF7, 
            "SchaffersF7IllConditioned":SchaffersF7IllConditioned, 
            "GriewankRosenbrock":GriewankRosenbrock, 
            "Schwefel":Schwefel, 
            "Katsuura":Katsuura, 
            "Lunacek":Lunacek, 
            "Gallagher101Me":Gallagher101Me, 
            "Gallagher21Me":Gallagher21Me, 
            "NegativeSphere":NegativeSphere, 
            "NegativeMinDifference":NegativeMinDifference, 
            "FonsecaFleming":FonsecaFleming
            }
        func = partial(func_dict[search_space_id], seed=int(dataset_id))
    return func

if __name__ == "__main__":
    hpo_ids, hpo_data, random_test_data, optimal= extract_data(mode)
    algo_name = SERACH_SPACE_ID
    config_space = get_configspace_instance(algo_id=algo_name, dims = dims, mode=mode)
    num_source_problem = len(hpo_ids)

    run_id = [DATASET_ID]
    
    # Exp folder to save results.
    exp_dir = 'data/bbob-exp_results/%s' % (SERACH_SPACE_ID)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

    pbar = tqdm(total=rep * len(baselines) * len(run_id) * trial_num)
    for rep_id in range(start_id, start_id + rep):
        for id in run_id:
            for mth in baselines:
                name = "supervised-gp" if "ours" in mth else mth
                if wandb_flag:
                    ts = datetime.datetime.utcnow() + datetime.timedelta(hours=+8)
                    ts_name = f'-ts{ts.month}-{ts.day}-{ts.hour}-{ts.minute}-{ts.second}'
                    wandb.init(
                        project="real100",
                        name=f"{name}-{SERACH_SPACE_ID}-{DATASET_ID}-{ts_name}",
                        job_type=f"{name}-final-v3",
                        tags=[f"dim={dims}", f"similar={similar}", f"search_space_id={SERACH_SPACE_ID}", f"dataset_id={DATASET_ID}"]
                    )
                seed = int(2 ** 31 - 1)
                print('=== start rep', rep_id, 'seed', seed)

                print('=' * 20)
                print('[%s-%s] Evaluate the seed=%s problem - [%d].' % (SERACH_SPACE_ID, mth, id, rep_id))
                pbar.set_description('[%s-%s] seed=%s - [%d]' % (SERACH_SPACE_ID, mth, id, rep_id))
                start_time = time.time()

                # Generate the source and target hpo data.
                source_hpo_data = list()
                if test_mode == 'bo':
                    raise NotImplementedError
                else:
                    target_hpo_data = random_test_data
                for _id, data in enumerate(hpo_data):
                    if _id != id:
                        source_hpo_data.append(data)
                
                func = get_func(SERACH_SPACE_ID, DATASET_ID, mode, dims)

                # Select a subset of source problems to transfer.
                rng = np.random.RandomState(seed)
                shuffled_ids = np.arange(len(source_hpo_data))
                rng.shuffle(shuffled_ids)
                source_hpo_data = [source_hpo_data[id] for id in shuffled_ids[:num_source_problem]]

                mth = convert_method_name(mth, SERACH_SPACE_ID)

                if mth == 'rgpe':
                    surrogate_class = RGPE
                elif mth == 'notl':
                    surrogate_class = NoTL
                elif mth == 'tst':
                    surrogate_class = TST
                elif mth == 'rs':
                    surrogate_class = RandomSearch
                elif mth.startswith('rgpe-space'):
                    if 'gp' in mth or 'smac' in mth:
                        surrogate_class = RGPESPACE_BO
                    elif 'rs' in mth:
                        surrogate_class = RGPESPACE_RS
                    elif 'tst' in mth:
                        surrogate_class = RGPESPACE_TST
                    else:
                        surrogate_class = RGPESPACE     # rgpe
                elif mth.startswith('tst-space'):
                    surrogate_class = TSTSPACE
                elif mth in ['box-gp', 'ellipsoid-gp']:
                    surrogate_class = NoTL  # BO
                elif mth in ['box-rs', 'ellipsoid-rs']:
                    surrogate_class = RandomSearch
                else:
                    raise ValueError('Invalid baseline name - %s.' % mth)

                if 'smac' in mth:
                    surrogate_type = 'rf'
                else:
                    surrogate_type = 'gp'

                print('surrogate_class:', surrogate_class.__name__)
                surrogate = surrogate_class(config_space, source_hpo_data, target_hpo_data, 
                                            seed=seed,
                                            surrogate_type=surrogate_type,
                                            num_src_hpo_trial=n_src_trial)

                if 'rf' in mth:
                    model = 'rf'
                elif 'knn' in mth:
                    model = 'knn'
                elif 'svm' in mth:
                    model = 'svm'
                elif 'lr' in mth:
                    model = 'lr'
                else:
                    model = 'gp'

                if 'final' in mth:
                    smbo_framework = partial(SMBO_SEARCH_SPACE_Enlarge if discrete else SMBO_SEARCH_SPACE_Enlarge_Continuous, mode='all+-sample+-threshold', model=model)
                elif 'sample' in mth:
                    smbo_framework = partial(SMBO_SEARCH_SPACE_Enlarge if discrete else SMBO_SEARCH_SPACE_Enlarge_Continuous, mode='sample', model=model)
                elif 'best' in mth:
                    smbo_framework = partial(SMBO_SEARCH_SPACE_Enlarge if discrete else SMBO_SEARCH_SPACE_Enlarge_Continuous, mode='best', model=model)
                elif 'box' in mth:
                        smbo_framework = partial(SMBO_SEARCH_SPACE_Enlarge if discrete else SMBO_SEARCH_SPACE_Enlarge_Continuous, mode='box', model=model)
                elif 'ellipsoid' in mth:
                        smbo_framework = partial(SMBO_SEARCH_SPACE_Enlarge if discrete else SMBO_SEARCH_SPACE_Enlarge_Continuous, mode='ellipsoid', model=model)
                else:
                    smbo_framework = SMBO_OFFLINE

                smbo = smbo_framework(target_hpo_data, config_space, surrogate, target_func=func, dims=dims,
                                      random_seed=seed, max_runs=trial_num,
                                      source_hpo_data=source_hpo_data,
                                      num_src_hpo_trial=n_src_trial,
                                      surrogate_type=surrogate_type,
                                      enable_init_design=enable_init_design,
                                      initial_runs=init_num,
                                      acq_func='ei')

                if hasattr(smbo, 'p_min'):
                    smbo.p_min = pmin
                    smbo.p_max = pmax
                smbo.use_correct_rate = True

                result = list()
                for _iter_id in range(trial_num):
                    config, _, perf, _ = smbo.iterate()
                    time_taken = time.time() - start_time
                    y_inc = smbo.get_inc_y()
                    result.append([y_inc, time_taken])
                    print(f"best value={y_inc}")
                    if wandb_flag:
                        factor = -1 if mode == "real" else 1
                        wandb.log({
                            "sample counter": _iter_id+1,
                            "sample value": perf * factor,
                            "best value": y_inc * factor,
                            "regret": y_inc-optimal
                        })
                    pbar.update(1)
                    print(perf)
                # print('In problem: %s' % (id), 'nce, y_inc', result[-1])
                # print('min/max', smbo.y_min, smbo.y_max)
                # print('mean,std', np.mean(smbo.ys), np.std(smbo.ys))

                mth_file = '%s_%s_%s_%d_%d_%d.pkl' % (
                    mth, id, SERACH_SPACE_ID, n_src_trial, trial_num, seed)
                with open(os.path.join(exp_dir, mth_file), 'wb') as f:
                    data = np.array(result)
                    pickle.dump(data, f)
                if wandb_flag:
                    wandb.finish()
    pbar.close()
