import argparse
import glob
import numpy as np
import os
import torch

from torch.distributions import Normal
from torch.distributions.kl import kl_divergence

from algorithms.dreamer import Dreamer, MultitaskDreamer
from algorithms.dreamer.models.utils import bottle
from common.utils import AttrDict, arg_type, set_gpu_mode, to_torch, preprocess
from common.logger import configure_logger
from environments import make_env


def get_config():
    config = AttrDict()
    config.algo = "dreamer"
    config.env_id = "franka-push-red"
    config.expr_name = "default"
    config.seed = 0
    config.use_gpu = True
    config.pixel_obs = True
    config.ckpt_id = -1
    config.task_id = -1
    config.buffer_path = "data"
    config.truncate_size = 1000000

    config.replay_size = 100
    config.embedding_size = 1024
    config.hidden_size = 200
    config.belief_size = 200
    config.state_size = 30
    config.dense_activation_function = "elu"
    config.cnn_activation_function = "relu"
    config.batch_size = 50
    config.chunk_size = 50
    config.free_nats = 3
    config.gamma = 0.99
    config.gae_lambda = 0.95
    config.horizon = 15
    config.model_lr = 1e-3
    config.actor_lr = 8e-5
    config.value_lr = 8e-5

    parser = argparse.ArgumentParser()
    for key, value in config.items():
        parser.add_argument(f"--{key}", type=arg_type(value), default=value)
    config = parser.parse_args()
    return config


def load_state_dict(agent, ckpt):
    agent.encoder.load_state_dict(ckpt["encoder"])
    agent.transition_model.load_state_dict(ckpt["transition_model"])
    agent.obs_model.load_state_dict(ckpt["obs_model"])
    agent.reward_model.load_state_dict(ckpt["reward_model"])
    agent.actor_model.load_state_dict(ckpt["actor_model"])
    agent.value_model.load_state_dict(ckpt["value_model"])


def compute_dynamics_loss(agent, obs, actions, rewards, nonterms, tasks=None):
    with torch.no_grad():
        batch_size = obs.shape[1]
        init_belief = torch.zeros(batch_size, agent.c.belief_size).to(agent.device)
        init_state = torch.zeros(batch_size, agent.c.state_size).to(agent.device)
        (
            beliefs,
            prior_states,
            prior_means,
            prior_std_devs,
            posterior_states,
            posterior_means,
            posterior_std_devs,
        ) = agent.transition_model.observe(
            init_belief,
            init_state,
            actions[:-1],
            bottle(agent.encoder, (obs[1:],)),
            nonterms[:-1],
        )

        # Reconstruction loss
        obs_dist = Normal(bottle(agent.obs_model, (beliefs, posterior_states)), 1)
        obs_loss = (
            -obs_dist.log_prob(obs[1:])
            .sum((2, 3, 4) if agent.c.pixel_obs else 2)
            .mean((0, 1))
        )

        # Reward loss
        if tasks is None:
            reward_dist = Normal(
                bottle(agent.reward_model, (beliefs, posterior_states)), 1
            )
        else:
            task_beliefs = torch.cat((beliefs, tasks[1:]), -1)
            reward_dist = Normal(
                bottle(agent.reward_model, (task_beliefs, posterior_states)), 1
            )
        reward_loss = -reward_dist.log_prob(rewards[:-1].squeeze(-1)).mean((0, 1))

        # KL loss
        kl_div = kl_divergence(
            Normal(posterior_means, posterior_std_devs),
            Normal(prior_means, prior_std_devs),
        ).sum(2)
        kl_loss = torch.max(kl_div, agent.free_nats).mean((0, 1))
    return obs_loss.item(), reward_loss.item(), kl_loss.item()


if __name__ == "__main__":
    config = get_config()
    set_gpu_mode(config.use_gpu)

    # Configure logger
    logdir = os.path.join(
        "logdir", config.algo, config.env_id, config.expr_name, str(config.seed)
    )
    logger = configure_logger(logdir, ["stdout", "tensorboard", "wandb"])

    # Environment
    env = make_env(config.env_id, config.seed, config.pixel_obs)

    # Agent
    if config.algo == "dreamer":
        agent = Dreamer(config, env, logger)
    elif config.algo == "dreamer_multitask":
        agent = MultitaskDreamer(config, env, logger)

    # Load checkpoint
    if config.ckpt_id == -1:
        ckpt_paths = list(glob.glob(os.path.join(logger.dir, "models_*.pt")))
        max_episode = 0
        for path in ckpt_paths:
            episode = path[path.rfind("/") + 8 : -3]
            if episode.isdigit() and int(episode) > max_episode:
                max_episode = int(episode)
        ckpt_path = os.path.join(logger.dir, f"models_{max_episode}.pt")
    else:
        ckpt_path = os.path.join(logger.dir, f"models_{config.ckpt_id}.pt")
    ckpt = torch.load(ckpt_path)
    load_state_dict(agent, ckpt)
    print(f"Loaded checkpoint from {ckpt_path}")

    # Load offline buffer
    with np.load(config.buffer_path) as buffer:
        data_keys = ["observations", "actions", "rewards", "dones"]
        data = {k: buffer[k] for k in data_keys}
        pos, full = buffer["pos"], buffer["full"]
        if full:
            # Unroll data
            data = {k: np.concatenate((v[pos:], v[:pos])) for k, v in data.items()}
        else:
            # Remove empty space
            data = {k: v[:pos] for k, v in data.items()}
        # Truncate buffer
        size = min(len(data["observations"]), config.truncate_size)
        data = {k: v[:size] for k, v in data.items()}
        # Terminate at the end of buffer
        data["dones"][-1, :] = 1
        if config.algo == "dreamer_multitask":
            # Create one-hot task vectors
            tasks = np.zeros((size, env.num_tasks), dtype=np.float32)
            tasks[:, config.task_id] = 1
            data["tasks"] = tasks
    for k, v in data.items():
        setattr(agent.buffer, k, v)
    agent.buffer.capacity = len(agent.buffer.observations)
    agent.buffer.pos = 0
    agent.buffer.full = True
    print(f"Loaded buffer from {config.buffer_path}")

    # Evaluate model error
    obs_losses, reward_losses, kl_losses = [], [], []
    if config.algo == "dreamer":
        for obs, actions, rewards, dones in agent.buffer.iterate(
            config.batch_size, config.chunk_size
        ):
            obs = to_torch(preprocess(obs, config.pixel_obs))
            actions = to_torch(actions)
            rewards = to_torch(rewards)
            nonterms = to_torch(1 - dones)
            obs_loss, reward_loss, kl_loss = compute_dynamics_loss(
                agent, obs, actions, rewards, nonterms
            )
            obs_losses.append(obs_loss)
            reward_losses.append(reward_loss)
            kl_losses.append(kl_loss)
    else:
        for tasks, obs, actions, rewards, dones in agent.buffer.iterate(
            config.batch_size, config.chunk_size
        ):
            tasks = to_torch(tasks)
            obs = to_torch(preprocess(obs, config.pixel_obs))
            actions = to_torch(actions)
            rewards = to_torch(rewards)
            nonterms = to_torch(1 - dones)
            obs_loss, reward_loss, kl_loss = compute_dynamics_loss(
                agent, obs, actions, rewards, nonterms, tasks
            )
            obs_losses.append(obs_loss)
            reward_losses.append(reward_loss)
            kl_losses.append(kl_loss)
    print(f"Average reconstruction loss: {np.mean(obs_losses)}")
    print(f"Average reward loss: {np.mean(reward_losses)}")
    print(f"Average kl loss: {np.mean(kl_losses)}")
