import argparse
import datetime
import sys
import time
from functools import partial

import gymnasium as gym
import torch
import yaml
from torch.utils.tensorboard import SummaryWriter

from dcrl.algos import A2CAlgo, DualAlgo, PPOAlgo
from dcrl.models import (
    ObsActorModel,
    ObsCriticModel,
    ObsCriticObsStateCriticModel,
    ObsStateCriticModel,
    StateCriticModel,
)
from dcrl.models.obs_state_actor_model import ObsStateActorModel
from dcrl.utils.common_func import data_synthesize, get_device, set_seed
from dcrl.utils.env_utils import make_minigrid_env, make_miniworld_env
from dcrl.utils.eval_utils import Evaluation
from dcrl.utils.model_utils import (
    get_representation_net_for_minigrid,
    get_representation_net_for_miniworld,
)
from dcrl.utils.storage_utils import (
    get_csv_logger,
    get_log_dir,
    get_model_dir,
    get_txt_logger,
    save_model_status,
)


def parse_args():
    # fmt: off

    # Parse arguments
    parser = argparse.ArgumentParser(description="train")
    parser.add_argument("--method", type=str, default='DCRL',
                        choices=['DCRL', 'Recurrent Actor-Critic', 'Asymmetric Actor-Critic',
                                 'Oracle Guiding', 'Unbiased Asymmetric Actor-Critic'],
                        help="method (default: DCRL)")
    parser.add_argument("--dropout-rate", type=float, default=1/2,
                        help="dropout rate (default: 1/2), required when method is Oracle Guiding")
    parser.add_argument("--algo", type=str, choices=['ppo', 'a2c'], default='a2c',
                        help="algorithm to use: a2c | ppo (default: a2c)")
    parser.add_argument("--env-name", type=str, default="MiniGrid-LavaCrossingS9N2-v0",
                        help="name of the environment to train on, only support MiniGrid and MiniWorld")
    parser.add_argument("--seed", type=int, default=0,
                        help="random seed (default: 0)")
    parser.add_argument("--num-envs", type=int, default=16,
                        help="number of parallel envs (default: 16)")
    parser.add_argument("--num-frames", type=int, default=5,
                        help="number of frames to run for each process per update (default: 5 for A2C, 512 for PPO)")
    parser.add_argument("--total-frames", type=int, default=10 ** 7,
                        help="number of frames of training (default: 1e7)")
    parser.add_argument("--view-size", type=int, default=None,
                        help="minigrid agent view size (default: None), None means do not change original parameters (7)")
    parser.add_argument("--eval", action="store_true", default=False,
                        help="whether to evaluate the agent")
    parser.add_argument("--eval-episodes", type=int, default=100,
                        help="number of episodes of evaluation (default: 100)")
    parser.add_argument("--eval-interval", type=int, default=100,
                        help="number of updates between two evaluations (default: 100)")
    parser.add_argument("--eval-num-envs", type=int, default=8,
                        help="number of parallel envs for evaluation (default: 8)")
    parser.add_argument("--run-name", type=str, default=None,
                        help="name of the experiment (default: {ENV}_{ALGO}_{METHOD}_{SEED}_{TIME})")
    parser.add_argument("--log-interval", type=int, default=1,
                        help="number of updates between two logs (default: 10)")
    parser.add_argument("--save-interval", type=int, default=0,
                        help="number of updates between two saves (default: 1000, 0 means no saving)")
    parser.add_argument("--capture-video", action="store_true", default=False,
                        help="whether to capture videos of the agent performances")
    parser.add_argument("--lr", type=float, default=1e-3,
                        help="learning rate (default: 1e-3 for A2C, 3e-4 for PPO)")
    parser.add_argument("--optim-eps", type=float, default=1e-5,
                        help="Adam and RMSprop optimizer epsilon (default: 1e-5)")
    parser.add_argument("--optim-alpha", type=float, default=0.99,
                        help="RMSprop optimizer alpha (default: 0.99)")
    parser.add_argument("--recurrent", action="store_true", default=True,
                        help="add a LSTM to the model")
    parser.add_argument("--embedding-size", type=int, default=64,
                        help="The embedding size of the representation net")
    parser.add_argument('--hidden_size_list', nargs='+', type=int, default=[128,],
                        help='The fc hidden sizes in actor / critic model')
    parser.add_argument('--rnn_hidden_size', type=int, default=64,
                        help='The hidden size of LSTM in actor / critic model')
    parser.add_argument("--gamma", type=float, default=0.99,
                        help="discount factor (default: 0.99)")
    parser.add_argument("--gae-lambda", type=float, default=1,
                        help="lambda coefficient in GAE formula, 1 means no gae (default: 1 for A2C, 0.95 for PPO)")
    parser.add_argument("--entropy-coef", type=float, default=0.01,
                        help="entropy term coefficient (default: 0.01)")
    parser.add_argument("--value-loss-coef", type=float, default=0.5,
                        help="value loss term coefficient (default: 0.5)")
    parser.add_argument("--max-grad-norm", type=float, default=0.5,
                        help="maximum norm of gradient (default: 0.5)")
    parser.add_argument("--num-mini-batches", type=int, default=16,
                        help="number of mini batch for PPO (default: 16)")
    parser.add_argument("--epochs", type=int, default=4,
                        help="number of epochs for PPO (default: 4)")
    parser.add_argument("--clip-eps", type=float, default=0.2,
                        help="clipping epsilon for PPO (default: 0.2)")
    parser.add_argument("--dual-num-update", type=int, default=16,
                        help="number for dual update (default: 16)")
    parser.add_argument("--dual-epochs", type=int, default=4,
                        help="number of epochs for dual (default: 4)")
    parser.add_argument('--dual-num-mini-batches', type=int, default=2,
                        help='number of mini batch for dual (default: 2 for A2C, 16 for PPO)')
    parser.add_argument("--dual-coef", type=float, default=1,
                        help="coefficient for dual (default: 1)")
    parser.add_argument("--track", action="store_true", default=False,
                        help="if toggled, this experiment will be tracked with Weights and Biases")
    parser.add_argument("--wandb-project-name", type=str, default=None,
                        help="the wandb's project name")
    parser.add_argument("--wandb-group-name", type=str, default=None,
                        help="the wandb's group name")
    # fmt: on

    args, unknown = parser.parse_known_args()

    return args


def main():
    args = parse_args()

    # Set run dir
    date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    default_run_name = f"{args.env_name}_{args.algo}_{args.method}_seed{args.seed}_{date}"
    run_name = args.run_name or default_run_name
    args.full_log_dir = get_log_dir(run_name)
    args.model_dir = get_model_dir(args.full_log_dir)

    # Load loggers and Tensorboard writer
    txt_logger = get_txt_logger(args.full_log_dir)
    csv_file, csv_logger = get_csv_logger(args.full_log_dir)
    if args.track:
        import wandb

        wandb.tensorboard.patch(root_logdir=args.full_log_dir)
        run = wandb.init(
            name=run_name,
            entity="your_wandb_entity",
            project=args.wandb_project_name,
            group=args.wandb_group_name,
            monitor_gym=True,
            save_code=True,
            config=vars(args),
        )
    tb_writer = SummaryWriter(args.full_log_dir)

    # Log command and all script arguments
    txt_logger.info("{}\n".format(" ".join(sys.argv)))
    txt_logger.info(f"Save logs at {args.full_log_dir} \n")
    txt_logger.info(f"{args}\n")
    tb_writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )
    # Save config
    with open(file=f"{args.full_log_dir}/config.yaml", mode="w", encoding="utf-8") as f:
        yaml.dump(vars(args), f)

    # Set device
    device = get_device()
    txt_logger.info(f"Device: {device}\n")

    # Set seed for all randomness sources
    set_seed(args.seed)

    # Load environments
    if "MiniGrid" in args.env_name:
        envs = gym.vector.SyncVectorEnv(
            [
                make_minigrid_env(
                    args.env_name,
                    args.seed + 1000 * i,
                    i,
                    capture_video=args.capture_video,
                    log_dir=args.full_log_dir,
                    view_size=args.view_size,
                )
                for i in range(args.num_envs)
            ]
        )
        get_representation_net_func = partial(get_representation_net_for_minigrid, embedding_size=args.embedding_size)
    elif "MiniWorld" in args.env_name:
        envs = gym.vector.SyncVectorEnv(
            [
                make_miniworld_env(
                    args.env_name,
                    args.seed + 10000 * i,
                    i,
                    capture_video=args.capture_video,
                    log_dir=args.full_log_dir,
                )
                for i in range(args.num_envs)
            ]
        )
        get_representation_net_func = partial(get_representation_net_for_miniworld, embedding_size=args.embedding_size)
    else:
        raise Exception(f"Unrecognized type of env name {args.env_name}")
    txt_logger.info("Environments loaded\n")

    # Load models
    reshape_reward_fn = None
    reshape_adv_fn = lambda adv: adv[..., 0]
    args.use_dual = False
    if args.method == "Recurrent Actor-Critic":
        actor_model = ObsActorModel(
            envs.single_observation_space,
            envs.single_action_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
        critic_model = ObsCriticModel(
            envs.single_observation_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
    elif args.method == "Asymmetric Actor-Critic":
        actor_model = ObsActorModel(
            envs.single_observation_space,
            envs.single_action_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
        critic_model = StateCriticModel(
            envs.single_observation_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
    elif args.method == "Unbiased Asymmetric Actor-Critic":
        actor_model = ObsActorModel(
            envs.single_observation_space,
            envs.single_action_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
        critic_model = ObsStateCriticModel(
            envs.single_observation_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
    elif args.method == "Oracle Guiding":
        actor_model = ObsStateActorModel(
            envs.single_observation_space,
            envs.single_action_space,
            args.recurrent,
            has_value_head=False,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
        critic_model = ObsStateCriticModel(
            envs.single_observation_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )

        def drop_func(obs, drop_prob):
            mask = torch.bernoulli(torch.ones_like(obs["state"], device=device) * (1 - drop_prob))
            obs["state"] = obs["state"] * mask
            return obs

    elif args.method == "DCRL":
        args.use_dual = True
        reshape_adv_fn = lambda adv: adv[..., 1]
        actor_model = ObsActorModel(
            envs.single_observation_space,
            envs.single_action_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )
        critic_model = ObsCriticObsStateCriticModel(
            envs.single_observation_space,
            args.recurrent,
            hidden_size_list=args.hidden_size_list,
            rnn_hidden_size=args.rnn_hidden_size,
            get_representation_net_func=get_representation_net_func,
        )

    else:
        raise Exception(f"Unrecognized type of method {args.method}")

    # Log model's parameters
    txt_logger.info("*" * 5 + "Actor Model's state_dict" + "*" * 5)
    txt_logger.info(actor_model)
    txt_logger.info("\n" + "*" * 5 + "Critic Model's state_dict" + "*" * 5)
    txt_logger.info(critic_model)
    total_params = sum(
        [param.nelement() for param in actor_model.parameters()] + [param.nelement() for param in critic_model.parameters()]
    )
    txt_logger.info(f"\n" + "*" * 5 + f"Model's total parameters: {total_params}" + "*" * 5 + "\n")

    # Load algo
    if args.algo == "a2c":
        algo = A2CAlgo(
            envs,
            actor_model,
            critic_model,
            device,
            args.num_frames,
            args.gamma,
            args.lr,
            args.gae_lambda,
            args.entropy_coef,
            args.value_loss_coef,
            args.max_grad_norm,
            reshape_reward_fn,
            reshape_adv_fn,
            args.optim_alpha,
            args.optim_eps,
        )
    elif args.algo == "ppo":
        algo = PPOAlgo(
            envs,
            actor_model,
            critic_model,
            device,
            args.num_frames,
            args.gamma,
            args.lr,
            args.gae_lambda,
            args.entropy_coef,
            args.value_loss_coef,
            args.max_grad_norm,
            reshape_reward_fn,
            reshape_adv_fn,
            args.optim_eps,
            args.clip_eps,
            args.epochs,
            args.num_mini_batches,
        )
    else:
        raise ValueError(f"Incorrect algorithm name: {args.algo}")

    if args.use_dual:
        num_frames = int(args.num_frames * args.dual_num_update)
        dual_algo = DualAlgo(
            actor_model,
            critic_model,
            algo.optimizer,
            algo.parameters,
            device,
            args.gamma,
            args.max_grad_norm,
            args.num_envs,
            num_frames,
            args.dual_epochs,
            args.dual_coef,
            args.dual_num_mini_batches,
            max_nlogp=100 if args.algo == "ppo" else 5,
        )

    # Eval
    if args.eval:
        assert args.eval_interval % args.log_interval == 0
        evaluation = Evaluation(
            args.env_name,
            args.seed,
            args.eval_num_envs,
            args.view_size,
            actor_model,
            device,
            args.eval_episodes,
        )

    # Load training status
    try:
        status = torch.load(f"{args.model_dir}/agent.pt", map_location=device)
        actor_model.load_state_dict(status["actor_model_state"])
        critic_model.load_state_dict(status["critic_model_state"])
        algo.optimizer.load_state_dict(status["optimizer_state"])
        txt_logger.info(f"Training status loaded: num_frames={status['num_frames']}, update={status['update']}\n")
    except OSError:
        status = {"num_frames": 0, "update": 0}

    # Train model
    num_frames = status["num_frames"]
    update = status["update"]
    start_time = time.time()

    txt_logger.info("Begin training ...\n")
    while num_frames < args.total_frames:
        # Update model parameters
        update_start_time = time.time()
        if args.method == "Oracle Guiding":
            drop_prob = min(1.0, num_frames / (args.total_frames * args.dropout_rate))
            cur_drop_func = partial(drop_func, drop_prob=drop_prob)
        else:
            drop_prob = None
            cur_drop_func = None
        exps, logs1, dual_exps = algo.collect_experiences(use_dual=args.use_dual, drop_func=cur_drop_func)
        logs2 = algo.update_parameters(exps)
        update_end_time = time.time()
        num_frames += logs1["num_frames"]
        update += 1

        if args.use_dual:
            dual_algo.store_experiences(dual_exps)
            logs3 = None
            if update % args.dual_num_update == 0:
                batch_dual_exps = dual_algo.collect_experiences()
                logs3 = dual_algo.update_parameters(batch_dual_exps)

        if args.eval and update % args.eval_interval == 0:
            cur_drop_func = None
            if args.method == "Oracle Guiding":
                cur_drop_func = partial(drop_func, drop_prob=1.0)
            log4 = evaluation.eval(drop_func=cur_drop_func)
        else:
            log4 = None

        # Print logs
        if update % args.log_interval == 0:
            fps = logs1["num_frames"] / (update_end_time - update_start_time)
            duration = int(time.time() - start_time)
            return_per_episode = data_synthesize(logs1["return_per_episode"])
            reshaped_return_per_episode = data_synthesize(logs1["reshaped_return_per_episode"])
            num_frames_per_episode = data_synthesize(logs1["num_frames_per_episode"])
            avg_return = data_synthesize(logs1["avg_return"])

            header = ["info/update", "info/frames", "info/FPS", "info/duration"]
            data = [update, num_frames, fps, duration]

            performance_chart_name = "train_performance" if args.eval else "performance"
            header += [f"{performance_chart_name}/avg_return_" + key for key in avg_return.keys()]
            data += avg_return.values()
            header += [f"{performance_chart_name}/num_frames_" + key for key in num_frames_per_episode.keys()]
            data += num_frames_per_episode.values()

            header += [
                f"{args.algo}/entropy",
                f"{args.algo}/value",
                f"{args.algo}/policy_loss",
                f"{args.algo}/value_loss",
                f"{args.algo}/grad_norm",
            ]
            data += [
                logs2["entropy"],
                logs2["value"],
                logs2["policy_loss"],
                logs2["value_loss"],
                logs2["grad_norm"],
            ]
            if drop_prob is not None:
                header += ["info/drop_prob"]
                data += [drop_prob]

            log_str = (
                "Update {:6} | Frames {:7} | FPS {:4.0f} | Duration {} "
                "| Rew:μσmM {:1.2f} {:1.2f} {:1.2f} {:1.2f} | Num:μσmM {:2.1f} {:2.1f} {:2.1f} {:2.1f} "
                "| H {:1.3f} | V {:1.3f} | pL {:1.3f} | vL {:1.3f} | ∇ {:1.3f} "
                + ("" if drop_prob is None else "| DropProb {:1.2f}")
            )
            txt_logger.info(log_str.format(*data))

            header += [f"{performance_chart_name}/reshaped_return_" + key for key in reshaped_return_per_episode.keys()]
            data += reshaped_return_per_episode.values()
            header += [f"{performance_chart_name}/return_" + key for key in return_per_episode.keys()]
            data += return_per_episode.values()

            if args.use_dual and logs3 is not None:
                dual_header = [
                    "dual/loss",
                    "dual/policy_loss",
                    "dual/value_loss",
                    "dual/entropy_loss",
                    "dual/num_valid",
                    "dual/grad_norm",
                ]
                dual_data = [
                    logs3["loss"],
                    logs3["policy_loss"],
                    logs3["value_loss"],
                    logs3["entropy_loss"],
                    logs3["num_valid"],
                    logs3["grad_norm"],
                ]
                dual_log_str = "[dual] L {:1.3f} | pL {:1.3f} | vL {:1.3f} | eL {:1.3f} | numV {:3} | ∇ {:1.3f}"
                txt_logger.info(dual_log_str.format(*dual_data))
                header += dual_header
                data += dual_data

            if args.eval and log4 is not None:
                eval_header = ["eval/FPS", "eval/duration"]
                eval_data = [log4["FPS"], log4["duration"]]

                eval_return_per_episode = data_synthesize(log4["return_per_episode"])
                eval_header += ["performance/avg_return_" + key for key in eval_return_per_episode.keys()]
                eval_data += eval_return_per_episode.values()
                eval_num_frames_per_episode = data_synthesize(log4["num_frames_per_episode"])
                eval_header += ["performance/num_frames_" + key for key in eval_num_frames_per_episode.keys()]
                eval_data += eval_num_frames_per_episode.values()

                eval_log_str = (
                    "[Eval] FPS {:4.0f} | Duration {} "
                    "| Rew:μσmM {:1.2f} {:1.2f} {:1.2f} {:1.2f} "
                    "| Num:μσmM {:2.1f} {:2.1f} {:2.1f} {:2.1f}"
                )
                txt_logger.info(eval_log_str.format(*eval_data))
                header += eval_header
                data += eval_data

            if status["num_frames"] == 0:
                csv_logger.writerow(header)
            csv_logger.writerow(data)
            csv_file.flush()

            for field, value in zip(header, data):
                tb_writer.add_scalar(field, value, num_frames)

        # Save status
        if args.save_interval > 0 and update % args.save_interval == 0:
            model_dir = save_model_status(
                {
                    "num_frames": num_frames,
                    "update": update,
                    "actor_model_state": actor_model.state_dict(),
                    "critic_model_state": critic_model.state_dict(),
                    "optimizer_state": algo.optimizer.state_dict(),
                },
                args.full_log_dir,
                num_frames,
            )
            if args.track:
                wandb.save(f"{model_dir}/agent.pt", base_path=f"{model_dir}", policy="now")
            txt_logger.info(f"Status saved to {model_dir}\n")

    # Save final status
    if args.save_interval > 0 and update % args.save_interval != 0:
        model_dir = save_model_status(
            {
                "num_frames": num_frames,
                "update": update,
                "actor_model_state": actor_model.state_dict(),
                "critic_model_state": critic_model.state_dict(),
                "optimizer_state": algo.optimizer.state_dict(),
            },
            args.full_log_dir,
            num_frames,
        )
        if args.track:
            wandb.save(f"{model_dir}/agent.pt", base_path=f"{model_dir}", policy="now")
        txt_logger.info(f"Status saved to {model_dir}\n")


if __name__ == "__main__":
    main()
