# reference: https://github.com/deepmind/constrained_optidice.git
from copy import deepcopy
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from fsrl.utils import DummyLogger, WandbLogger
from torch import distributions as pyd
from torch.distributions.beta import Beta
from torch.nn import functional as F  # noqa
from tqdm.auto import trange  # noqa

from osrl.common.net import SquashedGaussianMLPActor
from osrl.common.net import (NonnegativeEnsembleDoubleQCritic, EnsembleDoubleQCritic)

class WSAC(nn.Module):
    """
    Offline Constrained Policy Optimization 
    via stationary DIstribution Correction Estimation (COptiDICE)
    
    Args:
        state_dim (int): dimension of the state space.
        action_dim (int): dimension of the action space.
        max_action (float): Maximum action value.
        a_hidden_sizes (list): List of integers specifying the sizes 
                               of the layers in the actor network.
        c_hidden_sizes (list): List of integers specifying the sizes 
                               of the layers in the critic network (nu and chi networks).
        gamma (float): Discount factor for the reward.
        alpha (float): The coefficient for the cost term in the loss function.
        cost_ub_epsilon (float): A small value added to the upper bound on the cost term.
        num_nu (int): The number of critics to use for the nu-network.
        num_chi (int): The number of critics to use for the chi-network.
        cost_limit (int): Upper limit on the cost per episode.
        episode_len (int): Maximum length of an episode.
        beta_r: beta for reward critic NN
        beta_c: beta for cost critic NN
        lambda_:initial lambda
        lambda_max: the upperbound of lambda
        device (str): Device to run the model on (e.g. 'cpu' or 'cuda:0'). 
    """

    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 max_action: float,
                 a_hidden_sizes: list = [128, 128],
                 c_hidden_sizes: list = [128, 128],
                 gamma: float = 0.99,
                 alpha: float = 0.5,
                 num_q: int = 1,
                 num_qc: int = 1,
                 qc_ub: float = 30,
                 num_nu: int = 1,
                 num_chi: int = 1,
                 cost_limit: int = 10,
                 episode_len: int = 300,
                 tau: float = 0.005,
                 beta_r: float = 1,
                 beta_c: float = 1,
                 lambda_: float = 20,
                 lambda_max: float = 1000,
                 act_times: int = 10,
                 device: str = "cpu"):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.max_action = max_action
        self.a_hidden_sizes = a_hidden_sizes
        self.c_hidden_sizes = c_hidden_sizes
        self.gamma = gamma
        self.alpha = alpha
        self.qc_ub = qc_ub
        self.num_nu = num_nu
        self.num_chi = num_chi
        self.cost_limit = cost_limit
        self.episode_len = episode_len
        self.device = device
        self.beta_r = beta_r
        self.beta_c = beta_c
        self.num_q = num_q
        self.num_qc = num_qc
        self.lambda_ = lambda_
        self.lambda__ = self.lambda_
        self.lambda_max = lambda_max
        self.act_imes = act_times

        self.tau = torch.ones(1, requires_grad=True, device=self.device)
        self.actor = SquashedGaussianMLPActor(self.state_dim, self.action_dim,
                                              self.a_hidden_sizes,
                                              nn.ReLU).to(self.device)

        self.critic = EnsembleDoubleQCritic(self.state_dim,
                                            self.action_dim,
                                            self.c_hidden_sizes,
                                            nn.ReLU,
                                            num_q=self.num_q).to(self.device)
        self.cost_critic = NonnegativeEnsembleDoubleQCritic(self.state_dim,
                                                 self.action_dim,
                                                 self.c_hidden_sizes,
                                                 nn.ReLU,
                                                 num_q=self.num_qc).to(self.device)

        self.actor_old = deepcopy(self.actor)
        self.actor_old.eval()
        self.critic_old = deepcopy(self.critic)
        self.critic_old.eval()
        self.cost_critic_old = deepcopy(self.cost_critic)
        self.cost_critic_old.eval()


    def normalized_sum(self, loss_abs, loss, ret, w):
        return loss / w + ret if w > 1.0 else loss + w * ret

    def _soft_update(self, tgt: nn.Module, src: nn.Module, tau: float) -> None:
        """
        Softly update the parameters of target module
        towards the parameters of source module.
        """
        for tgt_param, src_param in zip(tgt.parameters(), src.parameters()):
            tgt_param.data.copy_(tau * src_param.data + (1 - tau) * tgt_param.data)

    def critic_loss(self, observations, next_observations, actions, rewards, done):
        qq1, qq2, q1_list, q2_list = self.critic.predict(observations, actions)
        with torch.no_grad():
            batch_size = next_observations.shape[0]
            act_targ_next, act_log = self.actor(next_observations)
            q1_targ, q2_targ, _, _ = self.critic_old.predict(next_observations, act_targ_next)
            q_targ = 0.75 * torch.min(q1_targ, q2_targ)  + 0.25 * torch.max(q1_targ,q2_targ)
            q_targ = q_targ.reshape(batch_size, -1).max(1)[0]
            backup = rewards + self.gamma * (1 - done) * q_targ

        loss_critic_ =  (self.critic.loss(backup, q1_list) + self.critic.loss(backup, q2_list))
        act_targ_c, _ = self.actor(observations)
        qf1_new,_,_,_ = self.critic.predict(observations,act_targ_c.detach())
        qf1_pre,_,_,_ = self.critic.predict(observations,actions)

        loss1 = (qf1_new - qf1_pre).mean()
        loss1_abs = (qf1_new - qf1_pre).abs().mean().detach()
        """
        may need to change way to normalize the cost
        """
        loss_critic = self.normalized_sum(loss1_abs, loss1, loss_critic_, self.beta_r)
        self.critic_optim.zero_grad()
        loss_critic.backward()
        self.critic_optim.step()
        stats_critic = {"loss/critic_loss": loss_critic.item(),
                        "loss/critic_value": qf1_new.mean().item()
        }
        return loss_critic, stats_critic

    def cost_critic_loss(self, observations, next_observations, actions, costs, done):
        _, _, q1_list, q2_list = self.cost_critic.predict(observations, actions)
        with torch.no_grad():
            batch_size = next_observations.shape[0]
            act_targ_next, act_log = self.actor(next_observations)
            q1_targ, q2_targ,_,_ = self.cost_critic_old.predict(next_observations, act_targ_next)
            q_targ = 0.75 * torch.min(q1_targ, q2_targ)  + 0.25 * torch.max(q1_targ,q2_targ)

            q_targ = q_targ.reshape(batch_size, -1).max(1)[0]
            backup = costs + self.gamma * q_targ

        loss_cost_critic_ = self.cost_critic.loss(backup, q1_list) + self.cost_critic.loss(backup, q2_list)
        act_targ_c, _ = self.actor(observations)
        qf1_new, _,_,_ = self.cost_critic.predict(observations,act_targ_c.detach())
        qf1_pre, _,_,_ = self.cost_critic.predict(observations,actions)
        loss1 = (qf1_new - qf1_pre).mean()
        loss1_abs = (qf1_new - qf1_pre).abs().mean().detach()

        loss_cost_critic = self.normalized_sum(loss1_abs, -loss1, loss_cost_critic_, self.beta_c)
        self.cost_critic_optim.zero_grad()
        loss_cost_critic.backward()
        self.cost_critic_optim.step()
        stats_cost_critic = {"loss/cost_critic_loss": loss_cost_critic.item(),
                            "loss/cost_critic_value": qf1_new.mean().item()
        
        }
        return loss_cost_critic, stats_cost_critic

    def actor_loss(self, observations, violate):
        for p in self.critic.parameters():
            p.requires_grad = False
        for p in self.cost_critic.parameters():
            p.requires_grad = False


        actions, log_actions = self.actor(observations)
        q1_pi,q2_pi,_,_ = self.critic.predict(observations, actions)  # [batch_size]
        qc1_pi,qc2_pi,_,_ = self.cost_critic.predict(observations, actions)
        q_pi = torch.min(q1_pi, q2_pi)
        qc_pi = torch.min(qc1_pi, qc2_pi)

        loss_actor_ = q_pi - self.lambda__ * torch.clamp(qc_pi + self.qc_ub - self.cost_limit, min = 0)  
        loss_actor = - loss_actor_.mean()
        self.actor_optim.zero_grad()
        loss_actor.backward()
        self.actor_optim.step()
        stats_actor = {
            "loss/actor_loss": loss_actor.item(),
            "loss/lambda__": self.lambda__
        }
        for p in self.critic.parameters():
            p.requires_grad = True
        for p in self.cost_critic.parameters():
            p.requires_grad = True
        self.lambda__ = max(1.0, min(0.95 * self.lambda__ + 0.05 * (qc_pi.detach().mean().item() + self.qc_ub  - self.cost_limit), 20))

        return loss_actor, stats_actor

    def setup_optimizers(self, actor_lr, critic_lr, scalar_lr):
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.tau_optim = torch.optim.Adam([self.tau], lr=scalar_lr)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.cost_critic_optim = torch.optim.Adam(self.cost_critic.parameters(),
                                                  lr=critic_lr)

    def sync_weight(self):
        """
        Soft-update the weight for the target network.
        """
        self._soft_update(self.critic_old, self.critic, self.tau)
        self._soft_update(self.cost_critic_old, self.cost_critic, self.tau)
        self._soft_update(self.actor_old, self.actor, self.tau)

    def act(self,
            obs: np.ndarray,
            deterministic: bool = False,
            with_logprob: bool = False):
        """
        Given a single obs, return the action, logp.
        """
        obs = torch.tensor(obs[None, ...], dtype=torch.float32).to(self.device)
        a, logp_a = self.actor.forward(obs, deterministic, with_logprob)
        a = a.data.numpy() if self.device == "cpu" else a.data.cpu().numpy()
        logp_a = logp_a.data.numpy() if self.device == "cpu" else logp_a.data.cpu(
        ).numpy()
        return np.squeeze(a, axis=0), np.squeeze(logp_a)


class WSACTrainer:
    """
    COptiDICE trainer
    
    Args:
        model (COptiDICE): The COptiDICE model to train.
        env (gym.Env): The OpenAI Gym environment to train the model in.
        logger (WandbLogger or DummyLogger): The logger to use for tracking training progress.
        actor_lr (float): learning rate for actor
        critic_lr (float): learning rate for critic (nu and chi networks)
        scalar_lr (float, optional): The learning rate for the scalar (tau, lmbda).
        reward_scale (float): The scaling factor for the reward signal.
        cost_scale (float): The scaling factor for the constraint cost.
        device (str): The device to use for training (e.g. "cpu" or "cuda").
    """

    def __init__(self,
                 model: WSAC,
                 env: gym.Env,
                 logger: WandbLogger = DummyLogger(),
                 actor_lr: float = 1e-3,
                 critic_lr: float = 1e-3,
                 scalar_lr: float = 1e-3,
                 reward_scale: float = 1.0,
                 cost_scale: float = 1.0,
                 device="cpu"):
        self.model = model
        self.logger = logger
        self.env = env
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.device = device
        self.model.setup_optimizers(actor_lr, critic_lr, scalar_lr)


    def train_one_step(self, observations, next_observations, actions, rewards, costs,
                       done, act_times, violate):
        """
        Trains the model by updating the VAE, critic, cost critic, and actor.
        """
        for i in range(20):
            # update critic
            loss_critic, stats_critic = self.model.critic_loss(observations,
                                                           next_observations, actions,
                                                           rewards, done)
            # update cost critic
            loss_cost_critic, stats_cost_critic = self.model.cost_critic_loss(
                observations, next_observations, actions, costs, done)
        # update actor
        #for i in range(act_times):
        for i in range(1):
            loss_actor, stats_actor = self.model.actor_loss(observations,violate)
        self.model.sync_weight()
        self.logger.store(**stats_critic)
        self.logger.store(**stats_cost_critic)
        self.logger.store(**stats_actor)

    def evaluate(self, eval_episodes):
        """
        Evaluates the performance of the model on a number of episodes.
        """
        self.model.eval()
        episode_rets, episode_costs, episode_lens = [], [], []
        for _ in trange(eval_episodes, desc="Evaluating...", leave=False):
            epi_ret, epi_len, epi_cost = self.rollout()
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
        self.model.train()
        return np.mean(episode_rets) / self.reward_scale, np.mean(
            episode_costs) / self.cost_scale, np.mean(episode_lens)

    @torch.no_grad()
    def rollout(self):
        """
        Evaluates the performance of the model on a single episode.
        """
        obs, info = self.env.reset()
        episode_ret, episode_cost, episode_len = 0.0, 0.0, 0
        for _ in range(self.model.episode_len):
            act, _ = self.model.act(obs, True, True)
            obs_next, reward, terminated, truncated, info = self.env.step(act)
            cost = info["cost"] * self.cost_scale
            obs = obs_next
            episode_ret += reward
            episode_len += 1
            episode_cost += cost
            if terminated or truncated:
                break
        return episode_ret, episode_len, episode_cost
