"""
Utilities for processing a deep ensemble.
"""
import torch
from torch import nn
from torch.nn import functional as F
import torch.backends.cudnn as cudnn

def entropy_prob(probs):
    p = probs
    eps = 1e-12
    logp = torch.log(p + eps)
    plogp = p * logp
    entropy = -torch.sum(plogp, dim=1)
    return entropy


def mutual_information_prob(probs):
    mean_output = torch.mean(probs, dim=0)
    predictive_entropy = entropy_prob(mean_output)

    # Computing expectation of entropies
    p = probs
    eps = 1e-12
    logp = torch.log(p + eps)
    plogp = p * logp
    exp_entropies = torch.mean(-torch.sum(plogp, dim=2), dim=0)

    # Computing mutual information
    mi = predictive_entropy - exp_entropies
    return mi

def ensemble_forward_pass(model_ensemble, data):
    """
    Single forward pass in a given ensemble providing softmax distribution,
    predictive entropy and mutual information.
    """
    outputs = []
    for i, model in enumerate(model_ensemble):
        output = F.softmax(model(data), dim=1)
        outputs.append(torch.unsqueeze(output, dim=0))

    outputs = torch.cat(outputs, dim=0)
    mean_output = torch.mean(outputs, dim=0)
    predictive_entropy = entropy_prob(mean_output)
    mut_info = mutual_information_prob(outputs)

    return mean_output, predictive_entropy, mut_info