import random
import torch
from datasets import load_dataset
from typing import Callable

class ToFU(torch.utils.data.Dataset):
    def __init__(self, 
        name=None, split=None, template_func: Callable=None, 
        question_key='question', answer_key='answer',
        start_idx=-1, max_num=-1, data=None, as_dpo=False,
    ):

        if data is not None:
            self.data = data
        else:
            self.data = load_dataset(
                name, split
            )['train']

        if start_idx != -1 and max_num != -1:
            if start_idx < 0:
                start_idx = len(self.data) + start_idx
            self.data = self.data.select(range(start_idx, start_idx+max_num))
        self.template_func = template_func
        self.question_key = question_key
        self.answer_key = answer_key
        self.idks = [x.strip() for x in open("data/idontknow.jsonl")]
        self.as_dpo = as_dpo

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return iter([self[i] for i in range(len(self))])
    
    def __getitem__(self, index):
        if 'retain_label' in self.data.features:
            is_retain = self.data[index]['retain_label']
        else:
            is_retain = False
        qa_texts = self.apply_template(
            self.data[index],
            self.template_func,
            is_retain=is_retain
        )

        if 'retain_label' in self.data.features:
            return qa_texts, 1 if is_retain else 0
        else:
            return qa_texts

    def apply_template(self, item, template_func : Callable, is_retain=False):
        def compose_single(item, answer_key):
            question = item[self.question_key]
            if answer_key == 'idk':
                answer = random.choice(self.idks)
            else:
                answer = item[answer_key]
            return template_func(
                question=question,
                answer=answer,
            )
        
        if self.as_dpo:
            #! This branch considers add idk response
            return [compose_single(item, ans_key) for ans_key in [
                'answer', 'idk'
            ]]
        elif self.answer_key == 'idk':
            return compose_single(item, 'idk')
        else:
            if isinstance(item[self.answer_key], str):
                return compose_single(item, self.answer_key)
            else:
                #! This branch is for the perturb answer branch
                return [template_func(
                    question=item[self.question_key],
                    answer=ans,
                ) for ans in item[self.answer_key]]
