import argparse
import json
import os

import numpy as np
import torch
from tqdm import tqdm

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset
from llava.model.builder import load_pretrained_model
from llava.train import DataArguments
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer


questions = [
    'Describe the image concisely.',
    'Provide a brief description of the given image.',
    'Offer a succinct explanation of the picture presented.',
    'Summarize the visual content of the image.',
    'Provide a brief description of the image.',
    'Describe the image briefly.',
    'Summarize the image.',
    'Give a short and clear explanation of the subsequent image.',
    'Describe the following image in detail.',
    'Provide a detailed description of the given image.',
    'Give an elaborate explanation of the image you see.',
    'Share a comprehensive rundown of the presented image.',
    'Offer a detailed description of the image.',
    'Describe the image in detail.',
    'Offer a thorough analysis of the image.',
    'Provide a detailed explanation of the subsequent image.',
]


def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def main(args):
    # Model

    data = json.load(open("/mnt/NSD_dataset/datasets/nsd/sft_data/all/sft_all_te.json"))["conversations"]
    vqa = []
    vqa_name = '_vqa'
    for item in data:
        if 'complex_reasoning' not in item['conversation_type']:  # detail_description
            continue
        vqa.append(
            item["conversations"][0]["value"].replace('<image>\n', '').replace("</s>", "").replace("\n", "").replace("\t", "").replace("\r", "").replace("  ", " ").strip())
    print(len(vqa))

    output_fname = args.model_path.split("/")[-3] + '_' + args.model_path.split("/")[-2] + vqa_name  # .split('-')[-1]
    args.output_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subj}/llava_captions/{output_fname}'
    os.makedirs(args.output_dir, exist_ok=True)
    print(args.output_dir)

    generated_captions = []

    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    if image_processor is None:
        image_processor = model.get_vision_tower().image_processor

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    if args.subj == 'all':
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/pretrain.json'
    else:
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/{args.subj}/pretrain.json'

    dataset_val = fMRIViT3dDataset(
        data_path=data_path,
        is_train=False,
        # data_args=data_args,
        return_fmris=True,
        return_embeds=False,
        return_images=False,
        requires_norm=True,
        requires_padding=False,
        select_brain_region=args.select_region,
    )
    sample_per_slice = len(dataset_val) // 8 + 1
    if args.slice >= 0:
        start = args.slice * sample_per_slice
        end = min((args.slice + 1) * sample_per_slice, len(dataset_val))
    else:
        start = 0
        end = len(dataset_val)

    print(start, end)

    print(f'{args.output_dir}/{output_fname}_slice_{args.slice}.json')

    for idx in tqdm(range(start, end), desc="Generating embeds"):
        data = dataset_val[idx]
        image_tensor = data['fmri']
        image_tensor = image_processor(torch.tensor(image_tensor)).to(args.device)
        image_size = image_tensor.shape

        if vqa is not None:
            questions = [vqa[idx]]

        generated_caption = []
        for question in questions:
            inp = question

            conv = conv_templates[args.conv_mode].copy()
            if "mpt" in model_name.lower():
                roles = ('user', 'assistant')
            else:
                roles = conv.roles

            if model.config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else:
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            conv.append_message(conv.roles[0], inp)

            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]

            # print(image_tensor.shape)
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.half(),
                    image_sizes=[image_size],
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                    vision_embeds=None,
                )

            outputs = tokenizer.decode(output_ids[0]).strip()

            # print(outputs)
            generated_caption.append(outputs)
            if args.debug:
                print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

        generated_captions.append({
            "id": idx,
            "captions": generated_caption,
        })
        print(generated_caption)
        # break

    with open(f'{args.output_dir}/{output_fname}_slice_{args.slice}.json', 'w') as f:
        json.dump(generated_captions, f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="the seed (for reproducible sampling)",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="nsd",
    )

    parser.add_argument(
        "--subj",
        type=str,
        default="subj01",
    )

    parser.add_argument(
        "--batch-size",
        type=int,
        default=1,
    )

    parser.add_argument(
        "--slice",
        type=int,
        default=-1,
    )

    parser.add_argument(
        "--select-region",
        nargs='+',
        type=str,
        default=None,
    )

    parser.add_argument(
        "--data-path",
        type=str,
        default=""
    )

    args = parser.parse_args()
    main(args)
