from torch.utils.data import Dataset
import datasets
from datasets import load_dataset
import transformers
from typing import Dict
import torch
from PIL import Image
from tqdm import tqdm
import json
import random
random.seed(0)

mistral_short_circuit_template = "{user_tag} {instruction} {assistant_tag}<SEPARATOR>{response}"

llama_short_circuit_template = "{user_tag}\n\n{instruction}{assistant_tag}\n\n<SEPARATOR>{response}"

class ShortCircuitingMMDataset(Dataset):
    
    def __init__(self, 
                processor, 
                num_examples,
                lorra_args,
                model_name_or_path,
                ):
        super(ShortCircuitingMMDataset, self).__init__()

        self.model_name_or_path = model_name_or_path.lower()
        tokenizer = processor.tokenizer
        image_processor = processor.image_processor

        if 'llama' in self.model_name_or_path:
            print("USING LLAMA TEMPLATE")
            short_circuit_template = llama_short_circuit_template

        elif 'mistral' in self.model_name_or_path:
            print("USING MISTRAL TEMPLATE")
            short_circuit_template = mistral_short_circuit_template
            # fix spacing issue in template
            tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
        else:
            raise ValueError("Model name not recognized")

        self.user_tag = user_tag = lorra_args.user_tag
        self.assistant_tag = assistant_tag = lorra_args.assistant_tag

        self.user_tokens = tokenizer(self.user_tag, add_special_tokens=False, return_tensors="pt").input_ids.squeeze()
        self.assistant_tokens = tokenizer(self.assistant_tag, add_special_tokens=False, return_tensors="pt").input_ids.squeeze()
        self.image_tokens = torch.tensor([processor.image_token_index])

        # ======================= Retain ======================= #
        ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft")
        orig_s = []
        for example in ds:
            messages = example["messages"]
            if len(messages) < 2: continue
            formatted_input = tokenizer.apply_chat_template(messages, tokenize=False)

            orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

            if len(orig_s) > num_examples:
                break
        self.orig_s_retain = orig_s
        random.shuffle(self.orig_s_retain)
        print("orig_s_retain[0]", orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))

        # ======================= MM Retain ======================= #
        dataset = json.load(open('/data/datasets/coco/llava_v1_5_mix665k.json'))[:10000]
        random.shuffle(dataset)
        image_dir = "/data/datasets"
        retain_mm = []
        for d in dataset:
            conv = []
            for item in d['conversations'][:4]:
                role = "user" if item['from'] == 'human' else "assistant"
                content = item['value'] 
                if "<image>" in content:
                    content = "<image>\n" + content.replace("<image>", "").strip()
                conv.append({'role': role, 'content': content})
            image = Image.open(f"{image_dir}/{d['image']}").convert("RGB")
            if image.width > 2048 or image.height > 2048:
                continue
            prompt = tokenizer.apply_chat_template(conv, tokenize=False)
            retain_mm.append((prompt, image))
        self.retain_mm = retain_mm
        print("MM retain[0]", self.retain_mm[0])
        print("MM retain len: ", len(self.retain_mm))
        
        # ======================= Borderline Retain ======================= #
        import csv
        with open(f'/data//short-circuiting/short-circuiting/data/xstest_v2_completions_gpt4_gpteval.csv', newline='') as f:
            data = [dict(row) for row in csv.DictReader(f)]
            data = [row for row in data if row['final_label'] == "1_full_compliance"]
        
        borderline_orig_s = []
        for i, d in enumerate(data * 50):
            formatted_input = short_circuit_template.format(
                user_tag=user_tag, assistant_tag=assistant_tag,
                instruction=d['prompt'], response=d['completion'])
            
            borderline_orig_s.append(formatted_input.replace(tokenizer.bos_token, ""))

        self.orig_s_retain += borderline_orig_s
        random.shuffle(self.orig_s_retain)
        print("borderline_orig_s[0]", borderline_orig_s[0])
        print("Orig s length:", len(self.orig_s_retain))


        # ======================= Refusal Retain ======================= #
        if 'llama' in self.model_name_or_path:
            with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench.json") as file:
                dataset = json.load(file)
            refusal_retain_orig = []
            for i, d in tqdm(enumerate(dataset * 2)):
                formatted_input = short_circuit_template.format(
                    user_tag=user_tag, assistant_tag=assistant_tag,
                    instruction=d['prompt'], response=d['llama3_output'])
                
                refusal_retain_orig.append(formatted_input.replace(tokenizer.bos_token, ""))

            self.orig_s_retain += refusal_retain_orig
            random.shuffle(self.orig_s_retain)
            print("refusal_orig_s[0]", refusal_retain_orig[0])
            print("Orig s length:", len(self.orig_s_retain))

        # ======================= Short Circuit ======================= #
        with open("/data/private_models/short_circuiting/data/5kharmful_filtered_1_1_filter_harmbench.json") as file:
            dataset = json.load(file)
        short_circuit_orig = []
        for i, d in tqdm(enumerate(dataset)):
            formatted_input = short_circuit_template.format(
                user_tag=user_tag, assistant_tag=assistant_tag,
                instruction=d['prompt'], response=d['output'])
            
            short_circuit_orig.append(formatted_input.replace(tokenizer.bos_token, ""))
        
        self.short_circuit_orig = short_circuit_orig
        random.shuffle(self.short_circuit_orig)
        print("short_circuit_orig[0]", short_circuit_orig[0])
        print("Short circuit length:", len(self.short_circuit_orig))

        # ======================= MM Short Circuit ======================= #
        with open(f"/data/private_models/short_circuiting/data/5kharmful_mm_filtered_1.json") as file:
            data = json.load(file)
        image_dir = "/data/datasets/coco/train2017"
        short_circuit_mm = []
        for d in data:
            formatted_input = short_circuit_template.format(
                user_tag=user_tag, assistant_tag=assistant_tag,
                instruction=f"<image>\n{d['prompt']}", response=d['output'])
            
            image = Image.open(f"{image_dir}/{d['image_name']}").convert("RGB")
            short_circuit_mm.append((formatted_input, image))
        self.short_circuit_mm = short_circuit_mm
        print("MM SC[0]", self.short_circuit_mm[0])
        print("MM SC len: ", len(self.short_circuit_mm))

        # ======================= Val ======================= #
        with open("/data/private_models/short_circuiting/data/5kharmful_filtered_harmless.json") as file:
            dataset = json.load(file)
        val_orig = []
        for i, d in tqdm(enumerate(dataset)):
            val_orig.append(short_circuit_template.format(
                user_tag=user_tag, assistant_tag=assistant_tag,
                instruction=d['prompt'], response=d['output']).replace(tokenizer.bos_token, ""))

        self.val_orig = val_orig

        self.tokenizer = tokenizer

    def __len__(self):
        return min(len(self.orig_s_retain), len(self.short_circuit_orig), len(self.retain_mm), len(self.short_circuit_mm))\
        
    def find_start_index(self, input_ids, sequence):
        for i in range(len(input_ids) - len(sequence) + 1):
            if torch.equal(input_ids[i:i+len(sequence)], sequence):
                return i
        return 0

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        orig_s_retain = self.orig_s_retain[i]
        short_circuit_orig = self.short_circuit_orig[i]
        retain_mm, retain_img = self.retain_mm[i]
        mm_sc, sc_img = self.short_circuit_mm[i]

            
        val_orig = self.val_orig[i % len(self.val_orig)]

        target = ' ' if 'mistral' in self.model_name_or_path else '' # mistral has a space after assistant token

        self.tokenizer.padding_side = "left"
        tokenize_kwargs = dict(padding="max_length", truncation=True, return_tensors="pt")
        tokenized_inputs_val = self.tokenizer(val_orig.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)

        # ======= Text =======
        loss_tokens = self.user_tokens
        tokenized_inputs_sc = self.tokenizer(short_circuit_orig.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)
        tokenized_inputs_retain = self.tokenizer(orig_s_retain.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)
        sc_loss_start_idx = self.find_start_index(tokenized_inputs_sc.input_ids.squeeze(), loss_tokens) - tokenized_inputs_sc.input_ids.shape[1]
        retain_loss_start_idx = self.find_start_index(tokenized_inputs_retain.input_ids.squeeze(), loss_tokens) - tokenized_inputs_retain.input_ids.shape[1]

        # ======= MM =======
        loss_tokens = self.image_tokens
        tokenized_inputs_sc_mm = self.tokenizer(mm_sc.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)
        tokenized_inputs_retain_mm = self.tokenizer(retain_mm.replace('<SEPARATOR>', target), max_length=768, **tokenize_kwargs)
        sc_loss_start_idx_mm = self.find_start_index(tokenized_inputs_sc_mm.input_ids.squeeze(), loss_tokens) - tokenized_inputs_sc_mm.input_ids.shape[1]
        retain_loss_start_idx_mm = self.find_start_index(tokenized_inputs_retain_mm.input_ids.squeeze(), loss_tokens) - tokenized_inputs_retain_mm.input_ids.shape[1]

        return dict(
            # ==== Text ====
            input_ids=tokenized_inputs_retain["input_ids"],
            attention_mask=tokenized_inputs_retain["attention_mask"],
            input_ids_val=tokenized_inputs_val["input_ids"],
            attention_mask_val=tokenized_inputs_val["attention_mask"],
            input_ids_short_circuit=tokenized_inputs_sc["input_ids"],
            attention_mask_short_circuit=tokenized_inputs_sc["attention_mask"],
            sc_loss_start_idx=sc_loss_start_idx,
            retain_loss_start_idx=retain_loss_start_idx,

            # ===== MM =====
            input_ids_mm=tokenized_inputs_retain_mm["input_ids"],
            attention_mask_mm=tokenized_inputs_retain_mm["attention_mask"],
            input_ids_short_circuit_mm=tokenized_inputs_sc_mm["input_ids"],
            attention_mask_short_circuit_mm=tokenized_inputs_sc_mm["attention_mask"],
            retain_img=retain_img,
            sc_img=sc_img,
            sc_loss_start_idx_mm=sc_loss_start_idx_mm,
            retain_loss_start_idx_mm=retain_loss_start_idx_mm
        )


    # def __getitem__(self, i) -> Dict[str, torch.Tensor]:
    #     orig_s_retain = self.orig_s_retain[i]
    #     short_circuit_orig = self.short_circuit_orig[i]
    #     retain_mm, retain_img = self.retain_mm[i]
    #     mm_sc, sc_img = self.short_circuit_mm[i]

            
    #     val_orig = self.val_orig[i % len(self.val_orig)]

    #     target = ' ' if 'mistral' in self.model_name_or_path else '' # mistral has a space after assistant token

    #     self.tokenizer.padding_side = "left"
    #     tokenize_kwargs = dict(padding="max_length", truncation=True, return_tensors="pt")
    #     tokenized_inputs_retain = self.tokenizer(orig_s_retain.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)
    #     tokenized_inputs_retain_mm = self.tokenizer(retain_mm.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)
    #     tokenized_inputs_val = self.tokenizer(val_orig.replace('<SEPARATOR>', target), max_length=1024, **tokenize_kwargs)

    #     # ======= Text SC =======
    #     self.tokenizer.padding_side = "left"
    #     tokenized_short_circuit = self.tokenizer(short_circuit_orig.split('<SEPARATOR>')[0], max_length=512, **tokenize_kwargs)
    #     self.tokenizer.padding_side = "right"
    #     response_tokenized_short_circuit = self.tokenizer(short_circuit_orig.split('<SEPARATOR>')[1], max_length=512, add_special_tokens=False, **tokenize_kwargs)
    #     self.tokenizer.padding_side = "left"

    #     combined_input_ids_short_circuit = torch.cat([tokenized_short_circuit["input_ids"], response_tokenized_short_circuit["input_ids"]], dim=1)
    #     combined_attention_mask_short_circuit = torch.cat([tokenized_short_circuit["attention_mask"], response_tokenized_short_circuit["attention_mask"]], dim=1)

    #     # ======= MM SC =======
    #     self.tokenizer.padding_side = "left"
    #     tokenized_short_circuit_mm = self.tokenizer(mm_sc.split('<SEPARATOR>')[0], max_length=512, **tokenize_kwargs)
    #     self.tokenizer.padding_side = "right"
    #     response_tokenized_short_circuit_mm = self.tokenizer(mm_sc.split('<SEPARATOR>')[1], max_length=512, add_special_tokens=False, **tokenize_kwargs)
    #     self.tokenizer.padding_side = "left"
    #     combined_input_ids_short_circuit_mm = torch.cat([tokenized_short_circuit_mm["input_ids"], response_tokenized_short_circuit_mm["input_ids"]], dim=1)
    #     combined_attention_mask_short_circuit_mm = torch.cat([tokenized_short_circuit_mm["attention_mask"], response_tokenized_short_circuit_mm["attention_mask"]], dim=1)

    #     return dict(
    #         # ==== Text ====
    #         input_ids=tokenized_inputs_retain["input_ids"],
    #         attention_mask=tokenized_inputs_retain["attention_mask"],
    #         input_ids_val=tokenized_inputs_val["input_ids"],
    #         attention_mask_val=tokenized_inputs_val["attention_mask"],
    #         input_ids_short_circuit=combined_input_ids_short_circuit,
    #         attention_mask_short_circuit=combined_attention_mask_short_circuit,

    #         # ===== MM =====
    #         input_ids_mm=tokenized_inputs_retain_mm["input_ids"],
    #         attention_mask_mm=tokenized_inputs_retain_mm["attention_mask"],
    #         input_ids_short_circuit_mm=combined_input_ids_short_circuit_mm,
    #         attention_mask_short_circuit_mm=combined_attention_mask_short_circuit_mm,
    #         retain_img=retain_img,
    #         sc_img=sc_img,
    #     )
