import json
import math

from tqdm import tqdm

import spacy
from sklearn.cluster import KMeans
import numpy as np

# spacy.require_gpu()
# print("GPU is being used:", spacy.util.use_gpu())
nlp = spacy.load('en_core_web_lg')

if __name__ == "__main__":
    file_paths = {
        "Ecology & Environmental Biology": "data/protocol_list/Ecology.json",
        "Molecular Biology & Genetics": "data/protocol_list/Genetics.json",
        "Biomedical & Clinical Research": "data/protocol_list/Medical.json",
        "Bioengineering & Technology": "data/protocol_list/BioEng.json"
    }

    lensaa = {
        "Ecology & Environmental Biology": 200,
        "Molecular Biology & Genetics": 650,
        "Biomedical & Clinical Research": 600,
        "Bioengineering & Technology": 420
    }

    a = 147/(math.sqrt(812)-math.sqrt(53))
    b = 53 - (a*math.sqrt(53))
    print(a, b)

    for T in file_paths:
        path = file_paths[T]
        with open(path, "r") as f:
            data = json.load(f)
        sub = path.split("/")[2].split(".")[0]
        print(sub, len(data), a*math.sqrt(len(data))+b)

        embeddings = []
        print("embedding")
        for text in tqdm(data):
            embeddings .append(nlp(text).vector)

        kmeans = KMeans(n_clusters=lensaa[T], random_state=0)
        kmeans.fit(embeddings)

        cluster_labels = kmeans.labels_

        slots = [[] for _ in range(lensaa[T])]
        results = []
        for text, label in zip(data, cluster_labels):
            slots[label].append(text)

        for i in range(lensaa[T]):
            results.append(min(slots[i], key=len))
        
        print("complete " + sub)
        
        with open("data/protocol_list/" + sub + "_sample.json", "w") as f:
            json.dump(results, f, indent=2)