import orjson
import os, re
import pandas as pd
import numpy as np

RESULT_DIR = "/fs/cml-projects/E2H/HMMT/item_difficulty"

black_list = [
    "HMMT_Feb_2010_results_Guts",
    "HMMT_Nov_2009_results_Guts",
    "HMMT_Nov_2010_results_Guts",
]

black_list0 = [
    "HMMT_Feb_2011_results_AlgebraCalculus",
    "HMMT_Feb_2011_results_AlgebraCombinatorics",
    "HMMT_Feb_2011_results_AlgebraGeometry",
    "HMMT_Feb_2011_results_CalculusCombinatorics",
    "HMMT_Feb_2011_results_GeometryCalculus",
    "HMMT_Feb_2011_results_GeometryCombinatorics",
]

abbr_dict = {
    "Algebra":"alg",
    "Combinatorics":"comb",
    "Calculus":"calc",
    "Geometry":"geo",
    "General":"gen",
    "AlgebraCalculus":"algcalc",
    "AlgebraCombinatorics":"algcomb",
    "AlgebraGeometry":"alggeo",
    "CalculusCombinatorics":"calccomb",
    "GeometryCalculus":"calcgeo",
    "GeometryCombinatorics":"combgeo",
    "General":"gen",
    "Guts":"guts",
    "Team":"team",
    "Team A":"team1",
    "Team B":"team2",
    "Theme":"thm"
}


def is_dash_line(line):
    assert len(line)>0
    return all([x==line[0] for x in line])


def get_df():
    filename_list = os.listdir(f"{RESULT_DIR}/subtest")
    df_dict = {}

    for filename in filename_list:
        file_true_name, _ = os.path.splitext(filename)
        with open(f"{RESULT_DIR}/subtest/{filename}", "r") as f:
            file_lines = f.read().splitlines()
            num_lines = []
            for line in file_lines:
                if len(line)==0:
                    continue
                elif is_dash_line(line):
                    break
                new_line = line.replace('-', '0')
                mode_A_index = [x.start() for x in re.finditer('\|', line)]
                if len(mode_A_index)==2:
                    new_line = new_line[mode_A_index[0]+1:mode_A_index[1]]
                    parsed_numbers = re.findall(r"\d+\.?\d*", new_line)
                else:
                    parsed_numbers = re.findall(r"\d+\.?\d*", new_line)
                    if file_true_name in black_list0:
                        parsed_numbers =  parsed_numbers[5:]
                    else: 
                        parsed_numbers =  parsed_numbers[2:]
                num_lines += [parsed_numbers,]

            num_q = np.min([len(nums) for nums in num_lines])
            for n, nums in enumerate(num_lines):
                if len(nums)!=num_q:
                    num_lines[n] = nums[1:] if file_true_name in black_list else nums[:-1]

            df = pd.DataFrame({f"q_{n}":[num_lines[m][n] for m in range(len(num_lines))] for n in range(len(num_lines[0]))})
            df_dict[file_true_name.replace("_results", "")] = df

    df_dict = {key:df_dict[key] for key in sorted(df_dict.keys())}

    return df_dict    


def get_json(df_dict):
    #csv_list = sorted(os.listdir(f"{RESULT_DIR}/subtest_csv"))

    with open(f"{RESULT_DIR}/HMMT_results.jsonl", "w") as wf:
        #for csv_name in csv_list:
        for df_name, df in df_dict.items():
            #csv_true_name, _ = os.path.splitext(csv_name)
            df_keywords = df_name.split("_")
            month = df_keywords[1]
            year = df_keywords[2]
            subtest = abbr_dict[df_keywords[3]]
            #test_results = pd.read_csv(f"{RESULT_DIR}/subtest_csv/{csv_name}", skiprows=1, header=None)
            for problem_idx, score_list in df.items():
                if problem_idx:
                    problem_dict = {"result_name":f"""{month}_{year}_{subtest}_{int(problem_idx.replace("q_", ""))+1}"""}
                    problem_dict["r_num_tester"] = len(score_list)
                    score_list = [float(n) for n in score_list]
                    problem_dict["r_num_zero_score"] = len([n for n in score_list if abs(n)<0.0001])
                    nonzero_list = [n for n in score_list if abs(n)>0.0001]
                    problem_dict["r_num_value"] = len(set(nonzero_list))
                    if problem_dict["r_num_tester"]==problem_dict["r_num_zero_score"]:
                        problem_dict["correct_percentage"] = 0.0
                    elif problem_dict["r_num_value"]==1:
                        problem_dict["correct_percentage"] = 100*(1.0-problem_dict["r_num_zero_score"]/problem_dict["r_num_tester"])
                    else:
                        problem_dict["correct_percentage"] = np.mean(score_list)
                    json_line = orjson.dumps(problem_dict, option=orjson.OPT_NAIVE_UTC | orjson.OPT_SERIALIZE_NUMPY)
                    wf.write(f"{str(json_line, encoding='utf-8')}\n")


def main():
    df_dict = get_df()
    get_json(df_dict)


if __name__ == "__main__":
    main()