import torch
import torch.nn.functional as F
import torch.nn as nn

def kl_divergence(p : torch.Tensor, q : torch.Tensor):
    eps = 1e-9
    p = p + eps
    q = q + eps
    return (p * (p / q).log()).sum(dim=-1)

def jensen_shannon_divergence(p : torch.Tensor, q : torch.Tensor):
    m = 0.5 * (p + q)
    kl_p_m = kl_divergence(p, m)
    kl_q_m = kl_divergence(q, m)
    return 0.5 * (kl_p_m + kl_q_m)

def to_dist(logits : torch.Tensor):
    return F.softmax(logits, dim=-1)

def entropy(probs: torch.Tensor):
    return -torch.sum(probs * torch.log2(probs), dim=-1)