from functions.bbob import *
import random
import json
import os
from space_gen import run_lamcts, random_search
from tqdm import tqdm

def data_gen(args):
    args.mode = 'bbob'
    args.rep = 1

    meta_dataset = {}
    dims = args.dims
    directory = 'lamcts/bbob'
    file_path = os.path.join(directory, f"meta_dataset_{dims}.json")
    os.makedirs(directory, exist_ok=True)
    
    functions = [
        Sphere, Rastrigin, BuecheRastrigin, LinearSlope, AttractiveSector, 
        StepEllipsoidal, RosenbrockRotated, Ellipsoidal, Discus, BentCigar, 
        SharpRidge, DifferentPowers, Weierstrass, SchaffersF7, SchaffersF7IllConditioned, 
        GriewankRosenbrock, Schwefel, Katsuura, Lunacek, Gallagher101Me, 
        Gallagher21Me, NegativeSphere, NegativeMinDifference, FonsecaFleming]
    # algorithms = [simpleBO, random_search, run_lamcts]
    algorithms = [simpleBO]
    
    functions = tqdm(functions, desc="Load functions...")
    for function in functions:
        search_space_id = function.__name__
        args.search_space_id = search_space_id
        meta_dataset[search_space_id] = {}
        
        for alg in algorithms:  
            dataset_id = alg.__name__
            args.dataset_id = dataset_id
            samples = alg(args, generate=True)
            
            X_samples = [x.tolist() for x, y in samples]
            y_samples = [y*-1 for x, y in samples]
            
            meta_dataset[search_space_id][dataset_id] = {
                "X": X_samples,
                "y": y_samples
            }
    meta_dataset_json = json.dumps(meta_dataset)
    
    with open(file_path, 'w') as json_file:
        json_file.write(meta_dataset_json)

def data_gen_Sphere2D(args):
    from functions.test_weight import Sphere2D
    from functools import partial
    
    directory = 'lamcts/Sphere2D'
    file_path = os.path.join(directory, f"meta_dataset_2.json")
    os.makedirs(directory, exist_ok=True)
    centers = [(-5.0,-5.0), (5.0, 5.0), (-5.0, 5.0)]
    lb = np.ones(2) * -10
    ub = np.ones(2) * 10
    alg = simpleBO_gen
    meta_dataset = {}
    
    args.search_space_id = 'Sphere2D'
    args.dims = 2
    args.iteration = 100
    meta_dataset[args.search_space_id] = {}
    for center in centers:
        args.dataset_id = str(center)
        function = partial(Sphere2D, center=center) 
        samples = alg(args, f=function)
        X_samples = [x.tolist() for x, y in samples]
        y_samples = [y*-1 for x, y in samples]
        
        meta_dataset[args.search_space_id][args.dataset_id] = {
            "X": X_samples,
            "y": y_samples
        }
        
    meta_dataset_json = json.dumps(meta_dataset)
    
    with open(file_path, 'w') as json_file:
        json_file.write(meta_dataset_json)
    