import torch
from benchmarking_utils.cifar100_superclass import CIFAR100_mapper
from tqdm import tqdm
import gc
mapper = CIFAR100_mapper()

def compute_cdnv(feature, num_clusters, cluster_ids_x):
    feature = feature.T
    print(f"in compute_cdnv:{feature.shape}, cluster_ids_x:{cluster_ids_x.shape}")
    N = num_clusters * [0]
    mean = num_clusters * [0]
    mean_s = num_clusters * [0]

    for c in range(num_clusters):
        idxs = (cluster_ids_x == c).nonzero()[0]
        if len(idxs) == 0:  # If no class-c in this batch
            continue

        h_c = feature[idxs, :]
        mean[c] += torch.sum(h_c, dim=0)
        N[c] += h_c.shape[0]
        mean_s[c] += torch.sum(torch.square(h_c))

    for c in range(num_clusters):
        mean[c] /= N[c]
        mean_s[c] /= N[c]

    avg_cdnv, avg_dist, avg_variance = 0, 0, 0
    total_num_pairs = num_clusters * (num_clusters - 1) / 2
    for class1 in range(num_clusters):
        for class2 in range(class1 + 1, num_clusters):
            variance1 = abs(mean_s[class1].item() - torch.sum(torch.square(mean[class1])).item())
            variance2 = abs(mean_s[class2].item() - torch.sum(torch.square(mean[class2])).item())
            variance_avg = (variance1 + variance2) / 2
            dist = torch.norm((mean[class1]) - (mean[class2])) ** 2
            dist = dist.item()
            if dist == 0:
                continue
            else:
                cdnv = variance_avg / dist
                avg_dist += dist / total_num_pairs
                avg_variance += variance_avg / total_num_pairs
                avg_cdnv += cdnv / total_num_pairs
    return avg_cdnv, avg_dist, avg_variance


def ncc_predict(feature: torch.Tensor,
                feature_bank: torch.Tensor,
                feature_labels: torch.Tensor,
                num_classes: int):
    print("beginning ncc_predict")
    feature_bank = feature_bank.cpu()
    feature = feature.T.cpu()
    class_means_bank = torch.zeros((num_classes, feature_bank.shape[0]))
    print("beginning class loop")
    for c in range(num_classes):
        try:
            class_means_bank[c] = feature_bank[:, (feature_labels == c)].mean(dim=1)
        except:
            print(f"in except: feature_bank: {feature_bank.shape}")
            exit()

    print(f"before norm, feature:{feature.shape}")
    NCC_scores = [torch.norm(feature[i, :] - class_means_bank, dim=1) for i in tqdm(range(feature.shape[0]))]
    NCC_scores = torch.stack(NCC_scores)
    NCC_pred = torch.argmin(NCC_scores, dim=1)
    del NCC_scores
    gc.collect()
    print("after pred")

    ncc_acc = (NCC_pred == feature_labels.cpu()).sum() / len(NCC_pred)
    print("computing cdnv")
    cdnv, dist, variance = compute_cdnv(feature, num_classes, feature_labels)
    print("after cdnv")

    return ncc_acc, cdnv, dist, variance


def ncc_predict_superclass(feature: torch.Tensor,
                           feature_bank: torch.Tensor,
                           feature_labels: torch.Tensor,
                           num_classes: int):
    print("beginning ncc_predict")
    feature_bank = feature_bank.cpu()
    feature = feature.T.cpu()

    class_means_bank = torch.zeros((20, feature_bank.shape[0]))
    print("beginning class loop")
    for superclass in range(20):
        superclass_instances = mapper(superclass)
        superclass_num = 0.
        for instance in superclass_instances:
            instance_features = feature_bank[:, (feature_labels == instance)]
            class_means_bank[superclass] += instance_features.sum(dim=1)
            superclass_num += instance_features.shape[1]

        class_means_bank[superclass] /= superclass_num

    print(f"before norm, feature:{feature.shape}")
    NCC_scores = [torch.norm(feature[i, :] - class_means_bank, dim=1) for i in tqdm(range(feature.shape[0]))]
    NCC_scores = torch.stack(NCC_scores)
    NCC_pred = torch.argmin(NCC_scores, dim=1)
    del NCC_scores
    gc.collect()
    print("after pred")

    superclass_gt_labels = mapper.coarse_labels[feature_labels.cpu()]

    ncc_acc = (NCC_pred.numpy() == superclass_gt_labels).sum() / len(NCC_pred)
    cdnv, dist, variance = compute_cdnv(feature, 20, torch.tensor(superclass_gt_labels))
    return ncc_acc, cdnv, dist, variance