from collections import defaultdict
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from algorithm.modules import gae_trace
from algorithm.trainer import SampleBatch, feed_forward_generator, recurrent_generator
from utils.namedarray import recursive_apply


def get_gard_norm(it):
    sum_grad = 0
    for x in it:
        if x.grad is None:
            continue
        sum_grad += x.grad.norm()**2
    return math.sqrt(sum_grad)


def huber_loss(e, d):
    a = (abs(e) <= d).float()
    b = (e > d).float()
    return a * e**2 / 2 + b * d * (abs(e) - d / 2)


def mse_loss(e):
    return e**2 / 2


def gaussian_kl(m1, std1, m2, std2):
    """Compute KL divergence between two isotropic Gaussians.

    \sum [log (\sigma2i) - log (\sigma1i) - 1 / 2
        + \sigma1i^2 / (2 * \sigma2i^2)
        + (\mu_2i-\mu_1i)^2 / (2 * \sigma2i^2)]

    Args:
        m1 (torch.Tensor): mean 1, shape [*, d]
        std1 (torch.Tensor): standard deviation 1, shape [*, d]
        m2 (torch.Tensor): mean 2, shape [*, d]
        std2 (torch.Tensor): standard deviation 2, shape [*, d]

    Returns:
        torch.Tensor: shape [*, 1]
    """
    part1 = std2.log() - std1.log() - 1 / 2
    part2 = std1.pow(2) / (std2.pow(2) * 2)
    part3 = (m2 - m1).pow(2) / (2 * std2.pow(2))
    return (part1 + part2 + part3).sum(-1, keepdim=True)


def rbf_kernel(x, y):
    return (-(x - y).square().mean()).exp()


def getcofactor(m, i, j):
    return [row[:j] + row[j + 1:] for row in (m[:i] + m[i + 1:])]


def determinantOfMatrix(mat):
    # if given matrix is of order
    # 2*2 then simply return det
    # value by cross multiplying
    # elements of matrix.
    if (len(mat) == 2):
        value = mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1]
        return value
    # initialize Sum to zero
    Sum = 0
    # loop to traverse each column
    # of matrix a.
    for current_column in range(len(mat)):
        # calculating the sign corresponding
        # to co-factor of that sub matrix.
        sign = (-1)**(current_column)
        # calling the function recursily to
        # get determinant value of
        # sub matrix obtained.
        sub_det = determinantOfMatrix(getcofactor(mat, 0, current_column))
        # adding the calculated determinant
        # value of particular column
        # matrix to total Sum.
        Sum += (sign * mat[0][current_column] * sub_det)
    # returning the final Sum
    return Sum


class MAPPO:

    def __init__(self, rank, args, policy):
        self.rank = rank
        self.args = args

        self.policy = policy

        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm
        self.huber_delta = args.huber_delta

        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks

        self.actor_optimizer = torch.optim.Adam(
            self.policy.actor.parameters(),
            lr=args.lr,
            eps=args.opti_eps,
            weight_decay=args.weight_decay,
        )
        self.critic_optimizer = torch.optim.Adam(
            self.policy.critic.parameters(),
            lr=args.critic_lr,
            eps=args.opti_eps,
            weight_decay=args.weight_decay,
        )

    def cal_value_loss(self, values, value_preds_batch, return_batch,
                       active_masks_batch):
        if self.policy.popart_head is not None:
            self.policy.update_popart(return_batch)
            return_batch = self.policy.normalize_value(return_batch)

        loss_fn = lambda x: huber_loss(x, self.huber_delta
                                       ) if self._use_huber_loss else mse_loss

        error_original = return_batch - values
        value_loss_original = loss_fn(error_original)

        if value_preds_batch is not None and self._use_clipped_value_loss:
            value_pred_clipped = value_preds_batch + (
                values - value_preds_batch).clamp(-self.clip_param,
                                                  self.clip_param)
            error_clipped = return_batch - value_pred_clipped
            value_loss_clipped = loss_fn(error_clipped)
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss *
                          active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def ppo_update(self,
                   sample: SampleBatch,
                   buffer,
                   dvd_coef,
                   update_actor=True):

        buffer.put_batch(self.rank, sample)
        dist.barrier()

        # Reshape to do in a single forward pass for all steps
        action_log_probs, values, dist_entropy, actor_output = self.policy.analyze(
            sample)

        dvd_loss = 0
        pd = None
        # dvd_coef = 0.5
        #######################################################################
        ############################ dvd algorithm ############################
        if dvd_coef > 0:
            _, _, _, all_action_embedding = self.policy.analyze(
                recursive_apply(buffer.storage,
                                lambda x: x.to(self.policy.device)))

            tensor_list = [
                torch.zeros_like(all_action_embedding)
                for _ in range(self.args.pbt_size)
            ]
            dist.all_gather(tensor_list, all_action_embedding)
            for i, tensor in enumerate(tensor_list):
                if i == self.rank:
                    tensor_list[i] = actor_output
                    # assert torch.all(tensor_list[i] == all_action_embedding)
                    assert tensor_list[i].requires_grad == True
                else:
                    if self.policy.num_rnn_layers > 0:
                        bs = sample.masks.shape[1]
                        tensor_list[i] = tensor[:, bs * self.rank:bs *
                                                (self.rank + 1)].detach()
                    else:
                        bs = sample.masks.shape[0]
                        tensor_list[i] = tensor[bs * self.rank:bs *
                                                (self.rank + 1)].detach()
                    # assert tensor_list[i].requires_grad == False

            rbf_matrix = [[0 for _ in range(len(tensor_list))]
                          for _ in range(len(tensor_list))]
            matrix = torch.zeros((len(tensor_list), len(tensor_list)),
                                 dtype=torch.float32)
            for i in range(len(tensor_list)):
                for j in range(i, len(tensor_list)):
                    d = rbf_kernel(tensor_list[i], tensor_list[j])
                    rbf_matrix[i][j] = rbf_matrix[j][i] = d
                    matrix[i, j] = matrix[j, i] = d.item()
            pd = determinantOfMatrix(rbf_matrix)
            print(pd, matrix)

            # we need to maximize the population diversity
            dvd_loss = -dvd_coef * pd.log()
        ########################## dvd algorithm end ##########################
        #######################################################################

        # actor update
        imp_weights = torch.exp(action_log_probs - sample.action_log_probs)

        surr1 = imp_weights * sample.advantages
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param,
                            1.0 + self.clip_param) * sample.advantages
        assert surr1.shape[-1] == surr2.shape[-1] == 1

        if self._use_policy_active_masks:
            policy_loss = (-torch.min(surr1, surr2) * sample.active_masks
                           ).sum() / sample.active_masks.sum()
            dist_entropy = (dist_entropy * sample.active_masks
                            ).sum() / sample.active_masks.sum()
        else:
            policy_loss = -torch.min(surr1, surr2).mean()
            dist_entropy = dist_entropy.mean()

        value_loss = self.cal_value_loss(values, sample.value_preds,
                                         sample.returns, sample.active_masks)

        self.actor_optimizer.zero_grad()

        if update_actor:
            (policy_loss - dist_entropy * self.entropy_coef +
             dvd_loss).backward()

        if self._use_max_grad_norm:
            actor_grad_norm = nn.utils.clip_grad_norm_(
                self.policy.actor.parameters(), self.max_grad_norm)
        else:
            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

        self.actor_optimizer.step()

        self.critic_optimizer.zero_grad()

        (value_loss * self.value_loss_coef).backward()

        if self._use_max_grad_norm:
            critic_grad_norm = nn.utils.clip_grad_norm_(
                self.policy.critic.parameters(), self.max_grad_norm)
        else:
            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

        self.critic_optimizer.step()

        return (value_loss, critic_grad_norm, policy_loss, dist_entropy,
                actor_grad_norm, imp_weights, pd)

    def train(self, storage, buffer, dvd_coef, update_actor=True):
        train_info = defaultdict(lambda: 0)

        for _ in range(self.ppo_epoch):
            if self.policy.num_rnn_layers > 0:
                data_generator = recurrent_generator(storage,
                                                     self.num_mini_batch,
                                                     self.data_chunk_length)
            else:
                data_generator = feed_forward_generator(
                    storage, self.num_mini_batch)

            for sample in data_generator:
                (value_loss, critic_grad_norm, policy_loss, dist_entropy,
                 actor_grad_norm, imp_weights,
                 pd) = self.ppo_update(sample,
                                       buffer,
                                       dvd_coef,
                                       update_actor=update_actor)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()
                if pd is not None:
                    train_info['population_diversity'] += float(pd)

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates

        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()
