from torch.utils.data import Dataset
import datasets
from datasets import load_dataset
import transformers
from typing import Dict
import torch
import numpy as np
from tqdm import tqdm
import json
import random
random.seed(0)

from typing import List

mistral_short_circuit_template = "{user_tag} {instruction} {assistant_tag}<SEPARATOR>{response}"
# mistral_short_circuit_template_with_sys = "{user_tag} <<SYS>>\nAlways assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.\n<</SYS>>\n\n{instruction} {assistant_tag}<SEPARATOR>{response}"
# mistral_sys_prompt = "{user_tag} <<SYS>>\nAlways assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.\n<</SYS>>\n\n"

llama_short_circuit_template = "{user_tag}\n\n{instruction}{assistant_tag}\n\n<SEPARATOR>{response}"



# FUNCTION_CALL_TEMPLATE = '''You are an expert in composing functions. You are given a question and a set of possible functions.
# Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
# If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out. You should only return the function call in tools call sections."""

# Questions: {instruction}\nHere is a list of functions in JSON format that you can invoke:\n{functions}.
# Should you decide to return the function call(s), put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]\n
# NO other text MUST be included.'''

FUNCTION_CALL_TEMPLATE = """{instruction}
Here is a list of functions in JSON format that you can invoke:\n{functions}.
Please return the function call(s)
"""

class ShortCircuitingFunctionDataset(Dataset):
    
    def __init__(self, 
                tokenizer: transformers.PreTrainedTokenizer, 
                num_examples,
                lorra_args,
                model_name_or_path,
                sc_train_subset : List[str] = None,
                ):
        super(ShortCircuitingFunctionDataset, self).__init__()

        self.model_name_or_path = model_name_or_path.lower()

        if 'llama' in self.model_name_or_path:
            print("USING LLAMA TEMPLATE")
            short_circuit_template = llama_short_circuit_template

        elif 'mistral' in self.model_name_or_path:
            print("USING MISTRAL TEMPLATE")
            short_circuit_template = mistral_short_circuit_template
            # fix spacing issue in template
            tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
        else:
            raise ValueError("Model name not recognized")

        self.user_tag = user_tag = lorra_args.user_tag
        self.assistant_tag = assistant_tag = lorra_args.assistant_tag

        if 'llama' in self.model_name_or_path:
            switch_select = [0,2]
        else:
            switch_select = [0]


        # ======================= Retain ======================= #
        ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft")
        orig_s = []
        for example in ds:
            messages = example["messages"]
            if len(messages) < 2: continue

            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
            elif switch == 1:
                formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)
                formatted_input = formatted_input.replace(f"<s>{user_tag} ", sys_prompt.format(user_tag=user_tag))
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=messages[1]["content"])

            orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

            if len(orig_s) > num_examples:
                break
        self.orig_s_retain = orig_s
        random.shuffle(self.orig_s_retain)
        print("orig_s_retain[0]", orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))

        
        # ======================= Borderline Retain ======================= #
        import csv
        with open(f'/data/private_models/short_circuiting/data/xstest_v2_completions_gpt4_gpteval.csv', newline='') as f:
            data = [dict(row) for row in csv.DictReader(f)]
            data = [row for row in data if row['final_label'] == "1_full_compliance"]
        
        borderline_orig_s = []
        for i, d in enumerate(data * 50):
            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['completion'])
            elif switch == 1:
                formatted_input = short_circuit_template_with_sys.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['completion'])
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=d['completion'])
            
            borderline_orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

        self.orig_s_retain += borderline_orig_s
        random.shuffle(self.orig_s_retain)
        print("borderline_orig_s[0]", borderline_orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))

        # # ===================== Function calling retain ================= #
        with open("/data/private_models/short_circuiting/data/function-call-retain_filtered.json") as file:
            data = json.load(file)
        function_retain_s = []
        for d in data:
            prompt = FUNCTION_CALL_TEMPLATE.format(instruction=d['prompt'], functions=d['available_functions'])
            formatted_input = short_circuit_template.format(
            user_tag=user_tag, assistant_tag=assistant_tag,
            instruction=prompt, response=d['function_calls'][0])
            function_retain_s.append(formatted_input)
        self.orig_s_retain += function_retain_s
        random.shuffle(self.orig_s_retain)
        print("function_retain_s[0]", function_retain_s[0])
        print("function retain length:", len(function_retain_s))


        # ======================= Refusal Retain ======================= #
        if 'llama' in self.model_name_or_path:
            with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench.json") as file:
                dataset = json.load(file)
            
            # select subset
            random.shuffle(dataset)
            dataset = dataset[:2000]

            refusal_retain_orig = []
            for i, d in tqdm(enumerate(dataset * 2)):
                switch = np.random.choice(switch_select)
                if switch == 0:
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=d['prompt'], response=d['llama3_output'])
                elif switch == 1:
                    formatted_input = short_circuit_template_with_sys.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=d['prompt'], response=d['llama3_output'])
                elif switch == 2:
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction="", response=d['llama3_output'])
                elif switch == 3:
                    random_prefix = self.orig_s_retain[i % len(self.orig_s_retain)]
                    instruction_prefix = random_prefix.split(user_tag)[1:]
                    instruction_prefix = f"{user_tag}".join(instruction_prefix) + "\n<|user|>\n"
                    formatted_input = short_circuit_template.format(
                        user_tag=user_tag, assistant_tag=assistant_tag,
                        instruction=instruction_prefix + d['prompt'], response=d['llama3_output'])
                
                refusal_retain_orig.append(formatted_input.replace(tokenizer.bos_token, ""))

            self.orig_s_retain += refusal_retain_orig
            random.shuffle(self.orig_s_retain)
            print("refusal_orig_s[0]", refusal_retain_orig[0])
            print("Orig s length:", len(self.orig_s_retain))

        # ======================= Short Circuit ======================= #
        with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench.json") as file:
        # with open(f"/data/{person_path}/short-circuiting/data/5kharmful_filtered_1_filter_harmbench_remapped.json") as file:
            dataset = json.load(file)
        short_circuit_orig = []
        for i, d in tqdm(enumerate(dataset)):
            switch = np.random.choice(switch_select)
            if switch == 0:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['output'])
            elif switch == 1:
                formatted_input = short_circuit_template_with_sys.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['output'])
            elif switch == 2:
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction="", response=d['output'])
            elif switch == 3:
                random_prefix = self.orig_s_retain[i % len(self.orig_s_retain)]
                instruction_prefix = random_prefix.split(user_tag)[1:]
                instruction_prefix = f"{user_tag}".join(instruction_prefix) + "\n<|user|>\n"
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=instruction_prefix + d['prompt'], response=d['output'])
            
            if sc_train_subset is not None:
                # print(f"Only training on: {sc_train_subset}")
                if d['category'] not in sc_train_subset:
                    continue
            
            short_circuit_orig.append(formatted_input.replace(tokenizer.bos_token, ""))
        
        self.short_circuit_orig = short_circuit_orig
        random.shuffle(self.short_circuit_orig)
        print("short_circuit_orig[0]", short_circuit_orig[0])
        print("Short circuit length:", len(self.short_circuit_orig))


        # ===================== Function calling SC ================= #
        with open("/data/short-circuiting/short-circuiting/data/function-call-sc_filtered_0.1.json") as file:
            data = json.load(file)
        function_sc = []
        for d in data:
            prompt = FUNCTION_CALL_TEMPLATE.format(instruction=d['prompt'], functions=d['available_functions'])
            formatted_input = short_circuit_template.format(
            user_tag=user_tag, assistant_tag=assistant_tag,
            instruction=prompt, response=d['function_calls'][0])
            function_sc.append(formatted_input)
        self.short_circuit_orig += function_sc
        random.shuffle(self.short_circuit_orig)

        sc_text_ratio = 0.5
        sc_fc_ratio = 0.5
        total_length = len(function_sc) + len(short_circuit_orig)
        self.short_circuit_orig = random.choices(short_circuit_orig, k=int(sc_text_ratio * total_length)) + random.choices(function_sc, k=int(sc_fc_ratio * total_length))

        
        print("function_sc[0]", function_sc[0])
        print("function_sc length:", len(function_sc))

        # ======================= Val ======================= #
        # with open("/data/private_models/short_circuiting/data/5kharmful_filtered_harmless.json") as file:
        # # with open(f"/data/{person_path}/short-circuiting/data/5kharmful_filtered_harmless.json") as file:
        #     dataset = json.load(file)
        # val_orig = []
        # for i, d in tqdm(enumerate(dataset)):
        #     val_orig.append(short_circuit_template.format(
        #         user_tag=user_tag, assistant_tag=assistant_tag,
        #         instruction=d['prompt'], response=d['output']).replace(tokenizer.bos_token, ""))

        self.val_orig = function_retain_s
        self.tokenizer = tokenizer

    def __len__(self):
        return min(len(self.orig_s_retain), len(self.short_circuit_orig))

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        orig_s_retain = self.orig_s_retain[i]
        short_circuit_orig = self.short_circuit_orig[i]
        val_orig = self.val_orig[i % len(self.val_orig)]

        target = ' ' if 'mistral' in self.model_name_or_path else '' # mistral has a space after assistant token

        self.tokenizer.padding_side = "left"
        tokenized_inputs_retain = self.tokenizer(
            orig_s_retain.replace('<SEPARATOR>', target),
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )
        tokenized_inputs_val = self.tokenizer(
            val_orig.replace('<SEPARATOR>', target),
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )

        self.tokenizer.padding_side = "left"
        tokenized_short_circuit = self.tokenizer(
            short_circuit_orig.split('<SEPARATOR>')[0],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
        )
        self.tokenizer.padding_side = "right"
        response_tokenized_short_circuit = self.tokenizer(
            short_circuit_orig.split('<SEPARATOR>')[1],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt",
            add_special_tokens=False
        )

        self.tokenizer.padding_side = "left"
        combined_input_ids_short_circuit = torch.cat([tokenized_short_circuit["input_ids"], response_tokenized_short_circuit["input_ids"]], dim=1)
        combined_attention_mask_short_circuit = torch.cat([tokenized_short_circuit["attention_mask"], response_tokenized_short_circuit["attention_mask"]], dim=1)

        return dict(
            input_ids=tokenized_inputs_retain["input_ids"],
            attention_mask=tokenized_inputs_retain["attention_mask"],
            input_ids_val=tokenized_inputs_val["input_ids"],
            attention_mask_val=tokenized_inputs_val["attention_mask"],
            input_ids_short_circuit=combined_input_ids_short_circuit,
            attention_mask_short_circuit=combined_attention_mask_short_circuit
        )