import sys
from typing import Optional
import logging
from logging import info, basicConfig, INFO, error
from mylog import log, logto, log_arr, remove_picklewriter
from parse import setup_args
from config import configure, configure_env
import torch
from os.path import join, exists
from tqdm import tqdm
from envs.env import Env
from agents.agent import Agent
from utils import compute_return
import numpy as np
from interact import interact
import gtimer as gt
from os import makedirs
import glob


@gt.wrap
def eval_episode(dt: float, env: Env, agent: Agent, 
                time_limit: Optional[float] = None,
                progress_bar: bool = False,
                run_idx: int = 0,
                inner_idx: int = 0):
    """ Eval the agent for one episode under dt

        Parallellization is already performed inside env, which can be a vector of envs
    """

    agent.eval()
    agent.reset()

    rewards, dones = [], []
    time_limit = time_limit if time_limit else 10
    nb_steps = int(time_limit / dt)
    # for pendulum, the physical time is 200 * 0.05 = 2*5 = 10 (seconds)
    info(f"infer> run the policy on a physical time {time_limit}"
         f" ({nb_steps} steps in total)")
    obs = env.reset()
    iter_range = tqdm(range(nb_steps)) if progress_bar else range(nb_steps)
    nb_env = len(env.envs)

    for k in iter_range: # until task horizon T (steps)
        # info(f'infer: step {k}')
        obs, reward, done = interact(env, agent, obs) # reward shape [nb_eval_env,]
        rewards.append(reward)
        dones.append(done)
        # log for each env, do not log on CC as it takes too much storage
        #for i in range(nb_env):
        #    log(f'Reward_{run_idx}_{inner_idx}_{i}', reward[i], k)
        #    log(f'Done_{run_idx}_{inner_idx}_{i}', done[i], k)
    rewards_arr = np.stack(rewards, axis=0) 
    dones_arr = np.stack(dones, axis=0)
    # store the rewards
    # do not store now, first return, store once gathered enough nb_inner_iters
    #log_arr(f'Rewards_{rewards_arr.shape[0]}_{rewards_arr.shape[1]}_{run_idx}', rewards_arr) # shape [nb_steps, nb_env]
    #log_arr(f'Dones_{dones_arr.shape[0]}_{dones_arr.shape[1]}_{run_idx}', dones_arr) 
    # gt.stamp('save the results')
    R = compute_return(rewards_arr, dones_arr)
    info(f"infer> return for run {run_idx}, inner iter {inner_idx}: {R}")
    info(f"infer> return scaled to physical time {time_limit}: {R*dt}")
    gt.stamp(f'compute return for run {run_idx}, inner iter {inner_idx}')
    return rewards_arr, dones_arr ## added the return so we could save in the end

def collect_data(args):
    """ collect data from ct env """
    logdir = args.logdir
    dt = args.dt
    T = args.time_limit
    nb_runs = args.nb_runs
    nb_eval_env = args.nb_eval_env

    # set random seed (for training as there's no alg randomness during eval)
    torch.manual_seed(0)
    np.random.seed(0)

    # device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    R = - np.inf
    cur_e = 0

    agent_created = False

    # number of internal iterations, effectively,
    # nb_eval_env in parallel then 10 serially
    nb_inner_iters = args.nb_inner_iters
    rewards, dones = [], []
    #for i in range(0, nb_runs): 
    for i in range(nb_runs, nb_runs+1):  # use nb_runs as an index
        for j in range(nb_inner_iters):
            args.eval_seed = args.eval_seed + (i*nb_inner_iters+j)*nb_eval_env
            # set up the agent and the eval_env
            if agent_created:
                eval_env = configure_env(args)
            else:
                agent, _, eval_env = configure(args)
                # agent = agent.to(device)
                # load checkpoints if agent_ckpt is not empty otherwise from logdir
                if args.agent_ckpt == '':
                    agent_file = join(logdir, 'best_agent.pt')
                else:
                    agent_file = args.agent_ckpt
                if exists(agent_file):
                    if torch.cuda.is_available(): # load from gpu to gpu,
                        # if using map_location , it needs to be "cuda:0"
                        state_dict = torch.load(agent_file)
                    else: # load from gpu to cpu
                        state_dict = torch.load(agent_file, map_location=device)
                    state_dict = torch.load(agent_file, map_location=device)
                    R = state_dict["return"]
                    cur_e = state_dict["epoch"]
                    info(f"infer> Loading agent {agent_file} with return {R}/scaled return {R*dt} at epoch {cur_e}...")
                    agent.load_state_dict(state_dict)
                    agent = agent.to(device)
                    agent_created = True
                else:
                    error(f"infer> cannot load policy")

            obs = eval_env.reset() # obs shape [nb_envs, dim_obs]
            # run the policy in eval_env, and record the reward and done
            rewards_single_iter, dones_single_iter = eval_episode(dt, eval_env, agent, T, True, i, j)
            rewards.append(rewards_single_iter)
            dones.append(dones_single_iter)
            eval_env.close()
            del eval_env
            gt.stamp(f'inner iteration {j}')

    # store the rewards
    rewards_arr = np.hstack(rewards) # concatenate on the cols
    dones_arr = np.hstack(dones)
    # store once gathered enough nb_inner_iters
    log_arr(f'Rewards_{rewards_arr.shape[0]}_{rewards_arr.shape[1]}_{nb_runs}', rewards_arr) # shape [nb_steps, nb_env]
    log_arr(f'Dones_{dones_arr.shape[0]}_{dones_arr.shape[1]}_{nb_runs}', dones_arr)
    # convert and store the rewards dones array for this run
    gt.stamp(f'collect_data function, run: {nb_runs}')
    info(gt.report())

def load_data_from_npy(data_dir, T, nb_runs):
    #paths = [path for path in glob.glob(join(data_dir, 'Rewards_*.npy'))]
    #paths.sort()
    rewards = []
    paths = [join(data_dir, f'Rewards_{int(T*1000)}_1000_{i}.npy') for i in range(nb_runs)]
    for path in paths:
        rewards.append(np.load(path))
    rewards = np.hstack(rewards)
    info(f'Loaded from {paths}')
    #paths = [path for path in glob.glob(join(args.data_dir, 'Dones_*.npy'))]
    #paths.sort()
    #dones = []
    # gt.stamp('Load the dones data')

    return rewards

def compute_V(args):
    logdir = args.logdir
    T = args.time_limit
    env_id = args.env_id
    nb_runs = args.nb_runs

    # load_the_file
    rewards = load_data_from_npy(args.data_dir, T, nb_runs)
    info(f'rewards shape: {rewards.shape}')
    gt.stamp('Load the rewards data')

    h = args.dt
    xh = rewards
    Jh = h * np.sum(xh[:-1, :], axis=0)
    Jh += (T - (xh.shape[0]-1)*h) * xh[-1, :] # this is in case that the data cannot fill a full episode
    # Jh has a shape of (M)
    info(f"shape of Jh {Jh.shape}")
    V = np.mean(Jh)
    info(f"true value is {V}")
    open(join(logdir, f"V_{V}_{env_id}.txt"), 'w').close()
    return V


def run_MSE(args):
    """ run MSE for one choice of dt """
    D_list = args.D
    info(f'D_list is {D_list}')
    logdir = args.logdir
    T = args.time_limit
    h0 = args.dt

    # load_the_file
    rewards = load_data_from_npy(args.data_dir, T, args.nb_runs)
    #import pdb;pdb.set_trace()
    info(f'rewards shape: {rewards.shape}')
    gt.stamp('Loaded the rewards data')

    # given an h, compute the Vhat
    #h_list = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1]
    h_list = args.h_list
    # h_list = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]
    # h_list = [0.1]
    N_trials = 10
    for D in D_list:
        Vhat = np.zeros((N_trials, len(h_list)))

        for j in range(N_trials):
            i = 0
            for h in h_list:
                M = int(D*h/T) # nb of episodes, assuming it's an integer (enough data to cover full episodes)
                Vhat[j, i] = compute_Vhat(h0, h, M*j, M, T, rewards, True)
                i += 1
            gt.stamp(f'compute Vhat for D={D}, trial {j}')

        # save Vhat
        out_path = join(logdir, f'Vhat_D_{D}.npy')
        np.save(out_path, Vhat) 
        gt.stamp(f'saved Vhat for h={h_list} and D={D}')

    info(gt.report())

def compute_Vhat(h0, h, base_M, M, T, rewards, verbose=False):
    # each col in rewards is an episode
    h0 = h0
    h_ratio = int(h/h0)
    info(f'h={h}, h_ratio={h_ratio}, M={M}')

    xh = rewards[::h_ratio, base_M:base_M+M]

    Jh = h * np.sum(xh[:-1, :], axis=0)
    Jh += (T - (xh.shape[0]-1)*h) * xh[-1, :]
    # the last pt accounts for the remaining integral
    # it works for partial and full trajectory
    # seems an overkill here
    if False:
        print(f"Jh was {Jh[:10]}")
        Jh_orig = h * np.sum(xh, axis=0)
        print(f"Jh_orig was {Jh_orig[:10]}")

    # each Jh[i] is the return for i-th trajectory
    if verbose:
        print(f"shape of xh: {xh.shape}")
        print(f"shape of Jh: {Jh.shape}")

    Vhat = np.mean(Jh) # avg over M trajectories

    return Vhat


def main(args):
    """ entry point """
    if args.gen_data_mode:
        collect_data(args) # this is to generate CT data
    elif args.compute_V:
        compute_V(args)
    else:
        run_MSE(args) # compute the MSE

if __name__ == '__main__':
    args = setup_args()

    # configure logging
    if args.redirect_stdout:
        format ='[%(asctime)s] %(levelname)s:%(name)s:%(message)s' 
        datefmt ='%Y-%m-%d,%H:%M:%S' 
        basicConfig(filename=join(args.logdir, f'MSE_out_{args.nb_runs}.log'), level=INFO, format=format, datefmt=datefmt)
        console = logging.StreamHandler()
        console.setLevel(INFO)
        formatter = logging.Formatter(format, datefmt=datefmt)
        console.setFormatter(formatter)
        logging.getLogger('').addHandler(console)
    else:
        basicConfig(stream=sys.stdout, level=INFO)

    logto(args.logdir, reload=True)
    remove_picklewriter()

    main(args)
