import argparse
import nltk
import random
import json
import time
import os

from tqdm import tqdm

from llm_compiler.compiler import llm_compiler
from utils.config import Config
from utils.logger import create_logger, display_exp_setting
from utils.loader import load_data

parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='testset')
parser.add_argument('--log_dir', default="exp")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--quickrun", action="store_true")

parser.add_argument("--dataset", type=str, default="testset")
parser.add_argument("--dsl", type=str, default="autodsl")

parser.add_argument("--engine", type=str, default="openai/gpt-3.5-turbo")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--freq_penalty", type=float, default=0.0)
parser.add_argument("--max_tokens", type=int, default=2048)
parser.add_argument("--llm_cache_dir", type=str, default="llm_cache")
args = parser.parse_args()

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet') 

if __name__ == "__main__":
    random.seed(args.seed)
    cfg = Config(args)
    logger = create_logger(os.path.join(cfg.log_dir, 'log.txt'))
    display_exp_setting(logger, cfg)
    test_data = load_data(cfg.dataset, cfg.quickrun)

    if args.mode == "testset":
        sources, predictions, times, infos = [], [], [], []

        for dataset in ["BioEng", "Ecology", "Genetics", "Medical"]:
            compiler = llm_compiler(dataset, cfg.dsl, cfg.engine, cfg.temperature, cfg.freq_penalty, cfg.max_tokens, cfg.llm_cache_dir)

            test_data_subset = [a for a in test_data if a["bigAreas"] == dataset]
            logger.info(dataset)
            for example in tqdm(test_data_subset, total=len(test_data_subset)):
                infos.append(json.dumps({"bigAreas":example["bigAreas"], "bigProb":example["bigProb"], "smallProb":example["smallProb"]}))
                example = example["procedures"]
                s_time = time.time()
                prediction = compiler.compile(example[:])
                e_time = time.time()

                sources.append(example)
                predictions.append(prediction)
                times.append(str(e_time - s_time))

                logger.info(example)
                logger.info(prediction)

                json_results = {
                    "sources": sources,
                    "predictions": predictions,
                    "times": times,
                    "infos": infos
                }

                with open(f"{cfg.result_dir}/results.json", "w") as f:
                    logger.info(f"dumping results to {cfg.result_dir}/results.json")
                    json.dump(json_results, f, indent=2)