import os
import math
import sys
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from fastchat.conversation import get_conv_template

sys.path.append('.')
sys.path.append('..')
from data_utils import construct_prompts

IGNORE_TOKEN_ID = -100

def build_prompt(text, model_name):
    conv = None
    if "llama" in model_name:
        conv = get_conv_template("llama-2")
    if "mistral" in model_name:
        conv = get_conv_template("mistral")
    if "gpt2" in model_name.lower():
        prompt = text + "\n"
    if "phi" in model_name.lower():
        prompt = f"<|user|>\n{text}\n<|assistant|>" + "\n"
    
    if conv:
        conv.messages = []
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    return prompt

class TrainDataset(Dataset):
    def __init__(self, args, tokenizer, max_length=1024, split="train"):
        self.args = args
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.all_data = construct_prompts(args.prompt_method, args, need_train=True)

    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, idx):
        # prompt = self.all_data[idx]["prompt"] + "\n"
        prompt = build_prompt(self.all_data[idx]["prompt"], self.args.model_name)
        label = self.all_data[idx]["label"]
        text = prompt + label
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length = self.max_length,
            return_tensors="pt",
        )
        label = encoding["input_ids"].squeeze().clone()
        label[:len(self.tokenizer(prompt)["input_ids"])] = IGNORE_TOKEN_ID
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": label,
        }
    
    def collator(self, features):
        max_length = max([f["input_ids"].size(0) for f in features])
        input_ids = torch.stack([F.pad(f["input_ids"], (0, max_length-f["input_ids"].size(0)), value=self.tokenizer.pad_token_id) for f in features])
        attention_mask = torch.stack([F.pad(f["attention_mask"], (0, max_length-f["attention_mask"].size(0))) for f in features])
        labels = torch.stack([F.pad(f["label"], (0, max_length-f["label"].size(0)), value=self.tokenizer.pad_token_id) for f in features])
        
        return_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }
        return return_dict

class EvaluationDataset(Dataset):
    def __init__(self, args, tokenizer, max_length=1024, split="valid"):
        self.args = args
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        self.all_data = construct_prompts(args.prompt_method, args, need_train=False)

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, idx):
        prompt = build_prompt(self.all_data[idx]["prompt"], self.args.model_name)
        label = self.all_data[idx]["label"]
        value = self.all_data[idx]["value"]
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            max_length = self.max_length,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "label": label,
            "prompt": prompt,
            "value": value
        }

    def collator(self, features):
        max_length = max([f["input_ids"].size(0) for f in features])
        input_ids = torch.stack([F.pad(f["input_ids"], (max_length-f["input_ids"].size(0), 0), value=self.tokenizer.bos_token_id) for f in features])
        attention_mask = torch.stack([F.pad(f["attention_mask"], (max_length-f["attention_mask"].size(0), 0)) for f in features])
        labels = [f["label"] for f in features]
        prompts = [f["prompt"] for f in features]
        values = [f["value"] for f in features]
        return_dict = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "prompts": prompts,
            "values": values
        }
        return return_dict