"""
Measuring how much cross checking with SWT Bench improves the precision of the combined code patch generation.
"""
import json

import fire

from datasets import load_from_disk

from measure_coverage_patch import log, extract_coverages_from_eval_output, extract_changed_lines_from_patch, \
    extract_good_case_from_eval_output, no_lines_covered, extract_number_added_tests_from_patch, load_eval_outputs, \
    BLACKLIST


def main(
    eval_output_dir: str = "evaluation_output_swe_agent_patches/swe-agent-demo3__swt_bench_lite__test/mode_vanilla",
    swe_bench_results: str = "results/experiments-swe-bench/20240402_sweagent_gpt4/results/results.json",
    dataset: str = "datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k",
    split: str = "test",
    log: callable = log,
):

    swe_bench_ress = json.load(open(swe_bench_results))["resolved"]
    swe_bench_ress = [res for res in swe_bench_ress if res not in BLACKLIST]
    good_cases = []

    dataset = load_from_disk(dataset)
    eval_output_by_instance = load_eval_outputs(eval_output_dir)

    for example in dataset[split]:
        instance_id = example["instance_id"]
        if instance_id not in eval_output_by_instance:
            continue
        eval_outputs = eval_output_by_instance[instance_id]
        coverage = extract_coverages_from_eval_output(eval_outputs)
        if len(coverage) != 4:
            continue

        ftp, etp, _, _, _ = extract_good_case_from_eval_output(eval_outputs)
        good_case = ftp or etp
        if good_case and instance_id in BLACKLIST:
            print("Good case in blacklist:", instance_id)
            continue
        if good_case:
            good_cases.append(instance_id)

    precision = len(set(good_cases) & set(swe_bench_ress)) / len(good_cases)
    recall = len(set(good_cases) & set(swe_bench_ress)) / len(swe_bench_ress)

    print("FTP of generated test cases on generated patches:", len(good_cases))
    print("SWE Bench resolvs:", len(swe_bench_ress))
    print("Overlap:", len(set(good_cases) & set(swe_bench_ress)))
    print("Precision:", precision)
    print("Recall:", recall)



        


if __name__ == "__main__":
    fire.Fire(main)
