import jax.numpy as jnp




def cosine_push_loss(a, labels=None):
    a = a/jnp.linalg.norm(a, axis=1, keepdims=True)
    div = jnp.einsum('ij,kj->ik', a, a) ** 2
    if labels is None:
        mask = jnp.ones_like(div)
        diag_elements = jnp.diag_indices_from(mask)
        mask = mask.at[diag_elements].set(0.)
    else:
        raise NotImplementedError
        n = div.shape[0]
        exclusion_mask = jnp.ones((n, n+1))
        print(exclusion_mask)
        for label in jnp.unique(labels):
            label_idx = jnp.nonzero(labels == label, size=n, fill_value=n)
            for i in range(labels.shape[0]):
                exclusion_mask = exclusion_mask.at[(i, *label_idx)].set(1)
                print(exclusion_mask)

        mask = 1 - exclusion_mask[:, :n]
        print(mask)

    loss = (div * mask).sum() / mask.sum()

    return loss


def participation_ratio(z):
    """Compute the participation ratio of a batch of representation."""

    z = z - jnp.mean(z, axis=0, keepdims=True)

    corr = jnp.einsum('ij,ik->jk', z, z) / z.shape[0]
    eigvals = jnp.linalg.eigvalsh(corr)
    eigvals = jnp.real(eigvals)
    pr = jnp.sum(eigvals) ** 2 / jnp.sum(eigvals ** 2)

    return   pr, eigvals
    