import re, os
import numpy as np
from datasets import load_dataset, Dataset
import random, itertools, argparse

CACHE_DIR = "/fs/cml-projects/E2H/Huggingface_cache"
CSV_SAVE_DIR = "./csv"
TXT_SAVE_DIR = "./txt"


def gsm8k_formatter(sample):
    return {
        "question": sample["question"],
        "answer": sample["answer"],
    }


def arc_formatter(sample):
    return {
        # Removed "Question: {question}.\nAnswer:" template
        "question": sample["question"],
        # Add period to the end of each choice
        "choices": f"(A) {sample['choices']['text'][0]}   (B) {sample['choices']['text'][1]}   (C) {sample['choices']['text'][2]}   (D) {sample['choices']['text'][3]}",
        "target": sample["choices"]["label"].index(sample["answerKey"]),
    }


def winogrande_formatter(sample):
    idx = sample["sentence"].index("_")
    return {
        "question": sample["sentence"][:idx],
        "choices": f"""(A) {sample["option1"]+sample["sentence"][idx + 1 :]}   (B) {sample["option2"]+sample["sentence"][idx + 1 :]}""",
        "target": {"1": 0, "2": 1}[sample["answer"]],
    }


def hellaswag_formatter(sample):
    preprocess = lambda text: re.sub(
        "\\[.*?\\]", "", text.replace(" [title]", ". ").strip()
    ).replace("  ", " ")
    ctx = sample["ctx_a"] + " " + sample["ctx_b"].capitalize()
    return {
        "question": preprocess(sample["activity_label"] + ": " + ctx),
        "choices": f"{sample['activity_label']}: {sample['ctx']}     (A) {sample['endings'][0]}   (B) {sample['endings'][1]}   (C) {sample['endings'][2]}   (D) {sample['endings'][3]}",
        "target": int(sample["label"]),
    }


def get_questions(dataset_name, data_row):
    if dataset_name=="GSM8K":
        formatted_data = gsm8k_formatter(data_row)
        return f"{formatted_data['question']}"
    elif dataset_name=="HellaSwag":
        formatted_data = hellaswag_formatter(data_row)
        return f"{formatted_data['question']}    {formatted_data['choices']}"
    elif dataset_name=="ARC":
        formatted_data = arc_formatter(data_row)
        return f"{formatted_data['question']}    {formatted_data['choices']}"
    else: # dataset_name=="Winogrande"
        formatted_data = winogrande_formatter(data_row)
        return f"{formatted_data['question']}    {formatted_data['choices']}"
    

def generate_text(txt_path, dataset):
    with open(txt_path, 'w') as f:
        for index, row in enumerate(dataset.__iter__()):
            f.write(f"{index}. Please identify the more challenging question from the following pair. If you encounter problems with similar difficulty and are unsure, please still make a selection based on your intuition.\n\n")
            f.write(f"{row['question_text_0']}\n")
            f.write(f"{row['question_text_1']}\n\n")


def generate_problem_set(dataset_name, dataset, q_index, random_seed=7, length=10):
    random.seed(random_seed)
    pair_index_table = random.sample(list(itertools.combinations(np.arange(dataset.num_rows), 2)), 100000)
    question_set = {}
    for q_id in range(2):
        question_set[f"{q_id}"] = dataset.select([pair_index_table[n][q_id] for n in range(20*length)]).rename_columns({ori_name:f"{ori_name}_{q_id}" for ori_name in dataset.column_names})
    merge_set = question_set["0"]
    for column_name in question_set["1"].column_names:
        #print(column_name)
        merge_set = merge_set.add_column(column_name, question_set["1"][column_name])

    def get_line(example, dataset_name):
        example["question_text_0"] = get_questions(dataset_name, {column_name[:-2]:example[column_name] for column_name in example.keys() if column_name.endswith("_0")})
        example["question_text_1"] = get_questions(dataset_name, {column_name[:-2]:example[column_name] for column_name in example.keys() if column_name.endswith("_1")})
        return example
    
    return merge_set.filter(
        lambda example: abs(example["model_avg_acc_0"]-example["model_avg_acc_1"])>=0.1
    ).select(
        np.arange(length)+q_index*length
    ).map(
        lambda example: get_line(example, dataset_name),
        load_from_cache_file=False,
    ).select_columns(
        ["sorted_index_0", "sorted_index_1", "model_avg_acc_0", "model_avg_acc_1", "question_text_0", "question_text_1"]
    )
    

def main(q_index, random_seed=7):
    os.makedirs("txt", exist_ok=True)
    os.makedirs("csv", exist_ok=True)

    dataset = {
       "GSM8K":load_dataset("mcding-org/Easy2Hard-GSM8K", "v1", split="test", cache_dir=CACHE_DIR),
        #"HellaSwag":load_dataset("mcding-org/Easy2Hard-HellaSwag", "v1", split="validation", cache_dir=CACHE_DIR),
        "ARC":load_dataset("mcding-org/Easy2Hard-ARC", "v1", split="test", cache_dir=CACHE_DIR),
        "Winogrande":load_dataset("mcding-org/Easy2Hard-Winogrande", "v1", split="validation", cache_dir=CACHE_DIR)
    }
    for dataset_name in dataset.keys():
        problem_set = generate_problem_set(dataset_name, dataset[dataset_name], q_index, length=10, random_seed=random_seed)
        problem_set.to_csv(f"{CSV_SAVE_DIR}/q_{q_index}_dataset_{dataset_name}.csv")
        generate_text(f"{TXT_SAVE_DIR}/q_{q_index}_dataset_{dataset_name}.txt", problem_set)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--q_index",
        type=int,
        default=0,
        help="Random seed",
    )
    args = argparser.parse_args()

    main(args.q_index)