import d3rlpy
import torch
from IPython import embed
from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.models.encoders import VectorEncoderFactory
from d3rlpy.base import _serialize_params
from sklearn.model_selection import train_test_split
import argparse
import os
import pickle
import json
import numpy as np
from alg import Alg
import gc

parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int)
parser.add_argument("--T", type=int)
parser.parse_args()
res = parser.parse_args()

n = res.n # number of episodes
T = res.T # number of episodes
assert res.n is not None
assert res.T is not None



dataset, env = get_cartpole()
n_cut = n
if n > len(dataset.episodes):
    n_cut = len(dataset.episodes)
episodes = dataset[:n_cut]
n_train = int(.8 * n_cut)
train_eps, test_eps = episodes[:n_train], episodes[n_train:]

K = env.action_space.n
dim = env.observation_space.shape[0]


def load_model(d, n):
    path = 'models/models_cp/'
    filename = 'dqn_d{d}_n{n}_T{T}'.format(d=d,n=n,T=T)
    filepath_json = os.path.join(path, filename + '.json')
    filepath_model = os.path.join(path, filename + '.pt')


    dqn = DQN()
    dqn = DQN.from_json(filepath_json)
    dqn.load_model(filepath_model)
    return dqn


def pred_V(next_obs, dqn):
    preds = [ dqn.predict_value(next_obs,  np.ones(len(next_obs)) * k ) for k in range(K) ]
    preds = np.array(preds)
    Vs = np.max(preds, axis=0)
    return Vs

def gen_dataset(eps, dqn):
    xs = []
    As = []
    ys = []

    for ep in eps:
        if len(xs) == 0:
            xs = ep.observations.copy()
            As = ep.actions.copy()
        else:
            xs = np.concatenate((xs, ep.observations))
            As = np.concatenate((As, ep.actions))

        next_obs = ep.observations[1:]
        Vs = pred_V(next_obs, dqn)
        Vs = np.append(Vs, 0)

        labels = ep.rewards + dqn.gamma * Vs
        ys = np.concatenate((ys, labels))

    return xs, As, ys

def L(alg, xs, As, ys):
    preds = alg.pred(xs, As)
    loss = np.sum(np.square(ys - preds))
    m = len(xs)
    avg_loss = loss / float(m)
    return avg_loss

def L_dqn(eps, dqn):
    loss = 0
    preds = []
    for ep in eps:
        Qs = dqn.predict_value(ep.observations, ep.actions)
        preds = np.concatenate((preds, Qs))

    _, _, ys = gen_dataset(eps, dqn)
    m = len(ys)
    loss = np.sum(np.square(ys - preds))
    avg_loss = loss / float(m)
    return avg_loss




def holdout(ds):
    dqn_losses = []
    for d in ds:

        dqn = load_model(d, n)

        loss_dqn = L_dqn(test_eps, dqn)
        dqn_losses.append(loss_dqn)

    k_hat = np.argmin(dqn_losses)
    print("\n\nHoldout losses: " + str(dqn_losses))
    print("Holdout best model: " + str(k_hat))
    print("\n\n")
    return k_hat


def learn_and_loss(d, train, test):
    xs_train, As_train, ys_train = train
    xs_test, As_test, ys_test = test
    alg_d = Alg(dim, K, d)
    alg_d.train(xs_train, As_train, ys_train)
    loss_alg_d = L(alg_d, xs_test, As_test, ys_test)
    
    print("\tLoss of d=" + str(d) +": " + str(loss_alg_d))
    del alg_d.net
    return loss_alg_d

def ms(ds):

    for i in range(len(ds)):    
        d = ds[i]
        if i == len(ds) - 1:
            break

        if i > 0:
            del dqn
            gc.collect()
        dqn = load_model(d, n)

        train = gen_dataset(train_eps, dqn)
        test = gen_dataset(test_eps, dqn)

        m = len(train[0])

        loss_alg_d = learn_and_loss(d, train, test)
        gc.collect()
        torch.cuda.empty_cache()


        done = True
        for j in range(i + 1, len(ds)):
            dprime = ds[j]

            loss_alg = learn_and_loss(dprime, train, test)
            gc.collect()
            torch.cuda.empty_cache()


            print("\tComparing to Loss of d=" + str(dprime) +": " + str(loss_alg))

            if loss_alg_d - loss_alg > dprime / float(m):
                "\tTest failed... moving to " + str(ds[i + 1])
                done = False
                break
        if done:
            print("\n\nMS best model: " + str(i))
            print("\n\n")
            return i
    print("\n\nMS best model: " + str(i))
    print("\n\n")
    return i




ds = [10, 50, 1000, 5000, 25000, 50000]
k_holdout = holdout(ds)
gc.collect()
k_ms = ms(ds)


selections = { 
    'ds': ds,
    'n': n,
    'holdout': k_holdout,
    'ms': k_ms
}

if not os.path.exists('selection/'):
    os.makedirs('selection/')


filename = 'selection/cp_n{n}_T{T}.pkl'.format(n=n,T=T)
with open(filename, 'wb') as f:
    pickle.dump(selections, f)




