from utils.data_loader import load_train_test_data
from utils.space import get_space, get_small_space
from utils.evaluate import evaluate
from ensembles.ensemble_selection import EnsembleSelection

import argparse
import numpy as np
import pickle as pkl
from functools import partial
from mindware.components.metrics.metric import get_metric

import sys

parser = argparse.ArgumentParser()
parser.add_argument('--datasets', type=str)
parser.add_argument('--task_type', type=str, default='cls', choices=['cls', 'rgs'])
parser.add_argument('--rep_num', type=int, default=5)
parser.add_argument('--algos', type=str, default='rs')
parser.add_argument('--start_id', type=int, default=0)
parser.add_argument('--iter_num', type=int, default=200)
parser.add_argument('--ens_size', type=int, default=25)
parser.add_argument('--beta', type=float, default=0.025)
parser.add_argument('--tau', type=float, default=0.2)

args = parser.parse_args()
datasets = args.datasets.split(',')
task_type = args.task_type
rep_num = args.rep_num
algos = args.algos.split(',')
start_id = args.start_id
iter_num = args.iter_num
ens_size = args.ens_size
alpha = args.beta
beta = args.tau

config_space = get_space(task_type=task_type)
time_limit_per_trial = 60
scorer = get_metric('acc') if task_type == 'cls' else get_metric('mse')

for dataset in datasets:
    for algo in algos:
        print("dataset: %s, algo: %s" % (dataset, algo))

        train_node, test_node = load_train_test_data(dataset=dataset, data_dir='./', test_size=0.2,
                                                     task_type=0 if task_type == 'cls' else 4)
        eval_func = partial(evaluate,
                            scorer=scorer,
                            data_node=train_node, test_node=test_node, task_type=task_type,
                            resample_ratio=1.0, seed=1)
        for i in range(start_id, start_id + rep_num):
            # Store valid labels and test labels
            test_size = 0.25  # Consistent with evaluate.py
            seed = 1  # Consistent with evaluate.py
            if task_type == 'cls':
                from sklearn.model_selection import StratifiedShuffleSplit

                ss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
            else:
                from sklearn.model_selection import ShuffleSplit

                ss = ShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
            for train_index, test_index in ss.split(train_node.data[0], train_node.data[1]):
                _y_train, _y_val = train_node.data[1][train_index], train_node.data[1][test_index]
            if algo == 'rs':
                from searchers.random_search import RandomSearch

                optimizer = RandomSearch(config_space=config_space, eval_func=eval_func, iter_num=iter_num,
                                         task_name=dataset,
                                         save_dir='./results/%s' % task_type)
            elif algo == 'bo':
                from searchers.bayesian_optimization import BayesianOptimization

                optimizer = BayesianOptimization(config_space=config_space, eval_func=eval_func, iter_num=iter_num,
                                                 task_name=dataset,
                                                 save_dir='./results/%s' % task_type,
                                                 surrogate_type='prf')
            elif algo == 'rb':
                from searchers.rising_bandit import RisingBandit

                optimizer = RisingBandit(config_space=config_space, eval_func=eval_func, iter_num=iter_num,
                                         task_name=dataset,
                                         save_dir='./results/%s' % task_type,
                                         surrogate_type='prf')
            elif algo == 'eo':
                from searchers.bayesian_optimization_ensemble import BayesianOptimizationEnsemble

                optimizer = BayesianOptimizationEnsemble(config_space=config_space, eval_func=eval_func,
                                                         iter_num=iter_num,
                                                         task_name=dataset,
                                                         save_dir='./results/%s' % task_type,
                                                         surrogate_type='prf',
                                                         scorer=scorer,
                                                         task_type=task_type,
                                                         train_node=train_node,
                                                         test_node=test_node)
            elif algo == 'bo_div':
                from searchers.bayesian_optimization_diversity import BayesianOptimizationDiversity

                optimizer = BayesianOptimizationDiversity(config_space=config_space, eval_func=eval_func,
                                                          iter_num=iter_num,
                                                          task_name=dataset,
                                                          save_dir='./results/%s' % task_type,
                                                          surrogate_type='prf',
                                                          scorer=scorer,
                                                          task_type=task_type,
                                                          ens_size=ens_size,
                                                          val_y_labels=_y_val,
                                                          alpha=alpha,
                                                          beta=beta)
            elif algo == 'rea_es':
                from searchers.rea_es import RegularizedEAEnsemble

                optimizer = RegularizedEAEnsemble(config_space=config_space, eval_func=eval_func,
                                                  iter_num=iter_num,
                                                  task_name=dataset,
                                                  save_dir='./results/%s' % task_type,
                                                  scorer=scorer,
                                                  task_type=task_type,
                                                  ens_size=ens_size,
                                                  val_y_labels=_y_val)

            optimizer.run(time_limit_per_trial=time_limit_per_trial)

            save_path = optimizer.save_path
            with open(save_path, 'rb') as f:
                observations = pkl.load(f)

            with open(save_path, 'wb') as f:
                pkl.dump([observations, _y_val, test_node.data[1]], f)

            with open(save_path, 'rb') as f:
                observations, val_labels, test_labels = pkl.load(f)
            val_pred_list = []
            test_pred_list = []

            if algo == 'eo':
                for ob in observations:
                    _, val_perf, test_perf, val_pred, test_pred, _ = ob
                    if val_pred is not None:
                        val_pred_list.append(val_pred)
                        test_pred_list.append(test_pred)
                        best_val = val_perf
                        best_test = test_perf
            elif algo == 'rea_es':
                best_val = np.inf
                best_test = np.inf
                for ob in observations[-20:]:
                    _, val_perf, test_perf, val_pred, test_pred, _ = ob
                    if val_pred is not None:
                        val_pred_list.append(val_pred)
                        test_pred_list.append(test_pred)
                    if val_perf < best_val:
                        best_val = val_perf
                        best_test = test_perf
            else:
                best_val = np.inf
                best_test = np.inf
                for ob in observations:
                    _, val_perf, test_perf, val_pred, test_pred, _ = ob
                    if val_pred is not None:
                        val_pred_list.append(val_pred)
                        test_pred_list.append(test_pred)
                    if val_perf < best_val:
                        best_val = val_perf
                        best_test = test_perf

            ensemble_builder = EnsembleSelection(ensemble_size=ens_size,
                                                 task_type=task_type,
                                                 scorer=scorer)
            ensemble_builder.fit(val_pred_list, val_labels)

            ens_val_pred = ensemble_builder.predict(val_pred_list)
            if task_type == 'cls':
                ens_val_pred = np.argmax(ens_val_pred, axis=-1)

            print('Best validation perf: %s' % str(-best_val))
            print('Ensemble validation perf: %s' % str(ensemble_builder.scorer._score_func(ens_val_pred, _y_val)))

            ens_test_pred = ensemble_builder.predict(test_pred_list)
            if task_type == 'cls':
                ens_test_pred = np.argmax(ens_test_pred, axis=-1)
            print('Best test perf: %s' % str(-best_test))
            print(
                'Ensemble test perf: %s' % str(ensemble_builder.scorer._score_func(ens_test_pred, test_node.data[1])))
            print(ensemble_builder.model_idx)
            sys.stdout.flush()
