import copy
import glob
import numpy as np
import os

import torch
import torch.nn.functional as F
from torch.optim import Adam

from algorithms.sac import SAC
from common.buffers import MultitaskReplayBuffer
from common.utils import (
    soft_update,
    to_torch,
    to_np,
    preprocess,
    FreezeParameters,
)
from models.cnns import MultiheadedCNN
from models.policies import TanhGaussianPolicy
from models.values import QNetwork


class MultitaskSAC(SAC):
    def __init__(self, config, env, eval_env, logger):
        super().__init__(config, env, logger)
        self.eval_env = eval_env

    def build_models(self):
        obs_shape = self.env.observation_space.shape
        act_shape = self.env.action_space.shape
        num_tasks = self.env.num_tasks

        # Replay buffer
        self.buffer = MultitaskReplayBuffer(
            self.c.replay_size,
            num_tasks,
            obs_shape,
            act_shape,
            obs_type=np.uint8 if self.c.pixel_obs else np.float32,
        )

        # Encoder
        if self.c.pixel_obs:
            self.encoder = MultiheadedCNN(obs_shape[2], self.env.num_tasks, self.c.repr_size)
            self.encoder_optim = Adam(self.encoder.parameters(), lr=self.c.lr)
            obs_shape = (self.c.repr_size + num_tasks,)
        else:
            obs_shape = (obs_shape[0] + num_tasks,)

        # Policy
        self.policy = TanhGaussianPolicy(
            obs_shape,
            act_shape,
            self.c.hidden_size,
            self.env.action_space,
        ).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=self.c.lr)

        # Critic
        self.critic = QNetwork(
            obs_shape, act_shape, self.c.hidden_size
        ).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=self.c.lr)
        self.critic_target = copy.deepcopy(self.critic)

        # Temperature
        if self.c.automatic_entropy_tuning:
            if self.c.target_entropy == "auto":
                # Target entropy is −dim(A) as given in the paper
                self.target_entropy = -torch.prod(torch.tensor(act_shape)).item()
            else:
                self.target_entropy = float(self.c.target_entropy)
            self.log_alpha = torch.zeros((num_tasks,), requires_grad=True).to(
                self.device
            )
            self.alpha_optim = Adam([self.log_alpha], lr=self.c.lr)
        else:
            self.log_alpha = torch.full((num_tasks,), self.c.alpha).log().to(self.device)
    
    def get_log_alpha(self, task_one_hot):
        return torch.mm(task_one_hot, self.log_alpha.unsqueeze(1))

    def select_action(self, obs, task_one_hot, evaluate=False):
        obs = to_torch(preprocess(obs[None]))
        task_one_hot = to_torch(task_one_hot[None])
        if self.c.pixel_obs:
            obs = self.encoder(obs, torch.argmax(task_one_hot, 1))
        obs = torch.cat((obs, task_one_hot), 1)
        action, _ = self.policy(obs, deterministic=evaluate)
        return to_np(action)[0]

    def update_parameters(self, task, obs, act, rew, next_obs, done, updates):
        obs = to_torch(preprocess(obs))
        next_obs = to_torch(preprocess(next_obs))
        task, act, rew, done = map(to_torch, [task, act, rew, done])
        # Get task-specific log alpha
        log_alpha = self.get_log_alpha(task)
        alpha = log_alpha.exp().detach()

        if self.c.pixel_obs:
            obs = self.encoder(obs, torch.argmax(task, 1))
            next_obs = self.encoder(next_obs, torch.argmax(task, 1))
        obs = torch.cat((obs, task), 1)
        next_obs = torch.cat((next_obs, task), 1)

        # Compute Q target
        with torch.no_grad():
            next_act, next_logp = self.policy(next_obs)
            next_q1_target, next_q2_target = self.critic_target(next_obs, next_act)
            min_next_q_target = torch.min(next_q1_target, next_q2_target)
            q_target = rew + (1 - done) * self.c.gamma * (
                min_next_q_target - alpha * next_logp
            )

        # Compute Q loss
        q1, q2 = self.critic(obs, act)
        q1_loss = F.mse_loss(q1, q_target)
        q2_loss = F.mse_loss(q2, q_target)
        q_loss = q1_loss + q2_loss

        # Compute policy loss
        new_act, new_logp = self.policy(obs)
        with FreezeParameters(list(self.critic.parameters())):
            new_q1, new_q2 = self.critic(obs, new_act)
        min_new_q = torch.min(new_q1, new_q2)
        policy_loss = ((alpha * new_logp) - min_new_q).mean()

        self.policy_optim.zero_grad()
        self.critic_optim.zero_grad()
        if self.c.pixel_obs:
            self.encoder_optim.zero_grad()
        total_loss = q_loss + policy_loss
        total_loss.backward()
        self.policy_optim.step()
        self.critic_optim.step()
        if self.c.pixel_obs:
            self.encoder_optim.step()

        # Update alpha with dual descent
        if self.c.automatic_entropy_tuning:
            alpha_loss = -(log_alpha * (new_logp + self.target_entropy).detach()).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        else:
            alpha_loss = torch.tensor(0).to(self.device)

        # Update target critic
        if updates % self.c.target_update_freq == 0:
            soft_update(self.critic_target, self.critic, self.c.tau)

        return (
            q1_loss.item(),
            q2_loss.item(),
            policy_loss.item(),
            alpha_loss.item(),
        )
    
    def train(self):
        while self.step < self.c.num_steps:
            obs = self.env.reset()
            task = self.env.task
            task_one_hot = self.env.task_one_hot
            done = False
            episode_reward = 0
            episode_success = 0
            while not done:
                # Train agent
                if len(self.buffer) > self.c.batch_size:
                    # Number of updates per environment step
                    for i in range(self.c.updates_per_step):
                        # Update parameters of all the networks
                        batch = self.buffer.sample(self.c.batch_size)
                        (
                            critic_1_loss,
                            critic_2_loss,
                            policy_loss,
                            entropy_loss,
                        ) = self.update_parameters(*batch, self.updates)
                        self.logger.record("train/critic_1_loss", critic_1_loss)
                        self.logger.record("train/critic_2_loss", critic_2_loss)
                        self.logger.record("train/policy_loss", policy_loss)
                        self.logger.record("train/entropy_loss", entropy_loss)
                        self.logger.record("train/alpha", self.log_alpha.exp().mean().item())
                        self.updates += 1

                # Take environment step
                if self.step < self.c.start_step:
                    action = self.env.action_space.sample()
                else:
                    with torch.no_grad():
                        action = self.select_action(obs, task_one_hot)
                next_obs, reward, done, info = self.env.step(action)
                episode_reward += reward
                episode_success += info.get("success", 0)
                self.step += 1

                # Ignore done if it comes from truncation
                real_done = 0 if info.get("TimeLimit.truncated", False) else float(done)
                self.buffer.push(task_one_hot, obs, action, reward, next_obs, real_done)
                obs = next_obs

            if self.episode % self.c.eval_freq == 0:
                self.evaluate()

            if self.episode % self.c.checkpoint_freq == 0:
                self.save_checkpoint()

            self.logger.record(f"train_multitask/return_{task}", episode_reward)
            self.logger.record(f"train_multitask/success_{task}", float(episode_success > 0))
            self.logger.record("train/step", self.step)
            self.logger.dump(step=self.step)
            self.episode += 1

    def evaluate(self):
        for _ in range(self.eval_env.num_tasks):
            task =self.eval_env.sample_task(round_robin=True)
            obs = self.eval_env.reset(task=task)
            task_one_hot = self.eval_env.task_one_hot
            done = False
            episode_reward = 0
            episode_success = 0
            while not done:
                with torch.no_grad():
                    action = self.select_action(obs, task_one_hot, evaluate=True)
                next_obs, reward, done, info = self.eval_env.step(action)
                episode_reward += reward
                episode_success += info.get("success", 0)
                obs = next_obs
            self.logger.record(f"test_multitask/return_{task}", episode_reward)
            self.logger.record(f"test_multitask/success_{task}", float(episode_success > 0))
    

    def save_checkpoint(self):
        ckpt_path = os.path.join(self.logger.dir, f"models_{self.episode}.pt")
        ckpt = {
            "step": self.step,
            "episode": self.episode,
            "updates": self.updates,
            "policy": self.policy.state_dict(),
            "policy_optim": self.policy_optim.state_dict(),
            "critic": self.critic.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "log_alpha": self.log_alpha,
        }
        if self.c.automatic_entropy_tuning:
            ckpt["alpha_optim"] = self.alpha_optim.state_dict()
        if self.c.pixel_obs:
            ckpt["encoder"] = self.encoder.state_dict()
            ckpt["encoder_optim"] = self.encoder_optim.state_dict()
        torch.save(ckpt, ckpt_path)

    def load_checkpoint(self):
        # Load models from the latest checkpoint
        ckpt_paths = list(glob.glob(os.path.join(self.logger.dir, "models_*.pt")))
        if len(ckpt_paths) > 0:
            max_episode = 0
            for path in ckpt_paths:
                episode = path[path.rfind("/") + 8 : -3]
                if episode.isdigit() and int(episode) > max_episode:
                    max_episode = int(episode)
            ckpt_path = os.path.join(self.logger.dir, f"models_{max_episode}.pt")
            ckpt = torch.load(ckpt_path)
            print(f"Loaded checkpoint from {ckpt_path}")

            self.step = ckpt["step"]
            self.episode = ckpt["episode"]
            self.updates = ckpt["updates"]
            self.policy.load_state_dict(ckpt["policy"])
            self.policy_optim.load_state_dict(ckpt["policy_optim"])
            self.value_function.load_state_dict(ckpt["value_function"])
            self.value_function_optim.load_state_dict(ckpt["value_function_optim"])
            self.log_alpha = ckpt["log_alpha"]
            if self.c.automatic_entropy_tuning:
                self.alpha_optim = Adam([self.log_alpha], lr=self.c.lr)
                self.alpha_optim.load_state_dict(ckpt["alpha_optim"])
            if self.c.pixel_obs:
                self.encoder.load_state_dict(ckpt["encoder"])
                self.encoder_optim.load_state_dict(ckpt["encoder_optim"])
            



        