import os
import sys
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader

from models import *
from models import MODELS
from data_utils import load_dataset, construct_prompts
from utils import create_model, compute_metrics

RESULT_PATH = "../results"
DATASET_PATH = "../data"

def parse_arguments():
    parser = argparse.ArgumentParser(description="Value Evaluation")

    parser.add_argument("--dataset", default="moral_choice", type=str, help="Dataset to evaluate (beavertails, denevil, moral_choice, value_fulcra)")
    parser.add_argument("--data_version", default="original", type=str, help="Data version to use (original, pairwise, augmented)")
    parser.add_argument("--model_name", default="openai/gpt-3.5-turbo", type=str, help="Model to evaluate")
    parser.add_argument("--prompt_method", default="vanilla", type=str, help="Prompt method to use, vanilla means just prompts.")
    parser.add_argument("--with_definition", default="True", type=str, help="Whether to include the basic/prior definition in the prompt.")
    parser.add_argument("--few_shot_num", default=3, type=int, help="Number of few-shot examples to use")
    parser.add_argument("--load_local_ckpt", default=None, type=str, help="Load a local checkpoint for the model")

    parser.add_argument("--train_split", default="small", type=str, help="Train split to use, small / large")
    parser.add_argument("--eval_max_tokens", default=100, type=int, help="Max tokens for evaluation")
    parser.add_argument("--eval_temp", default=1.0, type=float, help="Temperature for sampling")
    parser.add_argument("--eval_top_p", default=1.0, type=float, help="Top-P parameter for top-p sampling")
    parser.add_argument("--eval_repeats", default=1, type=int, help="Number of repeats for evaluation on one sample")
    parser.add_argument("--eval_num_samples", default=10, type=int, help="Number of samples to evaluate on")
    parser.add_argument("--batch_size", default=8, type=int, help="Batch size for evaluation")
    parser.add_argument("--gpu_num", default=8, type=int, help="GPU number to use for evaluation")
    parser.add_argument("--remove_value", default=None, type=str, help="Remove a specific value from the training data")
    
    parser.add_argument("--evaluate", action="store_true", help="Run evaluation on the dataset with the model and prompt method.")
    return parser.parse_args()


# RUN EVALUATION
if __name__ == "__main__":
    args = parse_arguments()
    print("Running with args: ", args)

    ### Step 1: Prepare the dataset and model
    prompts_and_labels = construct_prompts(args.prompt_method, args)
    dataset = Dataset.from_list(prompts_and_labels)  # Dataset({features: ['context', 'action', 'value', 'prompt', 'label'], num_rows: 774})
    print("Dataset loaded: ", dataset)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    result_path = f"{RESULT_PATH}/{args.dataset}/{args.data_version}/{args.train_split}/{args.model_name.split('/')[-1]}"
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    ### Step 2: Evaluate the dataset with the model
    if args.evaluate:
        model = create_model(args.model_name, args)
        results = []
        if os.path.exists(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv"):
            results = pd.read_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv").to_dict('records')
            for r in results:
                r.pop('Unnamed: 0', None)
            # print(results)

        for idx, batch in tqdm(enumerate(dataloader), desc=f"Evaluating on Dataset: {args.dataset}, Model: {args.model_name}"):
            if (idx + 1) * args.batch_size * args.eval_repeats <= len(results):  # skip the already evaluated samples
                continue
            for current_repeat in range(args.eval_repeats):
                responses = model.get_top_p_answer(    # return a list of responses for the current batch prompts
                    prompt_base = batch["prompt"],
                    prompt_system = [""] * args.batch_size,
                    max_tokens = args.eval_max_tokens,
                    temperature = args.eval_temp,
                    top_p = args.eval_top_p,
                )

                for fidx, (context, action, value, prompt, label, response) in enumerate(zip(batch["context"], batch["action"], batch["value"], batch["prompt"], batch["label"], responses)):
                    result_base = {
                        "index": idx * args.batch_size + fidx,
                        "context": context,
                        "action": action,
                        "value": value,
                        "prompt": prompt,
                        "label": label,
                        "model_id": args.model_name,
                        "eval_temperature": args.eval_temp,
                        "eval_top_p": args.eval_top_p,
                        "eval_repeats": args.eval_repeats,
                        "current_repeat": current_repeat,
                    }
                    result = {**result_base, **response}

                    results.append(result)
            
            results_df = pd.DataFrame(results)   # save intermediate results
            results_df.to_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")

            if args.eval_num_samples > 0 and idx * args.batch_size >= args.eval_num_samples:
                break
        
        results_df = pd.DataFrame(results)
        results_df.to_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")
    else:
        results_df = pd.read_csv(f"{result_path}/{args.prompt_method}_with_definition_{args.with_definition}.csv")
    print("Evaluation results: ", results_df)

    ### Step 3: Compute the metrics
    metrics = compute_metrics(results_df)
    print(f"Metrics for evaluation on {args.dataset} with model {args.model_name} and prompt method {args.prompt_method}:")
    for metric, value in metrics.items():
        print(f"{metric}: {value}")