from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import argparse
from time import sleep
import numpy as np
import concurrent.futures

# from hiv_domain.hiv import hiv_config
# config = hiv_config()

from toy_domain.toy import toy_config
config = toy_config()

# from pendulum_domain.pendulum import pendulum_config
# config = pendulum_config()

def my_job(args):
    return config.interval_estimation(*args)

def my_multiprocess(args):
    RES = []
    with concurrent.futures.ProcessPoolExecutor() as executor:
        for res in executor.map(my_job, args):
            RES.append(res)
    # print(RES)
    # print(np.array(RES).shape)
    return np.array(RES)

def run_config(num_seed, concurrent_size):
    num_concurrent = int(num_seed / concurrent_size)

    # exp1: change nt in NT
    for nt in config.NT:
        Interval = np.zeros([num_seed, 4])
        for i in range(num_concurrent):
            seed_start = i * concurrent_size
            seed_end = (i+1) * concurrent_size
            args = ((nt, config.eta, config.subsample_size, seed) for seed in range(seed_start, seed_end))
            interval = my_multiprocess(args)

            Interval[seed_start:seed_end,:] = interval
        np.save(config.result_path + 'interval_nt={}_ts={}_eta={}_size={}.npy'.format(nt, config.truncate_size, config.eta, config.subsample_size), Interval)

    # exp2: change et in ETA
    for et in config.ETA:
        Interval = np.zeros([num_seed, 4])
        for i in range(num_concurrent):
            seed_start = i * concurrent_size
            seed_end = (i+1) * concurrent_size
            args = ((config.num_trajectory, et, config.subsample_size, seed) for seed in range(seed_start, seed_end))
            interval = my_multiprocess(args)

            Interval[seed_start:seed_end,:] = interval
        np.save(config.result_path + 'interval_nt={}_ts={}_eta={}_size={}.npy'.format(config.num_trajectory, config.truncate_size, et, config.subsample_size), Interval)

    # exp3: change ss in SSIZE
    for ss in config.SSIZE:
        Interval = np.zeros([num_seed, 4])
        for i in range(num_concurrent):
            seed_start = i * concurrent_size
            seed_end = (i+1) * concurrent_size
            args = ((config.num_trajectory, config.eta, ss, seed) for seed in range(seed_start, seed_end))
            interval = my_multiprocess(args)

            Interval[seed_start:seed_end,:] = interval
        np.save(config.result_path + 'interval_nt={}_ts={}_eta={}_size={}.npy'.format(config.num_trajectory, config.truncate_size, config.eta, ss), Interval)

def test_config(config, seed):
    lower, upper, lower2, upper2 = config.interval_estimation(config.num_trajectory, config.eta, config.subsample_size, seed)
    print('test sucessful!!')

if __name__ == '__main__':
    num_seed = 300
    concurrent_size = 60
    run_config(num_seed, concurrent_size)
    # test_config(config, 44)
