import numpy as np
from sklearn.cluster import spectral_clustering
from scipy.linalg import eigh


def discover_groups(fair_mat: np.ndarray, num_groups: int) -> (np.ndarray, np.ndarray):
    """
    :param fair_mat: (num_nodes, num_nodes) Fairness matrix
    :param num_groups: Number of groups to discover
    :return groups: (num_nodes,) Discovered groups
    :return group_fair_mat: (num_nodes, num_nodes) Modified fairness matrix corresponding to the groups
    """
    groups = spectral_clustering(fair_mat, n_clusters=num_groups)
    groups_one_hot = np.eye(num_groups)[groups, :]
    group_fair_mat = np.matmul(groups_one_hot, groups_one_hot.T)
    return group_fair_mat


def low_rank_approx(fair_mat: np.ndarray, rank: int) -> np.ndarray:
    """
    :param fair_mat: (num_nodes, num_nodes) Fairness matrix
    :param rank: Desired rank of the fairness matrix
    :return low_rank_fair_mat: (num_nodes, num_nodes) Best rank 'rank' approximation of fair_mat
    """
    val, vec = eigh(fair_mat, subset_by_index=[fair_mat.shape[0] - rank, fair_mat.shape[0] - 1])
    fair_mat = np.matmul(np.matmul(vec, np.diag(val)), vec.T)
    return np.real(fair_mat)
