import json
import math
import os

from tqdm import tqdm

import spacy
from sklearn.cluster import KMeans
import numpy as np


nlp = spacy.load('en_core_web_sm')

def get_actions(text):
    doc = nlp(text)
    verbs = [token.lemma_.lower() for token in doc if token.pos_ == "VERB"]
    return verbs

def list_direct_subdirectories(folder_path):
    subdirectories = [os.path.join(folder_path, d) for d in os.listdir(folder_path) if
                      os.path.isdir(os.path.join(folder_path, d))]
    return subdirectories

def list_direct_files(folder_path):
    subdirectories = [os.path.join(folder_path, d) for d in os.listdir(folder_path) if
                      os.path.isfile(os.path.join(folder_path, d))]
    return subdirectories

if __name__ == "__main__":
    protocol_actions_list = {
        "XDL": set(),
        "Ecology": set(),
        "Genetics": set(),
        "Medical": set(),
        "BioEng": set()
    }

    dsl_actions_list = {
        "XDL": ['CleanVessel', 'Confirm', 'Wait', 'StopStir', 'PreparativeChromatography', 'StopHeatChill', 'Dry', 'EvacuateAndRefill', 'Recrystallize', 'WashSolid', 'StartPurge', 'SwitchArgon', 'Add', 'StartStir', 'Stir', 'SeparatePhases', 'HeatChillToTemp', 'CConnect', 'ResetHandling', 'Purge', 'Transfer', 'Filter', 'SwitchVacuum', 'Dissolve', 'HeatChill', 'Separate', 'AddSolid', 'Evaporate', 'FilterThrough', 'CSwitchArgon'],
        "Ecology": [],
        "Genetics": [],
        "Medical": [],
        "BioEng": []
    }

    file_paths = {
        "XDL": "data/XDL_NL",
        "Ecology & Environmental Biology": "data/protocol_list/Ecology.json",
        "Molecular Biology & Genetics": "data/protocol_list/Genetics.json",
        "Biomedical & Clinical Research": "data/protocol_list/Medical.json",
        "Bioengineering & Technology": "data/protocol_list/BioEng.json"
    }

    dsl_paths = {
        "Ecology & Environmental Biology": "data/autodsl/Ecology.json",
        "Molecular Biology & Genetics": "data/autodsl/Genetics.json",
        "Biomedical & Clinical Research": "data/autodsl/Medical.json",
        "Bioengineering & Technology": "data/autodsl/BioEng.json"
    }

    dsl_actions_list["XDL"] = [a.lower() for a in dsl_actions_list["XDL"]]
    for T in dsl_paths:
        path = file_paths[T]
        with open(path, "r") as f:
            data = json.load(f)
        sub = path.split("/")[2].split(".")[0]
        for action in data:
            dsl_actions_list[sub].append(action.lower())
    
    for a in sorted(list(dsl_actions_list.keys())):
        print(a)
        for b in sorted(list(dsl_actions_list.keys())):
            print(b, end=", ")
        print()
        for b in sorted(list(dsl_actions_list.keys())):
            list_a = set(dsl_actions_list[a])
            list_b = set(dsl_actions_list[b])
            print(len(list_a.intersection(list_b))/len(list_a), end=", ")
        print()

    for T in file_paths:
        if T == "XDL":
            print("XDL")
            for root, dirs, files in os.walk("data/XDL_NL"):
                if "procedure.txt" in files:
                    file_path = os.path.join(root, "procedure.txt")
                    with open(file_path, 'r', encoding='iso8859-1') as file:
                        content = file.read()
                    actions = get_actions(content)
                    for a in actions:
                        protocol_actions_list["XDL"].add(a)
            print(protocol_actions_list["XDL"])
            continue

        path = file_paths[T]
        with open(path, "r") as f:
            data = json.load(f)
        sub = path.split("/")[2].split(".")[0]
        
        for text in tqdm(data):
            actions = get_actions(text)
            for a in actions:
                protocol_actions_list[sub].add(a)
        print(protocol_actions_list[sub])

    for a in sorted(list(protocol_actions_list.keys())):
        print(a)
        for b in sorted(list(protocol_actions_list.keys())):
            print(b, end=", ")
        print()
        for b in sorted(list(protocol_actions_list.keys())):
            list_a = set(protocol_actions_list[a])
            list_b = set(protocol_actions_list[b])
            print(len(list_a.intersection(list_b))/len(list_a), end=", ")
        print()

# BioEng
# BioEng, Ecology, Genetics, Medical, XDL, 
# 1.0, 0.2956567242281528, 0.798116169544741, 0.6308738880167452, 0.007953950811093667, 
# Ecology
# BioEng, Ecology, Genetics, Medical, XDL, 
# 0.6853469189713731, 1.0, 0.7814167879670063, 0.7794759825327511, 0.017709849587578846, 
# Genetics
# BioEng, Ecology, Genetics, Medical, XDL, 
# 0.38687094155844154, 0.16340300324675325, 1.0, 0.451856737012987, 0.003956980519480519, 
# Medical
# BioEng, Ecology, Genetics, Medical, XDL, 
# 0.44077215560105293, 0.23493711611582335, 0.6512869260017549, 1.0, 0.005630301257677684, 
# XDL
# BioEng, Ecology, Genetics, Medical, XDL, 
# 0.9382716049382716, 0.9012345679012346, 0.9629629629629629, 0.9506172839506173, 1.0, 