"""
Statistics of the golden tests
"""

import pathlib

import fire

from metrics.log_parsers import MAP_REPO_TO_PARSER
from metrics.getters import get_file_name_from_lp, get_repo_from_lp, log_path_to_sms, FAIL_TO_NULL, PASS_TO_NULL, \
    NULL_TO_FAIL, NULL_TO_PASS

from measure_coverage_patch import log, extract_coverages_from_eval_output, load_eval_outputs, BLACKLIST, BLACKLIST_FULL
from metrics.getters import FAIL_TO_FAIL, FAIL_TO_PASS, PASS_TO_FAIL, PASS_TO_PASS, test_failed, test_passed, TestStatus

def convert_log_to_ftp(
   sm_before: dict, sm_after: dict
) -> dict:

    status_ground_truth = {
        FAIL_TO_PASS: [],
        FAIL_TO_FAIL: [],
        PASS_TO_PASS: [],
        PASS_TO_FAIL: [],
        FAIL_TO_NULL: [],
        PASS_TO_NULL: [],
        NULL_TO_FAIL: [],
        NULL_TO_PASS: [],
    }

    for test, status in sm_after.items():
        if status == TestStatus.PASSED.value:
            if test_passed(test, sm_before):
                status_ground_truth[PASS_TO_PASS].append(test)
            elif test_failed(test, sm_before):
                status_ground_truth[FAIL_TO_PASS].append(test)
            else:
                status_ground_truth[NULL_TO_PASS].append(test)
        if status == TestStatus.FAILED.value:
            if test_passed(test, sm_before):
                status_ground_truth[PASS_TO_FAIL].append(test)
            elif test_failed(test, sm_before):
                status_ground_truth[FAIL_TO_FAIL].append(test)
            else:
                status_ground_truth[NULL_TO_FAIL].append(test)
    tests_after = set(sm_after.keys())
    tests_before = set(sm_before.keys())
    for test in tests_before - tests_after:
        if test_passed(test, sm_before):
            status_ground_truth[PASS_TO_NULL].append(test)
        elif test_failed(test, sm_before):
            status_ground_truth[FAIL_TO_NULL].append(test)

    return status_ground_truth

def count_tests(sm):
    return len(sm)


def main(
    eval_output_dir: str = "evaluation_output/swt_golden_test/mode_vanillafuzzy",
):
    log = []

    for file in sorted(pathlib.Path(eval_output_dir).rglob("*.log")):
        file = str(file)
        inst_file_name = get_file_name_from_lp(file)
        inst_id = inst_file_name[:inst_file_name.rfind(".golden.eval.log")]
        if inst_id in BLACKLIST_FULL:
            continue

        repo = get_repo_from_lp(file)
        log_parser = MAP_REPO_TO_PARSER[repo]

        sms, found = log_path_to_sms(file, log_parser)
        if sms is None or len(sms) != 4:
            log.append({
                "inst_file_name": inst_file_name,
                "message": "Log file could not be parsed properly (Before, After Logs not found)"
            })
            continue

        # total tests in original test suite
        total_cases = count_tests(sms[0])

        # difference between adding nothing and adding the golden patch + golden test
        diff_between_addition = convert_log_to_ftp(sms[0], sms[2])
        added_cases = diff_between_addition[NULL_TO_FAIL] + diff_between_addition[NULL_TO_PASS]
        removed_cases = diff_between_addition[PASS_TO_NULL] + diff_between_addition[FAIL_TO_NULL]

        # ftp, ftf, ptp, ptf of original test suite
        diff_of_original_test_suite = convert_log_to_ftp(sms[0], sms[3])

        # ftp, ftf, ptp, ptf of added test suite
        diff_of_added_test_suite = convert_log_to_ftp(sms[1], sms[2])
        ftp_only_added = set(diff_of_added_test_suite[FAIL_TO_PASS]) - set(diff_of_original_test_suite[FAIL_TO_PASS])
        ftf_only_added = set(diff_of_added_test_suite[FAIL_TO_FAIL]) - set(diff_of_original_test_suite[FAIL_TO_FAIL])
        ptp_only_added = set(diff_of_added_test_suite[PASS_TO_PASS]) - set(diff_of_original_test_suite[PASS_TO_PASS])
        ptf_only_added = set(diff_of_added_test_suite[PASS_TO_FAIL]) - set(diff_of_original_test_suite[PASS_TO_FAIL])

        log.append({
            "inst_file_name": inst_file_name,
            "ftp_orig": len(diff_of_original_test_suite[FAIL_TO_PASS]),
            "ftf_orig": len(diff_of_original_test_suite[FAIL_TO_FAIL]),
            "ptp_orig": len(diff_of_original_test_suite[PASS_TO_PASS]),
            "ptf_orig": len(diff_of_original_test_suite[PASS_TO_FAIL]),
            "ftp_added": len(ftp_only_added),
            "ftf_added": len(ftf_only_added),
            "ptp_added": len(ptp_only_added),
            "ptf_added": len(ptf_only_added),
            "total_cases": total_cases,
            "added_cases": len(added_cases),
            "removed_cases": len(removed_cases),
        })
    return log

def summary(
    eval_output_dir: str = "evaluation_output/swt_golden_test/mode_vanillafuzzy",
):
    log = main(eval_output_dir)
    log = [r for r in log if r.get("message") is None]
    for title in [
        "total_cases",
        "added_cases",
        "removed_cases",
        "ftp_orig",
        "ftf_orig",
        "ptp_orig",
        "ptf_orig",
        "ftp_added",
        "ftf_added",
        "ptp_added",
        "ptf_added",
    ]:
        values = [r[title] for r in log]
        print(title, "&", sum(values)/len(values), "&", max(values))


if __name__ == "__main__":
    fire.Fire(summary)
