
import os
import sys

import json
import argparse

import scipy
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GPT2TokenizerFast
from transformers import AutoTokenizer, CLIPTextModelWithProjection
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from diffusers import StableUnCLIPImg2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionXLImg2ImgPipeline
import torch
import matplotlib.gridspec as gridspec

from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset
# from llava.model.fmri_encoder.vit3d import CLIPVision3dModelWithProjection
from llava.model.fmri_encoder.vit3d_decoder import ViT3dWithProjectionModel
from llava.train import DataArguments


parser = argparse.ArgumentParser()

parser.add_argument(
    "--device",
    type=str,
    default="cuda:0",
    help="device"
)
parser.add_argument(
    "--seed",
    type=int,
    default=42,
    help="the seed (for reproducible sampling)",
)

parser.add_argument(
    "--dataset",
    type=str,
    default="nsd",
)

parser.add_argument(
    "--subject",
    type=str,
    default="subj01",
)

parser.add_argument(
    "--select-subj",
    type=str,
    default=None,
)

parser.add_argument(
    "--select-region",
    nargs='+',
    type=str,
    default=None,
)

parser.add_argument(
    "--model",
    type=str,
    default="",
)

parser.add_argument(
    "--batch-size",
    type=int,
    default=16,
)

args = parser.parse_args()

device = args.device


if __name__ == '__main__':

    model_name = args.model.split('/')[-3]
    args.output_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}/results/{args.subject}/fmri2embeds/{model_name}'

    os.makedirs(args.output_dir, exist_ok=True)

    print(args.output_dir)

    model = ViT3dWithProjectionModel.from_pretrained(args.model).to(device)
    config = model.config

    if args.subject == 'all':
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/pretrain.json'
    else:
        data_path = f'/mnt/NSD_dataset/datasets/{args.dataset}/fmris/{args.subject}/pretrain.json'

    if model.with_vae:
        args.vae_dir = args.output_dir.replace('fmri2embeds', 'fmri2vae')
        os.makedirs(args.vae_dir, exist_ok=True)

    dataset_val = fMRIViT3dDataset(
        data_path=data_path,
        is_train=False,
        return_images=False,
        return_subject=True,
        return_embeds=True,
        return_vae_embeds=model.with_vae,
        patch_size=model.patch_size[0],
        return_image_type='np',
        select_subject=args.select_subj,
        select_brain_region=args.select_region
    )

    loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        drop_last=False
    )

    model.eval()
    # model.to(torch.bfloat16)
    loss = []
    vae_loss = []
    subject_counter = {}
    with torch.no_grad():
        for idx, data in enumerate(tqdm(loader_val)):

            # images = data['images']
            fmri = data['fmri'].to(device)
            vision_embeds = data['labels'].to(device)
            outputs = model(pixel_values=fmri.to(device))
            vision_embeds_gen = outputs.vision_embeds

            loss.append(torch.nn.MSELoss()(vision_embeds, vision_embeds_gen).item())

            if model.with_vae:
                vae_embeds = data['vae_labels'].to(device)
                vae_embeds_gen = outputs.vae_embeds
                # print(vae_embeds.shape, vae_embeds_gen.shape)
                vae_loss.append(torch.nn.MSELoss()(vae_embeds.flatten(1), vae_embeds_gen).item())

            subjects = data['subject']
            # subjects = [int(x) for x in subjects]

            for j in range(vision_embeds_gen.shape[0]):
                output_dir = f'{args.output_dir}/{subjects[j]}'

                if model.with_vae:
                    vae_output_dir = f'{args.vae_dir}/{subjects[j]}'

                if subjects[j] not in subject_counter:
                    subject_counter[subjects[j]] = 0
                    os.makedirs(f'{output_dir}', exist_ok=True)
                    if model.with_vae:
                        os.makedirs(f'{vae_output_dir}', exist_ok=True)

                np.save(f'{output_dir}/{subject_counter[subjects[j]]:06}.npy', vision_embeds_gen[j].float().cpu().numpy())

                if model.with_vae:
                    torch.save(vae_embeds_gen[j].reshape(4, 96, 96), f'{vae_output_dir}/{subject_counter[subjects[j]]:06}.pt')

                subject_counter[subjects[j]] += 1

        print(f'Loss: {sum(loss) / len(loss)}')
        if model.with_vae:
            print(f'VAE Loss: {sum(vae_loss) / len(vae_loss)}')
