import pickle
from tqdm import tqdm
import xgboost as xgb
import numpy as np
import time
import warnings
from lamcts import MCTS
import os
from hpob_handler import HPOBHandler
from lamcts.utils import standardization, minmax, build_model
from functions.bbob import *
from functools import partial
import json
from functions.lunar_lander import lunar
from functions.push_function import push
from functions.rover_function import rover

warnings.filterwarnings('ignore')

this_dir = os.path.abspath(os.path.dirname(__file__))
data_path = os.path.join(this_dir, "lamcts/hpob-data/")
surrogate_path = os.path.join(this_dir, "lamcts/saved-surrogates/")

def get_bbob_data():
    path = "lamcts/bbob/meta_dataset.json"
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def get_real_data():
    path = "lamcts/real/meta_dataset.json"
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def get_Sphere2D_data():
    path = f"lamcts/Sphere2D/meta_dataset_2.json"
    with open(path, 'r') as json_file:
        data = json.load(json_file)
    return data

def read_data_from_json(path):
    with open(path, 'r') as f:
        json_data = json.load(f)
    X = np.array(json_data['X'])
    y = np.array(json_data['y']).reshape(-1,1)
    return X, y
        
def get_source_data(search_space_id, dataset_id, data, similar, dims, mode, lb, ub):
    data_X = None
    data_Y = None
    
    history = dict()
    source_id = []
    
    is_maximize = 1 if mode in ["bbob", "real"] else -1 
    
    def update_data(X, y, sid, did, data_X, data_Y, history, source_id):
        assert X.shape[0] == y.shape[0]
        source_id.extend([f"{sid}+{did}"]*X.shape[0])
        if sid not in history:
            history[sid] = dict()
        history[sid][did] = dict()
        
        if is_maximize == 1:
            idx = np.argmax(y)
        else:
            idx = np.argmin(y)
            
        history[sid][did]["X"] = X
        history[sid][did]["y"] = y
        history[sid][did]["model"] = build_model(X, y, lb, ub)
        history[sid][did]["X_optimal"] = X[idx, :]
        
        
        data_X = X if data_X is None else np.vstack((data_X, X))
        data_Y = y if data_Y is None else np.vstack((data_Y, y))
        return data_X, data_Y, history, source_id
    
    def generate_datasets(data, sid, did, data_X, data_Y, history, source_id):
        X = np.array(data[sid][did]["X"])
        y = np.array(data[sid][did]["y"]).reshape(-1,1)
        y = standardization(y)
        return update_data(X, y, sid, did, data_X, data_Y, history, source_id)
    
    def generate_mix_datasets(sid, similar, data_X, data_Y, history, source_id):
        assert similar in ["mix-similar", 'mix-unsimilar', 'mix-both']
        mix_data_dir = f"data/generated_data/{sid}"
        if similar in ["mix-similar", "mix-both"]:
            dir = os.path.join(mix_data_dir,"similar/")
            for root, dirs, files in os.walk(dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    X, y = read_data_from_json(file_path)
                    y = standardization(y)
                    data_X, data_Y, history, source_id = update_data(X, y, 'similar', os.path.splitext(file)[0], data_X, data_Y, history, source_id)
        if similar in ["mix-unsimilar", "mix-both"]:
            dir = os.path.join(mix_data_dir,"unsimilar/")
            for root, dirs, files in os.walk(dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    X, y = read_data_from_json(file_path)
                    y = standardization(y)
                    data_X, data_Y, history, source_id = update_data(X, y, 'unsimilar', os.path.splitext(file)[0], data_X, data_Y, history, source_id)
        return data_X, data_Y, history, source_id
    
    if similar in ["unsimilar", "combine"] and mode != "Sphere2D":
        if mode == "hpob":
            dim_task={
                2: ["5860", "5970"],
                3: ["4796"],
                6: ["5859", "5889"],
                8: ["5891"],
                9: ["7607", "7609"],
                16: ["5906", "5971"]
                }
            assert search_space_id in dim_task[dims]
            search_spaces = dim_task[dims][:]
        elif mode in ["bbob", "real"]:
            search_spaces = list(data.keys())
        
        if similar == "unsimilar":
            search_spaces.remove(search_space_id)
        
        for sid in search_spaces:
            history[sid] = dict()
            datasets = tqdm(data[sid].keys(), desc="Load Source Datasets...")
            for did in datasets:
                data_X, data_Y, history, source_id = generate_datasets(data, sid, did, data_X, data_Y, history, source_id)
        data_Y = data_Y * is_maximize
    else:
        datasets = tqdm(data[search_space_id].keys())
        datasets.set_description("Load Source Datasets...")

        history[search_space_id] = dict()
        for did in datasets:
            if did == dataset_id:
                continue
            if similar == "unsimilar" and mode == "Sphere2D":
                if did == "(5.0, 5.0)":
                    continue
            data_X, data_Y, history, source_id = generate_datasets(data, search_space_id, did, data_X, data_Y, history, source_id)
        
        sample_num = data_Y.shape[0]
        if "mix" in similar:
            data_X, data_Y, history, source_id = generate_mix_datasets(search_space_id, similar, data_X, data_Y, history, source_id)
        data_Y = data_Y * is_maximize
        
        if similar in ["mix-similar", "mix-unsimilar"]:
            assert data_Y.shape[0] == sample_num + 7*300
        elif similar == "mix-both":
            assert data_Y.shape[0] == sample_num + 14*300
            
        if mode == "Sphere2D":
            if similar == "similar":
                assert data_Y.shape[0] == 330
            else:
                assert data_Y.shape[0] == 220
        
    assert data_X.shape[0] == len(source_id)
    return data_X, data_Y, source_id, history

def generate_mcts(data_X, data_Y, agent):
    agent.learn(data_X, data_Y)

def get_bounds(args):
    if args.mode == 'Sphere2D':
        lb = np.full(args.dims, -10)
        ub = np.full(args.dims, 10)
    elif args.mode == 'bbob':
        lb = np.full(args.dims, -5)
        ub = np.full(args.dims, 5)
    elif args.mode == "hpob":
        lb = np.zeros(args.dims)
        ub = np.ones(args.dims)
    elif args.mode == "real":
        lb = np.zeros(args.dims)
        ub = np.ones(args.dims)
    return lb, ub

def mcts_prelearn(args, data, model_path):
    search_space_id = args.search_space_id
    dataset_id = args.dataset_id
    similar = args.similar
    dims = args.dims
    mode = args.mode
        
    name = model_path + "mcts"+str(search_space_id)+similar+".pkl"
    print(name)
    if os.path.exists(name):
        return
    
    lb, ub = get_bounds(args)
    data_X, data_Y, source_id, history = get_source_data(search_space_id, dataset_id, data, similar, dims, mode, lb, ub)
    dims = data_X.shape[1]
    data_X = minmax(data_X, min=lb, max=ub)
    agent = MCTS(
        lb = lb,              # the lower bound of each problem dimensions
        ub = ub,              # the upper bound of each problem dimensions
        dims = dims,          # the problem dimensions
        ninits = 0,      # the number of random samples used in initializations 
        func = None,               # function object to be optimized
        Cp = 1,              # Cp for MCTS
        leaf_size = min(max(10, np.round(data_X.shape[0]/1500)),30), # tree leaf size (max points in leaf nodes)
        # leaf_size = 20,
        kernel_type = args.kernel_type, # SVM configruation
        gamma_type = "auto",    # SVM configruation
        state = "learn-only",
        source_X = data_X,
        source_Y = data_Y,
        source_id = source_id,
        stage = 0,
        search_space_id = search_space_id,
        dataset_id = dataset_id,
        similar = similar,
        )
    
    agent.learn(data_X, data_Y)
    agent.history_sample = history
    
    with open(name, "wb") as f:
        pickle.dump(agent, f)
        
def load_model(path):
    with open(path, 'rb') as f:
        model = pickle.load(f)
    print(f"load model from {path}")
    return model

def load_surrogate(search_space_id, dataset_id):
    surrogate_name='surrogate-'+search_space_id+'-'+dataset_id
    bst_surrogate = xgb.Booster()
    bst_surrogate.load_model(surrogate_path+surrogate_name+'.json')
    return bst_surrogate

def get_func(search_space_id, dataset_id, mode, dims):
    if mode == "hpob":
        surrogate = load_surrogate(search_space_id, dataset_id)
        func = lambda x: surrogate.predict(xgb.DMatrix(x.reshape(-1, dims)))[0]
    elif mode == 'Sphere2D':
        from functions.test_weight import Sphere2D
        func = partial(Sphere2D, center=eval(dataset_id))
    elif mode == "real":
        # "LunarLander" "RobotPush" "Rover"
        func_dict = {
            "LunarLander": lunar,
            "RobotPush": push,
            "Rover": rover,
        }   
        func = func_dict[search_space_id]
    else:
        func_dict = {
            "Sphere": Sphere,
            "Rastrigin":Rastrigin,
            "BuecheRastrigin":BuecheRastrigin,
            "LinearSlope":LinearSlope,
            "AttractiveSector":AttractiveSector,
            "StepEllipsoidal":StepEllipsoidal,
            "RosenbrockRotated":RosenbrockRotated,
            "Ellipsoidal":Ellipsoidal,
            "Discus":Discus,
            "BentCigar":BentCigar,
            "SharpRidge":SharpRidge, 
            "DifferentPowers":DifferentPowers, 
            "Weierstrass":Weierstrass, 
            "SchaffersF7":SchaffersF7, 
            "SchaffersF7IllConditioned":SchaffersF7IllConditioned, 
            "GriewankRosenbrock":GriewankRosenbrock, 
            "Schwefel":Schwefel, 
            "Katsuura":Katsuura, 
            "Lunacek":Lunacek, 
            "Gallagher101Me":Gallagher101Me, 
            "Gallagher21Me":Gallagher21Me, 
            "NegativeSphere":NegativeSphere, 
            "NegativeMinDifference":NegativeMinDifference, 
            "FonsecaFleming":FonsecaFleming
            }
        func = partial(func_dict[search_space_id], seed=int(dataset_id))
        
    return func

def run_mcts_transfer(args):
    search_space_id = args.search_space_id
    dataset_id = args.dataset_id
    model_path = f"{this_dir}/lamcts/model_{args.mode}_new_all/{args.kernel_type}/"
    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)
    
    if args.mode == "hpob":
        hpob_hdlr = HPOBHandler(root_dir=data_path, mode="v3", surrogates_dir=surrogate_path)
        assert dataset_id in hpob_hdlr.meta_test_data[search_space_id].keys()
        data = hpob_hdlr.meta_train_data
        optimal = np.min(hpob_hdlr.meta_test_data[search_space_id][dataset_id]["y"]) * -1
    elif args.mode == "Sphere2D":
        data = get_Sphere2D_data()
        optimal = 0
    elif args.mode == "real":
        assert search_space_id in ["LunarLander", "RobotPush", "Rover"]
        assert args.similar in ["similar", "mix-similar", 'mix-unsimilar', 'mix-both']
        data = get_real_data()
        optimal = 0
    else:
        assert search_space_id in ["GriewankRosenbrock", "Lunacek", "Rastrigin", "RosenbrockRotated", "SharpRidge"]
        data = get_bbob_data()
        optimal = 0
        
    assert search_space_id in data.keys()
    
    print("search space id:", search_space_id)

    mcts_prelearn(args, data, model_path) 
    print("="*10, f"{search_space_id} have been pretrained", "="*10)
    
    for i in range(args.rep):
        print("="*10, f"{search_space_id} train from existing tree", "="*10)
        model = load_model(model_path + "mcts"+str(search_space_id)+args.similar+".pkl")
        model.Cp = args.Cp
        model.optimal =optimal
        model.weight_update = args.weight_update
        model.weight_decay = args.weight_decay
        model.func = get_func(search_space_id, dataset_id, args.mode, args.dims)
        model.func_maximize = -1
        model.dataset_id = dataset_id
        model.decay_factor = 1.0
        model.mode = args.mode
        model.method = args.methods
        model.model_path = f"checkpoints/{args.similar}-{args.search_space_id}.pth"
        model.kernel_type = args.kernel_type
        model.search_from_tree(iterations = args.iteration, threshold = args.threshold, local = args.local, similarity = args.similarity, N = args.N, gamma = args.gamma, alpha = args.alpha)

    
def run_lamcts(args, generate = False):
    search_space_id = args.search_space_id
    dataset_id = args.dataset_id
    optimal = 0
    if args.mode == "hpob":
        hpob_hdlr = HPOBHandler(root_dir=data_path, mode="v3", surrogates_dir=surrogate_path)
        optimal = np.min(hpob_hdlr.meta_test_data[search_space_id][dataset_id]["y"]) * -1
    lb, ub = get_bounds(args)
    for i in range(args.rep):
        agent = MCTS(
            lb = lb,              # the lower bound of each problem dimensions
            ub = ub,              # the upper bound of each problem dimensions
            dims = args.dims,          # the problem dimensions
            ninits = 3,      # the number of random samples used in initializations 
            func = get_func(search_space_id, dataset_id, args.mode, args.dims),
            Cp = args.Cp,              # Cp for MCTS
            leaf_size = 3, # tree leaf size
            kernel_type = args.kernel_type, #SVM configruation
            gamma_type = "auto",    #SVM configruation
            search_space_id = search_space_id,
            dataset_id = dataset_id,
            optimal = optimal,
            log = not generate,
            func_maximize = -1,
            mode = args.mode,
            )
        agent.search(iterations = args.iteration, threshold = args.threshold, local = args.local)
        if generate:
            return agent.samples

def random_search(args, lb, ub, f=None):
    print(lb, ub)
    samples = []
    for _ in range(args.iteration):
        x = np.random.uniform(lb, ub)
        function = f if f else get_func(args.search_space_id, args.dataset_id, args.mode, args.dims)
        y = function(x)*-1
        samples.append((x,y))
    return samples
    