from tqdm import tqdm
import torch
import torch.nn.functional as F
import time
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from copy import deepcopy
import numpy as np

from benchmarking_utils.clustering_metrics import compute_cdnv, mapper


class NCCClassifierHandler:
    def __init__(self, feature_bank, targets_bank, from_layer=2, end_layer=5):
        self.feature_bank = feature_bank
        self.targets_bank = targets_bank
        self.from_layer = from_layer
        self.end_layer = end_layer
        self.lin_classifiers = {}
        self.class_means_solutions = {}
        self.num_classes = 100

        self.compute_ncc_classifiers()

    def compute_ncc_classifiers(self):
        print(f"beginning compute classifiers")
        classifier_name = "NCC_orig"
        self.class_means_solutions[classifier_name] = {}
        start = time.time()
        for layer_idx in tqdm(range(self.from_layer, self.end_layer)):
            cur_features, cur_targets = self.feature_bank[layer_idx], self.targets_bank[layer_idx]
            print(f"computing ncc for layer: {layer_idx}, "
                  f"cur_features: {cur_features.T.shape}, cur_targets: {cur_targets.shape}")


            class_means_bank = torch.zeros((self.num_classes, cur_features.shape[0]))
            print("beginning class loop")
            for c in range(self.num_classes):
                try:
                    class_means_bank[c] = cur_features[:, (cur_targets == c)].mean(dim=1)
                except:
                    print(f"in except: feature_bank: {cur_features.shape}")
                    exit()

            self.class_means_solutions[classifier_name][layer_idx] = class_means_bank
            print(f"time for computing ncc {layer_idx}: {time.time() - start}")



    def predict_ncc(self, features, targets):
        print(f"starting sklearn predict")
        result = {}
        for classifier_key, cur_classifier in self.class_means_solutions.items():
            result[classifier_key] = {}
            result["CDNV_orig"] = {}
            result["DIST_orig"] = {}
            result["VARIANCE_orig"] = {}
            for layer_idx in tqdm(range(self.from_layer, self.end_layer)):
                start = time.time()
                cur_features, cur_targets = features[layer_idx], targets[layer_idx]
                print(f"cur_features:{cur_features.shape}, cur_classifier:{cur_classifier[layer_idx].shape}")
                print(f"classifier_key:{classifier_key}")

                #
                NCC_scores = [torch.norm(cur_features.T[i, :] - cur_classifier[layer_idx], dim=1) for i in
                              tqdm(range(cur_features.T.shape[0]))]
                NCC_scores = torch.stack(NCC_scores)
                prediction = torch.argmin(NCC_scores, dim=1)

                print(f"prediction:{prediction.shape}, cur_targets:{cur_targets.shape}")
                print(f"prediction:{type(prediction)}, cur_targets:{type(cur_targets)}")

                prediction = prediction.squeeze()
                if classifier_key == "NCC_super":
                    cur_targets = mapper.coarse_labels[cur_targets.cpu()]
                else:
                    cur_targets = cur_targets.numpy()

                accuracy = np.sum(prediction.numpy() == cur_targets) / len(prediction)
                print(f"{classifier_key}, accuracy for layer:{layer_idx}: {accuracy}")
                print(f"time for layer:{layer_idx}:{time.time() - start}")
                result[classifier_key][layer_idx] = accuracy

                avg_cdnv, avg_dist, avg_variance = compute_cdnv(cur_features, 100, cur_targets)

                result["CDNV_orig"][layer_idx] = avg_cdnv
                result["DIST_orig"][layer_idx] = avg_dist
                result["VARIANCE_orig"][layer_idx] = avg_variance

        return result
