from lqr_env import LQREnv, RandController, LQRController
import darkroom_env
import lqr_env
import bandit_env
from ppo_data.ppo_darkroom import generate_dr_ppo_histories_for_envs

import numpy as np
import os
import pickle
from IPython import embed
import scipy
import matplotlib.pyplot as plt
import os
import random

max_probs = []

def rollin(env):
    raise NotImplementedError # old lqr code
    H = env.H
    xs, us, xps, rs = [], [], [], []

    for h in range(H):
        x = np.random.uniform(-1, 1, env.dx)
        u = np.random.uniform(-.5, .5, env.du)
        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)

    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs




def rollin_dr(env, rollin='uniform'):
    H = env.H

    if rollin == 'uniform':
        xs, us, xps, rs = [], [], [], []

        for h in range(H):
            x = env.sample_x()
            u = env.sample_u()

            xp, r = env.transit(x, u)

            xs.append(x)
            us.append(u)
            xps.append(xp)
            rs.append(r)
    elif rollin == 'stitch':
        xs, us, xps, rs = [], [], [], []

        for h in range(H):
            x = env.sample_stitch_x()
            u = env.sample_stitch_opt_a(x)

            xp, r = env.transit(x, u)

            xs.append(x)
            us.append(u)
            xps.append(xp)
            rs.append(r)
    elif rollin == 'expert':
        xs, us, xps, rs = [], [], [], []

        x = env.reset()
        for h in range(H):
            u = env.opt_a(x)

            xp, r = env.transit(x, u)

            xs.append(x)
            us.append(u)
            xps.append(xp)
            rs.append(r)

            x = xp
    else:
        raise NotImplementedError


    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs

def rollin_bandit(env, cov=0.0, orig=False):
    H = env.H_context
    opt_a_index = env.opt_a_index
    xs, us, xps, rs = [], [], [], []

    # hierarchical bayesian model that generates many very different looking bandit problems
    if orig:
        alpha = 20 * np.ones(env.dim) * np.random.uniform(1e-3, 1)
        alpha[opt_a_index] += 20 * cov
        probs = np.random.dirichlet(alpha)
    else:    
        alpha = np.ones(env.dim) #* np.random.uniform(1e-3, 1)
        probs = np.random.dirichlet(alpha)
        if random.random() < cov:
            # swap
            amax = np.argmax(probs)
            tmp = probs[amax]
            probs[amax] = probs[opt_a_index]
            probs[opt_a_index] = tmp

        max_probs.append(np.max(probs))


    for h in range(H):
        x = np.array([1])
        u = np.zeros(env.dim)
        i = np.random.choice(np.arange(env.dim), p=probs)
        u[i] = 1.0
        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)
    
    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs

def rollin_bandit_topk(env):
    H = env.H_context
    xs, us, xps, rs = [], [], [], []

    for h in range(H):
        x = np.array([1])
        u = np.zeros(env.dim)

        indices = np.arange(env.dim)
        np.random.shuffle(indices)
        indices = indices[:env.k]
        u[indices] = 1.0

        xp, r = env.transit(x, u)

        xs.append(x)
        us.append(u)
        xps.append(xp)
        rs.append(r)

    xs, us, xps, rs = np.array(xs), np.array(us), np.array(xps), np.array(rs)
    return xs, us, xps, rs


def generate_histories(n_envs, n_hists, n_samples, H):
    trajs = []

    # iterate over environments 
    for i in range(n_envs):
        if i % 100 == 0:
            print(f"Env: {i}")

        A, B, Q, R = lqr_env.sample(dx, du)

        opt = LQRController(A, B, Q, R)
        rand = RandController(A, B, Q, R)
        env = LQREnv(A, B, Q, R, H)

        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin(env)
            for k in range(n_samples):
                x = np.random.uniform(-1, 1, dx)
                u = opt.act(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'Q': Q,
                    'matrices': (A, B, Q, R),
                }
                trajs.append(traj)

    return trajs


def generate_dr_histories(n_envs, n_hists, n_samples, H, dim):
    envs = [darkroom_env.sample(dim, H) for i in range(n_envs)]

    trajs = generate_dr_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim)

    return trajs


def generate_dr_histories_for_goals(goals, n_hists, n_samples, H, dim, rollin='uniform'):
    envs = [darkroom_env.DarkroomEnv(dim, goal, H) for goal in goals]

    trajs = generate_dr_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        rollin=rollin
        )

    return trajs


def generate_dr_permuted_histories_for_indices(indices, n_hists, n_samples, H, dim, rollin='uniform'):
    envs = [darkroom_env.DarkroomEnvPermuted(dim, index, H) for index in indices]

    trajs = generate_dr_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        rollin=rollin
        )

    return trajs


def generate_dr_histories_from_envs(envs, n_hists, n_samples, H, dim, rollin='uniform'):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_dr(env, rollin=rollin)
            for k in range(n_samples):
                x = env.sample_x()
                u = env.opt_a(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'goal': env.goal,
                }
                # Add perm_index for DarkroomEnvPermuted
                if hasattr(env, 'perm_index'):
                    traj['perm_index'] = env.perm_index
                trajs.append(traj)

    return trajs


def generate_dr_stitch_histories_for_goals(goals, n_hists, n_samples, H, dim, rollin='uniform', eval=False):
    envs = [darkroom_env.DarkroomEnvStitch(dim, goal, H, eval=eval) for goal in goals]

    trajs = generate_dr_stitch_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        rollin='stitch' if eval else 'uniform',
    )

    return trajs


def generate_dr_stitch_histories_from_envs(envs, n_hists, n_samples, H, dim, rollin='uniform'):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_dr(env, rollin=rollin)
            for k in range(n_samples):
                x = env.sample_opt_x()
                u = env.opt_a(x)

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'goal': env.goal,
                }
                trajs.append(traj)

    return trajs


def generate_bandit_histories(n_envs, n_hists, n_samples, H, dim, var=0.0, cov=0.0, type='uniform'):
    envs = [bandit_env.sample(dim, H, var, type=type) for i in range(n_envs)]

    trajs = generate_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        var=var,
        cov=cov,
        )

    return trajs

def generate_topk_bandit_histories(n_envs, n_hists, n_samples, H, dim, var=0.0, k=1):
    envs = [bandit_env.sample_topk(dim, H, var, k) for i in range(n_envs)]

    trajs = generate_topk_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim)

    return trajs


def generate_topk_bandit_histories_from_envs(envs, n_hists, n_samples, H, dim):
    trajs = []
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_bandit_topk(env)
            for k in range(n_samples):
                x = np.array([1])
                u = env.opt_a
                
                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'means': env.means,
                    'k': env.k,
                    'var': env.var,
                }
                trajs.append(traj)
    return trajs


def generate_bandit_histories_for_arms(arms, n_hists, n_samples, H, dim, var=0.0, cov=0.0):
    """Generates bandit histories for a list of arms.
    To generate multiple environments for a single arm, pass in [arm] * n_envs for arms.
    """
    envs = [bandit_env.sample_for_arm(arm, dim, H, var) for arm in arms]

    trajs = generate_bandit_histories_from_envs(
        envs=envs,
        n_hists=n_hists,
        n_samples=n_samples,
        H=H,
        dim=dim,
        var=var,
        cov=cov)

    return trajs


def generate_bandit_histories_from_envs(envs, n_hists, n_samples, H, dim, var=0.0, cov=0.0):
    trajs = []

    # iterate over environments
    for env in envs:
        for j in range(n_hists):
            rollin_xs, rollin_us, rollin_xps, rollin_rs = rollin_bandit(env, cov=cov)
            for k in range(n_samples):
                x = np.array([1])
                u = env.opt_a

                traj = {
                    'state': x,
                    'action': u,
                    'rollin_xs': rollin_xs,
                    'rollin_us': rollin_us,
                    'rollin_xps': rollin_xps,
                    'rollin_rs': rollin_rs,
                    'Q': np.zeros(x.shape),
                    'means': env.means,
                }
                trajs.append(traj)
            
    return trajs
            

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn')

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top k subset")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Bandit arm variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage of optimal arm")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--alg", type=str, required=False, default="random", help="Algorithm to generate data")

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    env = args['env']
    if env == 'bandit':
        bandit = True
    else:
        bandit = False

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']

    dx = args['dim']
    du = args['dim']
    dim = args['dim']
    var = args['var']
    cov = args['cov']
    k = args['k']
    alg = args['alg']

    n_envs_tr = int(.8 * n_envs)
    n_envs_te = n_envs - n_envs_tr

    # orig = False

    if env == 'bandit':
        train_trajs = generate_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, var=var, cov=cov)
        test_trajs = generate_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, var=var, cov=cov)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'
    
        # plot histogram of max probs
        plt.hist(max_probs, bins=100)
        os.makedirs('figs', exist_ok=True)
        plt.savefig('figs/max_probs.png')

        # compute fraction of max probs that are at least .99
        print(f'Fraction of max probs at least .99: {np.mean(np.array(max_probs) >= .99)}')


    elif env == 'bandit_ood':
        n_envs_first = int(.9 * n_envs_tr)
        n_envs_second = n_envs_tr - n_envs_first
        first_envs = list(range(dim // 2)) * (n_envs_first // (dim // 2))
        second_envs = list(range(dim // 2, dim)) * (n_envs_second // (dim // 2))
        train_trajs = generate_bandit_histories_for_arms(first_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        second_train_trajs = generate_bandit_histories_for_arms(second_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        train_trajs += second_train_trajs
        assert len(train_trajs) == n_envs_tr * n_hists * n_samples

        n_envs_first = int(.5 * n_envs_te)
        n_envs_second = n_envs_te - n_envs_first
        first_envs = list(range(dim // 2)) * (n_envs_first // (dim // 2))
        second_envs = list(range(dim // 2, dim)) * (n_envs_second // (dim // 2))
        test_trajs = generate_bandit_histories_for_arms(first_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        test_trajs += generate_bandit_histories_for_arms(second_envs, n_hists, n_samples, H, dim, var=var, cov=cov)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'
    
    elif env == 'bandit_topk':
        train_trajs = generate_topk_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, k=k, var=var)
        test_trajs = generate_topk_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, k=k, var=var)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{k}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{k}_test.pkl'

    elif env == 'bandit_thompson':
        train_trajs = generate_bandit_histories(n_envs_tr, n_hists, n_samples, H, dim, var=var, cov=cov, type='bernoulli')
        test_trajs = generate_bandit_histories(n_envs_te, n_hists, n_samples, H, dim, var=var, cov=cov, type='bernoulli')
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'

    elif env == 'darkroom':
        train_trajs = generate_dr_histories(n_envs_tr, n_hists, n_samples, H, dim)
        test_trajs = generate_dr_histories(n_envs_te, n_hists, n_samples, H, dim)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    
    elif env == 'darkroom_heldout' or env.startswith('darkroom_heldout'):
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]

        if alg == "random":
            train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
            test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

            train_trajs = generate_dr_histories_for_goals(train_goals, n_hists, n_samples, H, dim)
            test_trajs = generate_dr_histories_for_goals(test_goals, n_hists, n_samples, H, dim)
            filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
            filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

        elif alg == "ppo":
            random_init = env.startswith('darkroom_heldout_random_init')
            print("Random init: ", random_init)
            train_envs = [darkroom_env.DarkroomEnv(dim, goal, H, random_init=random_init) for goal in train_goals]
            train_trajs = generate_dr_ppo_histories_for_envs(train_envs, range(len(train_envs)))
            test_envs = [darkroom_env.DarkroomEnv(dim, goal, H, random_init=random_init) for goal in test_goals]
            test_trajs = generate_dr_ppo_histories_for_envs(test_envs, range(len(train_envs), len(train_envs) + len(test_envs)))
            filepath_tr = f'ppo_datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
            filepath_te = f'ppo_datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

        else:
            raise NotImplementedError

    elif env == 'darkroom_stitch':
        goals = [np.array([dim // 2, dim - 1]), np.array([dim - 1, dim // 2])]
        train_goals = np.repeat(goals, n_envs_tr // len(goals), axis=0)
        test_goals = np.repeat(goals, n_envs_te // len(goals), axis=0)
        assert len(train_goals) + len(test_goals) == n_envs
        train_trajs = generate_dr_stitch_histories_for_goals(train_goals, n_hists, n_samples, H, dim, eval=False)
        test_trajs = generate_dr_stitch_histories_for_goals(test_goals, n_hists, n_samples, H, dim, eval=True)

        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    elif env == 'darkroom_permuted':
        indices = np.arange(120)    # 5! permutations in darkroom
        np.random.RandomState(seed=0).shuffle(indices)
        train_test_split = int(.8 * len(indices))
        train_indices = indices[:train_test_split]
        test_indices = indices[train_test_split:]
        train_indices = np.repeat(train_indices, n_envs_tr // len(train_indices), axis=0)
        test_indices = np.repeat(test_indices, n_envs_te // len(test_indices), axis=0)

        train_trajs = generate_dr_permuted_histories_for_indices(train_indices, n_hists, n_samples, H, dim)
        test_trajs = generate_dr_permuted_histories_for_indices(test_indices, n_hists, n_samples, H, dim)
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    elif env == 'darkroom_expert':
        raise NotImplementedError
        goals = np.array([[(j, i) for i in range(dim)] for j in range(dim)]).reshape(-1, 2)
        np.random.RandomState(seed=0).shuffle(goals)
        train_test_split = int(.8 * len(goals))
        train_goals = goals[:train_test_split]
        test_goals = goals[train_test_split:]
        train_goals = np.repeat(train_goals, n_envs // (dim * dim), axis=0)
        test_goals = np.repeat(test_goals, n_envs // (dim * dim), axis=0)

        train_trajs = generate_dr_histories_for_goals(train_goals, n_hists, n_samples, H, dim, rollin='expert')
        test_trajs = generate_dr_histories_for_goals(test_goals, n_hists, n_samples, H, dim, rollin='expert')
        filepath_tr = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    else:
        # we're no longer doing lqr
        raise NotImplementedError
        train_trajs = generate_histories(n_envs_tr, n_hists, n_samples, H)
        test_trajs = generate_histories(n_envs_te, n_hists, n_samples, H)
        filepath_tr = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        filepath_te = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'

    if not os.path.exists('datasets'):
        os.makedirs('datasets', exist_ok=True)
    if not os.path.exists('ppo_datasets'):
        os.makedirs('ppo_datasets', exist_ok=True)
    with open(filepath_tr, 'wb') as file:
        pickle.dump(train_trajs, file)
    with open(filepath_te, 'wb') as file:
        pickle.dump(test_trajs, file)

    print(f"Saved to {filepath_tr}.")
    print(f"Saved to {filepath_te}.")
