from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


class BetaStepTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by Q^pi
    """
    def __init__(
            self,
            env,
            behavior_policy,
            beta_prime_policy,
            qf,
            target_qf,

            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-4,
            qf_lr=1e-4,
            optimizer_class=optim.Adam,

            soft_target_tau=5e-3,
            target_update_period=2,

            std_scale = 1.0,
            use_automatic_entropy_tuning=False,

    ):
        super().__init__()
        self.env = env
        self.behavior_policy = behavior_policy
        self.beta_prime_policy = beta_prime_policy
        self.qf = qf
        self.target_qf = target_qf
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.beta_prime_policy.parameters(),
            lr=policy_lr,
        )
        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=qf_lr,
        )
        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            self.target_entropy = -np.prod(self.env.action_space.shape).item()
            self.log_alpha = ptu.zeros(1, requires_grad=True)
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=policy_lr,
            )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.discrete = False

        self.std_scale = std_scale

    def train_from_torch(self, batch):
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Entropy parameter tunning
        """

        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.beta_prime_policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1

        if self._n_train_steps_total < 20000:
            """
            Policy and Alpha Loss
            """
            new_obs_actions, _, _, log_pi, *_ = self.beta_prime_policy(
                obs, reparameterize=True, return_log_prob=True,
            )

            log_likelihood = self.behavior_policy.log_prob(obs, new_obs_actions)
            entropy = -log_pi

            policy_loss = (-log_likelihood - alpha * entropy).mean()

            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
        else:

            """
            QF Loss
            """
            q_pred = self.qf(obs, actions)

            next_actions, _, _, next_log_pi, *_ = self.beta_prime_policy(
                next_obs, reparameterize=False, return_log_prob=True,
            )

            target_q_values = self.target_qf(next_obs, next_actions)
            target_q_values = target_q_values - alpha * next_log_pi

            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
            qf_loss = self.qf_criterion(q_pred, q_target.detach())

            self.qf_optimizer.zero_grad()
            qf_loss.backward()
            self.qf_optimizer.step()

            """
            Soft Updates
            """
            if self._n_train_steps_total % self.target_update_period == 0:
                ptu.soft_update_from_to(
                    self.qf, self.target_qf, self.soft_target_tau
                )

        # """
        # Save some statistics for eval
        # """
        # if self._need_to_update_eval_statistics:
        #     self._need_to_update_eval_statistics = False
        #     """
        #     Eval should set this to None.
        #     This way, these statistics are only computed for one batch.
        #     """
        #     if self._n_train_steps_total < 20000:
        #         self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(policy_loss))
        #     else:
        #         self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
        #         self.eval_statistics.update(create_stats_ordered_dict(
        #             'Q Predictions',
        #             ptu.get_numpy(q_pred),
        #         ))
        #         self.eval_statistics.update(create_stats_ordered_dict(
        #             'Q Targets',
        #             ptu.get_numpy(q_target),
        #         ))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.beta_prime_policy,
            self.qf,
            self.target_qf,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.beta_prime_policy,
            qf=self.qf,
            target_qf=self.target_qf,
        )

    def set_snapshot(self, snapshot):
        self.beta_prime_policy = snapshot['policy']
        self.qf = snapshot['qf']
        self.target_qf = snapshot['target_qf']