"""
An implementation of a LIBRO like operation
THIS IS NOT FINISHED YET, EXPECT WEIRD BEHAVIOR

Takes several samples for unittest generation,
looks at their evaluation trace and picks the one that most closely
resembles the issue description (picked by LLM or so)
"""
import bisect
import re
from collections import defaultdict
from typing import List, Optional, Tuple
import json
import pathlib

import fire
from cachier import cachier
from unidiff import PatchSet

from datasets import load_from_disk

from measure_coverage_patch import extract_coverages_from_eval_output, extract_changed_lines_from_patch, \
    no_lines_covered, extract_good_case_from_eval_output, log, extract_patch_from_eval_output, \
    extract_number_added_tests_from_patch, no_missed_lines_covered, no_missed_lines, load_eval_outputs, coverage_diff, \
    compute_overlap, cached_extract_coverages_from_eval_output, coverage_of_patchset, coverage_union, load_blacklisted, \
    BLACKLIST, save_div

@cachier()
def main(
    eval_output_dir: str = "evaluation_output/gpt-4-1106-preview__swt_bench_lite_aug1_bm25_27k_cl100k__seed=1,temperature=07__test/mode_custom",
    libro_inference_results: str = "inference_output/gpt-4-1106-preview__libro_gpt-4-1106-preview__swt_bench_lite_aug1__test__test.jsonl",
    golden_eval_output_dir: str = "evaluation_output/swt_lite_golden_test/mode_vanillafuzzy",
    dataset: str = "datasets/swt_bench_lite_aug1_bm25_diff_27k_cl100k",
    split: str = "test",
    seeds: str = "1,2,3,4,5", # should be a comma seperated list
    setting: str = "LIBRO", # LIBRO or IDEAL
):
    log = []
    seeds = [int(s) for s in seeds.split(",")]
    dataset = load_from_disk(dataset)
    eval_output_dirs = [re.sub(r"00\d+__", f"00{seed}__", re.sub(r"seed=\d+", f"seed={seed}", eval_output_dir)) for seed in seeds]
    eval_output_by_instance = {seed: dict() for seed in seeds}
    for seed, eval_output_dir in zip(seeds, eval_output_dirs):
        eval_output_by_instance[seed] = load_eval_outputs(eval_output_dir)

    golden_eval_output_by_instance = load_eval_outputs(golden_eval_output_dir)

    if setting == "LIBRO":
        with open(libro_inference_results) as f:
            libro_inference_results = [json.loads(x) for x in f.readlines()]
        libro_inference_results_by_instance = {res["instance_id"]: res for res in libro_inference_results}

    for example in dataset[split]:
        instance_id = example["instance_id"]
        if instance_id in BLACKLIST:
            continue
        eval_outputss = [eval_output_by_instance[seed].get(instance_id) for seed in seeds]
        if not any(eval_outputss):
            log.append({
                "instance_id": instance_id,
                "message": "no eval output found",
            })
            continue
        coverages = [extract_coverages_from_eval_output(eval_outputs) if eval_outputs is not None else [] for eval_outputs in eval_outputss]
        ress = {}
        golden_coverage = cached_extract_coverages_from_eval_output(golden_eval_output_by_instance.get(instance_id))
        if golden_coverage is None or len(golden_coverage) < 4:
            log.append({
                "instance_id": instance_id,
                "message": "no golden eval output found",
            })
            continue
        golden_patch = PatchSet(example["test_patch"])
        removed_lines, added_lines = extract_changed_lines_from_patch(golden_patch)

        for seed, coverage, eval_outputs in zip(seeds, coverages, eval_outputss):
            patch = extract_patch_from_eval_output(eval_outputs)
            if len(coverage) < 4 or eval_outputs is None:
                ress[seed] = {
                    "instance_id": instance_id,
                    "message": "coverage not found for all 3 steps",
                    "patch_len": len(patch) if patch is not None else None,
                }
                continue
            coverage_original, coverage_after_pred, coverage_after_patch, coverage_original_after_patch = coverage
            golden_coverage_original, golden_coverage_after_pred, golden_coverage_after_patch, golden_coverage_original_after_patch = golden_coverage

            additional_lines_covered_pre_patch = coverage_diff(coverage_of_patchset(coverage_original, removed_lines), coverage_of_patchset(coverage_after_pred, removed_lines))
            additional_lines_covered_post_patch = coverage_diff(coverage_of_patchset(coverage_original_after_patch, added_lines), coverage_of_patchset(coverage_after_patch, added_lines))
            golden_lines_covered_pre_patch = coverage_union(coverage_of_patchset(golden_coverage_original, removed_lines), coverage_of_patchset(golden_coverage_after_pred, removed_lines))
            golden_lines_covered_post_patch = coverage_union(coverage_of_patchset(golden_coverage_original_after_patch, added_lines), coverage_of_patchset(golden_coverage_after_patch, added_lines))
            inter_pre_patch, covered_pre_patch, golden_pre_patch = compute_overlap(additional_lines_covered_pre_patch, golden_lines_covered_pre_patch)
            inter_post_patch, covered_post_patch, golden_post_patch = compute_overlap(additional_lines_covered_post_patch, golden_lines_covered_post_patch)
            patch_executable = golden_pre_patch + golden_post_patch > 0
            recall = save_div(inter_pre_patch + inter_post_patch, golden_pre_patch + golden_post_patch, 1)
            precision = save_div(inter_pre_patch + inter_post_patch, covered_pre_patch + covered_post_patch, 1)

            unittest_patch = PatchSet(patch)
            no_added_tests = extract_number_added_tests_from_patch(unittest_patch)

            ftp, etp, fails_initially, error_initially, compilation_error = extract_good_case_from_eval_output(eval_outputs)

            ress[seed] = {
                "instance_id": instance_id,
                "good_case": int(ftp or etp),
                "recall": recall,
                "precision": precision,
                "patch_executable": patch_executable,
                "ftp": int(ftp),
                "etp": int(etp),
                "fails_initially": int(fails_initially),
                "error_initially": int(error_initially),
                "compilation_error": int(compilation_error),
                "patch_len": len(patch),
                "no_added_tests": no_added_tests,
            }
        if all("message" in x for x in ress.values()):
            log.append({
                "instance_id": instance_id,
                "message": "no eval output found for all seeds",
            })
            continue
        if setting == "LIBRO":
            # LIBRO implementation (initial failing and no-error filtering has already passed) # Now we cluster the inference results on whether the failure is related to the issue
            # and pick the shortest one of the better clusters
            cluster_1 = []
            cluster_2 = []
            for seed in seeds:
                res = libro_inference_results_by_instance.get(instance_id + "_seed=" + str(seed-1))
                if res is None:
                    continue
                if "yes" in res["full_output"].lower():
                    cluster_1.append(res)
                else:
                    cluster_2.append(res)
            if not cluster_1 and not cluster_2:
                best_case = [x for x in ress.values() if x is not None][0]
            else:
                def get_patch_len(x):
                    seed = int(x["instance_id"].split("seed=")[-1])+1
                    res = ress[seed]
                    if res["patch_len"] is None:
                        return float("inf")
                    return res["patch_len"]
                if not cluster_1:
                    best_case_inf = min(cluster_2, key=get_patch_len)
                else:
                    best_case_inf = min(cluster_1, key=get_patch_len)
                best_seed = int(best_case_inf["instance_id"].split("seed=")[-1])+1
                best_case = ress[best_seed]
        else:
            best_case = [x for x in ress.values() if "message" not in x][0]
            for res in ress.values():
                if "message" in res:
                    continue
                if res["good_case"] and not best_case["good_case"]:
                    best_case = res
                    continue
                if res["fails_initially"] and not best_case["fails_initially"]:
                    best_case = res
                    continue
                if not res["error_initially"] and best_case["error_initially"]:
                    best_case = res
                    continue
        log.append(best_case)
    return log



        


if __name__ == "__main__":
    fire.Fire(main)
