from sklearn.gaussian_process.kernels import RBF, PairwiseKernel, DotProduct, Kernel
from sklearn.utils.random import sample_without_replacement
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from pathlib import Path
import pickle
from random import SystemRandom
import time
from dppy.finite_dpps import FiniteDPP

class Timer(object):
    def __init__(self, verbose=False):
        self.verbose = verbose

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.secs = self.end - self.start
        self.msecs = self.secs * 1000  # millisecs
        if self.verbose:
            print('elapsed time: %f s' % self.secs)


class FastMixinKernel(Kernel):
    def __init__(self, gp_kernel, pairwise_kernel):
        self.gp_kernel = gp_kernel
        self.pairwise_kernel = pairwise_kernel

    def __call__(self, X, Y=None, **kwargs):
        return self.pairwise_kernel(X, Y, **kwargs)

    def diag(self, X):
        return self.gp_kernel.diag(X)

    def is_stationary(self):
        return self.gp_kernel.is_stationary()

mnist8m = np.load('/dev/shm/mnist8m.npz')
X_all = mnist8m['X']
X_all /= np.linalg.norm(X_all, axis=1).max()

n_list = np.geomspace(100000, 999999, 10)

urandom_seed = SystemRandom().randrange(99999)
r = np.random.RandomState(urandom_seed)

sigma = np.sqrt(10)
dot_func = FastMixinKernel(
    RBF(length_scale=sigma),
    PairwiseKernel(gamma=1/np.square(sigma), metric='rbf', pairwise_kernels_kwargs={'n_jobs':1})
)

desired_k = 10

for (i_cand, n_cand) in enumerate(n_list):
    result = []
    print(n_cand)
    I_train = sample_without_replacement(n_population=X_all.shape[0],
                                         n_samples=int(n_cand),
                                         random_state=r)

    X_train = X_all[I_train, :]
    n = X_train.shape[0]

    with Timer(verbose=False) as t_vfx:
        vfx_dpp_sampler = FiniteDPP(kernel_type='likelihood', L_eval_X_data=(dot_func, X_train))
        S_vfx = vfx_dpp_sampler.sample_exact_k_dpp(desired_k, 'vfx',
                                             rls_oversample_bless=4,
                                             rls_oversample_dppvfx=5,
                                             random_state=r,
                                             verbose=False)
    pc_state = vfx_dpp_sampler.intermediate_sample_info
    dict_dppvfx = pc_state.dict_dppvfx
    result.append({'n': n, 'alg': 'vfx', 'time': t_vfx.secs, 'k': len(S_vfx), 'alpha_hat': pc_state.alpha_star, 'm': len(dict_dppvfx.probs)/dict_dppvfx.rls_oversample, 'rej': pc_state.rej_to_first_sample})

    with Timer(verbose=False) as t_alpha:
        alpha_dpp_sampler = FiniteDPP(kernel_type='likelihood', L_eval_X_data=(dot_func, X_train))
        S_alpha = alpha_dpp_sampler.sample_exact_k_dpp(desired_k, 'alpha',
                                             rls_oversample_bless=4,
                                             rls_oversample_dppvfx=5,
                                             random_state=r,
                                             verbose=False)
    pc_state = alpha_dpp_sampler.intermediate_sample_info
    dict_dppvfx = pc_state.dict_dppvfx
    result.append({'n': n, 'alg': 'alpha', 'time': t_alpha.secs, 'k': len(S_alpha), 'alpha_switch': pc_state.alpha_switches, 'alpha_hat': pc_state.alpha_hat, 'm': len(dict_dppvfx.probs)/dict_dppvfx.rls_oversample, 'rej': pc_state.rej_to_first_sample, 'trial': pc_state.trial_to_first_sample})


    with tempfile.NamedTemporaryFile(prefix=f'run_k_succ_rbf_{urandom_seed:05d}_', suffix='.pickle', dir=Path.home() / 'data/result', delete=False) as file:
        pickle.dump(result, file)

pass
