import os

os.environ['MKL_THREADING_LAYER'] = 'GNU'

# import internal classes
from env_wrapper_bl import *
from inner_model import *

import argparse

from distutils.util import strtobool

parser = argparse.ArgumentParser(description='PPO agent')
# Common arguments
parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                    help='the name of this experiment')
parser.add_argument('--gym-id', type=str, default="MsPacmanNoFrameskip-v4",
                    help='the id of the gym environment')
parser.add_argument('--learning-rate', type=float, default=2.5e-4,
                    help='the learning rate of the optimizer')
parser.add_argument('--seed', type=int, default=1,
                    help='seed of the experiment')
parser.add_argument('--total-timesteps', type=int, default=100001,
                    help='total timesteps of the experiments')
parser.add_argument('--torch-deterministic', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help='if toggled, `torch.backends.cudnn.deterministic=False`')
parser.add_argument('--cuda', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help='if toggled, cuda will not be enabled by default')
parser.add_argument('--prod-mode', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help='run the script in production mode and use wandb to log outputs')
parser.add_argument('--capture-video', type=lambda x: bool(strtobool(x)), default=False, nargs='?', const=True,
                    help='weather to capture videos of the agent performances (check out `videos` folder)')
parser.add_argument('--wandb-project-name', type=str, default="meta_search",
                    help="the wandb's project name")
parser.add_argument('--wandb-entity', type=str, default=None,
                    help="the entity (team) of wandb's project")

# Algorithm specific arguments
parser.add_argument('--minibatch-size', type=int, default=256,
                    help='the size of mini batch')
parser.add_argument('--num-envs', type=int, default=16,
                    help='the number of parallel game environment')
parser.add_argument('--gamma', type=float, default=0.99,
                    help='the discount factor gamma')
parser.add_argument('--gae-lambda', type=float, default=0.95,
                    help='the lambda for the general advantage estimation')
parser.add_argument('--ent-coef', type=float, default=0.01,
                    help="coefficient of the entropy")
parser.add_argument('--vf-coef', type=float, default=0.5,
                    help="coefficient of the value function")
parser.add_argument('--max-grad-norm', type=float, default=0.5,
                    help='the maximum norm for the gradient clipping')
parser.add_argument('--clip-coef', type=float, default=0.2,
                    help="the surrogate clipping coefficient")
parser.add_argument('--update-epochs-actor', type=int, default=4,
                    help="the K epochs to update the actor")
parser.add_argument('--update-epochs-critic', type=int, default=1,
                    help="the K epochs to update the critic")
parser.add_argument('--kle-stop', type=lambda x: bool(strtobool(x)), default=False, nargs='?', const=True,
                    help='If toggled, the policy updates will be early stopped w.r.t target-kl')
parser.add_argument('--kle-rollback', type=lambda x: bool(strtobool(x)), default=False, nargs='?', const=True,
                    help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
parser.add_argument('--target-kl', type=float, default=0.03,
                    help='the target-kl variable that is referred by --kl')
parser.add_argument('--gae', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help='Use GAE for advantage computation')
parser.add_argument('--norm-adv', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help="Toggles advantages normalization")
parser.add_argument('--anneal-lr', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help="Toggle learning rate annealing for policy and value networks")
parser.add_argument('--clip-vloss', type=lambda x: bool(strtobool(x)), default=True, nargs='?', const=True,
                    help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')

parser.add_argument('--gif-frequency', type=int, default=500)
parser.add_argument('--gif-length', type=int, default=500)

parser.add_argument('--log-video', type=int, default=0)

parser.add_argument('--fast_gradient', type=int, default=0)
parser.add_argument('--use-sgd', type=int, default=0)
parser.add_argument('--num-update-value', type=int, default=1)
parser.add_argument('--num-update-policy', type=int, default=1)
# spinningup 80
parser.add_argument('--lr-policy', type=float, default=2.5e-4,
                    help='the learning rate of the optimizer')
parser.add_argument('--lr-value', type=float, default=2.5e-4,
                    help='the learning rate of the optimizer')
# spinningup pi_lr = 3e-4,
# vf_lr = 1e-3
parser.add_argument('--eps-adam', type=float, default=1e-5,
                    help='the learning rate of the optimizer')

parser.add_argument('--add-to-state-name', type=str, default='su')

parser.add_argument('--deterministic', type=int, default=1)
parser.add_argument('--steps-blueprint', type=int, default=1001472)

# parser.add_argument('--steps_finetune', type=int, default=349)
parser.add_argument('--steps-finetune', type=int, default=159)
# parser.add_argument('--steps_visualized', type=int, default=8) #349-356 #999:done at step 808
parser.add_argument('--steps-visualized', type=int, default=10)  # 349-356
parser.add_argument('--nb-updates-visualized', type=int, default=1000)

parser.add_argument('--nb-policy-visualized', type=int, default=3)
parser.add_argument('--stopping-criteria-factor', type=int, default=10)

parser.add_argument('--finetune-lr-policy', type=float, default=-1)  # -1: using state dic optimizer blueprint
parser.add_argument('--finetune-lr-value', type=float, default=-1)

parser.add_argument('--num-steps-per-update', type=int, default=128)
parser.add_argument('--max-updates-per-finetune', type=int, default=30)
parser.add_argument('--nb-backward-steps', type=int, default=100)
parser.add_argument('--lr-decay', type=int, default=100)
parser.add_argument('--num-finetune-envs', type=int, default=8)

parser.add_argument('--finetune-horizon', type=int, default=32)
parser.add_argument('--finetune-frequency', type=int, default=32)

parser.add_argument('--debug-cuda', type=int, default=0)
parser.add_argument('--only-head', type=int, default=0)
parser.add_argument('--other-eval-seed', type=int, default=0)
parser.add_argument('--other-finetune-seed', type=int, default=0)

parser.add_argument('--init-global-state', type=int, default=1)

parser.add_argument('--wandb', type=int, default=1)

parser.add_argument('--save-blueprint-frequency', type=int, default=-1)
parser.add_argument('--save-initial-blueprint', type=int, default=0)
parser.add_argument('--load-initial-blueprint', type=int, default=0)
parser.add_argument('--index-load-blueprint', type=int, default=9)
parser.add_argument('--step-load-blueprint', type=int, default=10240000)
parser.add_argument('--index-save-blueprint', type=int, default=0)

parser.add_argument('--sticky-prob', type=float, default=0)
parser.add_argument('--max-nb-random-actions', type=int, default=0)
parser.add_argument('--use-only-first-action', type=int, default=0)
parser.add_argument('--reward-clipping', type=int, default=0)

parser.add_argument('--nb-tests', type=int, default=1)
parser.add_argument('--test-frequency', type=int, default=-1)

parser.add_argument('--clone-full-state', type=int, default=1)
parser.add_argument('--render-ft', type=int, default=0)
parser.add_argument('--render-main', type=int, default=0)

parser.add_argument('--one-time-update', type=int, default=0)
parser.add_argument('--index-update', type=int, default=0)

parser.add_argument('--reload-blue', type=int, default=1)
parser.add_argument('--load-optimizer', type=int, default=0)

parser.add_argument('--max-update', type=int, default=10)
parser.add_argument('--finetune-anneal-lr', type=int, default=1)

parser.add_argument('--lr-rnd', type=float, default=1e-4)
parser.add_argument('--lr-icm', type=float, default=1e-4)
parser.add_argument('--lr-vae', type=float, default=1e-4)

parser.add_argument('--horizon-and-frequency', type=int, default=-1)

parser.add_argument('--intrinsic', type=int, default=0)

parser.add_argument('--type', type=str, default="online network finetuning")

parser.add_argument('--nb-rollout', type=int, default=1)

parser.add_argument('--episode-length', type=int, default=-1)

parser.add_argument('--fine-optimizer', type=int, default=0)
parser.add_argument('--finetune-critic', type=int, default=1)

parser.add_argument('--finite-horizon-problem', type=int, default=1)

parser.add_argument('--blue-if-fail', type=int, default=0)

args = parser.parse_args()
if not args.seed:
    args.seed = int(time.time())

args.batch_size = int(args.num_finetune_envs * args.num_steps_per_update)
if args.horizon_and_frequency >= 0:
    args.finetune_horizon = args.horizon_and_frequency
    args.finetune_frequency = args.horizon_and_frequency


def main(args):
    import os
    os.environ['MKL_THREADING_LAYER'] = 'GNU'

    if args.debug_cuda:
        os.environ["CUDA_LAUNCH_BLOCKING"] = '1'

    if not args.finetune_critic and args.update_epochs_critic>1:
        return

    if args.max_update ==0 and not (args.num_steps_per_update==32):
        return

    import cv2
    cv2.ocl.setUseOpenCL(False)
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F

    if args.debug_cuda:
        torch.autograd.set_detect_anomaly(True)
    import numpy as np

    import time
    import random

    import matplotlib.pyplot as plt
    if args.wandb:
        import wandb


    device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')

    # init variable
    envs = VecPyTorch(ParallelEnv([make_env(args.gym_id, args.seed + i,
                                            sticky_prob=args.sticky_prob,
                                            max_nb_random_actions=args.max_nb_random_actions,
                                            use_only_first_action=args.use_only_first_action,
                                            reward_clipping=args.reward_clipping)
                                   for i in range(1)]), device)

    simul_envs = VecPyTorch(ParallelEnv([make_env(args.gym_id, args.seed + i,
                                                  sticky_prob=args.sticky_prob,
                                                  max_nb_random_actions=args.max_nb_random_actions,
                                                  use_only_first_action=args.use_only_first_action,
                                                  reward_clipping=args.reward_clipping)
                                         # for i in range(1)]), device)
                                         for i in range(args.num_finetune_envs)]), device)

    simul_envs.reset()
    test_envs = VecPyTorch(ParallelEnv([make_env(args.gym_id, args.seed + i,
                                                  sticky_prob=args.sticky_prob,
                                                  max_nb_random_actions=args.max_nb_random_actions,
                                                  use_only_first_action=args.use_only_first_action,
                                                  reward_clipping=args.reward_clipping)
                                         # for i in range(1)]), device)
                                         for i in range(1)]), device)

    test_envs.reset()


    if args.wandb:
        wandb.init(project=args.wandb_project_name,
                   name=f"{args.exp_name}_horizon{args.finetune_horizon if args.finetune_horizon > 0 else 'no'}_frequency{args.finetune_frequency}{'_head' if args.only_head else ''}{'_evalseed' if args.other_eval_seed else ''}{'_ftnseed' if args.other_finetune_seed else ''}",
                   config=vars(args))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    global_init_obs, global_numpy_obs = envs.reset()
    global_init_state = envs.clone_full_state()

    obs = torch.zeros((args.num_steps_per_update, args.num_finetune_envs) + envs.observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps_per_update, args.num_finetune_envs) + envs.action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps_per_update, args.num_finetune_envs)).to(device)
    rewards = torch.zeros((args.num_steps_per_update, args.num_finetune_envs)).to(device)
    dones = torch.zeros((args.num_steps_per_update, args.num_finetune_envs)).to(device)
    values = torch.zeros((args.num_steps_per_update, args.num_finetune_envs)).to(device)

    agent = AgentTwoNetworksHead(envs).to(device)
    if args.fine_optimizer:
        optimizer_value = optim.Adam(agent.critic.parameters(), lr=args.lr_value, eps=args.eps_adam)
        optimizer_policy = optim.Adam(agent.actor.parameters(), lr=args.lr_policy, eps=args.eps_adam)
    if args.load_initial_blueprint:
        agent.actor.load_state_dict(torch.load(
            f"vppo_atari_actor_{args.index_load_blueprint}_init.pth"))
        agent.critic.load_state_dict(torch.load(
            f"vppo_agent_critic_{args.index_load_blueprint}_init.pth"))
        if args.fine_optimizer:
            optimizer_policy.load_state_dict(torch.load(
                f"vppo_opt_actor_{args.index_load_blueprint}_init.pth"))
            optimizer_value.load_state_dict(torch.load(
                f"vppo_opt_critic_{args.index_load_blueprint}_init.pth"))

    elif args.step_load_blueprint > 0:
        agent.actor.load_state_dict(torch.load(
            f"vppo_atari_actor_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
        agent.critic.load_state_dict(torch.load(
            f"vppo_agent_critic_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
        if args.fine_optimizer:
            optimizer_policy.load_state_dict(torch.load(
                f"vppo_opt_actor_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
            optimizer_value.load_state_dict(torch.load(
                f"vppo_opt_critic_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))


    blueprint = AgentTwoNetworksHead(envs).to(device)
    if args.load_initial_blueprint:
        blueprint.actor.load_state_dict(torch.load(
            f"vppo_atari_actor_{args.index_load_blueprint}_init.pth"))
        blueprint.critic.load_state_dict(torch.load(
            f"vppo_agent_critic_{args.index_load_blueprint}_init.pth"))

    elif args.step_load_blueprint > 0:
        blueprint.actor.load_state_dict(torch.load(
            f"vppo_atari_actor_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
        blueprint.critic.load_state_dict(torch.load(
            f"vppo_agent_critic_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))

    else:
        if args.save_initial_blueprint:
            torch.save(blueprint.actor.state_dict(),
                       f"vppo_atari_actor_{args.index_save_blueprint}_init.pth")
            torch.save(blueprint.critic.state_dict(),
                       f"vppo_agent_critic_{args.index_save_blueprint}_init.pth")

            # print("initial blueprint saved")
            return

    if args.intrinsic:
        # Exploration Bonuses
        rnd = RNDModel(envs).to(device)

        icm = ICMModel(envs).to(device)

        vae = VAEDensity().to(device)
        optimizer_vae = optim.Adam(vae.parameters(), lr=args.lr_vae)

        vae.load_state_dict(torch.load(
            f"vae_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
        rnd.load_state_dict(torch.load(
            f"rnd_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))
        icm.load_state_dict(torch.load(
            f"icm_{args.index_load_blueprint}_{args.step_load_blueprint}.pth"))

    test_envs.set_init_state(envs.clone_full_state())
    test_envs.set_init_obs(global_numpy_obs)
    test_envs.restore_full_state()
    next_obs = global_init_obs
    blueprint_ep_ret = 0
    blueprint_step_until_dead = 0
    while True:
        with torch.no_grad():
            action, logproba, _ = blueprint.get_action(next_obs, deterministic=True)

        next_obs, rs, ds, infos = test_envs.step(action)

        blueprint_ep_ret += rs.item()

        blueprint_step_until_dead += 1

        if ds[0]:
            break

    print(f"Episodic return of blueprint: {blueprint_ep_ret}")
    if args.wandb:
        wandb.log({"Nb of finetune so far": 0, "Episode return from here": blueprint_ep_ret,
                   "Finetuned episode return so far": blueprint_ep_ret,
                   "Steps survived": blueprint_step_until_dead,
                   "Real episode step": -1})

    if args.finetune_anneal_lr:
        lr_policy = lambda f: f * args.lr_policy

    def scale(im, nR, nC):
        img = np.zeros((nR, nC, 3))
        nR0 = len(im)  # source number of rows
        nC0 = len(im[0])  # source number of columns
        for h in range(nR):
            for w in range(nC):
                img[h, w] = im[int(h * nR0 / nR), int(w * nC0 / nC)]
        return img

    def update_blue(args, updates_per_finetune, init_obs, init_state, init_obs_numpy, init_render, inner_agent,
                    opt_policy, optimizer_value, meta_ep_step, nb_meta_up, current_meta_ret, gif_test=None):
        if args.render_ft and meta_ep_step > 0 and meta_ep_step < 30:
            gifs_finetune = []
            gifs_finetune.append([])
            first_render = True
        simul_envs.set_init_state(init_state)
        simul_envs.set_init_obs(init_obs_numpy)
        simul_envs.restore_full_state()
        samples = 0

        t0_blue = time.time()

        blueprint_ep_ret_fts = []

        for _ in range(args.nb_rollout):

            test_envs.set_init_state(init_state)
            test_envs.set_init_obs(init_obs_numpy)
            test_envs.restore_full_state()
            next_obs = init_obs
            blueprint_ep_ret_ft = 0
            blueprint_step_until_dead_ft = 0

            while True:
                # ALGO LOGIC: put action logic here
                with torch.no_grad():
                    action, logproba, _ = blueprint.get_action(next_obs, deterministic=True)

                next_obs, rs, ds, infos = test_envs.step(action)


                if gif_test is not None and blueprint_step_until_dead_ft == 0:
                    gif_test.append(test_envs.render(mode='rgb_array',
                                                     caption="test_blue").transpose(
                        2, 0, 1))

                blueprint_ep_ret_ft += rs.item()

                blueprint_step_until_dead_ft += 1

                if ds[0] or (args.episode_length>0 and blueprint_step_until_dead_ft >= args.episode_length-meta_ep_step):
                    blueprint_ep_ret_fts.append(blueprint_ep_ret_ft)
                    break

        t_blue = time.time() - t0_blue

        # print(f"ret blue from this step: {np.mean(blueprint_ep_ret_fts)}")

        next_obs = init_obs
        next_obs = next_obs.repeat(args.num_finetune_envs, 1, 1, 1)
        next_done = torch.zeros(args.num_finetune_envs).to(device)
        episode_finetune = 0
        episode_step = np.zeros(args.num_finetune_envs)


        finetune_done = False

        rets = []

        gradients_norms = []

        t_update = []
        t_simulation = []
        t_backward = []
        t_test = []

        for update in range(1, updates_per_finetune + 1):
            # print(f"update {update}")
            if args.finetune_anneal_lr:
                frac = 1.0 - (update - 1.0) / updates_per_finetune
                lrnow = lr_policy(frac)
                opt_policy.param_groups[0]['lr'] = lrnow
            else:
                assert opt_policy.param_groups[0]['lr'] == args.lr_policy

            # TRY NOT TO MODIFY: prepare the execution of the game.
            t_step_0 = time.time()
            with torch.no_grad():
                advantages = torch.zeros_like(rewards).to(device) #torch.zeros((args.num_steps_per_update, args.num_finetune_envs)).to(device)

            previous_episode = np.zeros(args.num_finetune_envs).astype(int)

            for step in range(0, args.num_steps_per_update):

                if args.render_ft and meta_ep_step > 0 and meta_ep_step < 30:
                    if first_render:
                        gifs_finetune[-1].append(init_render)
                        first_render = False
                    else:
                        gifs_finetune[-1].append(simul_envs.render(mode='rgb_array',
                                                                   caption=f"ms{meta_ep_step}up{update}ep{episode_finetune}st{episode_step}").transpose(
                            2, 0, 1))
                obs[step] = next_obs
                dones[step] = next_done

                with torch.no_grad():
                    if args.finetune_critic:
                        values[step] = inner_agent.get_value(obs[step]).flatten()
                    else:
                        values[step] = blueprint.get_value(obs[step]).flatten()
                    action, logproba, _ = inner_agent.get_action(obs[step])

                actions[step] = action
                logprobs[step] = logproba

                # TRY NOT TO MODIFY: execute the game and log data.

                next_obs, rs, ds, infos = simul_envs.step(action)
                samples += args.num_finetune_envs
                if gif_test is not None and episode_step == 0 and len(gif_test) < 100:
                    gif_test.append(simul_envs.render(mode='rgb_array',
                                                      caption=f"update {update}").transpose(
                        2, 0, 1))
                episode_step += 1
                rewards[step], next_done = rs.view(-1), torch.Tensor(ds).to(device)

                if np.array(ds).any() or (episode_step==args.finetune_horizon).any():
                    for idx, dn in enumerate(ds):
                        if episode_step[idx] == args.finetune_horizon or dn:
                            episode_finetune += 1
                            episode_step[idx] = 0
                            with torch.no_grad():
                                if args.finetune_critic and not (episode_step[idx] == args.finetune_horizon):
                                    last_value = inner_agent.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                                elif args.finetune_critic and not args.finite_horizon_problem:
                                    last_value = inner_agent.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                                else:
                                    last_value = blueprint.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                                lastgaelam = 0
                                for t in reversed(range(previous_episode[idx], step + 1)):
                                    if t == step:
                                        nextnonterminal = 1.0 - next_done[idx].unsqueeze(0)
                                        nextvalues = last_value
                                    else:
                                        nextnonterminal = 1.0 - dones[t + 1][idx].unsqueeze(0)
                                        nextvalues = values[t + 1][idx].unsqueeze(0)
                                    delta = rewards[t][idx].unsqueeze(0) + args.gamma * nextvalues * nextnonterminal - \
                                            values[t][idx].unsqueeze(0)
                                    advantages[
                                        t][
                                        idx] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                            if not dn:
                                simul_envs.restore_full_state_one(idx)
                                next_obs[idx] = init_obs
                            previous_episode[idx] = step+1

            t_simulation.append(time.time() - t_step_0)


            for idx in range(args.num_finetune_envs):
                if previous_episode[idx] == args.num_steps_per_update:
                    continue
                with torch.no_grad():
                    if args.finetune_critic and not (episode_step[idx] == args.finetune_horizon):
                        last_value = inner_agent.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                    elif args.finetune_critic and not args.finite_horizon_problem:
                        last_value = inner_agent.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                    else:
                        last_value = blueprint.get_value(next_obs[idx].unsqueeze(0).to(device)).reshape(1, -1)
                    lastgaelam = 0
                    for t in reversed(range(previous_episode[idx], args.num_steps_per_update)):
                        if t == step:
                            nextnonterminal = 1.0 - next_done[idx].unsqueeze(0)
                            nextvalues = last_value
                        else:
                            nextnonterminal = 1.0 - dones[t + 1][idx].unsqueeze(0)
                            nextvalues = values[t + 1][idx].unsqueeze(0)
                        delta = rewards[t][idx].unsqueeze(0) + args.gamma * nextvalues * nextnonterminal - values[t][idx].unsqueeze(0)
                        advantages[
                            t][
                            idx] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam

            with torch.no_grad():
                returns = advantages + values

            t_update_0 = time.time()

            # flatten the batch
            b_obs = obs.reshape((-1,) + simul_envs.observation_space.shape)
            b_logprobs = logprobs.reshape(-1)
            b_actions = actions.reshape((-1,) + simul_envs.action_space.shape)
            b_advantages = advantages.reshape(-1)
            b_returns = returns.reshape(-1)
            b_values = values.reshape(-1)

            batch_size = args.batch_size

            # Optimizaing the policy and value network
            if args.kle_rollback:
                target_agent = AgentTwoNetworksHead(envs).to(device)
            inds = np.arange(batch_size, )
            gradients_norm = []
            for i_epoch_pi in range(args.update_epochs_actor):
                np.random.shuffle(inds)
                if args.kle_rollback:
                    target_agent.load_state_dict(inner_agent.state_dict())
                for start in range(0, args.batch_size, args.minibatch_size):

                    end = start + args.minibatch_size
                    minibatch_ind = inds[start:end]

                    if len(minibatch_ind) < 2:
                        continue

                    mb_advantages = b_advantages[minibatch_ind]
                    if args.norm_adv:
                        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                    if args.only_head:
                        assert False
                        with torch.no_grad():
                            features = blueprint.actor.features(b_obs[minibatch_ind])
                        logits = inner_agent.actor.head(features)
                        probs = Categorical(logits=logits)
                        newlogproba = probs.log_prob(b_actions.long()[minibatch_ind])
                        entropy = probs.entropy()
                    else:
                        _, newlogproba, entropy = inner_agent.get_action(b_obs[minibatch_ind],
                                                                         b_actions.long()[minibatch_ind])

                    ratio = (newlogproba - b_logprobs[minibatch_ind]).exp()

                    # Stats
                    approx_kl = (b_logprobs[minibatch_ind] - newlogproba).mean()

                    # Policy loss
                    pg_loss1 = -mb_advantages * ratio
                    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                    pg_loss = torch.max(pg_loss1, pg_loss2).mean()
                    entropy_loss = entropy.mean()

                    policy_loss = pg_loss - args.ent_coef * entropy_loss

                    t_back_0 = time.time()

                    opt_policy.zero_grad()
                    try:
                        policy_loss.backward()
                    except:
                        import pdb; pdb.set_trace()
                    total_norm = 0
                    if args.only_head:
                        for p in inner_agent.actor.head.parameters():
                            if p.grad is not None:
                                param_norm = p.grad.data.norm(2)
                                total_norm += param_norm.item() ** 2
                    else:
                        for p in inner_agent.actor.parameters():
                            if p.grad is not None:
                                param_norm = p.grad.data.norm(2)
                                total_norm += param_norm.item() ** 2
                    total_norm = total_norm ** (1. / 2)
                    gradients_norm.append(total_norm)
                    nn.utils.clip_grad_norm_(inner_agent.actor.parameters(), args.max_grad_norm)
                    opt_policy.step()

                    t_backward.append(time.time() - t_back_0)

                if args.kle_stop:
                    if approx_kl > args.target_kl:
                        break
                if args.kle_rollback:
                    if (b_logprobs[minibatch_ind] -
                        inner_agent.get_action(b_obs[minibatch_ind], b_actions.long()[minibatch_ind])[
                            1]).mean() > args.target_kl:
                        inner_agent.load_state_dict(target_agent.state_dict())
                        break

            if args.finetune_critic:
                for i_epoch_pi in range(args.update_epochs_critic):
                    np.random.shuffle(inds)
                    for start in range(0, args.batch_size, args.minibatch_size):
                        end = start + args.minibatch_size
                        minibatch_ind = inds[start:end]
                        # Value loss
                        new_values = inner_agent.get_value(b_obs[minibatch_ind]).view(-1)
                        if args.clip_vloss:
                            v_loss_unclipped = ((new_values - b_returns[minibatch_ind]) ** 2)
                            v_clipped = b_values[minibatch_ind] + torch.clamp(new_values - b_values[minibatch_ind],
                                                                              -args.clip_coef, args.clip_coef)
                            v_loss_clipped = (v_clipped - b_returns[minibatch_ind]) ** 2
                            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                            v_loss = 0.5 * v_loss_max.mean()
                        else:
                            v_loss = 0.5 * ((new_values - b_returns[minibatch_ind]) ** 2).mean()

                        value_loss = v_loss
                        optimizer_value.zero_grad()
                        value_loss.backward()
                        nn.utils.clip_grad_norm_(inner_agent.critic.parameters(), args.max_grad_norm)
                        optimizer_value.step()
            t_update.append(time.time() - t_update_0)
            gradients_norms.append(np.mean(gradients_norm))

            t_test_0 = time.time()

            current_ep_rets = []

            for _ in range(args.nb_rollout):

                next_test_obs = init_obs

                test_envs.set_init_state(init_state)
                test_envs.set_init_obs(init_obs_numpy)
                test_envs.restore_full_state()

                current_step_until_dead = 0
                current_ep_ret = 0

                while True:
                    with torch.no_grad():
                        if args.finetune_horizon > 0:
                            if current_step_until_dead < args.finetune_horizon:
                                if args.only_head:
                                    features = blueprint.actor.features(next_test_obs)
                                    logits = inner_agent.actor.head(features)
                                    probs = Categorical(logits=logits)
                                    action_test = probs.probs.argmax(dim=-1, keepdim=False)
                                else:
                                    action_test, _, _ = inner_agent.get_action(next_test_obs, deterministic=True)
                            else:
                                action_test, _, _ = blueprint.get_action(next_test_obs, deterministic=True)
                        else:
                            if args.only_head:
                                features = blueprint.actor.features(next_test_obs)
                                logits = inner_agent.actor.head(features)
                                probs = Categorical(logits=logits)
                                action_test = probs.probs.argmax(dim=-1, keepdim=False)
                            else:
                                action_test, logproba, _ = inner_agent.get_action(next_test_obs)

                    next_test_obs, rs_test, ds_test, _ = test_envs.step(action_test)

                    current_step_until_dead += 1
                    current_ep_ret += rs_test.item()

                    if ds_test[0] or (args.nb_rollout == 1 and current_ep_ret > np.mean(blueprint_ep_ret_fts)) or (args.episode_length>0 and current_step_until_dead >= args.episode_length-meta_ep_step):
                        current_ep_rets.append(current_ep_ret)

                        break


            rets.append(np.mean(current_ep_ret))

            t_test.append(time.time() - t_test_0)

            if np.mean(current_ep_rets)>np.mean(blueprint_ep_ret_fts):
                finetune_done = True

            if finetune_done:
                print(f"Fine-tuning succeed after {update} updates")
                # print(f"Return of last update: {rets}")
                log_dic = {"Nb of finetune so far": nb_meta_up + 1, "Episode return from here": np.mean(current_ep_rets),
                           "Finetuned episode return so far": current_meta_ret + np.mean(current_ep_rets),
                           "Improvement over blueprint this fintune": np.mean(current_ep_rets) - np.mean(blueprint_ep_ret_fts),
                           "Steps survived": current_step_until_dead, "Number of updates": update,
                           "Blueprint testing time": t_blue, "Finetune Success": 1,
                           "Average testing time": np.mean(t_test),
                           "Total testing time": np.sum(t_test),
                           "Average backward time": np.mean(t_backward),
                           "Total backward time": np.sum(t_backward),
                           "Average update time": np.mean(t_update),
                           "Total update time": np.sum(t_update),
                           "Samples per step": samples,
                           "Average simulation time": np.mean(t_simulation),
                           "Total simulation time": np.sum(t_simulation),
                           "Real episode step": meta_ep_step, "Norm 1st gradient": gradients_norms[0],
                           "Average norm gradient": np.mean(gradients_norms)}
                if len(gradients_norms) > 1:
                    log_dic.update({"Norm 2nd gradient": gradients_norms[1]})

                return True, log_dic
        # print("finetune failed")
        # print(f"returns: {rets}")

        log_dic = {"Nb of finetune so far": nb_meta_up + 1, "Finetune Success": 0,
                   "Real episode step": meta_ep_step, "Norm 1st gradient": gradients_norms[0],
                   "Average norm gradient": np.mean(gradients_norms),
                    "Blueprint testing time": t_blue,
                    "Average testing time": np.mean(t_test),
                    "Total testing time": np.sum(t_test),
                    "Average backward time": np.mean(t_backward),
                    "Total backward time": np.sum(t_backward),
                    "Average update time": np.mean(t_update),
                    "Total update time": np.sum(t_update),
                    "Average simulation time": np.mean(t_simulation),
                    "Total simulation time": np.sum(t_simulation),
                   "Number of updates": update,
                   "Samples per step": samples}

        if len(gradients_norms) > 1:
            log_dic.update({"Norm 2nd gradient": gradients_norms[1]})
        return False, log_dic

    def run_episode(nb_update, str_render="", gifs=None):
        numpy_obs = global_numpy_obs
        next_obs = global_init_obs
        episode_step = 0
        cum_ep_ret = 0
        log = False
        nb_meta_up = 0
        nb_try = 0
        gift = None
        if args.one_time_update:
            index_update = 0
        times_search = []
        total_samples=0
        global_stat = {
                    "Blueprint testing time": [],
                    "Total testing time": [],
                    "Total backward time": [],
                    "Total update time": [],
                    "Total simulation time": [],
                    "Number of updates": []}
        while True:
            # print(f"episode_step : {episode_step}")

            if nb_update > 0 and args.finetune_frequency >= 0 and episode_step % args.finetune_frequency == 0:
                print(f"Trying to fine-tune at step {episode_step}")
                if not args.one_time_update or index_update == args.index_update:
                    nb_try+=1
                    t_search_0 = time.time()
                    candidate = AgentTwoNetworksHead(envs).to(device)

                    if args.reload_blue:
                        candidate_optimizer_policy = optim.Adam(candidate.actor.parameters(), lr=args.lr_policy,
                                                                eps=args.eps_adam)

                        candidate_optimizer_critic = optim.Adam(candidate.critic.parameters(), lr=args.lr_value,
                                                                eps=args.eps_adam)

                    assert gift is None or len(gift) == 0
                    if not args.reload_blue:
                        assert False
                        candidate.actor.load_state_dict(agent.actor.state_dict())
                        candidate.critic.load_state_dict(agent.critic.state_dict())

                    else:
                        candidate.actor.load_state_dict(blueprint.actor.state_dict())
                        candidate.critic.load_state_dict(blueprint.critic.state_dict())
                        if args.fine_optimizer:
                            candidate_optimizer_policy.load_state_dict(optimizer_policy.state_dict())
                            candidate_optimizer_critic.load_state_dict(optimizer_value.state_dict())

                    success, log_d = update_blue(args, nb_update, next_obs,
                                                 envs.clone_full_state(),
                                                 numpy_obs,
                                                 envs.render(mode='rgb_array').transpose(2, 0, 1), candidate,
                                                 candidate_optimizer_policy, candidate_optimizer_critic, episode_step,
                                                 nb_meta_up, cum_ep_ret, gift)
                    total_samples+=log_d["Samples per step"]
                    for key in global_stat:
                        assert key in log_d
                        global_stat[key].append(log_d[key])
                    if success:
                        nb_meta_up += 1
                        # print("Fine-tune succeed, updating the actor")
                        agent.actor.load_state_dict(candidate.actor.state_dict())
                        agent.critic.load_state_dict(candidate.critic.state_dict())

                        if args.one_time_update:
                            index_update += 1
                    elif args.blue_if_fail:
                        agent.actor.load_state_dict(blueprint.actor.state_dict())
                        agent.critic.load_state_dict(blueprint.critic.state_dict())

                    log = True
                    del candidate
                    del candidate_optimizer_policy
                    assert gift is None or len(gift) > 0
                    times_search.append(time.time() - t_search_0)


            if args.intrinsic and log:
                with torch.no_grad():
                    target_next_feature = rnd.target(next_obs[:, -1, :, :].unsqueeze(1))
                    predict_next_feature = rnd.predictor(next_obs[:, -1, :, :].unsqueeze(1))
                    intrinsic_reward_rnd = (target_next_feature - predict_next_feature).pow(2).sum(1) / 2
                    intrinsic_reward_rnd = intrinsic_reward_rnd.item()
                    recon_x, mu, logsigma = vae(next_obs[:, -1, :, :].unsqueeze(1))
                    BCE = F.mse_loss(recon_x, next_obs[:, -1, :, :].unsqueeze(1), size_average=False)
                    KLD = -0.5 * torch.sum(1 + 2 * logsigma - mu.pow(2) - (2 * logsigma).exp())
                    loss_vae = (BCE + KLD).mean().item()
                    recon_loss = BCE.mean().item()
            if args.render_main:
                gifs[-1].append(envs.render(mode='rgb_array', caption=str_render).transpose(2, 0, 1))

            with torch.no_grad():
                value = agent.get_value(next_obs)
                if args.only_head:
                    features = blueprint.actor.features(next_obs)
                    logits = agent.actor.head(features)
                    probs = Categorical(logits=logits)
                    action = probs.probs.argmax(dim=-1, keepdim=False)
                    logproba = probs.log_prob(action)
                    entropy = probs.entropy()
                else:
                    action, logproba, entropy = agent.get_action(next_obs, deterministic=True)

            next_next_obs, rs, ds, infos = envs.step(action)
            if args.intrinsic and log:
                with torch.no_grad():
                    action_onehot = torch.FloatTensor(
                        len(action), envs.action_space.n).to(
                        device)
                    action_onehot.zero_()
                    action_onehot.scatter_(1, action.view(len(action), -1), 1)

                    real_next_state_feature, pred_next_state_feature, pred_action = icm(
                        [next_obs, next_next_obs, action_onehot])
                    intrinsic_reward_icm = F.mse_loss(real_next_state_feature, pred_next_state_feature,
                                                      reduction='none').mean(-1)
                    intrinsic_reward_icm = intrinsic_reward_icm.item()

            if log:
                log_d.update({"Searching time": times_search[-1]})

                if args.intrinsic:
                    log_d.update({"loss vae": loss_vae, "reconstruction loss vae": recon_loss,
                                  "loss rnd": intrinsic_reward_rnd, "loss icm": intrinsic_reward_icm,
                                  "log prob": logproba.item(), "entropy": entropy.item(), "value": value.item()})
                else:
                    log_d.update({"log prob": logproba.item(), "entropy": entropy.item(), "value": value.item()})
                wandb.log(log_d)
                log = False

            next_obs = next_next_obs
            if gift is not None and len(gift) > 0:
                gift.append(envs.render(mode='rgb_array',
                                        caption=f"real env").transpose(
                    2, 0, 1))
                wandb.log(
                    {f"Meta Step {episode_step}": wandb.Video(np.stack(gift, axis=0), fps=4, format="gif")})
                gift = []
            numpy_obs = infos[0]["obs"]
            episode_step += 1

            cum_ep_ret += rs.item()
            done = False

            for idx, dn in enumerate(ds):
                if dn:
                    done = True
            if done or (args.episode_length>0 and episode_step >= args.episode_length):
                print(f"Episodic return of fine-tuned policy: {cum_ep_ret}")
                done_log = {"Horizon finetune": args.finetune_horizon, "Frequency finetune": args.finetune_frequency,
                           "Total steps survived": episode_step, "Total cumulative reward": cum_ep_ret,
                           "Average search time per step": np.mean(times_search), "Average search time (10 steps)": np.mean(times_search),
                            "Total search Time": np.sum(times_search),
                           "Max online iteration": nb_update, "Total number of successful finetune": nb_meta_up,
                           "Total number of finetune tries": nb_try, "Average time per step": np.sum(times_search)/episode_step,
                            "Average number of samples per step": int(total_samples/episode_step)}

                done_log.update({f"Average {key} per step": np.mean(global_stat[key]) for key in global_stat})
                done_log.update({f"Meta {key}": np.sum(global_stat[key]) for key in global_stat})
                wandb.log(done_log)

                return

    if args.render_main:
        gifs = []

    if args.render_main:
        gifs.append([])
    agent.actor.load_state_dict(blueprint.actor.state_dict())
    agent.critic.load_state_dict(blueprint.critic.state_dict())

    run_episode(args.max_update, gifs=gifs if args.render_main else None)

    if args.wandb:
        if args.render_main:
            stacked_gif = [np.hstack([gif[t] for gif in gifs]) for t in range(min([len(g) for g in gifs]))]
            wandb.log({"render": wandb.Video(np.stack(stacked_gif, axis=0), fps=4, format="gif")})

    envs.close()
    return 1

main(args)