import pickle as pkl

import botorch
import torch
from botorch import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from gpytorch import ExactMarginalLogLikelihood
from botorch.optim.fit import  fit_gpytorch_mll_torch

from pathlib import Path
import os, sys

ROOT = str(Path(os.path.realpath(__file__)).parent.parent)
sys.path.insert(0, ROOT)

from nap.RL.utils_gp import TransformedCategorical


def clean(seq_string):
    return list(map(int, seq_string.split(',')))


if __name__ == '__main__':
    # EDA Data from Genetic Algorithm on circuits
    eda_data_root = os.path.join(ROOT, 'eda_data')
    results_dict = pkl.load(open(os.path.join(eda_data_root, 'eda_ga500_data.pkl'), 'rb'))
    circuits = [k for k in list(results_dict.keys()) if
                k not in ['bar', 'action_space_id', 'cat_to_cmd', 'Description']]

    for circuit in sorted(circuits):
        Y = results_dict[circuit]['Count'] / results_dict[circuit]['ref_count'] + \
            results_dict[circuit]['Level'] / results_dict[circuit]['level_ref']
        Y = -Y  # minimize y to solve task but in codebase we maximise, so we flip the sign here
        stdY = (Y - Y.mean()) / Y.std()
        X = results_dict[circuit]['X']
        circuit_dict = {
            'domain': X,
            'accs': stdY,
        }
        pkl.dump(circuit_dict, open(os.path.join(eda_data_root, f'dataset_{circuit}.pkl'), 'wb'))
        print(f"saved {os.path.join(eda_data_root, f'dataset_{circuit}.pkl')}")

        if not os.path.exists(os.path.join(eda_data_root, f'gp_{circuit}.pt')):
            # Fit and save GP
            X = torch.from_numpy(X).to(dtype=float, device='cuda:0')
            stdY = torch.from_numpy(stdY).to(dtype=float, device='cuda:0')
            # Sub-sample dataset
            model = SingleTaskGP(
                train_X=X,
                train_Y=stdY.view(-1, 1),
                covar_module=TransformedCategorical(ard_num_dims=X.shape[-1]).to('cuda:0')
            ).to('cuda:0')
            mll = ExactMarginalLogLikelihood(model.likelihood, model)

            try:
                _ = fit_gpytorch_mll(mll=mll, optimizer=fit_gpytorch_mll_torch)
            except (RuntimeError, botorch.exceptions.errors.ModelFittingError) as e:
                print(e)
                print(f'Error during the GP fit on {circuit}-circuit.')
                X = X.cpu().numpy()
                stdY = stdY.cpu().numpy()
                model = model.cpu()
                mll = mll.cpu()
                del model, mll
                torch.cuda.empty_cache()
                continue

            with torch.no_grad():
                torch.save(model, os.path.join(eda_data_root, f'gp_{circuit}.pt'))
            print(f"saved model at {os.path.join(eda_data_root, f'gp_{circuit}.pt')}")

            X = X.cpu()
            stdY = stdY.cpu()
            model = model.cpu()
            mll = mll.cpu()
            del X, stdY, model, mll
            torch.cuda.empty_cache()

        else:
            print(f'{circuit}-circuit GP already fit and saved.')
