import glob
import numpy as np
import os

import torch
from torch.optim import Adam

from common.buffers import RolloutBuffer
from common.logger import Video
from common.utils import get_device, to_torch, to_np, preprocess
from models.policies import EntropyGaussianPolicy
from models.values import ValueNetwork


class PPO:
    def __init__(self, config, envs, eval_envs, logger):
        self.c = config
        self.envs = envs
        self.eval_envs = eval_envs
        self.logger = logger
        self.device = get_device()

        self.step = 0
        self.epoch = 0
        self.obs = envs.reset()

        self.build_models()

    def build_models(self):
        obs_shape = self.envs.observation_space.shape
        act_shape = self.envs.action_space.shape

        # Rollout buffer
        self.rollout_buffer = RolloutBuffer(
            self.c.rollout_length,
            self.envs.num_envs,
            obs_shape,
            act_shape,
            obs_type=np.uint8 if self.c.pixel_obs else np.float32,
        )

        # Policy
        self.policy = EntropyGaussianPolicy(
            obs_shape, act_shape, self.c.hidden_size
        ).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=self.c.policy_lr)

        # Value function
        self.value_function = ValueNetwork(obs_shape, self.c.hidden_size).to(
            self.device
        )
        self.value_function_optim = Adam(
            self.value_function.parameters(), lr=self.c.value_lr
        )

    def collect_rollouts(self, num_steps):
        self.policy.eval()
        self.value_function.eval()

        # Clear rollout buffer
        self.rollout_buffer.reset()

        steps_per_env = num_steps // self.envs.num_envs
        for _ in range(steps_per_env):
            # Select action
            with torch.no_grad():
                obs_tensor = to_torch(preprocess(self.obs))
                actions, log_probs, entropies = map(to_np, self.policy(obs_tensor))
                values = to_np(self.value_function(obs_tensor))

            # Take environment step
            next_obs, rewards, dones, infos = self.envs.step(actions)

            # Handle termination and truncation
            for i, done in enumerate(dones):
                if done:
                    # Record episode statistics
                    self.logger.record("train/return", infos[i]["episode_return"])
                    self.logger.record("train/success", infos[i]["episode_success"])
                    # Handle truncation by bootstraping from value function
                    if infos[i].get("TimeLimit.truncated", False):
                        term_obs = to_torch(preprocess(infos[i]["terminal_obs"][None]))
                        with torch.no_grad():
                            term_value = to_np(self.value_function(term_obs)[0])
                        rewards[i] += self.c.gamma * term_value

            # Add transition to buffer
            self.rollout_buffer.push(
                self.obs, actions, rewards, dones, values, log_probs, entropies
            )
            self.obs = next_obs
            self.step += self.envs.num_envs

        # Compute returns and advantages
        with torch.no_grad():
            next_obs_tensor = to_torch(preprocess(next_obs))
            last_values = to_np(self.value_function(next_obs_tensor))
        self.rollout_buffer.compute_returns_and_advantages(
            last_values, self.c.gamma, self.c.gae_lambda
        )

    def compute_policy_loss(self, obs, actions, old_ll, advantages):
        new_ll, entropies = self.policy.evaluate(obs, actions)
        # Compute surrogate objective
        lr = (new_ll - old_ll).exp()
        surrogate = lr * advantages
        # Compute clipped surrotate objective
        lr_clip = torch.clamp(lr, 1 - self.c.clip_range, 1 + self.c.clip_range)
        surrogate_clip = lr_clip * advantages
        # Take minimum of the two objectives
        objective = torch.min(surrogate, surrogate_clip)
        # Add entropy regularization
        objective += self.c.ent_coef * entropies
        policy_loss = -objective.mean()
        # Debugging statistics
        with torch.no_grad():
            approx_kl = ((lr - 1) - (new_ll - old_ll)).mean()
            clip_frac = ((lr - 1).abs() > self.c.clip_range).float().mean()
            entropy = entropies.mean()
        return policy_loss, approx_kl, clip_frac, entropy

    def compute_value_loss(self, obs, returns):
        value_preds = self.value_function(obs)
        value_loss = ((value_preds - returns) ** 2).mean()
        # Debugging statistics
        with torch.no_grad():
            explained_var = 1 - torch.var(returns - value_preds) / torch.var(returns)
        return value_loss, explained_var

    def update_parameters(self):
        self.policy.train()
        self.value_function.train()

        if not self.rollout_buffer.ready:
            self.rollout_buffer.prepare_rollouts()

        for _ in range(self.c.train_epochs):
            for batch in self.rollout_buffer.iterate(self.c.batch_size):
                obs = to_torch(preprocess(batch[0]))
                actions, log_probs, advantages, returns = map(to_torch, batch[1:])

                self.policy_optim.zero_grad()
                self.value_function_optim.zero_grad()

                policy_loss, approx_kl, clip_frac, entropy = self.compute_policy_loss(
                    obs, actions, log_probs, advantages
                )
                value_loss, explained_var = self.compute_value_loss(obs, returns)
                total_loss = policy_loss + value_loss
                total_loss.backward()

                self.policy_optim.step()
                self.value_function_optim.step()

                self.logger.record("train/policy_loss", policy_loss.item())
                self.logger.record("train/value_loss", value_loss.item())
                self.logger.record("train/approx_kl", approx_kl.item())
                self.logger.record("train/clip_frac", clip_frac.item())
                self.logger.record("train/entropy", entropy.item())
                self.logger.record("train/explained_var", explained_var.item())

    def train(self):
        self.load_checkpoint()

        while self.step < self.c.num_steps:
            self.collect_rollouts(self.c.rollout_length)
            self.update_parameters()

            if self.epoch % self.c.eval_freq == 0:
                self.evaluate()

            if self.epoch % self.c.checkpoint_freq == 0:
                self.save_checkpoint()

            self.logger.record("train/step", self.step)
            self.logger.dump(step=self.step)
            self.epoch += 1

    def evaluate(self):
        self.policy.eval()
        self.value_function.eval()

        ep_rewards = []
        ep_successes = []
        ep_videos = []
        for _ in range(self.c.num_eval_episodes // self.envs.num_envs):
            obs = self.eval_envs.reset()
            dones, infos, frames = False, [], []
            # Assume all eval_envs terminate at the same time
            while not np.all(dones):
                frames.append(self.eval_envs.render())
                with torch.no_grad():
                    obs_tensor = to_torch(preprocess(obs))
                    actions = to_np(self.policy(obs_tensor)[0])
                obs, _, dones, infos = self.eval_envs.step(actions)
            ep_rewards.extend([info["episode_return"] for info in infos])
            ep_successes.extend([info["episode_success"] for info in infos])
            ep_videos.append(np.stack(frames, 1).transpose(0, 1, 4, 2, 3))

        # Record episode statistics
        self.logger.record(f"test/return", np.mean(ep_rewards))
        self.logger.record(f"test/success", np.mean(ep_successes))

        # Record videos
        video = Video(np.concatenate(ep_videos, 0), fps=30)
        self.logger.record("test/video", video, exclude="stdout")

    def save_checkpoint(self):
        ckpt_path = os.path.join(self.logger.dir, f"models_{self.epoch}.pt")
        ckpt = {
            "step": self.step,
            "epoch": self.epoch,
            "policy": self.policy.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
            "value_function": self.value_function.state_dict(),
            "value_function_optim": self.value_function_optim.state_dict(),
        }
        torch.save(ckpt, ckpt_path)

    def load_checkpoint(self):
        # Load models from the latest checkpoint
        ckpt_paths = list(glob.glob(os.path.join(self.logger.dir, "models_*.pt")))
        if len(ckpt_paths) > 0:
            max_epoch = 0
            for path in ckpt_paths:
                epoch = path[path.rfind("/") + 8 : -3]
                if epoch.isdigit() and int(epoch) > max_epoch:
                    max_epoch = int(epoch)
            ckpt_path = os.path.join(self.logger.dir, f"models_{max_epoch}.pt")
            ckpt = torch.load(ckpt_path)
            print(f"Loaded checkpoint from {ckpt_path}")

            self.step = ckpt["step"]
            self.epoch = ckpt["epoch"]
            self.policy.load_state_dict(ckpt["policy"])
            self.policy_optim.load_state_dict(ckpt["policy_optim"])
            self.value_function.load_state_dict(ckpt["value_function"])
            self.value_function_optim.load_state_dict(ckpt["value_function_optim"])
