import json
from d5 import d5
from d5_problem import D5Problem as Problem
from validate_descriptions import get_validator_by_name
import os
from argparse import ArgumentParser


#  python3 src/run_d5_problems_w_locks.py --template_name detailed --d5_problem_dicts_path data/reward_analysis_d5_problems.json  --lock_name 0413detailed --validator_name claude-v1.3_loose --proposer_num_rounds 5 --num_descriptions_per_prompt 10
if __name__ == "__main__":

    parser = ArgumentParser()

    parser.add_argument(
        "--d5_problem_dicts_path", type=str, help="Path to D5 problem dicts"
    )
    parser.add_argument("--num_descriptions_per_prompt", type=int, default=15)
    parser.add_argument("--proposer_num_rounds", type=int, default=8)
    parser.add_argument("--template_name", type=str, default="detailed")
    parser.add_argument("--validator_name", type=str, default="claude-v1.3_loose")
    parser.add_argument("--proposer_model", type=str, default="claude-v1.3")

    args = parser.parse_args()

    d5_problem_dicts_path = args.d5_problem_dicts_path
    v_name = os.path.basename(args.validator_name)
    problem_dicts_name = os.path.basename(d5_problem_dicts_path).split(".")[0]
    lock_name = f"{v_name}_{args.proposer_model}_{args.template_name}_{args.proposer_num_rounds}_{args.num_descriptions_per_prompt}_{problem_dicts_name}"

    with open(d5_problem_dicts_path, "r") as f:
        problem_dicts = json.load(f)

    validator = get_validator_by_name(args.validator_name)

    save_path = f"results/d5_{lock_name}.json"

    all_locks = []
    for problem_dict in problem_dicts:
        problem = Problem.from_dict(problem_dict)

        problem_name = problem_dict["problem_name"]

        lock_path = f"locks/d5_{lock_name}_{problem_name}.lock"
        if os.path.exists(lock_path):
            print(f"Skipping {problem_name}")
            all_locks.append(lock_path)
            continue
        else:
            print(f"Running D5 on {problem_name}, lock path: {lock_path}")
            open(lock_path, "w").close()

        d5_result = d5(
            problem=problem,
            num_descriptions_per_prompt=args.num_descriptions_per_prompt,
            validator=validator,
            d5_problem_name=problem_name,
            proposer_num_rounds=args.proposer_num_rounds,
            proposer_model=args.proposer_model,
            template_name=args.template_name,
            early_stopping_significance_threshold=0.0,
        )
        with open(lock_path, "w") as f:
            d = d5_result.to_dict()
            f.write(json.dumps(d, indent=4))

    all_d5_results = []
    for lock_path in all_locks:
        with open(lock_path, "r") as f:
            all_d5_results.append(json.load(f))

    with open(save_path, "w") as f:
        json.dump(all_d5_results, f, indent=4)
